refactor(api): auto-inject discriminator value via generated MarshalJSON

Sealed-interface union variants now hardcode their wire discriminator
inside a generated MarshalJSON method instead of forcing callers to set
the field on every struct literal. Drops a class of silent-rejection
bugs where a typo in the discriminator slipped past the type checker
and through to Telegram, which then rejected the request with no
Go-side signal.

The discriminator field stays exported so incoming-message decoding,
type switches and debugging still see it. MarshalJSON wraps via a
function-local type alias and emits an outer field with the same json
tag; encoding/json (and goccy/go-json) resolve the outer field as the
shallower one and override whatever the caller wrote.

99 variants get MarshalJSON. 7 are skipped because their unions
dispatch structurally rather than by a string field: Message and
InaccessibleMessage (MaybeInaccessibleMessage, dispatched on date),
and the InputMessageContent family (InputTextMessageContent,
InputLocationMessageContent, InputVenueMessageContent,
InputContactMessageContent, InputInvoiceMessageContent — Telegram
identifies these by the presence of message_text / latitude /
phone_number / title etc.).

Discriminator extraction lives in the emitter (cmd/genapi/emitter.go).
Resolution: knownDiscriminators reverse-lookup for the 13 auto-decode
unions, then doc-string analysis ("must be X" / "always “X”")
of the variant's first required string field for marker-only unions
(BotCommandScope, InputMedia, InputPaidMedia, InputProfilePhoto,
InputStoryContent, InputPollMedia, InputPollOptionMedia,
InlineQueryResult, PassportElementError). Variants the emitter cannot
resolve a discriminator for are skipped silently rather than emitting
broken code.

Internal call-site cleanups: 4 manual discriminator assignments
removed (api/unionparam_test.go,
dispatch/filters/message/message_test.go, examples/inline/main.go ×2).
Regression tests added in api/marshaljson_variants_test.go covering
type-keyed variants, source-keyed variants, the override-user-typo
guarantee, round-trip preservation through UnmarshalChatMember, the
no-discriminator InputMessageContent path, and ride-along of
non-discriminator fields.

regen-from-fixture is deterministic across two consecutive runs;
go test -race / go vet / staticcheck all clean.
This commit is contained in:
2026-05-09 19:27:33 +01:00
parent 6ab80c27e1
commit 370c9c0802
8 changed files with 3182 additions and 321 deletions
+103
View File
@@ -0,0 +1,103 @@
package api
import (
"testing"
json "github.com/goccy/go-json"
"github.com/stretchr/testify/require"
)
// TestMarshalJSON_TypeDiscriminator_AutoInjected verifies the generated
// MarshalJSON hardcodes the wire discriminator for a Type-keyed variant
// even when the caller leaves the field zero.
func TestMarshalJSON_TypeDiscriminator_AutoInjected(t *testing.T) {
scope := &BotCommandScopeAllPrivateChats{}
got, err := json.Marshal(scope)
require.NoError(t, err)
require.JSONEq(t, `{"type":"all_private_chats"}`, string(got))
}
// TestMarshalJSON_SourceDiscriminator_AutoInjected verifies the same
// for variants that use a non-Type discriminator field. PassportElement
// errors key on "source" instead.
func TestMarshalJSON_SourceDiscriminator_AutoInjected(t *testing.T) {
err := &PassportElementErrorDataField{
Type: PassportElementErrorDataFieldTypePersonalDetails,
FieldName: "first_name",
DataHash: "abc123",
Message: "bad data",
}
got, mErr := json.Marshal(err)
require.NoError(t, mErr)
require.JSONEq(t,
`{"source":"data","type":"personal_details","field_name":"first_name","data_hash":"abc123","message":"bad data"}`,
string(got),
)
}
// TestMarshalJSON_UserSuppliedDiscriminator_Overridden documents the
// safety guarantee: a typo or stale value the caller pastes into the
// struct literal is silently overridden by the generated MarshalJSON.
// This is what saves callers from Telegram's "silent reject" failure
// mode when a discriminator is wrong.
func TestMarshalJSON_UserSuppliedDiscriminator_Overridden(t *testing.T) {
scope := &BotCommandScopeAllPrivateChats{Type: "wrong"}
got, err := json.Marshal(scope)
require.NoError(t, err)
require.JSONEq(t, `{"type":"all_private_chats"}`, string(got))
}
// TestMarshalJSON_RoundTrip confirms a marshal-then-unmarshal cycle
// preserves user-supplied fields. Discriminator field is set on the
// way out, read back on the way in — no data loss.
//
// Uses ChatMember (one of the auto-decode unions) so the round-trip
// can route through the generated UnmarshalChatMember dispatcher.
func TestMarshalJSON_RoundTrip(t *testing.T) {
orig := &ChatMemberLeft{
User: User{ID: 42, IsBot: false, FirstName: "alice"},
}
raw, err := json.Marshal(orig)
require.NoError(t, err)
out, err := UnmarshalChatMember(raw)
require.NoError(t, err)
round, ok := out.(*ChatMemberLeft)
require.True(t, ok, "expected *ChatMemberLeft, got %T", out)
require.Equal(t, ChatMemberLeftStatusLeft, round.Status)
require.Equal(t, orig.User.ID, round.User.ID)
require.Equal(t, orig.User.FirstName, round.User.FirstName)
}
// TestMarshalJSON_InputMessageContent_NoDiscriminator confirms that
// variants of InputMessageContent (the structurally-dispatched union
// Telegram identifies by field presence, not by a "type" field) do
// NOT get an injected discriminator. Their fields ride out as-is.
func TestMarshalJSON_InputMessageContent_NoDiscriminator(t *testing.T) {
content := &InputTextMessageContent{
MessageText: "hello world",
}
got, err := json.Marshal(content)
require.NoError(t, err)
// No "type" field should appear; just message_text.
require.JSONEq(t, `{"message_text":"hello world"}`, string(got))
}
// TestMarshalJSON_NonDiscriminatorMembers_RidealongUnchanged verifies
// the alias-embedding pattern: every non-discriminator field on the
// variant marshals through the *alias and keeps its own json tag and
// omitempty behaviour. Caption + ParseMode here exercise both
// required-string-with-discriminator and optional-with-omitempty.
func TestMarshalJSON_NonDiscriminatorMembers_RidealongUnchanged(t *testing.T) {
media := &InputMediaPhoto{
Media: "https://example.com/photo.jpg",
Caption: "look",
}
got, err := json.Marshal(media)
require.NoError(t, err)
require.JSONEq(t,
`{"type":"photo","media":"https://example.com/photo.jpg","caption":"look"}`,
string(got),
)
}
+1584
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -44,7 +44,7 @@ func TestSetMyCommands_BotCommandScope_NoPointerToInterface(t *testing.T) {
// is `BotCommandScope` (interface), not `*BotCommandScope`. // is `BotCommandScope` (interface), not `*BotCommandScope`.
ok, err := SetMyCommands(context.Background(), bot, &SetMyCommandsParams{ ok, err := SetMyCommands(context.Background(), bot, &SetMyCommandsParams{
Commands: []BotCommand{{Command: "start", Description: "begin"}}, Commands: []BotCommand{{Command: "start", Description: "begin"}},
Scope: &BotCommandScopeAllPrivateChats{Type: "all_private_chats"}, Scope: &BotCommandScopeAllPrivateChats{},
}) })
require.NoError(t, err) require.NoError(t, err)
require.True(t, ok) require.True(t, ok)
+171 -2
View File
@@ -8,12 +8,22 @@ import (
"go/format" "go/format"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"text/template" "text/template"
"github.com/lukaszraczylo/go-telegram/internal/spec" "github.com/lukaszraczylo/go-telegram/internal/spec"
) )
// Discriminator-value extractors. The curly form ("always “X”") is
// authoritative because Telegram quotes wire literals with curly quotes
// throughout the docs; the bare form ("must be X") is the looser
// non-quoted variant used for BotCommandScope, InputMedia, etc.
var (
discCurlyRE = regexp.MustCompile(`(?:must be|always)\s+“([^”]+)”`)
discBareRE = regexp.MustCompile(`must be\s+([A-Za-z0-9_]+)(?:[\s.,]|$)`)
)
//go:embed types.tmpl //go:embed types.tmpl
var typesTmpl string var typesTmpl string
@@ -165,11 +175,154 @@ type emitter struct {
api *spec.API api *spec.API
outDir string outDir string
enums *enumPlan enums *enumPlan
// variantDiscs maps a concrete variant type name (e.g.
// "BotCommandScopeAllPrivateChats") to its discriminator wire-field
// + value. Populated once at construction; consulted by the types
// template to emit per-variant MarshalJSON that hardcodes the
// discriminator so callers don't have to set it by hand.
variantDiscs map[string]variantDiscriminator
} }
func newEmitter(api *spec.API, outDir string) *emitter { func newEmitter(api *spec.API, outDir string) *emitter {
knownInterfaceTypes = buildUnionTypeSet(api) knownInterfaceTypes = buildUnionTypeSet(api)
return &emitter{api: api, outDir: outDir, enums: planEnums(api)} return &emitter{
api: api,
outDir: outDir,
enums: planEnums(api),
variantDiscs: variantDiscriminators(api),
}
}
// variantDiscriminator describes the JSON field+value that identifies a
// concrete variant of a sealed-interface union on the wire.
type variantDiscriminator struct {
JSONField string // wire field name, e.g. "type" or "source"
GoField string // Go struct field name, e.g. "Type" or "Source"
Value string // the wire value, e.g. "all_private_chats"
}
// variantDiscriminators returns variantTypeName → discriminator for every
// concrete struct that participates in a sealed-interface union and has
// a string-typed first field whose doc fixes its value (the canonical
// "must be X" / "always “X”" patterns Telegram uses).
//
// Resolution order:
//
// 1. knownDiscriminators reverse-lookup (the 13 auto-decode unions).
// This guarantees parity with UnmarshalXxx dispatch for the unions
// that round-trip through the library.
// 2. Doc-string analysis of the variant's first field, for marker-only
// unions (BotCommandScope, InputMedia, etc.) where the IR has no
// explicit discriminator metadata.
//
// Variants whose first field has no discriminator hint (Message,
// InaccessibleMessage, the InputMessageContent family) are omitted —
// the caller writes the dispatching fields directly and Telegram
// identifies them structurally.
func variantDiscriminators(api *spec.API) map[string]variantDiscriminator {
out := make(map[string]variantDiscriminator, 128)
// Pass 1: reverse-lookup from knownDiscriminators.
for _, ds := range knownDiscriminators {
if ds.Field == "" {
continue
}
for value, variant := range ds.Variants {
out[variant] = variantDiscriminator{
JSONField: ds.Field,
Value: value,
}
}
}
// Build the set of every variant type referenced by any OneOf so we
// can scan only those (avoids matching free-text "must be" prose in
// non-variant types like Message).
variantSet := make(map[string]bool, 128)
for _, t := range api.Types {
for _, v := range t.OneOf {
variantSet[v] = true
}
}
// Pass 2: doc-parse for variants without a known discriminator.
for _, t := range api.Types {
if !variantSet[t.Name] {
continue
}
if _, ok := out[t.Name]; ok {
// Pass-1 already provided the wire value; we still need
// the Go field name (mirrors the JSON field but with
// proper case). Resolve from t.Fields by JSONName match.
disc := out[t.Name]
for _, f := range t.Fields {
if f.JSONName == disc.JSONField {
disc.GoField = f.Name
out[t.Name] = disc
break
}
}
continue
}
disc, ok := extractVariantDiscriminator(t)
if !ok {
continue
}
out[t.Name] = disc
}
// Drop entries we couldn't resolve a Go field for (defensive — every
// pass-1 hit should have matched, but better to skip than emit
// broken code referencing an unknown field name).
for name, d := range out {
if d.GoField == "" {
delete(out, name)
}
}
return out
}
// extractVariantDiscriminator inspects the first field of a variant
// struct and returns its discriminator if the field is a required
// string whose doc nails the value via "must be X" or "always “X”".
// Returns (zero, false) when no clear discriminator is present.
func extractVariantDiscriminator(t spec.TypeDecl) (variantDiscriminator, bool) {
if len(t.Fields) == 0 {
return variantDiscriminator{}, false
}
f := t.Fields[0]
if !f.Required || f.Type.Kind != spec.KindPrimitive || f.Type.Name != "string" {
return variantDiscriminator{}, false
}
value := parseDiscriminatorDoc(f.Doc)
if value == "" {
return variantDiscriminator{}, false
}
return variantDiscriminator{
JSONField: f.JSONName,
GoField: f.Name,
Value: value,
}, true
}
// parseDiscriminatorDoc extracts the wire-level discriminator value
// from a field doc string. Handles both Telegram phrasings:
//
// - "Scope type, must be all_private_chats" (bare token)
// - "Type of the message origin, always “user”" (curly-quoted)
//
// Returns "" when no discriminator is present.
func parseDiscriminatorDoc(doc string) string {
// Curly-quoted form takes priority: "must be “X”" or "always “X”".
if m := discCurlyRE.FindStringSubmatch(doc); len(m) == 2 {
return m[1]
}
// Bare-token form: "must be <ident>" terminated by end-of-string,
// punctuation, or whitespace.
if m := discBareRE.FindStringSubmatch(doc); len(m) == 2 {
return m[1]
}
return ""
} }
// knownInterfaceTypes is the full set of sealed-interface union type names // knownInterfaceTypes is the full set of sealed-interface union type names
@@ -182,7 +335,7 @@ var knownInterfaceTypes = map[string]bool{}
// emitTypes renders types.gen.go. // emitTypes renders types.gen.go.
func (e *emitter) emitTypes() error { func (e *emitter) emitTypes() error {
t, err := template.New("types").Funcs(funcs(e.enums)).Parse(typesTmpl) t, err := template.New("types").Funcs(funcsWithDiscs(e.enums, e.variantDiscs)).Parse(typesTmpl)
if err != nil { if err != nil {
return fmt.Errorf("parse types.tmpl: %w", err) return fmt.Errorf("parse types.tmpl: %w", err)
} }
@@ -218,6 +371,22 @@ func loadAPI(path string) (*spec.API, error) {
return &api, nil return &api, nil
} }
// funcsWithDiscs returns the shared FuncMap with the variant
// discriminator helpers bound to discs. types.tmpl uses
// variantDiscFor/variantHasDisc to emit per-variant MarshalJSON that
// hardcodes the wire discriminator value.
func funcsWithDiscs(plan *enumPlan, discs map[string]variantDiscriminator) template.FuncMap {
fm := funcs(plan)
fm["variantHasDisc"] = func(name string) bool {
_, ok := discs[name]
return ok
}
fm["variantDiscField"] = func(name string) string { return discs[name].JSONField }
fm["variantDiscGoField"] = func(name string) string { return discs[name].GoField }
fm["variantDiscValue"] = func(name string) string { return discs[name].Value }
return fm
}
// funcs is the FuncMap shared across templates. plan is the resolved // funcs is the FuncMap shared across templates. plan is the resolved
// enum plan; pass nil only in unit tests that don't exercise enums. // enum plan; pass nil only in unit tests that don't exercise enums.
func funcs(plan *enumPlan) template.FuncMap { func funcs(plan *enumPlan) template.FuncMap {
+17
View File
@@ -84,6 +84,23 @@ func UnmarshalMaybeInaccessibleMessage(data []byte) (MaybeInaccessibleMessage, e
type {{.Name}} struct { type {{.Name}} struct {
{{range .Fields}}{{docComment .Doc}}{{goField $td.Name .}} {{range .Fields}}{{docComment .Doc}}{{goField $td.Name .}}
{{end}}} {{end}}}
{{if variantHasDisc .Name}}
// MarshalJSON encodes {{.Name}} with the discriminator field
// "{{variantDiscField .Name}}" forced to {{printf "%q" (variantDiscValue .Name)}}.
// The hardcoded value frees callers from setting {{variantDiscGoField .Name}} by hand —
// any user-supplied value on the struct literal is overridden so a typo
// can't slip through to Telegram.
func (v *{{.Name}}) MarshalJSON() ([]byte, error) {
type alias {{.Name}}
return json.Marshal(&struct {
{{variantDiscGoField .Name}} string `json:"{{variantDiscField .Name}}"`
*alias
}{
{{variantDiscGoField .Name}}: {{printf "%q" (variantDiscValue .Name)}},
alias: (*alias)(v),
})
}
{{end}}
{{$unionFields := unionFields .}}{{if $unionFields}} {{$unionFields := unionFields .}}{{if $unionFields}}
// UnmarshalJSON decodes {{.Name}} by dispatching union-typed fields // UnmarshalJSON decodes {{.Name}} by dispatching union-typed fields
// ({{range $i, $u := $unionFields}}{{if $i}}, {{end}}{{$u.Field.Name}}{{end}}) through their concrete UnmarshalXxx helpers. // ({{range $i, $u := $unionFields}}{{if $i}}, {{end}}{{$u.Field.Name}}{{end}}) through their concrete UnmarshalXxx helpers.
+1 -1
View File
@@ -109,7 +109,7 @@ func TestIsForward(t *testing.T) {
// ForwardOrigin is a MessageOrigin interface; set via a concrete type. // ForwardOrigin is a MessageOrigin interface; set via a concrete type.
f := msgfilter.IsForward() f := msgfilter.IsForward()
m := msg("fwd") m := msg("fwd")
m.ForwardOrigin = &api.MessageOriginUser{Type: "user"} m.ForwardOrigin = &api.MessageOriginUser{}
require.True(t, f(m)) require.True(t, f(m))
require.False(t, f(msg("no fwd"))) require.False(t, f(msg("no fwd")))
require.False(t, f(nil)) require.False(t, f(nil))
+1305 -315
View File
File diff suppressed because it is too large Load Diff
-2
View File
@@ -35,7 +35,6 @@ func main() {
// Echo the query as article results. // Echo the query as article results.
results := []api.InlineQueryResult{ results := []api.InlineQueryResult{
&api.InlineQueryResultArticle{ &api.InlineQueryResultArticle{
Type: "article",
ID: "echo", ID: "echo",
Title: "Echo: " + q.Query, Title: "Echo: " + q.Query,
InputMessageContent: &api.InputTextMessageContent{ InputMessageContent: &api.InputTextMessageContent{
@@ -43,7 +42,6 @@ func main() {
}, },
}, },
&api.InlineQueryResultArticle{ &api.InlineQueryResultArticle{
Type: "article",
ID: "upper", ID: "upper",
Title: "UPPER: " + strings.ToUpper(q.Query), Title: "UPPER: " + strings.ToUpper(q.Query),
InputMessageContent: &api.InputTextMessageContent{ InputMessageContent: &api.InputTextMessageContent{