mirror of
https://github.com/lukaszraczylo/go-telegram.git
synced 2026-06-05 22:43:59 +00:00
Initial release of go-telegram
A fully-generated, strongly-typed Go client for the Telegram Bot API. * 176 methods + 301 types generated from Bot API v10.0 * 1408 auto-generated tests (8 scenarios per method) * Typed unions throughout — no 'any' in the public surface * Pluggable HTTP transport and JSON codec (default goccy/go-json) * Built-in retry middleware honouring Telegram's retry_after * Generic dispatcher with filters and conversation handlers * Self-verifying codegen pipeline (regen → audit → emit → run tests) * 14 example bots covering common patterns
This commit is contained in:
@@ -0,0 +1,749 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/goccy/go-json"
|
||||
"go/format"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"text/template"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||
)
|
||||
|
||||
//go:embed types.tmpl
|
||||
var typesTmpl string
|
||||
|
||||
//go:embed methods.tmpl
|
||||
var methodsTmpl string
|
||||
|
||||
//go:embed enums.tmpl
|
||||
var enumsTmpl string
|
||||
|
||||
//go:embed tests.tmpl
|
||||
var testsTmpl string
|
||||
|
||||
// runtimeTypes lists types that are intentionally hand-coded and must not be
|
||||
// emitted by the code generator. Skipping them prevents collisions between
|
||||
// generated and hand-coded definitions.
|
||||
var runtimeTypes = map[string]bool{
|
||||
"InputFile": true,
|
||||
"ResponseParameters": true,
|
||||
"ChatID": true,
|
||||
"MessageOrBool": true,
|
||||
}
|
||||
|
||||
// discriminatorSpec describes how to decode a sealed-interface union by
|
||||
// peeking at a single JSON field.
|
||||
type discriminatorSpec struct {
|
||||
Field string // JSON field name to peek at
|
||||
Variants map[string]string // discriminator value → concrete Go type name
|
||||
}
|
||||
|
||||
// knownDiscriminators maps parent union name → discriminator spec.
|
||||
// Used by the template helpers hasDiscriminator / discriminatorField /
|
||||
// discriminatorMap to emit UnmarshalXxx helpers.
|
||||
var knownDiscriminators = map[string]discriminatorSpec{
|
||||
"ChatMember": {
|
||||
Field: "status",
|
||||
Variants: map[string]string{
|
||||
"creator": "ChatMemberOwner",
|
||||
"administrator": "ChatMemberAdministrator",
|
||||
"member": "ChatMemberMember",
|
||||
"restricted": "ChatMemberRestricted",
|
||||
"left": "ChatMemberLeft",
|
||||
"kicked": "ChatMemberBanned",
|
||||
},
|
||||
},
|
||||
"MessageOrigin": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"user": "MessageOriginUser",
|
||||
"hidden_user": "MessageOriginHiddenUser",
|
||||
"chat": "MessageOriginChat",
|
||||
"channel": "MessageOriginChannel",
|
||||
},
|
||||
},
|
||||
"ReactionType": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"emoji": "ReactionTypeEmoji",
|
||||
"custom_emoji": "ReactionTypeCustomEmoji",
|
||||
"paid": "ReactionTypePaid",
|
||||
},
|
||||
},
|
||||
"PaidMedia": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"preview": "PaidMediaPreview",
|
||||
"photo": "PaidMediaPhoto",
|
||||
"video": "PaidMediaVideo",
|
||||
},
|
||||
},
|
||||
"BackgroundType": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"fill": "BackgroundTypeFill",
|
||||
"wallpaper": "BackgroundTypeWallpaper",
|
||||
"pattern": "BackgroundTypePattern",
|
||||
"chat_theme": "BackgroundTypeChatTheme",
|
||||
},
|
||||
},
|
||||
"BackgroundFill": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"solid": "BackgroundFillSolid",
|
||||
"gradient": "BackgroundFillGradient",
|
||||
"freeform_gradient": "BackgroundFillFreeformGradient",
|
||||
},
|
||||
},
|
||||
"ChatBoostSource": {
|
||||
Field: "source",
|
||||
Variants: map[string]string{
|
||||
"premium": "ChatBoostSourcePremium",
|
||||
"gift_code": "ChatBoostSourceGiftCode",
|
||||
"giveaway": "ChatBoostSourceGiveaway",
|
||||
},
|
||||
},
|
||||
"RevenueWithdrawalState": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"pending": "RevenueWithdrawalStatePending",
|
||||
"succeeded": "RevenueWithdrawalStateSucceeded",
|
||||
"failed": "RevenueWithdrawalStateFailed",
|
||||
},
|
||||
},
|
||||
"TransactionPartner": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"fragment": "TransactionPartnerFragment",
|
||||
"user": "TransactionPartnerUser",
|
||||
"telegram_ads": "TransactionPartnerTelegramAds",
|
||||
"telegram_api": "TransactionPartnerTelegramApi",
|
||||
"other": "TransactionPartnerOther",
|
||||
},
|
||||
},
|
||||
"MenuButton": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"commands": "MenuButtonCommands",
|
||||
"web_app": "MenuButtonWebApp",
|
||||
"default": "MenuButtonDefault",
|
||||
},
|
||||
},
|
||||
"OwnedGift": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"regular": "OwnedGiftRegular",
|
||||
"unique": "OwnedGiftUnique",
|
||||
},
|
||||
},
|
||||
"StoryAreaType": {
|
||||
Field: "type",
|
||||
Variants: map[string]string{
|
||||
"location": "StoryAreaTypeLocation",
|
||||
"suggested_reaction": "StoryAreaTypeSuggestedReaction",
|
||||
"link": "StoryAreaTypeLink",
|
||||
"weather": "StoryAreaTypeWeather",
|
||||
"unique_gift": "StoryAreaTypeUniqueGift",
|
||||
},
|
||||
},
|
||||
// MaybeInaccessibleMessage uses an integer discriminator (date field).
|
||||
// Variants is nil — the standard template block is skipped; a
|
||||
// hand-coded UnmarshalMaybeInaccessibleMessage is emitted instead.
|
||||
"MaybeInaccessibleMessage": {
|
||||
Field: "",
|
||||
Variants: nil,
|
||||
},
|
||||
}
|
||||
|
||||
// emitter renders Go source from a spec.API IR.
|
||||
type emitter struct {
|
||||
api *spec.API
|
||||
outDir string
|
||||
}
|
||||
|
||||
func newEmitter(api *spec.API, outDir string) *emitter {
|
||||
return &emitter{api: api, outDir: outDir}
|
||||
}
|
||||
|
||||
// emitTypes renders types.gen.go.
|
||||
func (e *emitter) emitTypes() error {
|
||||
t, err := template.New("types").Funcs(funcs()).Parse(typesTmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse types.tmpl: %w", err)
|
||||
}
|
||||
filtered := *e.api
|
||||
filtered.Types = nil
|
||||
for _, typ := range e.api.Types {
|
||||
if !runtimeTypes[typ.Name] {
|
||||
filtered.Types = append(filtered.Types, typ)
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if execErr := t.Execute(&buf, &filtered); execErr != nil {
|
||||
return fmt.Errorf("execute types.tmpl: %w", execErr)
|
||||
}
|
||||
src, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
// Surface the unformatted output so debugging is possible.
|
||||
return fmt.Errorf("gofmt types.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||
}
|
||||
return os.WriteFile(filepath.Join(e.outDir, "types.gen.go"), src, 0o600)
|
||||
}
|
||||
|
||||
// loadAPI reads and decodes the IR JSON.
|
||||
func loadAPI(path string) (*spec.API, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var api spec.API
|
||||
if err := json.Unmarshal(data, &api); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api, nil
|
||||
}
|
||||
|
||||
// funcs is the FuncMap shared across templates.
|
||||
func funcs() template.FuncMap {
|
||||
return template.FuncMap{
|
||||
"goType": goType,
|
||||
"goField": goField,
|
||||
"docComment": docComment,
|
||||
"isOptional": func(f spec.Field) bool { return !f.Required },
|
||||
"not": func(b bool) bool { return !b },
|
||||
"title": title,
|
||||
"isFileField": isFileField,
|
||||
"fileCheck": fileCheck,
|
||||
"multipartFieldEntry": multipartFieldEntry,
|
||||
"multipartFileEntry": multipartFileEntry,
|
||||
"returnGoType": returnGoType,
|
||||
// discriminator helpers for types.tmpl
|
||||
"hasDiscriminator": func(name string) bool { s, ok := knownDiscriminators[name]; return ok && len(s.Variants) > 0 },
|
||||
"isSealedUnionReturn": func(tr spec.TypeRef) bool {
|
||||
if tr.Kind != spec.KindNamed {
|
||||
return false
|
||||
}
|
||||
s, ok := knownDiscriminators[tr.Name]
|
||||
return ok && len(s.Variants) > 0
|
||||
},
|
||||
"isMaybeInaccessibleMessage": func(name string) bool { return name == "MaybeInaccessibleMessage" },
|
||||
"discriminatorField": func(name string) string { return knownDiscriminators[name].Field },
|
||||
"discriminatorMap": func(name string) map[string]string { return knownDiscriminators[name].Variants },
|
||||
// union-field helpers for per-struct UnmarshalJSON emission
|
||||
"unionFields": unionFieldsOf,
|
||||
"isArrayUnion": func(tr spec.TypeRef) bool { return hasUnionElem(tr) },
|
||||
"unionTypeName": func(tr spec.TypeRef) string { name, _ := unionTypeFor(tr); return name },
|
||||
}
|
||||
}
|
||||
|
||||
// title upper-cases the first byte of s (ASCII only — all Telegram method names are ASCII).
|
||||
func title(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
r := s[0]
|
||||
if r >= 'a' && r <= 'z' {
|
||||
r = r - 'a' + 'A'
|
||||
}
|
||||
return string(r) + s[1:]
|
||||
}
|
||||
|
||||
// isFileField reports whether the field carries an InputFile.
|
||||
func isFileField(f spec.Field) bool {
|
||||
return mentionsInputFileTr(f.Type)
|
||||
}
|
||||
|
||||
func mentionsInputFileTr(tr spec.TypeRef) bool {
|
||||
switch tr.Kind {
|
||||
case spec.KindNamed:
|
||||
return tr.Name == "InputFile"
|
||||
case spec.KindArray:
|
||||
if tr.ElemType != nil {
|
||||
return mentionsInputFileTr(*tr.ElemType)
|
||||
}
|
||||
case spec.KindOneOf:
|
||||
for _, v := range tr.Variants {
|
||||
if v == "InputFile" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// fileCheck returns the HasFile guard line for a file-carrying field.
|
||||
// Both named InputFile and InputFile-or-String oneOf fields are now *InputFile,
|
||||
// so no type assertion is needed in either case.
|
||||
func fileCheck(f spec.Field) string {
|
||||
return fmt.Sprintf("\tif p.%s != nil && p.%s.IsLocalUpload() { return true }\n", f.Name, f.Name)
|
||||
}
|
||||
|
||||
// multipartFileEntry returns the MultipartFiles append block for a file field.
|
||||
// Both named InputFile and InputFile-or-String oneOf fields are now *InputFile,
|
||||
// so the same code works for both cases.
|
||||
func multipartFileEntry(f spec.Field) string {
|
||||
jsonName := f.JSONName
|
||||
return fmt.Sprintf(
|
||||
"\tif p.%s != nil && p.%s.IsLocalUpload() {\n\t\tname := p.%s.Filename\n\t\tif name == \"\" { name = %q }\n\t\tfiles = append(files, client.MultipartFile{FieldName: %q, Filename: name, Reader: p.%s.Reader})\n\t}\n",
|
||||
f.Name, f.Name, f.Name, jsonName, jsonName, f.Name)
|
||||
}
|
||||
|
||||
// multipartFieldEntry generates the line that adds f to the multipart map.
|
||||
// Required scalar fields go in unconditionally; optional ones go in only
|
||||
// when non-zero/non-empty.
|
||||
func multipartFieldEntry(f spec.Field) string {
|
||||
switch f.Type.Kind {
|
||||
case spec.KindPrimitive:
|
||||
switch f.Type.Name {
|
||||
case "int64":
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tout[%q] = strconv.FormatInt(p.%s, 10)\n", f.JSONName, f.Name)
|
||||
}
|
||||
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatInt(*p.%s, 10) }\n", f.Name, f.JSONName, f.Name)
|
||||
case "string":
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tout[%q] = p.%s\n", f.JSONName, f.Name)
|
||||
}
|
||||
return fmt.Sprintf("\tif p.%s != \"\" { out[%q] = p.%s }\n", f.Name, f.JSONName, f.Name)
|
||||
case "bool":
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tout[%q] = strconv.FormatBool(p.%s)\n", f.JSONName, f.Name)
|
||||
}
|
||||
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatBool(*p.%s) }\n", f.Name, f.JSONName, f.Name)
|
||||
case "float64":
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tout[%q] = strconv.FormatFloat(p.%s, 'f', -1, 64)\n", f.JSONName, f.Name)
|
||||
}
|
||||
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatFloat(*p.%s, 'f', -1, 64) }\n", f.Name, f.JSONName, f.Name)
|
||||
}
|
||||
case spec.KindOneOf:
|
||||
// Integer-or-String → ChatID: use .String() wire form.
|
||||
if matchesVariants(f.Type.Variants, "int64", "string") {
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tout[%q] = p.%s.String()\n", f.JSONName, f.Name)
|
||||
}
|
||||
return fmt.Sprintf("\tif !p.%s.IsZero() { out[%q] = p.%s.String() }\n", f.Name, f.JSONName, f.Name)
|
||||
}
|
||||
// InputFile-or-String → *InputFile: non-upload branch sends PathOrID.
|
||||
if matchesVariants(f.Type.Variants, "InputFile", "string") {
|
||||
return fmt.Sprintf("\tif p.%s != nil && !p.%s.IsLocalUpload() && p.%s.PathOrID != \"\" { out[%q] = p.%s.PathOrID }\n",
|
||||
f.Name, f.Name, f.Name, f.JSONName, f.Name)
|
||||
}
|
||||
// Sealed-interface unions — JSON-marshal.
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tif b, _ := json.Marshal(p.%s); len(b) > 0 && string(b) != \"null\" { out[%q] = string(b) }\n", f.Name, f.JSONName)
|
||||
}
|
||||
return fmt.Sprintf("\tif p.%s != nil { if b, _ := json.Marshal(p.%s); len(b) > 0 && string(b) != \"null\" { out[%q] = string(b) } }\n", f.Name, f.Name, f.JSONName)
|
||||
}
|
||||
// Named or array: fall back to JSON-marshal to JSON string.
|
||||
if f.Required {
|
||||
return fmt.Sprintf("\tif b, _ := json.Marshal(p.%s); len(b) > 0 { out[%q] = string(b) }\n", f.Name, f.JSONName)
|
||||
}
|
||||
return fmt.Sprintf("\tif p.%s != nil { if b, _ := json.Marshal(p.%s); len(b) > 0 { out[%q] = string(b) } }\n", f.Name, f.Name, f.JSONName)
|
||||
}
|
||||
|
||||
func returnGoType(tr spec.TypeRef) string {
|
||||
switch tr.Kind {
|
||||
case spec.KindPrimitive:
|
||||
return tr.Name
|
||||
case spec.KindNamed:
|
||||
// Sealed-interface unions are returned by interface value, not pointer
|
||||
// (you can't take a pointer to an interface in any useful way; the
|
||||
// generated UnmarshalXxx returns the interface directly).
|
||||
if _, ok := knownDiscriminators[tr.Name]; ok {
|
||||
return tr.Name
|
||||
}
|
||||
// MessageOrBool is a hand-coded runtime wrapper — pointer return.
|
||||
return "*" + tr.Name
|
||||
case spec.KindArray:
|
||||
if tr.ElemType == nil {
|
||||
return "[]any"
|
||||
}
|
||||
return "[]" + returnGoElem(*tr.ElemType)
|
||||
case spec.KindOneOf:
|
||||
// Integer-or-String return (rare but possible).
|
||||
if matchesVariants(tr.Variants, "int64", "string") {
|
||||
return "ChatID"
|
||||
}
|
||||
return "any"
|
||||
}
|
||||
return "any"
|
||||
}
|
||||
|
||||
func returnGoElem(tr spec.TypeRef) string {
|
||||
switch tr.Kind {
|
||||
case spec.KindPrimitive:
|
||||
return tr.Name
|
||||
case spec.KindNamed:
|
||||
return tr.Name
|
||||
case spec.KindArray:
|
||||
if tr.ElemType == nil {
|
||||
return "any"
|
||||
}
|
||||
return "[]" + returnGoElem(*tr.ElemType)
|
||||
}
|
||||
return "any"
|
||||
}
|
||||
|
||||
// emitMethods renders methods.gen.go.
|
||||
func (e *emitter) emitMethods() error {
|
||||
t, err := template.New("methods").Funcs(funcs()).Parse(methodsTmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse methods.tmpl: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if execErr := t.Execute(&buf, e.api); execErr != nil {
|
||||
return fmt.Errorf("execute methods.tmpl: %w", execErr)
|
||||
}
|
||||
src, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("gofmt methods.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||
}
|
||||
return os.WriteFile(filepath.Join(e.outDir, "methods.gen.go"), src, 0o600)
|
||||
}
|
||||
|
||||
// emitEnums renders enums.gen.go.
|
||||
func (e *emitter) emitEnums() error {
|
||||
t, err := template.New("enums").Funcs(funcs()).Parse(enumsTmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse enums.tmpl: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if execErr := t.Execute(&buf, e.api); execErr != nil {
|
||||
return fmt.Errorf("execute enums.tmpl: %w", execErr)
|
||||
}
|
||||
src, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("gofmt enums.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||
}
|
||||
return os.WriteFile(filepath.Join(e.outDir, "enums.gen.go"), src, 0o600)
|
||||
}
|
||||
|
||||
// goType returns the Go type expression for a TypeRef.
|
||||
// Optional fields use pointer types for primitives and named types,
|
||||
// or rely on omitempty for slices and maps. parameter `optional` controls
|
||||
// whether to wrap pointer-style.
|
||||
func goType(tr spec.TypeRef, optional bool) string {
|
||||
switch tr.Kind {
|
||||
case spec.KindPrimitive:
|
||||
if optional && (tr.Name == "bool" || tr.Name == "int64" || tr.Name == "float64") {
|
||||
return "*" + tr.Name
|
||||
}
|
||||
return tr.Name
|
||||
case spec.KindNamed:
|
||||
// Named types are always pointer-optional when optional, except:
|
||||
// 1. Union (interface) types — they are naturally nil-able; pointer-to-interface is invalid.
|
||||
// 2. InputFile is always pointer-typed even when required: the
|
||||
// multipart helpers (fileCheck, multipartFileEntry) call
|
||||
// f.IsLocalUpload() and dereference Reader, both of which
|
||||
// expect a pointer receiver.
|
||||
if _, isUnion := knownDiscriminators[tr.Name]; isUnion {
|
||||
// Interface type — never add *.
|
||||
return tr.Name
|
||||
}
|
||||
if optional || tr.Name == "InputFile" {
|
||||
return "*" + tr.Name
|
||||
}
|
||||
return tr.Name
|
||||
case spec.KindArray:
|
||||
if tr.ElemType == nil {
|
||||
return "[]any"
|
||||
}
|
||||
// Inside slices, the element shape is its own thing — never wrap
|
||||
// the element in a pointer just because the field is optional.
|
||||
return "[]" + goType(*tr.ElemType, false)
|
||||
case spec.KindOneOf:
|
||||
// Integer-or-String: typed ChatID wrapper.
|
||||
if matchesVariants(tr.Variants, "int64", "string") {
|
||||
if optional {
|
||||
return "*ChatID"
|
||||
}
|
||||
return "ChatID"
|
||||
}
|
||||
// InputFile-or-String: *InputFile runtime helper handles both.
|
||||
if matchesVariants(tr.Variants, "InputFile", "string") {
|
||||
return "*InputFile"
|
||||
}
|
||||
// All-named variants sealed interface: fall back to interface.
|
||||
return "any"
|
||||
}
|
||||
return "any"
|
||||
}
|
||||
|
||||
// unionField pairs a struct field with the name of its union type.
|
||||
type unionField struct {
|
||||
Field spec.Field
|
||||
UnionName string // e.g. "ChatMember"
|
||||
}
|
||||
|
||||
// unionFieldsOf returns the subset of t.Fields whose type is a known
|
||||
// discriminated union (directly or as array element).
|
||||
func unionFieldsOf(t spec.TypeDecl) []unionField {
|
||||
var out []unionField
|
||||
for _, f := range t.Fields {
|
||||
if u, ok := unionTypeFor(f.Type); ok {
|
||||
out = append(out, unionField{Field: f, UnionName: u})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// unionTypeFor inspects a TypeRef and reports whether it (or its array
|
||||
// element) is a known discriminated union. Returns the union name and true.
|
||||
func unionTypeFor(tr spec.TypeRef) (string, bool) {
|
||||
switch tr.Kind {
|
||||
case spec.KindNamed:
|
||||
if _, ok := knownDiscriminators[tr.Name]; ok {
|
||||
return tr.Name, true
|
||||
}
|
||||
case spec.KindArray:
|
||||
if tr.ElemType != nil {
|
||||
return unionTypeFor(*tr.ElemType)
|
||||
}
|
||||
case spec.KindOneOf:
|
||||
if u := unionNameByVariants(tr.Variants); u != "" {
|
||||
return u, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// unionNameByVariants finds the parent union whose variant type names exactly
|
||||
// match the given variant set (order-insensitive).
|
||||
func unionNameByVariants(variants []string) string {
|
||||
for parentName, ds := range knownDiscriminators {
|
||||
wanted := make([]string, 0, len(ds.Variants))
|
||||
for _, vt := range ds.Variants {
|
||||
wanted = append(wanted, vt)
|
||||
}
|
||||
if matchesVariants(variants, wanted...) {
|
||||
return parentName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// hasUnionElem reports whether tr is an array whose element type is a known union.
|
||||
func hasUnionElem(tr spec.TypeRef) bool {
|
||||
if tr.Kind != spec.KindArray || tr.ElemType == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := unionTypeFor(*tr.ElemType)
|
||||
return ok
|
||||
}
|
||||
|
||||
// matchesVariants reports whether got equals want as a set (order-insensitive).
|
||||
func matchesVariants(got []string, want ...string) bool {
|
||||
if len(got) != len(want) {
|
||||
return false
|
||||
}
|
||||
seen := make(map[string]int, len(got))
|
||||
for _, g := range got {
|
||||
seen[g]++
|
||||
}
|
||||
for _, w := range want {
|
||||
seen[w]--
|
||||
}
|
||||
for _, v := range seen {
|
||||
if v != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// goField returns the Go struct-field declaration for a Field.
|
||||
func goField(f spec.Field) string {
|
||||
tag := fmt.Sprintf("`json:%q`", f.JSONName+omitempty(f))
|
||||
return fmt.Sprintf("%s %s %s", f.Name, goType(f.Type, !f.Required), tag)
|
||||
}
|
||||
|
||||
func omitempty(f spec.Field) string {
|
||||
if f.Required {
|
||||
return ""
|
||||
}
|
||||
return ",omitempty"
|
||||
}
|
||||
|
||||
// docComment converts a doc string into a Go-style block comment with
|
||||
// a leading "// " on each line.
|
||||
func docComment(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
for _, line := range splitLines(s) {
|
||||
buf.WriteString("// ")
|
||||
buf.WriteString(line)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
var out []string
|
||||
start := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == '\n' {
|
||||
out = append(out, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(s) {
|
||||
out = append(out, s[start:])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// hasVariants reports whether the variant list contains all of the named strings (order-insensitive).
|
||||
func hasVariants(variants []string, names ...string) bool {
|
||||
return matchesVariants(variants, names...)
|
||||
}
|
||||
|
||||
// buildUnionTypeSet returns the set of all type names that generate interface types
|
||||
// (i.e., types with one_of). This includes knownDiscriminators and marker-interface
|
||||
// unions not covered by the discriminator map.
|
||||
func buildUnionTypeSet(api *spec.API) map[string]bool {
|
||||
s := make(map[string]bool, len(knownDiscriminators)+16)
|
||||
for name := range knownDiscriminators {
|
||||
s[name] = true
|
||||
}
|
||||
for _, t := range api.Types {
|
||||
if len(t.OneOf) > 0 {
|
||||
s[t.Name] = true
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// makeSentinelValue returns a sentinelValue func that uses the given union type set.
|
||||
// It returns a minimal valid Go expression for a spec.Field's type,
|
||||
// used in generated test param literals.
|
||||
func makeSentinelValue(unionTypes map[string]bool) func(spec.Field) string {
|
||||
return func(f spec.Field) string {
|
||||
return sentinelForField(f, unionTypes)
|
||||
}
|
||||
}
|
||||
|
||||
func sentinelForField(f spec.Field, unionTypes map[string]bool) string {
|
||||
tr := f.Type
|
||||
switch tr.Kind {
|
||||
case spec.KindPrimitive:
|
||||
switch tr.Name {
|
||||
case "int64":
|
||||
return "42"
|
||||
case "string":
|
||||
return `"test_value"`
|
||||
case "bool":
|
||||
return "true"
|
||||
case "float64":
|
||||
return "1.0"
|
||||
}
|
||||
case spec.KindNamed:
|
||||
switch tr.Name {
|
||||
case "ChatID":
|
||||
return "ChatIDFromInt(123)"
|
||||
case "InputFile":
|
||||
return `&InputFile{PathOrID: "file_id_test"}`
|
||||
}
|
||||
// Interface (union) types are nil-able.
|
||||
if unionTypes[tr.Name] {
|
||||
return "nil"
|
||||
}
|
||||
// Required named struct types are value types in the generated struct.
|
||||
if f.Required {
|
||||
return tr.Name + "{}"
|
||||
}
|
||||
return "&" + tr.Name + "{}"
|
||||
case spec.KindArray:
|
||||
return "nil"
|
||||
case spec.KindOneOf:
|
||||
if hasVariants(tr.Variants, "int64", "string") {
|
||||
return "ChatIDFromInt(123)"
|
||||
}
|
||||
if hasVariants(tr.Variants, "InputFile", "string") {
|
||||
return `&InputFile{PathOrID: "file_id_test"}`
|
||||
}
|
||||
// Sealed named-union interface: use nil (any).
|
||||
return "nil"
|
||||
}
|
||||
return "nil"
|
||||
}
|
||||
|
||||
// successResp returns a backtick Go string literal containing a minimal
|
||||
// {"ok":true,"result":...} JSON body for the method's return type.
|
||||
func successResp(m spec.MethodDecl) string {
|
||||
body := successBody(m.Returns)
|
||||
return "`{\"ok\":true,\"result\":" + body + "}`"
|
||||
}
|
||||
|
||||
func successBody(tr spec.TypeRef) string {
|
||||
switch tr.Kind {
|
||||
case spec.KindPrimitive:
|
||||
switch tr.Name {
|
||||
case "bool":
|
||||
return "true"
|
||||
case "int64", "float64":
|
||||
return "0"
|
||||
case "string":
|
||||
return `""`
|
||||
}
|
||||
case spec.KindNamed:
|
||||
if tr.Name == "MessageOrBool" {
|
||||
return "true"
|
||||
}
|
||||
// Sealed-interface unions need a discriminator field so UnmarshalXxx can dispatch.
|
||||
// Pick the lexicographically first variant value for determinism (map
|
||||
// iteration order in Go is randomized — using `range` directly produces
|
||||
// non-deterministic regen output).
|
||||
if disc, ok := knownDiscriminators[tr.Name]; ok && disc.Field != "" {
|
||||
values := make([]string, 0, len(disc.Variants))
|
||||
for v := range disc.Variants {
|
||||
values = append(values, v)
|
||||
}
|
||||
sort.Strings(values)
|
||||
if len(values) > 0 {
|
||||
return fmt.Sprintf(`{"%s":"%s"}`, disc.Field, values[0])
|
||||
}
|
||||
}
|
||||
// MaybeInaccessibleMessage uses date==0 → InaccessibleMessage variant.
|
||||
if tr.Name == "MaybeInaccessibleMessage" {
|
||||
return `{"date":0,"chat":{"id":1,"type":"private"},"message_id":1}`
|
||||
}
|
||||
return "{}"
|
||||
case spec.KindArray:
|
||||
return "[]"
|
||||
case spec.KindOneOf:
|
||||
return "null"
|
||||
}
|
||||
return "null"
|
||||
}
|
||||
|
||||
// emitTests renders methods_gen_test.go.
|
||||
func (e *emitter) emitTests() error {
|
||||
unionTypes := buildUnionTypeSet(e.api)
|
||||
|
||||
// Add test-specific helpers to the shared func map.
|
||||
fm := funcs()
|
||||
fm["sentinelValue"] = makeSentinelValue(unionTypes)
|
||||
fm["successResp"] = successResp
|
||||
|
||||
t, err := template.New("tests").Funcs(fm).Parse(testsTmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse tests.tmpl: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if execErr := t.Execute(&buf, e.api); execErr != nil {
|
||||
return fmt.Errorf("execute tests.tmpl: %w", execErr)
|
||||
}
|
||||
src, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("gofmt methods_gen_test.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||
}
|
||||
return os.WriteFile(filepath.Join(e.outDir, "methods_gen_test.go"), src, 0o600)
|
||||
}
|
||||
Reference in New Issue
Block a user