mirror of
https://github.com/lukaszraczylo/go-telegram.git
synced 2026-06-09 23:04:05 +00:00
refactor(api): auto-inject discriminator value via generated MarshalJSON
Sealed-interface union variants now hardcode their wire discriminator
inside a generated MarshalJSON method instead of forcing callers to set
the field on every struct literal. Drops a class of silent-rejection
bugs where a typo in the discriminator slipped past the type checker
and through to Telegram, which then rejected the request with no
Go-side signal.
The discriminator field stays exported so incoming-message decoding,
type switches and debugging still see it. MarshalJSON wraps via a
function-local type alias and emits an outer field with the same json
tag; encoding/json (and goccy/go-json) resolve the outer field as the
shallower one and override whatever the caller wrote.
99 variants get MarshalJSON. 7 are skipped because their unions
dispatch structurally rather than by a string field: Message and
InaccessibleMessage (MaybeInaccessibleMessage, dispatched on date),
and the InputMessageContent family (InputTextMessageContent,
InputLocationMessageContent, InputVenueMessageContent,
InputContactMessageContent, InputInvoiceMessageContent — Telegram
identifies these by the presence of message_text / latitude /
phone_number / title etc.).
Discriminator extraction lives in the emitter (cmd/genapi/emitter.go).
Resolution: knownDiscriminators reverse-lookup for the 13 auto-decode
unions, then doc-string analysis ("must be X" / "always “X”")
of the variant's first required string field for marker-only unions
(BotCommandScope, InputMedia, InputPaidMedia, InputProfilePhoto,
InputStoryContent, InputPollMedia, InputPollOptionMedia,
InlineQueryResult, PassportElementError). Variants the emitter cannot
resolve a discriminator for are skipped silently rather than emitting
broken code.
Internal call-site cleanups: 4 manual discriminator assignments
removed (api/unionparam_test.go,
dispatch/filters/message/message_test.go, examples/inline/main.go ×2).
Regression tests added in api/marshaljson_variants_test.go covering
type-keyed variants, source-keyed variants, the override-user-typo
guarantee, round-trip preservation through UnmarshalChatMember, the
no-discriminator InputMessageContent path, and ride-along of
non-discriminator fields.
regen-from-fixture is deterministic across two consecutive runs;
go test -race / go vet / staticcheck all clean.
This commit is contained in:
+171
-2
@@ -8,12 +8,22 @@ import (
|
||||
"go/format"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"text/template"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||
)
|
||||
|
||||
// Discriminator-value extractors. The curly form ("always “X”") is
|
||||
// authoritative because Telegram quotes wire literals with curly quotes
|
||||
// throughout the docs; the bare form ("must be X") is the looser
|
||||
// non-quoted variant used for BotCommandScope, InputMedia, etc.
|
||||
var (
|
||||
discCurlyRE = regexp.MustCompile(`(?:must be|always)\s+“([^”]+)”`)
|
||||
discBareRE = regexp.MustCompile(`must be\s+([A-Za-z0-9_]+)(?:[\s.,]|$)`)
|
||||
)
|
||||
|
||||
//go:embed types.tmpl
|
||||
var typesTmpl string
|
||||
|
||||
@@ -165,11 +175,154 @@ type emitter struct {
|
||||
api *spec.API
|
||||
outDir string
|
||||
enums *enumPlan
|
||||
// variantDiscs maps a concrete variant type name (e.g.
|
||||
// "BotCommandScopeAllPrivateChats") to its discriminator wire-field
|
||||
// + value. Populated once at construction; consulted by the types
|
||||
// template to emit per-variant MarshalJSON that hardcodes the
|
||||
// discriminator so callers don't have to set it by hand.
|
||||
variantDiscs map[string]variantDiscriminator
|
||||
}
|
||||
|
||||
func newEmitter(api *spec.API, outDir string) *emitter {
|
||||
knownInterfaceTypes = buildUnionTypeSet(api)
|
||||
return &emitter{api: api, outDir: outDir, enums: planEnums(api)}
|
||||
return &emitter{
|
||||
api: api,
|
||||
outDir: outDir,
|
||||
enums: planEnums(api),
|
||||
variantDiscs: variantDiscriminators(api),
|
||||
}
|
||||
}
|
||||
|
||||
// variantDiscriminator describes the JSON field+value that identifies a
|
||||
// concrete variant of a sealed-interface union on the wire.
|
||||
type variantDiscriminator struct {
|
||||
JSONField string // wire field name, e.g. "type" or "source"
|
||||
GoField string // Go struct field name, e.g. "Type" or "Source"
|
||||
Value string // the wire value, e.g. "all_private_chats"
|
||||
}
|
||||
|
||||
// variantDiscriminators returns variantTypeName → discriminator for every
|
||||
// concrete struct that participates in a sealed-interface union and has
|
||||
// a string-typed first field whose doc fixes its value (the canonical
|
||||
// "must be X" / "always “X”" patterns Telegram uses).
|
||||
//
|
||||
// Resolution order:
|
||||
//
|
||||
// 1. knownDiscriminators reverse-lookup (the 13 auto-decode unions).
|
||||
// This guarantees parity with UnmarshalXxx dispatch for the unions
|
||||
// that round-trip through the library.
|
||||
// 2. Doc-string analysis of the variant's first field, for marker-only
|
||||
// unions (BotCommandScope, InputMedia, etc.) where the IR has no
|
||||
// explicit discriminator metadata.
|
||||
//
|
||||
// Variants whose first field has no discriminator hint (Message,
|
||||
// InaccessibleMessage, the InputMessageContent family) are omitted —
|
||||
// the caller writes the dispatching fields directly and Telegram
|
||||
// identifies them structurally.
|
||||
func variantDiscriminators(api *spec.API) map[string]variantDiscriminator {
|
||||
out := make(map[string]variantDiscriminator, 128)
|
||||
|
||||
// Pass 1: reverse-lookup from knownDiscriminators.
|
||||
for _, ds := range knownDiscriminators {
|
||||
if ds.Field == "" {
|
||||
continue
|
||||
}
|
||||
for value, variant := range ds.Variants {
|
||||
out[variant] = variantDiscriminator{
|
||||
JSONField: ds.Field,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the set of every variant type referenced by any OneOf so we
|
||||
// can scan only those (avoids matching free-text "must be" prose in
|
||||
// non-variant types like Message).
|
||||
variantSet := make(map[string]bool, 128)
|
||||
for _, t := range api.Types {
|
||||
for _, v := range t.OneOf {
|
||||
variantSet[v] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: doc-parse for variants without a known discriminator.
|
||||
for _, t := range api.Types {
|
||||
if !variantSet[t.Name] {
|
||||
continue
|
||||
}
|
||||
if _, ok := out[t.Name]; ok {
|
||||
// Pass-1 already provided the wire value; we still need
|
||||
// the Go field name (mirrors the JSON field but with
|
||||
// proper case). Resolve from t.Fields by JSONName match.
|
||||
disc := out[t.Name]
|
||||
for _, f := range t.Fields {
|
||||
if f.JSONName == disc.JSONField {
|
||||
disc.GoField = f.Name
|
||||
out[t.Name] = disc
|
||||
break
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
disc, ok := extractVariantDiscriminator(t)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[t.Name] = disc
|
||||
}
|
||||
|
||||
// Drop entries we couldn't resolve a Go field for (defensive — every
|
||||
// pass-1 hit should have matched, but better to skip than emit
|
||||
// broken code referencing an unknown field name).
|
||||
for name, d := range out {
|
||||
if d.GoField == "" {
|
||||
delete(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// extractVariantDiscriminator inspects the first field of a variant
|
||||
// struct and returns its discriminator if the field is a required
|
||||
// string whose doc nails the value via "must be X" or "always “X”".
|
||||
// Returns (zero, false) when no clear discriminator is present.
|
||||
func extractVariantDiscriminator(t spec.TypeDecl) (variantDiscriminator, bool) {
|
||||
if len(t.Fields) == 0 {
|
||||
return variantDiscriminator{}, false
|
||||
}
|
||||
f := t.Fields[0]
|
||||
if !f.Required || f.Type.Kind != spec.KindPrimitive || f.Type.Name != "string" {
|
||||
return variantDiscriminator{}, false
|
||||
}
|
||||
value := parseDiscriminatorDoc(f.Doc)
|
||||
if value == "" {
|
||||
return variantDiscriminator{}, false
|
||||
}
|
||||
return variantDiscriminator{
|
||||
JSONField: f.JSONName,
|
||||
GoField: f.Name,
|
||||
Value: value,
|
||||
}, true
|
||||
}
|
||||
|
||||
// parseDiscriminatorDoc extracts the wire-level discriminator value
|
||||
// from a field doc string. Handles both Telegram phrasings:
|
||||
//
|
||||
// - "Scope type, must be all_private_chats" (bare token)
|
||||
// - "Type of the message origin, always “user”" (curly-quoted)
|
||||
//
|
||||
// Returns "" when no discriminator is present.
|
||||
func parseDiscriminatorDoc(doc string) string {
|
||||
// Curly-quoted form takes priority: "must be “X”" or "always “X”".
|
||||
if m := discCurlyRE.FindStringSubmatch(doc); len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
// Bare-token form: "must be <ident>" terminated by end-of-string,
|
||||
// punctuation, or whitespace.
|
||||
if m := discBareRE.FindStringSubmatch(doc); len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// knownInterfaceTypes is the full set of sealed-interface union type names
|
||||
@@ -182,7 +335,7 @@ var knownInterfaceTypes = map[string]bool{}
|
||||
|
||||
// emitTypes renders types.gen.go.
|
||||
func (e *emitter) emitTypes() error {
|
||||
t, err := template.New("types").Funcs(funcs(e.enums)).Parse(typesTmpl)
|
||||
t, err := template.New("types").Funcs(funcsWithDiscs(e.enums, e.variantDiscs)).Parse(typesTmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse types.tmpl: %w", err)
|
||||
}
|
||||
@@ -218,6 +371,22 @@ func loadAPI(path string) (*spec.API, error) {
|
||||
return &api, nil
|
||||
}
|
||||
|
||||
// funcsWithDiscs returns the shared FuncMap with the variant
|
||||
// discriminator helpers bound to discs. types.tmpl uses
|
||||
// variantDiscFor/variantHasDisc to emit per-variant MarshalJSON that
|
||||
// hardcodes the wire discriminator value.
|
||||
func funcsWithDiscs(plan *enumPlan, discs map[string]variantDiscriminator) template.FuncMap {
|
||||
fm := funcs(plan)
|
||||
fm["variantHasDisc"] = func(name string) bool {
|
||||
_, ok := discs[name]
|
||||
return ok
|
||||
}
|
||||
fm["variantDiscField"] = func(name string) string { return discs[name].JSONField }
|
||||
fm["variantDiscGoField"] = func(name string) string { return discs[name].GoField }
|
||||
fm["variantDiscValue"] = func(name string) string { return discs[name].Value }
|
||||
return fm
|
||||
}
|
||||
|
||||
// funcs is the FuncMap shared across templates. plan is the resolved
|
||||
// enum plan; pass nil only in unit tests that don't exercise enums.
|
||||
func funcs(plan *enumPlan) template.FuncMap {
|
||||
|
||||
@@ -84,6 +84,23 @@ func UnmarshalMaybeInaccessibleMessage(data []byte) (MaybeInaccessibleMessage, e
|
||||
type {{.Name}} struct {
|
||||
{{range .Fields}}{{docComment .Doc}}{{goField $td.Name .}}
|
||||
{{end}}}
|
||||
{{if variantHasDisc .Name}}
|
||||
// MarshalJSON encodes {{.Name}} with the discriminator field
|
||||
// "{{variantDiscField .Name}}" forced to {{printf "%q" (variantDiscValue .Name)}}.
|
||||
// The hardcoded value frees callers from setting {{variantDiscGoField .Name}} by hand —
|
||||
// any user-supplied value on the struct literal is overridden so a typo
|
||||
// can't slip through to Telegram.
|
||||
func (v *{{.Name}}) MarshalJSON() ([]byte, error) {
|
||||
type alias {{.Name}}
|
||||
return json.Marshal(&struct {
|
||||
{{variantDiscGoField .Name}} string `json:"{{variantDiscField .Name}}"`
|
||||
*alias
|
||||
}{
|
||||
{{variantDiscGoField .Name}}: {{printf "%q" (variantDiscValue .Name)}},
|
||||
alias: (*alias)(v),
|
||||
})
|
||||
}
|
||||
{{end}}
|
||||
{{$unionFields := unionFields .}}{{if $unionFields}}
|
||||
// UnmarshalJSON decodes {{.Name}} by dispatching union-typed fields
|
||||
// ({{range $i, $u := $unionFields}}{{if $i}}, {{end}}{{$u.Field.Name}}{{end}}) through their concrete UnmarshalXxx helpers.
|
||||
|
||||
Reference in New Issue
Block a user