Initial release of go-telegram

A fully-generated, strongly-typed Go client for the Telegram Bot API.

* 176 methods + 301 types generated from Bot API v10.0
* 1408 auto-generated tests (8 scenarios per method)
* Typed unions throughout — no 'any' in the public surface
* Pluggable HTTP transport and JSON codec (default goccy/go-json)
* Built-in retry middleware honouring Telegram's retry_after
* Generic dispatcher with filters and conversation handlers
* Self-verifying codegen pipeline (regen → audit → emit → run tests)
* 14 example bots covering common patterns
This commit is contained in:
2026-05-09 13:09:27 +01:00
commit ac7cae8fa7
164 changed files with 100239 additions and 0 deletions
+274
View File
@@ -0,0 +1,274 @@
name: ci
on:
push:
branches: [main]
pull_request:
workflow_dispatch:
inputs:
dry-run-release:
description: "Compute release version, do not tag or release"
type: boolean
default: false
permissions:
contents: write
pull-requests: read
packages: write
# Cancel any in-flight CI runs for the same branch when a new push lands.
concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true
jobs:
vet:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- run: go vet ./...
staticcheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- run: go install honnef.co/go/tools/cmd/staticcheck@v0.7.0
- run: staticcheck ./...
govulncheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- run: go install golang.org/x/vuln/cmd/govulncheck@latest
- run: govulncheck ./...
gosec:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- run: go install github.com/securego/gosec/v2/cmd/gosec@latest
# G404: math/rand/v2 jitter in transport/backoff.go — intentional (not crypto)
# G304: os.ReadFile from CLI flag variable — intentional (tool, not server)
# G306: 0o644 on generated doc artifacts in cmd/scrape — intentional
# G204: git subprocess in cmd/audit uses CLI flag path — intentional (operator tool)
# G706: log.Printf with values from Telegram/env in examples — illustrative,
# library users are expected to sanitise before logging in production
- run: gosec -quiet -exclude=G404,G304,G306,G204,G706 -exclude-dir=testdata ./...
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- run: go test -race -coverprofile=coverage.out ./...
- name: Build all examples
run: go build ./examples/...
- uses: actions/upload-artifact@v4
with:
name: coverage
path: coverage.out
codegen-clean:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- name: Regenerate against pinned snapshot
run: make regen-from-fixture
- name: Assert clean diff
run: git diff --exit-code internal/spec/api.json api/
audit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0 # need history for drift comparison
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
- name: Audit fallbacks
run: make audit
- name: Audit drift vs base
# On PRs: compare against the merge base (origin/<base>).
# On push to main: compare against the parent commit.
# Drift is informational; doesn't fail CI.
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
BASE="origin/${{ github.base_ref }}"
else
BASE="HEAD~1"
fi
echo "Drift base: $BASE"
go run ./cmd/audit -ir internal/spec/api.json -drift -against "$BASE" || true
# Aggregate gate — depends on every check above. Used as a single
# required status check in branch protection AND as a dependency for the
# release job below.
ci-success:
runs-on: ubuntu-latest
needs: [vet, staticcheck, govulncheck, gosec, test, codegen-clean, audit]
if: always()
steps:
- name: All checks passed
if: ${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}
run: echo "ci-success"
- name: At least one check failed
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') }}
run: |
echo "Failed/cancelled jobs:"
echo '${{ toJSON(needs) }}'
exit 1
# Auto-release fires on every clean push to main (and on manual
# workflow_dispatch for testing). Computes next SemVer from commit
# history via lukaszraczylo/semver-generator, dual-tags
# (v<X.Y.Z> + bot-api-v<A.B>), runs GoReleaser.
release:
needs: ci-success
if: |
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
github.event_name == 'workflow_dispatch'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
# Action interface from
# https://github.com/lukaszraczylo/semver-generator/blob/main/action.yml
# Inputs: repository_local: true (use already-cloned repo)
# existing: true (respect existing tags as base)
# Output: semantic_version (bare version string, no "v" prefix)
- name: Compute next SemVer
id: semver
uses: lukaszraczylo/semver-generator@v1
with:
repository_local: true
existing: true
config_file: .semver.yaml
- name: Read Bot API version
id: api_version
run: |
VERSION=$(python3 -c 'import json; print(json.load(open("internal/spec/api.json"))["version"])')
if [ -z "$VERSION" ] || [ "$VERSION" = "null" ]; then
echo "tag=" >> "$GITHUB_OUTPUT"
else
echo "tag=bot-api-v${VERSION}" >> "$GITHUB_OUTPUT"
fi
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
- name: Dry-run summary
if: github.event_name == 'workflow_dispatch' && inputs.dry-run-release == true
run: |
echo "Would release: v${{ steps.semver.outputs.semantic_version }}"
if [ -n "${{ steps.api_version.outputs.tag }}" ]; then
echo "Would also tag: ${{ steps.api_version.outputs.tag }} (Bot API ${{ steps.api_version.outputs.version }})"
fi
echo "Skipping tag + release (dry-run)."
- name: Tag library + bot-api versions
if: github.event_name != 'workflow_dispatch' || inputs.dry-run-release == false
env:
LIB_TAG: v${{ steps.semver.outputs.semantic_version }}
API_TAG: ${{ steps.api_version.outputs.tag }}
API_VER: ${{ steps.api_version.outputs.version }}
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git tag -a "$LIB_TAG" -m "Release $LIB_TAG"
git push origin "$LIB_TAG"
if [ -n "$API_TAG" ]; then
# Force-update the bot-api tag so it always points at the latest
# release that supports that API version.
if git rev-parse "$API_TAG" >/dev/null 2>&1; then
git tag -f -a "$API_TAG" -m "go-telegram release $LIB_TAG (Bot API $API_VER)"
git push -f origin "$API_TAG"
else
git tag -a "$API_TAG" -m "go-telegram release $LIB_TAG (Bot API $API_VER)"
git push origin "$API_TAG"
fi
fi
- name: Run GoReleaser
if: github.event_name != 'workflow_dispatch' || inputs.dry-run-release == false
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser
version: '~> v2'
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BOT_API_VERSION: ${{ steps.api_version.outputs.version }}
+118
View File
@@ -0,0 +1,118 @@
name: regen
on:
schedule:
- cron: "0 6 * * 1" # Monday 06:00 UTC
workflow_dispatch: {}
permissions:
contents: write
pull-requests: write
jobs:
regen:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0 # full history so audit -drift can compare against main
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.25.x'
check-latest: true
- name: Capture latest snapshot
id: snapshot
run: |
DATE=$(date +%Y-%m-%d)
DEST="testdata/html/snapshot_${DATE}.html"
curl -fsSL --user-agent "go-telegram codegen scraper" \
https://core.telegram.org/bots/api > "$DEST"
ln -sf "snapshot_${DATE}.html" testdata/html/latest.html
echo "date=$DATE" >> $GITHUB_OUTPUT
echo "dest=$DEST" >> $GITHUB_OUTPUT
- name: Regenerate (scrape + emit, with clean-generated)
# `make regen` depends on `clean-generated`, which sweeps any orphan
# api/*.gen.go files left behind by removed/renamed methods.
run: make regen
- name: Audit fallbacks
id: audit
run: |
set +e
OUT=$(make audit 2>&1)
STATUS=$?
set -e
echo "$OUT"
{
echo 'output<<EOF'
echo "$OUT"
echo 'EOF'
} >> $GITHUB_OUTPUT
echo "status=$STATUS" >> $GITHUB_OUTPUT
# Don't fail the workflow on fallbacks — surface them in the PR body so
# the reviewer can extend overrides.json or fix scraper patterns.
- name: Audit drift vs main
id: drift
run: |
set +e
DRIFT=$(go run ./cmd/audit -ir internal/spec/api.json -drift -against origin/main 2>&1)
set -e
echo "$DRIFT"
{
echo 'output<<EOF'
echo "$DRIFT"
echo 'EOF'
} >> $GITHUB_OUTPUT
- name: Run tests
run: go test -race ./...
- name: Detect changes
id: diff
run: |
git status --porcelain
if git diff --quiet internal/spec/api.json api/ testdata/html/; then
echo "no_changes=true" >> $GITHUB_OUTPUT
fi
- name: Read API version
if: steps.diff.outputs.no_changes != 'true'
id: meta
run: |
VERSION=$(python3 -c 'import json; print(json.load(open("internal/spec/api.json")).get("version", "unknown"))')
echo "version=$VERSION" >> $GITHUB_OUTPUT
- name: Open PR
if: steps.diff.outputs.no_changes != 'true'
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: |
chore(api): regenerate from Telegram Bot API v${{ steps.meta.outputs.version }}
branch: regen/api-v${{ steps.meta.outputs.version }}
title: "chore(api): regenerate from Telegram Bot API v${{ steps.meta.outputs.version }}"
labels: automated, api-update
body: |
Automated regeneration from `https://core.telegram.org/bots/api`.
**API version:** v${{ steps.meta.outputs.version }}
**Snapshot date:** ${{ steps.snapshot.outputs.date }}
## Audit (fallbacks)
```
${{ steps.audit.outputs.output }}
```
## Drift vs `main`
```
${{ steps.drift.outputs.output }}
```
Inspect the IR diff (`internal/spec/api.json`) for added/changed/removed methods.
CI must pass before merge. Auto-merge: enable when satisfied with the diff.
+45
View File
@@ -0,0 +1,45 @@
# Binaries
/bin/
*.exe
*.dll
*.so
*.dylib
# Test artifacts
*.test
*.out
coverage.out
coverage.html
# IDE
.idea/
.vscode/
*.swp
*~
# OS
.DS_Store
# Local env
.env
.env.local
# Stray binaries at repo root from `go build ./cmd/...` or `go run` artefacts.
# Listed explicitly (not via /* glob) so source dirs are never accidentally ignored.
/echo
/webhook
/genapi
/scrape
/audit
/callback
/conversation
/files
/inline
/middleware
/stateful
/admin
/moderation
/pagination
/payments
/polls
/welcome
+72
View File
@@ -0,0 +1,72 @@
version: "2"
linters:
default: none
enable:
- govet
- staticcheck
- unused
- errcheck
- gosec
settings:
unused:
field-writes-are-uses: true
post-statements-are-reads: true
exported-is-used: true
exported-fields-are-used: true
govet:
enable-all: true
disable:
- fieldalignment # Field order is intentional per project spec; alignment not enforced.
staticcheck:
checks: ["all"]
exclusions:
presets:
- common-false-positives
rules:
- path: _test\.go
linters:
- unused
- errcheck
- path: \.pb\.go$
linters:
- all
- path: transport/backoff\.go
linters:
- gosec
text: "G404" # math/rand/v2 acceptable for jitter
- path: cmd/scrape/
linters:
- gosec
text: "G306" # 0o644 intentional: output files are world-readable docs artifacts
- path: cmd/audit/
linters:
- gosec
text: "G204" # git subprocess with CLI flag path — intentional operator tool
- path: cmd/audit/.*_test\.go
linters:
- gosec
text: "G302" # test helper sets 0o755 on fake git binary — executable permission intentional
formatters:
enable:
- gofmt
settings:
gofmt:
simplify: true
run:
timeout: 5m
tests: true
modules-download-mode: readonly
output:
formats:
text:
path: stdout
colors: true
sort-results: true
+121
View File
@@ -0,0 +1,121 @@
# Documentation: https://goreleaser.com
version: 2
project_name: go-telegram
before:
hooks:
- go mod tidy
builds:
- id: scrape
main: ./cmd/scrape
binary: tg-scrape
env:
- CGO_ENABLED=0
goos:
- linux
- darwin
- windows
goarch:
- amd64
- arm64
ldflags:
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}}
- id: genapi
main: ./cmd/genapi
binary: tg-genapi
env:
- CGO_ENABLED=0
goos:
- linux
- darwin
- windows
goarch:
- amd64
- arm64
ldflags:
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}}
- id: audit
main: ./cmd/audit
binary: tg-audit
env:
- CGO_ENABLED=0
goos:
- linux
- darwin
- windows
goarch:
- amd64
- arm64
ldflags:
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}}
archives:
- id: default
formats: [tar.gz]
name_template: >-
{{ .ProjectName }}_{{ .Version }}_
{{- title .Os }}_
{{- if eq .Arch "amd64" }}x86_64
{{- else if eq .Arch "386" }}i386
{{- else }}{{ .Arch }}{{ end }}
format_overrides:
- goos: windows
formats: [zip]
files:
- LICENSE
- README.md
- CHANGELOG.md
checksum:
name_template: 'checksums.txt'
snapshot:
version_template: '{{ incpatch .Version }}-next'
changelog:
sort: asc
use: github
groups:
- title: Features
regexp: '^.*?feat(\(.+\))?:.+$'
order: 0
- title: Bug fixes
regexp: '^.*?fix(\(.+\))?:.+$'
order: 1
- title: Documentation
regexp: '^.*?docs(\(.+\))?:.+$'
order: 2
- title: Other
order: 999
filters:
exclude:
- '^chore:'
- '^test:'
- '^style:'
- 'merge conflict'
- Merge pull request
- Merge remote-tracking branch
- Merge branch
- go mod tidy
release:
github:
owner: lukaszraczylo
name: go-telegram
prerelease: auto
mode: replace
header: |
## go-telegram {{ .Tag }}
Released on {{ .Date }} for Telegram Bot API regeneration tooling.
Built against **Bot API {{ envOrDefault "BOT_API_VERSION" "(unspecified)" }}**.
The library itself ships via `go get github.com/lukaszraczylo/go-telegram@{{ .Tag }}`.
Binaries below are the codegen tools (`tg-scrape`, `tg-genapi`, `tg-audit`)
for users who want to vendor regen tooling.
footer: |
**Full changelog**: https://github.com/lukaszraczylo/go-telegram/compare/{{ .PreviousTag }}...{{ .Tag }}
+43
View File
@@ -0,0 +1,43 @@
# Configuration for lukaszraczylo/semver-generator.
# Reference: https://github.com/lukaszraczylo/semver-generator
#
# Word matching is fuzzy + case-insensitive. The keywords below mirror
# Conventional Commits prefixes used in this repo's git history.
version: 1
# Respect existing v* tags as the version baseline. semver-generator finds
# the highest existing tag and bumps from there.
force:
existing: true
# Skip merge commits and machine-generated traffic that would otherwise
# spuriously bump the version.
blacklist:
- "Merge branch"
- "Merge pull request"
- "Merge remote-tracking branch"
- "go mod tidy"
# Strip the auto-generated bot-api-vX.Y tag prefix when scanning existing
# tags — those are markers that point at library releases, not version
# sources themselves.
tag_prefixes:
- "bot-api-"
wording:
patch:
- "fix"
- "chore"
- "docs"
- "test"
- "style"
- "refactor"
- "build"
- "ci"
- "perf"
minor:
- "feat"
major:
- "breaking"
- "BREAKING CHANGE"
+56
View File
@@ -0,0 +1,56 @@
# Changelog
All notable changes to this project will be documented in this file. The format follows [Keep a Changelog](https://keepachangelog.com/) and this project adheres to [Semantic Versioning](https://semver.org/).
## [Unreleased]
## [1.0.0] - 2026-05-09
Initial public release. Built against Telegram Bot API v10.0 (176 methods, 301 types).
### Library surface
- `client.Bot` with pluggable `HTTPDoer`, `Codec` (default `github.com/goccy/go-json`), and `Logger` interfaces.
- Generic `client.Call[Req, Resp]` and `client.CallRaw[Req]` helpers — every API method funnels through one of these.
- `client.RetryDoer` middleware — exponential backoff with jitter; honours Telegram's `retry_after`; replays request bodies across attempts. Configurable via `RetryOption` functions (`WithMaxAttempts`, `WithBase`, `WithMax`, `WithFactor`, `WithJitter`).
- Typed errors: `*APIError` (sentinel-mapped: `ErrUnauthorized`, `ErrForbidden`, `ErrTooManyRequests`, `ErrChatNotFound`, `ErrMessageNotModified`, `ErrBadRequest`, `ErrUserNotFound`, `ErrMessageNotFound`), `*NetworkError`, `*ParseError`. Bot tokens redacted from error messages.
### Update delivery
- `transport.LongPoller` — getUpdates loop with `ExponentialBackoff`, retry-after honouring, graceful shutdown via `Stop`.
- `transport.WebhookServer``http.Handler` + `ListenAndServe`; secret-token verification (constant-time); 1 MiB body cap; in-flight handler tracking via `WaitGroup`.
### Dispatcher
- Generic `dispatch.Handler[T]` and `dispatch.Middleware[T]`.
- 21 typed handler registrations: `OnCommand`, `OnText`, `OnCallback`, `OnInlineQuery`, `OnEditedMessage`, `OnChannelPost`, `OnEditedChannelPost`, `OnMyChatMember`, `OnChatMember`, `OnChatJoinRequest`, `OnPreCheckoutQuery`, `OnShippingQuery`, `OnPoll`, `OnPollAnswer`, `OnChosenInlineResult`, `OnMessageReaction`, `OnMessageReactionCount`, `OnChatBoost`, `OnRemovedChatBoost`, `OnBusinessConnection`, `OnPurchasedPaidMedia`. Filter variants available for message, callback, inline-query, chat-member, chat-join-request, and pre-checkout-query types.
- Composable filters in `dispatch/filters/{message,callback,inline,chatmember,chatjoinrequest,precheckoutquery}` packages with `And`/`Or`/`Not`/`All`/`Any` combinators and 20+ filter helpers.
- `dispatch/conversation` — multi-step state machines with `Storage` interface (`MemoryStorage` default), pluggable `KeyStrategy` (`KeyByUser`, `KeyByChat`, `KeyByUserAndChat`), entry/state/exit/fallback steps, `AllowReEntry`. `Next(state)` and `End()` sentinel errors drive transitions. `Conversation.Dispatch` integrates as middleware via `router.Use`.
- Per-update goroutine pool (default 50; configurable via `WithMaxConcurrency`). Pass `0` for serial dispatch.
- Handler groups with `EndGroups`/`ContinueGroups` flow control.
- `NamedHandlers[T]` for runtime registration and replacement.
- Panic-recovery middleware + automatic handler-error logging registered by default.
### Generated API
- Full Bot API v10.0 surface in `api/*.gen.go` — 176 methods, 301 types, regenerated from a committed HTML snapshot of `core.telegram.org/bots/api`.
- Strongly typed unions: `ChatID` (Integer-or-String), `*InputFile` (InputFile-or-String), `MessageOrBool` (Message-or-True returns).
- 13 discriminated unions with auto-decode: `ChatMember`, `MessageOrigin`, `ReactionType`, `PaidMedia`, `BackgroundType`, `BackgroundFill`, `ChatBoostSource`, `RevenueWithdrawalState`, `TransactionPartner`, `MenuButton`, `OwnedGift`, `StoryAreaType`, `MaybeInaccessibleMessage`. Parent structs containing union-typed fields auto-decode via generated `UnmarshalJSON`.
- Sealed-interface return types use `client.CallRaw` + `Unmarshal<Name>` dispatch (e.g., `GetChatMember` returns `ChatMember` interface, decoded into the correct concrete variant).
### Runtime helpers
- `api.MeCache` — concurrent-safe `getMe` cache; zero-value safe.
- `api.DownloadFile` / `api.DownloadFileByPath` — fetch file contents from the Telegram CDN.
- `(*Message).GetSender()` — unifies `From`, `SenderChat`, and anonymous-admin fields into a `*Sender`.
### Codegen pipeline (`cmd/scrape`, `cmd/genapi`, `cmd/audit`)
- Two-stage codegen: HTML → IR (`internal/spec/api.json`) → Go.
- `internal/spec/overrides.json` pins specific method returns/field types when scraper regex does not match a particular doc phrasing.
- `cmd/audit` reports any-typed fields, fallback `bool` returns, and signature drift vs HEAD's IR. Exit codes: 0 clean, 1 fallback, 2 drift, 3 invalid.
- `make regen` is self-verifying: clean → scrape → audit → emit (code + 1428 tests) → run tests.
- Auto-generated tests cover 8 scenarios per method (1428 total): Success, APIError 429, NetworkError, ParseError, ContextCanceled, MissingRequiredFields (400 + ErrBadRequest), Forbidden (403 + ErrForbidden), ServerError (500 + IsRetryable).
### Tooling
- `.github/workflows/ci.yml` — Go matrix 1.23/1.24, vet, staticcheck, govulncheck, gosec, race tests, codegen-clean check, audit + drift detection.
- `.github/workflows/regen.yml` — weekly cron + workflow_dispatch; scrapes live API, regenerates, runs tests, opens auto-PR with audit summary in body.
- `.github/workflows/release.yml` — semver-generator-driven version bump; dual tagging (library SemVer + `bot-api-v<X.Y>` marker); GoReleaser.
- `.goreleaser.yaml` — ships `tg-scrape`, `tg-genapi`, `tg-audit` binaries for Linux/macOS/Windows × amd64/arm64.
### Examples (14)
echo, webhook, callback, conversation, files, inline, middleware, stateful, welcome, moderation, polls, payments, pagination, admin.
+71
View File
@@ -0,0 +1,71 @@
# Contributing
Thanks for your interest in go-telegram. The library mixes hand-written and generated code; this guide explains how to update each.
## Project layout
- **`client/`** — hand-written Bot client, generic Call helper, error taxonomy, retry middleware. Stable; rarely changes.
- **`transport/`** — long-poll and webhook updaters. Hand-written.
- **`dispatch/`** — typed router with command/text/callback matchers. Hand-written.
- **`api/`** — generated types and method wrappers (`*.gen.go`) plus runtime helpers (`runtime.go`, `download.go`, `me.go`).
- **`internal/spec/`** — IR types + the committed `api.json` snapshot of the Telegram Bot API.
- **`cmd/scrape/`** — HTML scraper that produces `internal/spec/api.json`.
- **`cmd/genapi/`** — emitter that consumes `api.json` and renders `api/*.gen.go`.
## Workflows
### Updating to a newer Telegram Bot API version
```bash
make snapshot # fetch latest HTML from core.telegram.org
make regen # scrape + emit
go test -race ./... # verify
```
If the live page introduces a phrasing the scraper doesn't recognise, you'll see methods falling back to `bool` returns or struct fields typed `any`. Check the audit script in `cmd/scrape/method_test.go` and add new patterns to `cmd/scrape/method.go` and / or `cmd/scrape/table.go`. Then `go test -update ./cmd/scrape/...` to refresh the small-fixture golden, and re-run `make regen`.
### Adding a new union for auto-decode
If Telegram introduces a new discriminated union type (similar to `ChatMember`):
1. Add an entry to `knownDiscriminators` in `cmd/genapi/emitter.go`.
2. Run `make regen`.
3. The emitter will produce `UnmarshalXxx` for the union and per-struct `UnmarshalJSON` for any field referencing it.
### Updating runtime helpers
Edit `api/runtime.go`, `api/download.go`, `api/me.go`, or any of `client/*.go`, `transport/*.go`, `dispatch/*.go`. Add tests for new functionality. CI runs `go test -race ./...`, `go vet`, `staticcheck`, and the codegen-clean check (which asserts the committed `api/` matches what the pipeline produces from the committed snapshot).
### Conventions
- Doc comments on every exported symbol — generated types carry verbatim Telegram prose.
- No `//nolint` directives anywhere; if the linter complains, fix the code or update `.golangci.yml`.
- No reordering of struct fields for `fieldalignment` — JSON field order tracks the spec for diff readability.
- TDD where practical: failing test, then implementation, then commit.
- Conventional Commits style for messages: `feat(...):`, `fix(...):`, `docs(...):`, `chore(...):`, `refactor(...):`, `test(...):`.
## Running locally
```bash
make test # unit tests
make test-race # with race detector (CI default)
make lint # vet + staticcheck
make integration # live API smoke tests (requires TELEGRAM_BOT_TOKEN)
```
The codegen tooling in `cmd/scrape` pulls `golang.org/x/net/html`. The runtime library packages depend only on the standard library plus `stretchr/testify` (test-only).
## Releasing
1. Bump version in `CHANGELOG.md`.
2. Tag with `git tag -a v0.x.0 -m "summary"` (no leading 'v' alone — use the full SemVer triple).
3. `git push --tags`.
(There is no GoReleaser config yet; releases are tag-only and `go install` works against tags.)
## Reporting issues
File issues on the GitHub repository with:
- The Telegram Bot API method involved (if applicable).
- A minimal reproduction (mocked HTTP transport is fine).
- The library version (`go list -m github.com/lukaszraczylo/go-telegram`).
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Lukasz Raczylo
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+76
View File
@@ -0,0 +1,76 @@
.PHONY: test test-race lint vet integration regen snapshot regen-from-fixture test-update-golden clean clean-generated audit audit-drift help
GO ?= go
help:
@echo "Targets:"
@echo " test - run unit tests"
@echo " test-race - run unit tests with race detector"
@echo " lint - go vet + staticcheck"
@echo " integration - run integration suite (requires TELEGRAM_BOT_TOKEN)"
@echo " snapshot - capture HTML snapshot from live API (Plan 2)"
@echo " regen - regenerate api/ from latest snapshot (Plan 2)"
@echo " regen-from-fixture - deterministic regen from pinned fixture (Plan 2)"
@echo " test-update-golden - refresh golden test fixtures (Plan 2)"
@echo " audit - report any-typed/bool fallbacks in current IR"
@echo " audit-drift - audit + compare against HEAD's IR for signature changes"
@echo " clean-generated - delete generated api/*.gen.go and internal/spec/api.json"
@echo " clean - clean-generated + transient artefacts (binaries, coverage)"
test:
$(GO) test ./...
test-race:
$(GO) test -race ./...
vet:
$(GO) vet ./...
lint: vet
@which staticcheck > /dev/null || (echo "install staticcheck: go install honnef.co/go/tools/cmd/staticcheck@latest" && exit 1)
staticcheck ./...
integration:
$(GO) test -tags=integration -v ./test/integration/...
SCRAPE_INPUT ?= testdata/html/snapshot_2026-05-08.html
SCRAPE_OUTPUT ?= internal/spec/api.json
snapshot:
./scripts/snapshot.sh
regen: clean-generated
$(GO) run ./cmd/scrape -input testdata/html/latest.html -output $(SCRAPE_OUTPUT)
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT)
$(GO) run ./cmd/genapi -input $(SCRAPE_OUTPUT) -outdir api
$(GO) test ./api/...
regen-from-fixture: clean-generated
$(GO) run ./cmd/scrape -input $(SCRAPE_INPUT) -output $(SCRAPE_OUTPUT)
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT)
$(GO) run ./cmd/genapi -input $(SCRAPE_OUTPUT) -outdir api
$(GO) test ./api/...
audit:
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT)
audit-drift:
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT) -drift
test-update-golden:
$(GO) test -run TestEmit -update ./cmd/genapi/...
$(GO) test -run TestScrape -update ./cmd/scrape/...
# clean-generated removes ONLY codegen output. Source code (cmd/scrape,
# cmd/genapi, runtime helpers) is untouched. Run before regen to avoid
# orphan files lingering when the IR shrinks (renamed/removed methods).
clean-generated:
rm -f api/*.gen.go api/*_gen_test.go
rm -f internal/spec/api.json
# clean removes generated output AND transient artefacts (binaries
# accidentally left at repo root, coverage reports). Source code is
# never touched.
clean: clean-generated
rm -f coverage.out coverage.html
rm -f echo webhook genapi scrape callback files inline conversation middleware stateful
+333
View File
@@ -0,0 +1,333 @@
# go-telegram
> A fully-generated, strongly-typed Go client for the Telegram Bot API — no `any`, no guessing.
[![CI](https://github.com/lukaszraczylo/go-telegram/actions/workflows/ci.yml/badge.svg)](https://github.com/lukaszraczylo/go-telegram/actions/workflows/ci.yml)
[![Go Reference](https://pkg.go.dev/badge/github.com/lukaszraczylo/go-telegram.svg)](https://pkg.go.dev/github.com/lukaszraczylo/go-telegram)
[![Go Version](https://img.shields.io/github/go-mod/go-version/lukaszraczylo/go-telegram)](go.mod)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
> Bot API **v10.0** · 176 methods · 301 types · 1428 auto-generated tests
Most Telegram bot libraries expose Telegram's "Integer or String" fields as `interface{}` or `any`. Every union type in go-telegram is a real Go type with compile-time safety and auto-decoding. The entire API surface is code-generated from a committed HTML snapshot of the live Telegram docs — regenerating picks up new Bot API versions in one command, with a self-verifying pipeline that catches regressions before they ship.
```go
bot := client.New(os.Getenv("TELEGRAM_BOT_TOKEN"),
client.WithHTTPClient(client.NewRetryDoer(client.NewDefaultHTTPDoer())),
)
router := dispatch.New(bot)
router.OnCommand("/start", func(c *dispatch.Context, m *api.Message) error {
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
ChatID: api.ChatIDFromInt(m.Chat.ID),
Text: "Hello! Send me anything to echo.",
})
return err
})
router.OnText(`.+`, func(c *dispatch.Context, m *api.Message) error {
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
ChatID: api.ChatIDFromInt(m.Chat.ID),
Text: m.Text,
ReplyParameters: &api.ReplyParameters{MessageID: m.MessageID},
})
return err
})
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
router.Run(ctx, transport.NewLongPoller(bot))
```
## Why go-telegram
| Feature | What it means for you |
|---|---|
| **Typed unions** | `ChatID`, `MessageOrBool`, `InputFile`, and 13 discriminated-union interfaces give you `switch v.(type)` instead of runtime panics |
| **Full Bot API v10.0** | 176 methods and 301 types — all generated, none hand-written, nothing missing |
| **Self-verifying codegen** | `make snapshot && make regen` regenerates everything and runs 1428 tests; any regression fails the pipeline |
| **Pluggable transport + codec** | `HTTPDoer` and `Codec` are one-method interfaces — swap in fasthttp, sonic, or your test fake without forking |
| **Retry middleware** | `RetryDoer` honours Telegram's `retry_after`, backs off on 5xx, replays request bodies |
| **Composable dispatcher** | Per-update goroutine pool (default 50), filter combinators (`And`/`Or`/`Not`), conversation state machines, named handlers |
## Quickstart
```bash
go get github.com/lukaszraczylo/go-telegram
```
Full echo bot — long-poll, graceful shutdown, retry on 429:
```go
package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/lukaszraczylo/go-telegram/dispatch"
"github.com/lukaszraczylo/go-telegram/transport"
)
func main() {
token := os.Getenv("TELEGRAM_BOT_TOKEN")
if token == "" {
log.Fatal("TELEGRAM_BOT_TOKEN required")
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
bot := client.New(token,
client.WithHTTPClient(client.NewRetryDoer(client.NewDefaultHTTPDoer())),
)
router := dispatch.New(bot)
router.OnCommand("/start", func(c *dispatch.Context, m *api.Message) error {
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
ChatID: api.ChatIDFromInt(m.Chat.ID),
Text: fmt.Sprintf("Hello %s! Send me anything.", m.From.FirstName),
})
return err
})
router.OnText(`.+`, func(c *dispatch.Context, m *api.Message) error {
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
ChatID: api.ChatIDFromInt(m.Chat.ID),
Text: m.Text,
ReplyParameters: &api.ReplyParameters{MessageID: m.MessageID},
})
return err
})
if err := router.Run(ctx, transport.NewLongPoller(bot)); err != nil && err != context.Canceled {
log.Printf("router exited: %v", err)
}
}
```
## Examples
Run any example: `TELEGRAM_BOT_TOKEN=xxx go run ./examples/<name>`
| Category | Example | What it shows |
|---|---|---|
| **Basics** | [`echo`](examples/echo) | Long-poll echo bot |
| | [`webhook`](examples/webhook) | Webhook server with secret-token verification |
| | [`files`](examples/files) | Upload and download cycle |
| | [`inline`](examples/inline) | Inline-mode results |
| **Conversations & state** | [`conversation`](examples/conversation) | Multi-step state machine with `/cancel` exit |
| | [`stateful`](examples/stateful) | Per-user state via closures |
| | [`callback`](examples/callback) | Inline keyboards and callback query handling |
| | [`pagination`](examples/pagination) | Multi-page inline keyboard |
| **Group management** | [`welcome`](examples/welcome) | Greet new chat members |
| | [`moderation`](examples/moderation) | Kick/ban/mute/warn with permission checks |
| | [`admin`](examples/admin) | Auth middleware allowlist |
| **Advanced** | [`middleware`](examples/middleware) | `Use` chains |
| | [`polls`](examples/polls) | `sendPoll` and answer tally |
| | [`payments`](examples/payments) | Invoice → pre-checkout → success |
## Concepts
<details>
<summary>Bot client and pluggable transport</summary>
`client.New` accepts functional options:
```go
bot := client.New(token,
client.WithHTTPClient(doer), // any HTTPDoer (one-method interface)
client.WithCodec(myCodec), // any Codec (Marshal + Unmarshal)
client.WithLogger(myLogger),
client.WithBaseURL("https://..."), // proxy or local Bot API server
)
```
`HTTPDoer` is `Do(*http.Request) (*http.Response, error)` — a plain `*http.Client` satisfies it.
`Codec` is `Marshal(any) ([]byte, error)` + `Unmarshal([]byte, any) error` — the default wraps `goccy/go-json`.
Every API call goes through `client.Call[Req, Resp]`; per-method generated functions are thin wrappers.
</details>
<details>
<summary>Typed unions — no any</summary>
Telegram's docs describe many fields as "Integer or String" or "one of N types". go-telegram turns every one of these into a concrete Go type.
```go
// ChatID: construct from int64 or @username
chatID := api.ChatIDFromInt(123456789)
chatID := api.ChatIDFromString("@mychannel")
// Discriminated unions — 13 interfaces with auto-decode via generated UnmarshalJSON
for _, u := range updates {
if u.MyChatMember == nil {
continue
}
switch v := u.MyChatMember.OldChatMember.(type) {
case *api.ChatMemberOwner:
log.Println("was owner")
case *api.ChatMemberAdministrator:
log.Printf("was admin: can_post=%v", v.CanPostMessages)
}
}
```
Full union list: `ChatMember`, `MessageOrigin`, `ReactionType`, `PaidMedia`, `BackgroundType`, `BackgroundFill`, `ChatBoostSource`, `RevenueWithdrawalState`, `TransactionPartner`, `MenuButton`, `OwnedGift`, `StoryAreaType`, `MaybeInaccessibleMessage`, plus `ChatID`, `MessageOrBool`, and `InputFile`.
</details>
<details>
<summary>Dispatcher, filters, and conversations</summary>
The router dispatches each update in its own goroutine (semaphore-bounded, default 50):
```go
r := dispatch.New(bot, dispatch.WithMaxConcurrency(50))
r.OnCommand("/start", handler)
r.OnText(`^hi (\w+)`, handler)
r.OnCallback(`^like:\d+`, handler)
r.OnInlineQuery(handler)
r.OnMyChatMember(handler)
// + 20 more typed On* methods
```
**Composable filters** — each update type has its own filter package:
```go
import "github.com/lukaszraczylo/go-telegram/dispatch/filters/message"
r.OnMessageFilter(
message.Command("/admin").And(message.IsReply()),
handler,
)
```
Filter packages: `message`, `callback`, `inline`, `chatmember`, `chatjoinrequest`, `precheckoutquery`. Combinators: `And`, `Or`, `Not`, `All`, `Any`.
**Conversation state machines** — multi-step flows with pluggable storage:
```go
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: dispatch.FilterFunc(func(c *dispatch.Context, u *api.Update) bool {
return u.Message != nil && u.Message.Text == "/start"
}),
Handler: func(c *dispatch.Context, u *api.Update) error {
// send prompt, advance state
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.End()
},
}},
},
}
router.Use(conv.Dispatch)
```
Key strategies: `KeyByUser`, `KeyByChat`, `KeyByUserAndChat` (default). Default storage: `MemoryStorage` (in-process, concurrency-safe). Implement the `Storage` interface for Redis or any other backend.
</details>
<details>
<summary>Errors and retry middleware</summary>
Wrap the default HTTP doer with `RetryDoer` for production:
```go
bot := client.New(token,
client.WithHTTPClient(
client.NewRetryDoer(
client.NewDefaultHTTPDoer(),
client.WithMaxAttempts(5),
client.WithBaseBackoff(500*time.Millisecond),
),
),
)
```
`RetryDoer` retries on 429, 5xx, and transient network errors. On a 429 it reads `retry_after` from Telegram's response body and waits exactly that long — overriding any backoff calculation. Request bodies are buffered and replayed across attempts.
Sentinel errors for `errors.Is` checks: `client.ErrForbidden`, `client.ErrNotFound`, `client.ErrUnauthorized`.
</details>
<details>
<summary>Handler groups and named handlers</summary>
Priority-ordered groups with flow control signals:
```go
// Group 0 runs first — return EndGroups to stop, ContinueGroups to continue
r.Group(0).OnText(`.*`, authMiddleware)
r.Group(1).OnText(`.*`, businessHandler)
```
Named handlers — register and replace at runtime:
```go
named := dispatch.NewNamedHandlers[*api.Message]()
named.Set("main", myHandler)
r.OnCommand("/cmd", named.Handler())
// later: named.Set("main", updatedHandler)
```
</details>
## Codegen pipeline
The full API surface in `api/*.gen.go` is generated from a committed HTML snapshot of `core.telegram.org/bots/api`:
```bash
make snapshot # fetch and commit latest HTML from core.telegram.org
make regen # scrape → audit → emit Go code → run generated tests
go test -race ./...
```
`make regen` is self-verifying. The audit tool (`cmd/audit`) checks:
- `any`-typed fields or returns that escaped the union machinery
- Methods returning `bool` not on the approved list (`internal/spec/overrides.json`)
- Signature drift vs HEAD's IR (added/removed/changed return types)
Exit codes: 0 clean · 1 fallback · 2 drift · 3 invalid. CI runs the audit on every PR. A weekly `regen.yml` workflow opens a PR with regenerated code and the audit summary in the body.
To track a new Bot API release: run `make snapshot && make regen`, review the audit output, update `internal/spec/overrides.json` for any newly unparseable methods, and submit a PR.
## Testing
Mock the one-method `HTTPDoer` interface to test handlers in isolation — no test server needed:
```go
type fakeDoer struct{ body string }
func (f fakeDoer) Do(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(f.body)),
}, nil
}
bot := client.New("token", client.WithHTTPClient(fakeDoer{
body: `{"ok":true,"result":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"}}}`,
}))
```
The library's own generated test suite (`api/methods_gen_test.go`) covers 176 methods × 8 scenarios each: Success, APIError, NetworkError, ParseError, ContextCanceled, MissingRequiredFields, Forbidden, ServerError.
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md).
## License
MIT
+57
View File
@@ -0,0 +1,57 @@
package api
import (
"context"
"fmt"
"io"
"net/http"
"github.com/lukaszraczylo/go-telegram/client"
)
// DownloadFile fetches the contents of a Telegram-hosted file given a
// previously-uploaded file_id. It calls GetFile to resolve the file's
// download path, then issues an HTTP GET to the file CDN endpoint.
//
// The returned io.ReadCloser must be closed by the caller. The size of
// the file is reported via *File.FileSize when known.
//
// For files larger than 20 MB, Telegram requires a self-hosted Bot API
// server (default api.telegram.org has a 20 MB limit on getFile).
func DownloadFile(ctx context.Context, b *client.Bot, fileID string) (io.ReadCloser, *File, error) {
f, err := GetFile(ctx, b, &GetFileParams{FileID: fileID})
if err != nil {
return nil, nil, fmt.Errorf("getFile: %w", err)
}
if f == nil || f.FilePath == "" {
return nil, f, fmt.Errorf("telegram: file %q has no download path", fileID)
}
rc, err := DownloadFileByPath(ctx, b, f.FilePath)
if err != nil {
return nil, f, err
}
return rc, f, nil
}
// DownloadFileByPath fetches a file by its file_path (typically obtained
// from a prior File response). Useful when the caller already has a
// *File and wants to skip the GetFile round-trip.
func DownloadFileByPath(ctx context.Context, b *client.Bot, filePath string) (io.ReadCloser, error) {
url := fmt.Sprintf("%s/file/bot%s/%s", b.BaseURL(), b.Token(), filePath)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := b.HTTP().Do(req)
if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, fmt.Errorf("download: %w", err)
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
return nil, fmt.Errorf("download: HTTP %d", resp.StatusCode)
}
return resp.Body, nil
}
+80
View File
@@ -0,0 +1,80 @@
package api
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/require"
)
func TestDownloadFile_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasSuffix(r.URL.Path, "/getFile"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true,"result":{"file_id":"abc","file_unique_id":"u","file_size":11,"file_path":"documents/hello.txt"}}`))
case strings.HasPrefix(r.URL.Path, "/file/bot"):
_, _ = w.Write([]byte("hello world"))
default:
http.NotFound(w, r)
}
}))
t.Cleanup(srv.Close)
bot := client.New("123:abc", client.WithBaseURL(srv.URL))
rc, file, err := DownloadFile(context.Background(), bot, "abc")
require.NoError(t, err)
defer rc.Close()
require.Equal(t, "documents/hello.txt", file.FilePath)
body, err := io.ReadAll(rc)
require.NoError(t, err)
require.Equal(t, "hello world", string(body))
}
func TestDownloadFile_GetFileFailure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":false,"error_code":400,"description":"Bad Request: invalid file_id"}`))
}))
t.Cleanup(srv.Close)
bot := client.New("t", client.WithBaseURL(srv.URL))
_, _, err := DownloadFile(context.Background(), bot, "bad")
require.Error(t, err)
require.Contains(t, err.Error(), "getFile")
}
func TestDownloadFile_NoFilePath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// result without file_path
_, _ = w.Write([]byte(`{"ok":true,"result":{"file_id":"abc","file_unique_id":"u"}}`))
}))
t.Cleanup(srv.Close)
bot := client.New("t", client.WithBaseURL(srv.URL))
_, _, err := DownloadFile(context.Background(), bot, "abc")
require.Error(t, err)
require.Contains(t, err.Error(), "no download path")
}
func TestDownloadFileByPath_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/file/bot") {
w.WriteHeader(http.StatusForbidden)
return
}
http.NotFound(w, r)
}))
t.Cleanup(srv.Close)
bot := client.New("t", client.WithBaseURL(srv.URL))
_, err := DownloadFileByPath(context.Background(), bot, "secret/file")
require.Error(t, err)
require.Contains(t, err.Error(), "403")
}
+60
View File
@@ -0,0 +1,60 @@
// Code generated by cmd/genapi. DO NOT EDIT.
//go:build !ignore_autogenerated
package api
// ParseMode controls how Telegram interprets formatting in message text.
type ParseMode string
const (
ParseModeMarkdown ParseMode = "Markdown" // legacy
ParseModeMarkdownV2 ParseMode = "MarkdownV2"
ParseModeHTML ParseMode = "HTML"
)
// ChatType is the type of a Telegram chat.
type ChatType string
const (
ChatTypePrivate ChatType = "private"
ChatTypeGroup ChatType = "group"
ChatTypeSupergroup ChatType = "supergroup"
ChatTypeChannel ChatType = "channel"
)
// UpdateType identifies an Update payload variant. Used by allowed_updates
// in getUpdates / setWebhook.
type UpdateType string
const (
UpdateMessage UpdateType = "message"
UpdateEditedMessage UpdateType = "edited_message"
UpdateChannelPost UpdateType = "channel_post"
UpdateEditedChannelPost UpdateType = "edited_channel_post"
UpdateCallbackQuery UpdateType = "callback_query"
UpdateInlineQuery UpdateType = "inline_query"
)
// MessageEntityType is the kind of an entity (mention, hashtag, command, ...).
type MessageEntityType string
const (
EntityMention MessageEntityType = "mention"
EntityHashtag MessageEntityType = "hashtag"
EntityCashtag MessageEntityType = "cashtag"
EntityBotCommand MessageEntityType = "bot_command"
EntityURL MessageEntityType = "url"
EntityEmail MessageEntityType = "email"
EntityPhoneNumber MessageEntityType = "phone_number"
EntityBold MessageEntityType = "bold"
EntityItalic MessageEntityType = "italic"
EntityUnderline MessageEntityType = "underline"
EntityStrike MessageEntityType = "strikethrough"
EntitySpoiler MessageEntityType = "spoiler"
EntityCode MessageEntityType = "code"
EntityPre MessageEntityType = "pre"
EntityTextLink MessageEntityType = "text_link"
EntityTextMention MessageEntityType = "text_mention"
EntityCustomEmoji MessageEntityType = "custom_emoji"
)
+50
View File
@@ -0,0 +1,50 @@
package api
import (
"context"
"sync"
"github.com/lukaszraczylo/go-telegram/client"
)
// MeCache caches the result of GetMe across calls. Construct one per
// Bot and call Get to retrieve the cached User on subsequent invocations.
//
// var meCache api.MeCache
// me, err := meCache.Get(ctx, bot)
//
// MeCache is safe for concurrent use.
type MeCache struct {
mu sync.Mutex
cached *User
}
// Get returns the User from a cached GetMe call. If the cache is empty,
// it calls GetMe and populates the cache on success.
func (c *MeCache) Get(ctx context.Context, b *client.Bot) (*User, error) {
c.mu.Lock()
if c.cached != nil {
u := c.cached
c.mu.Unlock()
return u, nil
}
c.mu.Unlock()
u, err := GetMe(ctx, b, &GetMeParams{})
if err != nil {
return nil, err
}
c.mu.Lock()
c.cached = u
c.mu.Unlock()
return u, nil
}
// Reset clears the cache. Useful in tests or after the bot's identity
// is known to have changed (very rare).
func (c *MeCache) Reset() {
c.mu.Lock()
c.cached = nil
c.mu.Unlock()
}
+58
View File
@@ -0,0 +1,58 @@
package api
import (
"context"
"net/http"
"strings"
"sync/atomic"
"testing"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestMeCache_FetchesOnce(t *testing.T) {
m := &mockDoer{}
var calls atomic.Int32
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
if strings.HasSuffix(r.URL.Path, "/getMe") {
calls.Add(1)
return true
}
return false
})).Return(newJSONResp(200, `{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"echo","username":"echo_bot"}}`), nil)
bot := client.New("t", client.WithHTTPClient(m))
var cache MeCache
me1, err := cache.Get(context.Background(), bot)
require.NoError(t, err)
require.Equal(t, "echo_bot", me1.Username)
me2, err := cache.Get(context.Background(), bot)
require.NoError(t, err)
require.Same(t, me1, me2)
require.Equal(t, int32(1), calls.Load(), "should fetch only once")
}
func TestMeCache_Reset(t *testing.T) {
var calls atomic.Int32
m := &mockDoer{}
m.On("Do", mock.Anything).Run(func(args mock.Arguments) {
calls.Add(1)
}).Return(newJSONResp(200, `{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"echo","username":"echo_bot"}}`), nil).Once()
m.On("Do", mock.Anything).Run(func(args mock.Arguments) {
calls.Add(1)
}).Return(newJSONResp(200, `{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"echo","username":"echo_bot"}}`), nil).Once()
bot := client.New("t", client.WithHTTPClient(m))
var cache MeCache
_, err := cache.Get(context.Background(), bot)
require.NoError(t, err)
cache.Reset()
_, err = cache.Get(context.Background(), bot)
require.NoError(t, err)
require.Equal(t, int32(2), calls.Load())
}
+5145
View File
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+143
View File
@@ -0,0 +1,143 @@
// Package api contains the Telegram Bot API types and method wrappers.
// Most of the package is generated by cmd/genapi from internal/spec/api.json;
// this file holds the runtime types that are intentionally hand-coded.
//
// InputFile carries either a local upload (Reader+Filename) or a reference
// to a previously-uploaded file (file_id) / URL Telegram can fetch. It is
// not a pure JSON type, so the codegen skips it (see runtimeTypes in
// cmd/genapi/emitter.go).
//
// ResponseParameters mirrors client.ResponseParameters so callers importing
// only `api` can access retry_after and migrate_to_chat_id without pulling
// in the client package.
package api
import (
"bytes"
"fmt"
"github.com/goccy/go-json"
"io"
"strconv"
)
// InputFile carries either a file path (for upload) or a Telegram file_id
// / URL string (for reuse). When PathOrID names a local file, the request
// is sent as multipart/form-data; otherwise the value is sent inline.
type InputFile struct {
// PathOrID is one of: an absolute or relative filesystem path, a
// previously-uploaded Telegram file_id, or an HTTPS URL Telegram
// can fetch.
PathOrID string
// Reader, when non-nil, is used as the file content (Filename names it).
Reader io.Reader
// Filename is the upload filename used when Reader is set.
Filename string
}
// IsLocalUpload reports whether this InputFile triggers a multipart upload.
func (f *InputFile) IsLocalUpload() bool {
if f == nil {
return false
}
return f.Reader != nil
}
// ResponseParameters is the optional metadata Telegram includes on certain
// failures. The most common is RetryAfter (seconds) on 429 responses.
//
// https://core.telegram.org/bots/api#responseparameters
type ResponseParameters struct {
MigrateToChatID int64 `json:"migrate_to_chat_id,omitempty"`
RetryAfter int `json:"retry_after,omitempty"`
}
// ChatID identifies a chat by either numeric id or "@username". The Telegram
// Bot API spells the same field as either an integer or a string; ChatID
// preserves both forms with explicit constructors and a custom MarshalJSON
// so callers never see `any` at the source level.
type ChatID struct {
int64Set bool
intID int64
username string
}
// ChatIDFromInt builds a ChatID for a numeric chat identifier (e.g. -1001234567890).
func ChatIDFromInt(id int64) ChatID { return ChatID{int64Set: true, intID: id} }
// ChatIDFromUsername builds a ChatID for a public chat (e.g. "@channel").
// The leading "@" is required by Telegram for usernames.
func ChatIDFromUsername(name string) ChatID { return ChatID{username: name} }
// IsZero reports whether c carries no value.
func (c ChatID) IsZero() bool { return !c.int64Set && c.username == "" }
// String returns the wire form (decimal integer or "@name") for use in
// multipart bodies.
func (c ChatID) String() string {
if c.int64Set {
return strconv.FormatInt(c.intID, 10)
}
return c.username
}
// MarshalJSON emits either a JSON number (integer form) or a JSON string
// (@username form). Empty values marshal as "null".
func (c ChatID) MarshalJSON() ([]byte, error) {
if c.int64Set {
return []byte(strconv.FormatInt(c.intID, 10)), nil
}
if c.username != "" {
return json.Marshal(c.username)
}
return []byte("null"), nil
}
// UnmarshalJSON accepts either a JSON number or a JSON string.
func (c *ChatID) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 || bytes.Equal(data, []byte("null")) {
*c = ChatID{}
return nil
}
if data[0] == '"' {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
*c = ChatIDFromUsername(s)
return nil
}
var n int64
if err := json.Unmarshal(data, &n); err != nil {
return fmt.Errorf("ChatID: %w", err)
}
*c = ChatIDFromInt(n)
return nil
}
// MessageOrBool wraps the "Message or True" return shape Telegram uses on
// edit methods (editMessageText, editMessageCaption, etc.). When the bot
// edits a regular chat message, Message is non-nil; when it edits an
// inline message, OK is true.
type MessageOrBool struct {
Message *Message
OK bool
}
// UnmarshalJSON decodes either {...} into Message or `true`/`false` into OK.
func (m *MessageOrBool) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 {
return nil
}
if data[0] == '{' {
m.Message = new(Message)
return json.Unmarshal(data, m.Message)
}
var b bool
if err := json.Unmarshal(data, &b); err != nil {
return fmt.Errorf("MessageOrBool: %w", err)
}
m.OK = b
return nil
}
+93
View File
@@ -0,0 +1,93 @@
package api
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestChatID_IntForm(t *testing.T) {
c := ChatIDFromInt(-1001234567890)
require.False(t, c.IsZero())
require.Equal(t, "-1001234567890", c.String())
data, err := json.Marshal(c)
require.NoError(t, err)
require.Equal(t, "-1001234567890", string(data))
var c2 ChatID
require.NoError(t, json.Unmarshal(data, &c2))
require.Equal(t, c, c2)
}
func TestChatID_UsernameForm(t *testing.T) {
c := ChatIDFromUsername("@channel")
require.False(t, c.IsZero())
require.Equal(t, "@channel", c.String())
data, err := json.Marshal(c)
require.NoError(t, err)
require.Equal(t, `"@channel"`, string(data))
var c2 ChatID
require.NoError(t, json.Unmarshal(data, &c2))
require.Equal(t, c, c2)
}
func TestChatID_Zero(t *testing.T) {
var c ChatID
require.True(t, c.IsZero())
require.Equal(t, "", c.String())
data, err := json.Marshal(c)
require.NoError(t, err)
require.Equal(t, "null", string(data))
var c2 ChatID
require.NoError(t, json.Unmarshal([]byte("null"), &c2))
require.True(t, c2.IsZero())
}
func TestChatID_UnmarshalInvalid(t *testing.T) {
var c ChatID
err := json.Unmarshal([]byte(`"not-a-number"`), &c)
require.NoError(t, err) // string always succeeds as username
require.Equal(t, "not-a-number", c.username)
}
func TestMessageOrBool_TrueForm(t *testing.T) {
var m MessageOrBool
require.NoError(t, json.Unmarshal([]byte("true"), &m))
require.True(t, m.OK)
require.Nil(t, m.Message)
}
func TestMessageOrBool_FalseForm(t *testing.T) {
var m MessageOrBool
require.NoError(t, json.Unmarshal([]byte("false"), &m))
require.False(t, m.OK)
require.Nil(t, m.Message)
}
func TestMessageOrBool_MessageForm(t *testing.T) {
// Message is a generated type; we can only test that it unmarshals without
// error into the struct — the generated api/*.gen.go is not available in
// the test build unless built. Use build tag !ignore_autogenerated default.
// Skip if Message type is not yet present (bootstrap phase).
data := []byte(`{"message_id":42,"date":0,"chat":{"id":1,"type":"private"}}`)
var m MessageOrBool
require.NoError(t, json.Unmarshal(data, &m))
require.NotNil(t, m.Message)
require.False(t, m.OK)
}
func TestInputFile_IsLocalUpload(t *testing.T) {
require.False(t, (*InputFile)(nil).IsLocalUpload())
require.False(t, (&InputFile{PathOrID: "AgADAgADu7gxG..."}).IsLocalUpload())
require.True(t, (&InputFile{Reader: nopReader{}}).IsLocalUpload())
}
type nopReader struct{}
func (nopReader) Read(p []byte) (int, error) { return 0, nil }
+90
View File
@@ -0,0 +1,90 @@
package api
// Sender condenses the various ways a Telegram update can identify the
// originator of a message or reaction into a single shape. Use the
// GetSender methods on supported types to construct one.
type Sender struct {
// User is the human user who sent the update, when applicable.
User *User
// Chat is the chat that sent the update (channel forwards,
// anonymous group admins, anonymous channel posts).
Chat *Chat
// IsAutomaticForward is true when the update originated as an
// automatic forward from a linked channel.
IsAutomaticForward bool
// ChatID is the chat the update was delivered into. Used to
// distinguish "this user" from "this anonymous admin posting
// in <chat>" when User is nil.
ChatID int64
// AuthorSignature is the custom title of an anonymous group
// administrator. Only meaningful when Chat == this chat.
AuthorSignature string
}
// ID returns the most-specific identifier available: prefers Chat.ID
// over User.ID. Returns 0 if neither is set.
func (s *Sender) ID() int64 {
if s == nil {
return 0
}
if s.Chat != nil {
return s.Chat.ID
}
if s.User != nil {
return s.User.ID
}
return 0
}
// IsAnonymousAdmin reports whether the sender is a group admin posting
// anonymously (Chat equals the message's own chat).
func (s *Sender) IsAnonymousAdmin() bool {
return s != nil && s.Chat != nil && s.Chat.ID == s.ChatID
}
// IsAnonymousChannel reports whether the sender is an anonymous
// channel post (Chat differs from the message's own chat).
func (s *Sender) IsAnonymousChannel() bool {
return s != nil && s.Chat != nil && s.Chat.ID != s.ChatID
}
// GetSender constructs a Sender for a Message. The result is never nil.
func (m *Message) GetSender() *Sender {
if m == nil {
return &Sender{}
}
isAuto := false
if m.IsAutomaticForward != nil {
isAuto = *m.IsAutomaticForward
}
return &Sender{
User: m.From,
Chat: m.SenderChat,
IsAutomaticForward: isAuto,
ChatID: m.Chat.ID,
AuthorSignature: m.AuthorSignature,
}
}
// GetSender constructs a Sender for a MessageReactionUpdated.
func (mru *MessageReactionUpdated) GetSender() *Sender {
if mru == nil {
return &Sender{}
}
return &Sender{
User: mru.User,
Chat: mru.ActorChat,
ChatID: mru.Chat.ID,
}
}
// GetSender constructs a Sender for a PollAnswer.
func (pa *PollAnswer) GetSender() *Sender {
if pa == nil {
return &Sender{}
}
return &Sender{
User: pa.User,
Chat: pa.VoterChat,
}
}
+366
View File
@@ -0,0 +1,366 @@
package api
import (
"testing"
)
func TestSenderID(t *testing.T) {
tests := []struct {
name string
sender *Sender
want int64
}{
{
name: "nil sender",
sender: nil,
want: 0,
},
{
name: "empty sender",
sender: &Sender{},
want: 0,
},
{
name: "user only",
sender: &Sender{
User: &User{ID: 123},
},
want: 123,
},
{
name: "chat only",
sender: &Sender{
Chat: &Chat{ID: 456},
},
want: 456,
},
{
name: "chat prefers over user",
sender: &Sender{
User: &User{ID: 123},
Chat: &Chat{ID: 456},
},
want: 456,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sender.ID()
if got != tt.want {
t.Errorf("ID() = %d, want %d", got, tt.want)
}
})
}
}
func chatEqual(a, b *Chat) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return a.ID == b.ID
}
func userEqual(a, b *User) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return a.ID == b.ID
}
func TestSenderIsAnonymousAdmin(t *testing.T) {
tests := []struct {
name string
sender *Sender
want bool
}{
{
name: "nil sender",
sender: nil,
want: false,
},
{
name: "no chat",
sender: &Sender{User: &User{ID: 123}, ChatID: 456},
want: false,
},
{
name: "chat id matches (anonymous admin)",
sender: &Sender{
Chat: &Chat{ID: 789},
ChatID: 789,
},
want: true,
},
{
name: "chat id differs (not anonymous admin)",
sender: &Sender{
Chat: &Chat{ID: 789},
ChatID: 456,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sender.IsAnonymousAdmin()
if got != tt.want {
t.Errorf("IsAnonymousAdmin() = %v, want %v", got, tt.want)
}
})
}
}
func TestSenderIsAnonymousChannel(t *testing.T) {
tests := []struct {
name string
sender *Sender
want bool
}{
{
name: "nil sender",
sender: nil,
want: false,
},
{
name: "no chat",
sender: &Sender{User: &User{ID: 123}, ChatID: 456},
want: false,
},
{
name: "chat id differs (anonymous channel)",
sender: &Sender{
Chat: &Chat{ID: 789},
ChatID: 456,
},
want: true,
},
{
name: "chat id matches (not anonymous channel)",
sender: &Sender{
Chat: &Chat{ID: 789},
ChatID: 789,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sender.IsAnonymousChannel()
if got != tt.want {
t.Errorf("IsAnonymousChannel() = %v, want %v", got, tt.want)
}
})
}
}
func TestMessageGetSender(t *testing.T) {
tests := []struct {
name string
msg *Message
want *Sender
}{
{
name: "nil message",
msg: nil,
want: &Sender{},
},
{
name: "regular user message",
msg: &Message{
From: &User{ID: 123},
Chat: Chat{ID: 456},
},
want: &Sender{
User: &User{ID: 123},
ChatID: 456,
},
},
{
name: "channel forward",
msg: &Message{
From: &User{ID: 123},
SenderChat: &Chat{ID: 789},
Chat: Chat{ID: 456},
},
want: &Sender{
User: &User{ID: 123},
Chat: &Chat{ID: 789},
ChatID: 456,
},
},
{
name: "anonymous admin",
msg: &Message{
SenderChat: &Chat{ID: 456},
Chat: Chat{ID: 456},
AuthorSignature: "Admin Signature",
},
want: &Sender{
Chat: &Chat{ID: 456},
ChatID: 456,
AuthorSignature: "Admin Signature",
},
},
{
name: "anonymous channel post",
msg: &Message{
SenderChat: &Chat{ID: 789},
Chat: Chat{ID: 456},
},
want: &Sender{
Chat: &Chat{ID: 789},
ChatID: 456,
},
},
{
name: "automatic forward",
msg: &Message{
From: &User{ID: 123},
IsAutomaticForward: func() *bool {
b := true
return &b
}(),
Chat: Chat{ID: 456},
},
want: &Sender{
User: &User{ID: 123},
IsAutomaticForward: true,
ChatID: 456,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.msg.GetSender()
if got == nil {
t.Fatal("GetSender() returned nil")
}
if !userEqual(got.User, tt.want.User) {
t.Errorf("User: got %v, want %v", got.User, tt.want.User)
}
if !chatEqual(got.Chat, tt.want.Chat) {
t.Errorf("Chat: got %v, want %v", got.Chat, tt.want.Chat)
}
if got.IsAutomaticForward != tt.want.IsAutomaticForward {
t.Errorf("IsAutomaticForward: got %v, want %v", got.IsAutomaticForward, tt.want.IsAutomaticForward)
}
if got.ChatID != tt.want.ChatID {
t.Errorf("ChatID: got %d, want %d", got.ChatID, tt.want.ChatID)
}
if got.AuthorSignature != tt.want.AuthorSignature {
t.Errorf("AuthorSignature: got %q, want %q", got.AuthorSignature, tt.want.AuthorSignature)
}
})
}
}
func TestMessageReactionUpdatedGetSender(t *testing.T) {
tests := []struct {
name string
mru *MessageReactionUpdated
want *Sender
}{
{
name: "nil reaction",
mru: nil,
want: &Sender{},
},
{
name: "user reaction",
mru: &MessageReactionUpdated{
User: &User{ID: 123},
Chat: Chat{ID: 456},
},
want: &Sender{
User: &User{ID: 123},
ChatID: 456,
},
},
{
name: "anonymous reaction",
mru: &MessageReactionUpdated{
ActorChat: &Chat{ID: 789},
Chat: Chat{ID: 456},
},
want: &Sender{
Chat: &Chat{ID: 789},
ChatID: 456,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.mru.GetSender()
if got == nil {
t.Fatal("GetSender() returned nil")
}
if !userEqual(got.User, tt.want.User) {
t.Errorf("User: got %v, want %v", got.User, tt.want.User)
}
if !chatEqual(got.Chat, tt.want.Chat) {
t.Errorf("Chat: got %v, want %v", got.Chat, tt.want.Chat)
}
if got.ChatID != tt.want.ChatID {
t.Errorf("ChatID: got %d, want %d", got.ChatID, tt.want.ChatID)
}
})
}
}
func TestPollAnswerGetSender(t *testing.T) {
tests := []struct {
name string
pa *PollAnswer
want *Sender
}{
{
name: "nil poll answer",
pa: nil,
want: &Sender{},
},
{
name: "user vote",
pa: &PollAnswer{
User: &User{ID: 123},
},
want: &Sender{
User: &User{ID: 123},
},
},
{
name: "anonymous vote",
pa: &PollAnswer{
VoterChat: &Chat{ID: 789},
},
want: &Sender{
Chat: &Chat{ID: 789},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.pa.GetSender()
if got == nil {
t.Fatal("GetSender() returned nil")
}
if !userEqual(got.User, tt.want.User) {
t.Errorf("User: got %v, want %v", got.User, tt.want.User)
}
if !chatEqual(got.Chat, tt.want.Chat) {
t.Errorf("Chat: got %v, want %v", got.Chat, tt.want.Chat)
}
})
}
}
+29
View File
@@ -0,0 +1,29 @@
package api
import (
"bytes"
"io"
"net/http"
"github.com/stretchr/testify/mock"
)
// mockDoer is a testify-mock HTTPDoer shared by hand-written tests.
type mockDoer struct{ mock.Mock }
func (m *mockDoer) Do(r *http.Request) (*http.Response, error) {
args := m.Called(r)
if v := args.Get(0); v != nil {
return v.(*http.Response), args.Error(1)
}
return nil, args.Error(1)
}
// newJSONResp constructs an *http.Response with a JSON body.
func newJSONResp(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
+5871
View File
File diff suppressed because it is too large Load Diff
+171
View File
@@ -0,0 +1,171 @@
package client
import (
"bytes"
"context"
"errors"
"github.com/goccy/go-json"
"io"
"net/http"
"reflect"
)
// Call is the single point through which every Telegram Bot API method
// invocation flows. It marshals the request, signs the URL with the bot
// token, dispatches via HTTPDoer, decodes the Result envelope, and
// translates non-OK responses into typed errors.
//
// It is generic over both request and response types. Methods with no
// parameters may pass a nil Req; the helper sends "{}" in that case so
// Telegram receives a syntactically valid empty object.
//
// Call is exported because the api package (which lives outside this one)
// invokes it from generated method wrappers. User code should not normally
// call it directly — use the typed wrappers in package api instead.
func Call[Req any, Resp any](ctx context.Context, b *Bot, method string, req Req) (Resp, error) {
var zero Resp
if mp, ok := any(req).(multipartRequest); ok {
if mp == nil {
return zero, &ParseError{Err: errors.New("client: nil multipart request")}
}
if mp.HasFile() {
return callMultipart[Resp](ctx, b, method, mp)
}
}
body, err := encodeJSONBody(b.codec, req)
if err != nil {
return zero, err
}
url := b.base + "/bot" + b.token + "/" + method
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return zero, &NetworkError{Err: err}
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
resp, err := b.http.Do(httpReq)
if err != nil {
// Surface ctx errors faithfully so callers can errors.Is(err, ctx.Err()).
if ctxErr := ctx.Err(); ctxErr != nil {
return zero, ctxErr
}
return zero, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return zero, &NetworkError{Err: err}
}
return decodeResult[Resp](b.codec, raw)
}
// CallRaw is like Call but returns the raw JSON of the result field
// instead of decoding it into a typed value. Generated method wrappers
// for sealed-interface return types (ChatMember, MenuButton, etc.) use
// this helper, then dispatch through the union's UnmarshalXxx function.
//
// CallRaw still translates non-OK responses into *APIError just like Call.
func CallRaw[Req any](ctx context.Context, b *Bot, method string, req Req) (json.RawMessage, error) {
if mp, ok := any(req).(multipartRequest); ok {
if mp == nil {
return nil, &ParseError{Err: errors.New("client: nil multipart request")}
}
if mp.HasFile() {
return callMultipartRaw(ctx, b, method, mp)
}
}
body, err := encodeJSONBody(b.codec, req)
if err != nil {
return nil, err
}
url := b.base + "/bot" + b.token + "/" + method
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return nil, &NetworkError{Err: err}
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
resp, err := b.http.Do(httpReq)
if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &NetworkError{Err: err}
}
return decodeResultRaw(b.codec, raw)
}
// decodeResultRaw is decodeResult's sibling that returns the raw result
// field instead of typing it.
func decodeResultRaw(codec Codec, raw []byte) (json.RawMessage, error) {
var env Result[json.RawMessage]
if err := codec.Unmarshal(raw, &env); err != nil {
return nil, &ParseError{Err: err, Body: copyBody(raw)}
}
if !env.OK {
return nil, mapAPIError(env.ErrorCode, env.Description, env.Parameters)
}
return env.Result, nil
}
// encodeJSONBody marshals req to a JSON body. A nil interface or nil
// pointer req yields "{}" so Telegram receives a valid empty object.
func encodeJSONBody(codec Codec, req any) (io.Reader, error) {
if req == nil || isNilPointer(req) {
return bytes.NewBufferString("{}"), nil
}
data, err := codec.Marshal(req)
if err != nil {
return nil, &ParseError{Err: err}
}
return bytes.NewReader(data), nil
}
// decodeResult unmarshals raw into Result[Resp] and translates non-OK
// responses into *APIError.
func decodeResult[Resp any](codec Codec, raw []byte) (Resp, error) {
var zero Resp
var env Result[Resp]
if err := codec.Unmarshal(raw, &env); err != nil {
return zero, &ParseError{Err: err, Body: copyBody(raw)}
}
if !env.OK {
return zero, mapAPIError(env.ErrorCode, env.Description, env.Parameters)
}
return env.Result, nil
}
// isNilPointer returns true when v is a typed nil pointer (the interface
// itself is non-nil because it carries a type, but the underlying value
// is nil). One reflect call per request; not on a hot path that demands
// allocation-freedom.
func isNilPointer(v any) bool {
rv := reflect.ValueOf(v)
return rv.Kind() == reflect.Ptr && rv.IsNil()
}
func copyBody(b []byte) []byte {
const max = 4096
if len(b) > max {
b = b[:max]
}
out := make([]byte, len(b))
copy(out, b)
return out
}
+121
View File
@@ -0,0 +1,121 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockDoer struct{ mock.Mock }
func (m *mockDoer) Do(r *http.Request) (*http.Response, error) {
args := m.Called(r)
if v := args.Get(0); v != nil {
return v.(*http.Response), args.Error(1)
}
return nil, args.Error(1)
}
func newResp(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
type echoReq struct {
ChatID int64 `json:"chat_id"`
Text string `json:"text"`
}
type echoResp struct {
MessageID int64 `json:"message_id"`
}
func TestCall_Success(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
if !strings.HasSuffix(r.URL.Path, "/bot123:abc/sendEcho") {
return false
}
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(r.Body)
return strings.Contains(buf.String(), `"chat_id":42`)
})).Return(newResp(200, `{"ok":true,"result":{"message_id":7}}`), nil)
b := New("123:abc", WithHTTPClient(m))
out, err := Call[*echoReq, *echoResp](context.Background(), b, "sendEcho", &echoReq{ChatID: 42, Text: "hi"})
require.NoError(t, err)
require.Equal(t, int64(7), out.MessageID)
m.AssertExpectations(t)
}
func TestCall_APIError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(
newResp(200, `{"ok":false,"error_code":429,"description":"Too Many Requests: retry after 3","parameters":{"retry_after":3}}`), nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var ae *APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 429, ae.Code)
require.True(t, ae.IsRetryable())
require.True(t, errors.Is(err, ErrTooManyRequests))
}
func TestCall_NetworkError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var ne *NetworkError
require.ErrorAs(t, err, &ne)
}
func TestCall_ParseError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(newResp(200, `not json`), nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var pe *ParseError
require.ErrorAs(t, err, &pe)
}
func TestCall_ContextCanceled(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(nil, context.Canceled).Maybe()
ctx, cancel := context.WithCancel(context.Background())
cancel()
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](ctx, b, "x", &echoReq{})
require.ErrorIs(t, err, context.Canceled)
}
func TestCall_NilRequest(t *testing.T) {
// Methods with no params (e.g. getMe) may pass a nil Req value.
m := &mockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(r.Body)
return buf.String() == "{}"
})).Return(newResp(200, `{"ok":true,"result":{"message_id":0}}`), nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", nil)
require.NoError(t, err)
}
+48
View File
@@ -0,0 +1,48 @@
package client
const defaultBaseURL = "https://api.telegram.org"
// Bot is the Telegram Bot API client. Construct via New. All API methods
// (declared in package api) hang off *Bot via thin wrappers around call.
type Bot struct {
token string
base string
http HTTPDoer
codec Codec
logger Logger
}
// Token returns the bot token. Exposed for advanced use cases (custom
// transports, manual URL building); ordinary code does not need it.
func (b *Bot) Token() string { return b.token }
// BaseURL returns the configured Telegram API base URL.
func (b *Bot) BaseURL() string { return b.base }
// HTTP returns the underlying HTTPDoer. Exposed for adapters that need
// to share connection pools or for diagnostic checks.
func (b *Bot) HTTP() HTTPDoer { return b.http }
// Codec returns the configured Codec.
func (b *Bot) Codec() Codec { return b.codec }
// Logger returns the configured Logger.
func (b *Bot) Logger() Logger { return b.logger }
// New constructs a Bot with the given token and optional configuration.
// The default HTTP client is tuned for long-poll workloads (see
// NewDefaultHTTPDoer); the default codec wraps encoding/json; the default
// logger discards records.
func New(token string, opts ...Option) *Bot {
b := &Bot{
token: token,
base: defaultBaseURL,
http: NewDefaultHTTPDoer(),
codec: DefaultCodec{},
logger: NoopLogger{},
}
for _, o := range opts {
o(b)
}
return b
}
+42
View File
@@ -0,0 +1,42 @@
package client
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestNew_Defaults(t *testing.T) {
b := New("123:abc")
require.Equal(t, "123:abc", b.token)
require.Equal(t, defaultBaseURL, b.base)
require.NotNil(t, b.http)
require.NotNil(t, b.codec)
require.NotNil(t, b.logger)
}
func TestNew_OptionsApplied(t *testing.T) {
custom := &http.Client{}
type fakeCodec struct{ DefaultCodec }
c := fakeCodec{}
b := New("t",
WithHTTPClient(custom),
WithCodec(c),
WithBaseURL("https://example.test"),
WithLogger(NoopLogger{}),
)
require.Same(t, custom, b.http)
require.Equal(t, c, b.codec)
require.Equal(t, "https://example.test", b.base)
}
func TestResultRoundTrip(t *testing.T) {
in := Result[int64]{OK: true, Result: 42}
data, err := DefaultCodec{}.Marshal(in)
require.NoError(t, err)
var out Result[int64]
require.NoError(t, DefaultCodec{}.Unmarshal(data, &out))
require.Equal(t, in, out)
}
+22
View File
@@ -0,0 +1,22 @@
// Package client provides HTTP client primitives for the Telegram Bot API.
package client
import "github.com/goccy/go-json"
// Codec encodes/decodes JSON payloads exchanged with the Telegram Bot API.
// The default implementation wraps goccy/go-json. Users may plug in
// bytedance/sonic or any compatible encoder by passing
// WithCodec to New.
type Codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
// DefaultCodec wraps goccy/go-json. It is the zero-value safe default.
type DefaultCodec struct{}
// Marshal calls json.Marshal.
func (DefaultCodec) Marshal(v any) ([]byte, error) { return json.Marshal(v) }
// Unmarshal calls json.Unmarshal.
func (DefaultCodec) Unmarshal(data []byte, v any) error { return json.Unmarshal(data, v) }
+29
View File
@@ -0,0 +1,29 @@
package client
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultCodec_RoundTrip(t *testing.T) {
c := DefaultCodec{}
type payload struct {
Name string `json:"name"`
N int `json:"n"`
}
in := payload{Name: "x", N: 7}
data, err := c.Marshal(in)
require.NoError(t, err)
require.JSONEq(t, `{"name":"x","n":7}`, string(data))
var out payload
require.NoError(t, c.Unmarshal(data, &out))
require.Equal(t, in, out)
}
func TestDefaultCodec_UnmarshalError(t *testing.T) {
var v map[string]any
err := DefaultCodec{}.Unmarshal([]byte(`not json`), &v)
require.Error(t, err)
}
+107
View File
@@ -0,0 +1,107 @@
package client
import (
"errors"
"fmt"
"strings"
"time"
)
// APIError represents a non-OK Telegram Bot API response.
// It satisfies error and unwraps to a sentinel (ErrUnauthorized, etc.)
// where the description matches a known prefix, enabling errors.Is checks.
type APIError struct {
Code int
Description string
Parameters *ResponseParameters
// sentinel, if non-nil, is the wrapped sentinel error returned by
// Unwrap. It is set by mapAPIError based on Code+Description.
sentinel error
}
// Error implements error.
func (e *APIError) Error() string {
return fmt.Sprintf("telegram: %d %s", e.Code, e.Description)
}
// Unwrap returns the matched sentinel error, if any.
func (e *APIError) Unwrap() error { return e.sentinel }
// IsRetryable returns true for transient HTTP statuses (429, 5xx).
func (e *APIError) IsRetryable() bool {
return e.Code == 429 || (e.Code >= 500 && e.Code < 600)
}
// RetryAfter returns the recommended back-off duration. It honours the
// Telegram-supplied retry_after parameter; if absent, returns 0.
func (e *APIError) RetryAfter() time.Duration {
if e.Parameters == nil {
return 0
}
return time.Duration(e.Parameters.RetryAfter) * time.Second
}
// NetworkError wraps a transport-level failure (DNS, TCP, TLS, timeout
// short of an HTTP response).
type NetworkError struct{ Err error }
func (e *NetworkError) Error() string { return "telegram: network: " + redactToken(e.Err.Error()) }
func (e *NetworkError) Unwrap() error { return e.Err }
// ParseError wraps a JSON decode failure on a response body. Body is
// retained (truncated to 4 KiB); Error() displays up to 256 bytes for diagnostics.
type ParseError struct {
Err error
Body []byte
}
func (e *ParseError) Error() string {
body := e.Body
if len(body) > 256 {
body = body[:256]
}
return fmt.Sprintf("telegram: parse: %s (body=%q)", redactToken(e.Err.Error()), body)
}
func (e *ParseError) Unwrap() error { return e.Err }
// Sentinel errors returned via APIError.Unwrap when the description matches.
// Compare with errors.Is.
var (
ErrUnauthorized = errors.New("telegram: unauthorized")
ErrChatNotFound = errors.New("telegram: chat not found")
ErrMessageNotModified = errors.New("telegram: message is not modified")
ErrTooManyRequests = errors.New("telegram: too many requests")
ErrBadRequest = errors.New("telegram: bad request")
ErrForbidden = errors.New("telegram: forbidden")
ErrUserNotFound = errors.New("telegram: user not found")
ErrMessageNotFound = errors.New("telegram: message not found")
)
// mapAPIError builds an *APIError and attaches the appropriate sentinel
// based on Code+Description. It is the single point where wire-level
// failures are translated into the Go error taxonomy.
func mapAPIError(code int, description string, params *ResponseParameters) *APIError {
e := &APIError{Code: code, Description: description, Parameters: params}
switch {
case code == 401:
e.sentinel = ErrUnauthorized
case code == 403:
e.sentinel = ErrForbidden
case code == 429:
e.sentinel = ErrTooManyRequests
case code == 400 && strings.Contains(description, "user not found"):
e.sentinel = ErrUserNotFound
case code == 400 && strings.Contains(description, "message to") && strings.Contains(description, "not found"):
e.sentinel = ErrMessageNotFound
case code == 400 && strings.Contains(description, "chat not found"):
e.sentinel = ErrChatNotFound
case code == 400 && strings.Contains(description, "message is not modified"):
e.sentinel = ErrMessageNotModified
case code == 400:
e.sentinel = ErrBadRequest
}
return e
}
+58
View File
@@ -0,0 +1,58 @@
package client
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestAPIError_FieldsAndMethods(t *testing.T) {
e := &APIError{
Code: 429,
Description: "Too Many Requests: retry after 5",
Parameters: &ResponseParameters{RetryAfter: 5},
}
require.Equal(t, "telegram: 429 Too Many Requests: retry after 5", e.Error())
require.True(t, e.IsRetryable())
require.Equal(t, 5*time.Second, e.RetryAfter())
}
func TestAPIError_Sentinels(t *testing.T) {
cases := []struct {
code int
desc string
sentinel error
}{
{401, "Unauthorized", ErrUnauthorized},
{400, "Bad Request: chat not found", ErrChatNotFound},
{400, "Bad Request: message is not modified", ErrMessageNotModified},
{429, "Too Many Requests: retry after 1", ErrTooManyRequests},
{400, "Bad Request: user not found", ErrUserNotFound},
{400, "Bad Request: message to delete not found", ErrMessageNotFound},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
e := mapAPIError(c.code, c.desc, nil)
require.True(t, errors.Is(e, c.sentinel), "expected %v to wrap %v", e, c.sentinel)
})
}
}
func TestAPIError_IsRetryable(t *testing.T) {
require.True(t, (&APIError{Code: 500}).IsRetryable())
require.True(t, (&APIError{Code: 502}).IsRetryable())
require.True(t, (&APIError{Code: 429}).IsRetryable())
require.False(t, (&APIError{Code: 400}).IsRetryable())
require.False(t, (&APIError{Code: 401}).IsRetryable())
}
func TestNetworkAndParseErrorWrapping(t *testing.T) {
inner := errors.New("dial tcp: timeout")
ne := &NetworkError{Err: inner}
require.ErrorIs(t, ne, inner)
pe := &ParseError{Err: errors.New("unexpected EOF"), Body: []byte("garbage")}
require.Contains(t, pe.Error(), "garbage")
}
+226
View File
@@ -0,0 +1,226 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// client.go option getters
// ---------------------------------------------------------------------------
func TestBot_Getters(t *testing.T) {
b := New("mytoken",
WithBaseURL("http://localhost:9999"),
WithCodec(DefaultCodec{}),
WithLogger(NoopLogger{}),
)
require.Equal(t, "mytoken", b.Token())
require.Equal(t, "http://localhost:9999", b.BaseURL())
require.NotNil(t, b.HTTP())
require.NotNil(t, b.Codec())
require.NotNil(t, b.Logger())
}
func TestWithLogger_NilBecomesNoop(t *testing.T) {
b := New("t", WithLogger(nil))
require.IsType(t, NoopLogger{}, b.Logger())
}
func TestNoopLogger_AllMethods(t *testing.T) {
l := NoopLogger{}
// None of these should panic.
l.Debug("msg")
l.Info("msg", "k", "v")
l.Warn("msg")
l.Error("msg", "err", "oops")
}
// ---------------------------------------------------------------------------
// RetryOption setters
// ---------------------------------------------------------------------------
func TestRetryOptions_Applied(t *testing.T) {
d := NewRetryDoer(nil,
WithMaxAttempts(7),
WithBaseBackoff(1*time.Second),
WithMaxBackoff(60*time.Second),
WithBackoffFactor(3.0),
WithJitter(0.5),
)
require.Equal(t, 7, d.maxAttempts)
require.Equal(t, 1*time.Second, d.base)
require.Equal(t, 60*time.Second, d.max)
require.Equal(t, 3.0, d.factor)
require.Equal(t, 0.5, d.jitter)
}
// ---------------------------------------------------------------------------
// RetryDoer.delay — override path
// ---------------------------------------------------------------------------
func TestRetryDoer_DelayOverride(t *testing.T) {
d := NewRetryDoer(nil)
got := d.delay(1, 5*time.Second)
require.Equal(t, 5*time.Second, got)
}
func TestRetryDoer_DelayExponential(t *testing.T) {
d := NewRetryDoer(nil,
WithBaseBackoff(100*time.Millisecond),
WithMaxBackoff(10*time.Second),
WithJitter(0), // no jitter for deterministic test
WithBackoffFactor(2.0),
)
d1 := d.delay(1, 0)
d2 := d.delay(2, 0)
require.Greater(t, int64(d2), int64(d1), "backoff should grow")
}
func TestRetryDoer_DelayMaxCap(t *testing.T) {
d := NewRetryDoer(nil,
WithBaseBackoff(1*time.Second),
WithMaxBackoff(2*time.Second),
WithJitter(0),
WithBackoffFactor(100.0),
)
delay := d.delay(10, 0)
require.LessOrEqual(t, delay, 2*time.Second)
}
// ---------------------------------------------------------------------------
// errors.go — RetryAfter nil parameters + ParseError.Unwrap
// ---------------------------------------------------------------------------
func TestAPIError_RetryAfterNilParams(t *testing.T) {
e := &APIError{Code: 429, Description: "Too Many Requests", Parameters: nil}
require.Equal(t, time.Duration(0), e.RetryAfter())
}
func TestParseError_Unwrap(t *testing.T) {
inner := errors.New("decode error")
pe := &ParseError{Err: inner, Body: []byte("body")}
require.ErrorIs(t, pe, inner)
}
func TestParseError_LongBodyTruncated(t *testing.T) {
body := bytes.Repeat([]byte("x"), 1000)
pe := &ParseError{Err: errors.New("e"), Body: body}
msg := pe.Error()
// Error() truncates body to 256 for display — should not include all 1000 chars
require.Less(t, len(msg), 800, "should truncate body in Error()")
}
func TestNetworkError_Unwrap(t *testing.T) {
inner := errors.New("tcp error")
ne := &NetworkError{Err: inner}
require.ErrorIs(t, ne, inner)
}
// ---------------------------------------------------------------------------
// mapAPIError — missing sentinel branches (generic 400, unmapped 500)
// ---------------------------------------------------------------------------
func TestMapAPIError_Generic400(t *testing.T) {
e := mapAPIError(400, "Bad Request: some unknown thing", nil)
require.True(t, errors.Is(e, ErrBadRequest))
}
func TestMapAPIError_Unmapped500(t *testing.T) {
e := mapAPIError(500, "Internal Server Error", nil)
require.Nil(t, e.sentinel)
require.Equal(t, 500, e.Code)
}
func TestMapAPIError_403(t *testing.T) {
e := mapAPIError(403, "Forbidden: bot was blocked", nil)
require.True(t, errors.Is(e, ErrForbidden))
}
// ---------------------------------------------------------------------------
// callMultipart — ctx cancelled
// ---------------------------------------------------------------------------
func TestCallMultipart_ContextCancelled(t *testing.T) {
// A doer that blocks then returns context error.
blocker := &extraBlockingDoer{done: make(chan struct{})}
b := New("t", WithHTTPClient(blocker))
ctx, cancel := context.WithCancel(context.Background())
mp := &extraFakeMultipartReq{
fields: map[string]string{"chat_id": "1"},
files: []MultipartFile{
{FieldName: "document", Filename: "f.txt", Reader: bytes.NewReader([]byte("data"))},
},
}
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
close(blocker.done)
}()
_, err := callMultipart[*struct{}](ctx, b, "sendDocument", mp)
require.Error(t, err)
}
type extraBlockingDoer struct{ done chan struct{} }
func (b *extraBlockingDoer) Do(r *http.Request) (*http.Response, error) {
<-b.done
return nil, r.Context().Err()
}
type extraFakeMultipartReq struct {
fields map[string]string
files []MultipartFile
}
func (f *extraFakeMultipartReq) HasFile() bool { return len(f.files) > 0 }
func (f *extraFakeMultipartReq) MultipartFiles() []MultipartFile { return f.files }
func (f *extraFakeMultipartReq) MultipartFields() map[string]string { return f.fields }
// ---------------------------------------------------------------------------
// copyBody size cap
// ---------------------------------------------------------------------------
func TestCopyBody_LargeBodyCapped(t *testing.T) {
big := bytes.Repeat([]byte("a"), 8000)
out := copyBody(big)
require.Len(t, out, 4096)
}
func TestCopyBody_SmallBody(t *testing.T) {
small := []byte("hello")
out := copyBody(small)
require.Equal(t, small, out)
}
// ---------------------------------------------------------------------------
// Call — 5xx non-200 HTTP status (transport level)
// ---------------------------------------------------------------------------
func TestCall_5xxHTTPStatus(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(&http.Response{
StatusCode: 500,
Body: io.NopCloser(bytes.NewBufferString(`{"ok":false,"error_code":500,"description":"Internal"}`)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}, nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var ae *APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 500, ae.Code)
}
+40
View File
@@ -0,0 +1,40 @@
package client
import (
"net"
"net/http"
"time"
)
// HTTPDoer abstracts the HTTP transport. The default is a net/http client
// tuned for Telegram's long-poll usage. Users may plug in valyala/fasthttp
// (via an adapter), or any custom retry/circuit-breaker client by passing
// WithHTTPClient to New.
type HTTPDoer interface {
Do(req *http.Request) (*http.Response, error)
}
// NewDefaultHTTPDoer returns an *http.Client with sensible defaults for
// Telegram Bot API usage:
// - 60s overall timeout (longer than typical long-poll Timeout=30s).
// - Connection pooling sized for a small number of long-lived hosts.
// - HTTP/2 enabled (default in net/http).
func NewDefaultHTTPDoer() *http.Client {
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 16,
MaxIdleConnsPerHost: 8,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: true,
}
return &http.Client{
Transport: t,
Timeout: 60 * time.Second,
}
}
+24
View File
@@ -0,0 +1,24 @@
package client
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultHTTPClient_Do(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
}))
t.Cleanup(srv.Close)
doer := NewDefaultHTTPDoer()
req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := doer.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusTeapot, resp.StatusCode)
}
+19
View File
@@ -0,0 +1,19 @@
package client
// Logger is a slog-shaped logging interface. Users pass any compatible
// implementation via WithLogger. The default is NoopLogger, which discards
// everything.
type Logger interface {
Debug(msg string, attrs ...any)
Info(msg string, attrs ...any)
Warn(msg string, attrs ...any)
Error(msg string, attrs ...any)
}
// NoopLogger discards all log records. It is the zero-value safe default.
type NoopLogger struct{}
func (NoopLogger) Debug(string, ...any) {}
func (NoopLogger) Info(string, ...any) {}
func (NoopLogger) Warn(string, ...any) {}
func (NoopLogger) Error(string, ...any) {}
+11
View File
@@ -0,0 +1,11 @@
package client
import "testing"
func TestNoopLogger_DoesNotPanic(t *testing.T) {
var l Logger = NoopLogger{}
l.Debug("d", "k", "v")
l.Info("i")
l.Warn("w")
l.Error("e")
}
+146
View File
@@ -0,0 +1,146 @@
package client
import (
"context"
"github.com/goccy/go-json"
"io"
"mime/multipart"
"net/http"
)
// multipartRequest is implemented by request structs that may carry an
// InputFile. The codegen emits this interface for any method whose IR
// MethodDecl.HasFiles is true.
//
// HasFile returns true if at least one file field is set; if false, the
// request is sent as plain JSON via the regular Call path.
//
// MultipartFiles returns one entry per file field that should be uploaded.
// The accompanying scalar/object fields are returned by MultipartFields.
type multipartRequest interface {
HasFile() bool
MultipartFiles() []MultipartFile
MultipartFields() map[string]string
}
// MultipartFile describes a single file part in a multipart upload.
type MultipartFile struct {
FieldName string
Filename string
Reader io.Reader
}
// callMultipart performs a multipart/form-data POST. It is invoked by Call
// when the request implements multipartRequest and HasFile() is true.
func callMultipart[Resp any](ctx context.Context, b *Bot, method string, mp multipartRequest) (Resp, error) {
var zero Resp
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
// Stream-write the multipart body in a goroutine so we don't buffer
// large files in memory.
go func() {
defer func() { _ = pw.Close() }()
defer func() { _ = mw.Close() }()
for k, v := range mp.MultipartFields() {
if err := mw.WriteField(k, v); err != nil {
_ = pw.CloseWithError(err)
return
}
}
for _, f := range mp.MultipartFiles() {
part, err := mw.CreateFormFile(f.FieldName, f.Filename)
if err != nil {
_ = pw.CloseWithError(err)
return
}
if _, err := io.Copy(part, f.Reader); err != nil {
_ = pw.CloseWithError(err)
return
}
}
}()
url := b.base + "/bot" + b.token + "/" + method
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr)
if err != nil {
_ = pr.CloseWithError(err)
return zero, &NetworkError{Err: err}
}
req.Header.Set("Content-Type", mw.FormDataContentType())
req.Header.Set("Accept", "application/json")
resp, err := b.http.Do(req)
if err != nil {
_ = pr.CloseWithError(err)
if ctxErr := ctx.Err(); ctxErr != nil {
return zero, ctxErr
}
return zero, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
_ = pr.CloseWithError(err)
return zero, &NetworkError{Err: err}
}
return decodeResult[Resp](b.codec, raw)
}
// callMultipartRaw is callMultipart's sibling that returns the raw result
// JSON instead of decoding into a typed value. Used by generated method
// wrappers whose return type is a sealed-interface union.
func callMultipartRaw(ctx context.Context, b *Bot, method string, mp multipartRequest) (json.RawMessage, error) {
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
go func() {
defer func() { _ = pw.Close() }()
defer func() { _ = mw.Close() }()
for k, v := range mp.MultipartFields() {
if err := mw.WriteField(k, v); err != nil {
_ = pw.CloseWithError(err)
return
}
}
for _, f := range mp.MultipartFiles() {
part, err := mw.CreateFormFile(f.FieldName, f.Filename)
if err != nil {
_ = pw.CloseWithError(err)
return
}
if _, err := io.Copy(part, f.Reader); err != nil {
_ = pw.CloseWithError(err)
return
}
}
}()
url := b.base + "/bot" + b.token + "/" + method
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr)
if err != nil {
_ = pr.CloseWithError(err)
return nil, &NetworkError{Err: err}
}
req.Header.Set("Content-Type", mw.FormDataContentType())
req.Header.Set("Accept", "application/json")
resp, err := b.http.Do(req)
if err != nil {
_ = pr.CloseWithError(err)
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
_ = pr.CloseWithError(err)
return nil, &NetworkError{Err: err}
}
return decodeResultRaw(b.codec, raw)
}
+103
View File
@@ -0,0 +1,103 @@
package client
import (
"context"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type fakeMultipartReq struct {
chatID int64
body string
}
func (f *fakeMultipartReq) HasFile() bool { return true }
func (f *fakeMultipartReq) MultipartFields() map[string]string {
return map[string]string{"chat_id": "42"}
}
func (f *fakeMultipartReq) MultipartFiles() []MultipartFile {
return []MultipartFile{{
FieldName: "document",
Filename: "hello.txt",
Reader: strings.NewReader(f.body),
}}
}
type fileResp struct {
MessageID int64 `json:"message_id"`
}
func TestCallMultipart_Success(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
ct := r.Header.Get("Content-Type")
if !strings.HasPrefix(ct, "multipart/form-data") {
return false
}
_, params, err := mime.ParseMediaType(ct)
if err != nil {
return false
}
mr := multipart.NewReader(r.Body, params["boundary"])
seenChat := false
seenFile := false
for {
p, err := mr.NextPart()
if err == io.EOF {
break
}
if err != nil {
return false
}
switch p.FormName() {
case "chat_id":
body, _ := io.ReadAll(p)
seenChat = string(body) == "42"
case "document":
body, _ := io.ReadAll(p)
seenFile = string(body) == "hello world"
}
}
return seenChat && seenFile
})).Return(newResp(200, `{"ok":true,"result":{"message_id":99}}`), nil)
b := New("t", WithHTTPClient(m))
out, err := Call[*fakeMultipartReq, *fileResp](context.Background(), b, "sendDocument", &fakeMultipartReq{chatID: 42, body: "hello world"})
require.NoError(t, err)
require.Equal(t, int64(99), out.MessageID)
}
func TestCallMultipart_NoGoroutineLeakOnError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
b := New("t", WithHTTPClient(m))
before := runtime.NumGoroutine()
for i := 0; i < 50; i++ {
_, _ = Call[*fakeMultipartReq, *fileResp](
context.Background(), b, "sendDocument",
&fakeMultipartReq{chatID: 42, body: strings.Repeat("x", 1<<14)},
)
}
// Allow goroutines to finish exiting after Close propagates.
time.Sleep(50 * time.Millisecond)
runtime.GC()
after := runtime.NumGoroutine()
// A small drift is normal (timers, finalizers); 5 is generous.
if after-before > 5 {
t.Fatalf("goroutine leak: before=%d after=%d", before, after)
}
}
+29
View File
@@ -0,0 +1,29 @@
package client
// Option configures a Bot at construction time. Per-call configuration is
// expressed via typed parameter structs (e.g. SendMessageParams), not options.
type Option func(*Bot)
// WithHTTPClient overrides the HTTP transport. Pass any HTTPDoer
// implementation (e.g. an *http.Client wrapping a custom RoundTripper, or
// a fasthttp adapter).
func WithHTTPClient(c HTTPDoer) Option { return func(b *Bot) { b.http = c } }
// WithCodec overrides the JSON codec. Pass goccy/go-json, sonic, or any
// type implementing Codec to swap out encoding/json.
func WithCodec(c Codec) Option { return func(b *Bot) { b.codec = c } }
// WithBaseURL overrides the API base URL. Useful for testing against a
// local httptest.Server, or for self-hosted Bot API servers.
func WithBaseURL(url string) Option { return func(b *Bot) { b.base = url } }
// WithLogger sets the logger used for diagnostic events. Passing nil
// silently disables logging.
func WithLogger(l Logger) Option {
return func(b *Bot) {
if l == nil {
l = NoopLogger{}
}
b.logger = l
}
}
+15
View File
@@ -0,0 +1,15 @@
package client
import "regexp"
// tokenInURL matches a Telegram bot token segment in a URL path. Tokens
// have the form <bot_id>:<api_key>, where bot_id is digits and api_key
// is 35 base64-url characters. The pattern is conservative: matches
// /bot<id>:<key>/ to avoid false positives.
var tokenInURL = regexp.MustCompile(`/bot(\d{5,15}):([A-Za-z0-9_-]{30,40})/`)
// redactToken replaces any bot token in s with /bot<REDACTED>/. Used by
// error formatters so logs don't leak credentials.
func redactToken(s string) string {
return tokenInURL.ReplaceAllString(s, "/bot<REDACTED>/")
}
+46
View File
@@ -0,0 +1,46 @@
package client
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedactToken(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{"plain bot URL", "https://api.telegram.org/bot123456789:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe",
"https://api.telegram.org/bot<REDACTED>/getMe"},
{"in net/http error", `Post "https://api.telegram.org/bot987654321:Z9YxWvUtSrQpOnMlKjIhGfEdCbA9876543210/sendMessage": dial tcp: lookup api.telegram.org: no such host`,
`Post "https://api.telegram.org/bot<REDACTED>/sendMessage": dial tcp: lookup api.telegram.org: no such host`},
{"no token", "regular error message", "regular error message"},
{"underscore + dash in token", "/bot123456789:abc-def_ghi-jkl_mno-pqr_stu-vwx_yz/sendDocument",
"/bot<REDACTED>/sendDocument"},
{"too short id (no match)", "/bot123:abc/getMe", "/bot123:abc/getMe"},
{"too short key (no match)", "/bot123456789:short/getMe", "/bot123456789:short/getMe"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
require.Equal(t, c.want, redactToken(c.in))
})
}
}
func TestNetworkError_RedactsToken(t *testing.T) {
inner := errors.New(`Post "https://api.telegram.org/bot1234567890:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe": dial tcp: timeout`)
e := &NetworkError{Err: inner}
require.NotContains(t, e.Error(), "ABCdefGHIjklMNOpqrSTUvwxYZ")
require.Contains(t, e.Error(), "<REDACTED>")
}
func TestParseError_RedactsToken(t *testing.T) {
inner := fmt.Errorf(`unexpected response from /bot1234567890:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe`)
e := &ParseError{Err: inner, Body: []byte("garbage")}
require.NotContains(t, e.Error(), "ABCdefGHI")
require.Contains(t, e.Error(), "<REDACTED>")
}
+27
View File
@@ -0,0 +1,27 @@
package client
// Result is the universal Telegram API response envelope. Every successful
// response is shaped {"ok":true,"result":T,...}; failure responses set ok
// to false and populate ErrorCode / Description / Parameters.
//
// Result is generic over T so generated method wrappers can decode the
// strongly-typed payload directly. Users do not normally construct or
// inspect Result values; method wrappers unwrap them and return either
// the typed payload or a *APIError.
type Result[T any] struct {
OK bool `json:"ok"`
Result T `json:"result,omitempty"`
ErrorCode int `json:"error_code,omitempty"`
Description string `json:"description,omitempty"`
Parameters *ResponseParameters `json:"parameters,omitempty"`
}
// ResponseParameters is the optional metadata Telegram includes on certain
// failures. The most common is RetryAfter (seconds) on 429 responses.
//
// This type is duplicated in package api for users; keeping a copy here
// avoids an import cycle (api imports client, not vice versa).
type ResponseParameters struct {
MigrateToChatID int64 `json:"migrate_to_chat_id,omitempty"`
RetryAfter int `json:"retry_after,omitempty"`
}
+225
View File
@@ -0,0 +1,225 @@
package client
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/binary"
"github.com/goccy/go-json"
"io"
"math"
"net/http"
"time"
)
// RetryDoer is an HTTPDoer that retries transient failures (429, 5xx,
// and network errors) with exponential backoff. It honours the
// retry_after value Telegram supplies on rate-limit responses.
//
// Wrap any HTTPDoer to add retry behaviour:
//
// bot := client.New(token, client.WithHTTPClient(
// client.NewRetryDoer(client.NewDefaultHTTPDoer())))
type RetryDoer struct {
inner HTTPDoer
maxAttempts int
base time.Duration
max time.Duration
factor float64
jitter float64
}
// RetryOption configures a RetryDoer.
type RetryOption func(*RetryDoer)
// WithMaxAttempts sets the maximum number of attempts (including the
// initial one). Default 4 (one initial + three retries).
func WithMaxAttempts(n int) RetryOption {
return func(d *RetryDoer) { d.maxAttempts = n }
}
// WithBaseBackoff sets the initial backoff duration. Default 500ms.
func WithBaseBackoff(d time.Duration) RetryOption {
return func(r *RetryDoer) { r.base = d }
}
// WithMaxBackoff caps the backoff at max. Default 30s.
func WithMaxBackoff(d time.Duration) RetryOption {
return func(r *RetryDoer) { r.max = d }
}
// WithBackoffFactor sets the exponential growth factor. Default 2.0.
func WithBackoffFactor(f float64) RetryOption {
return func(r *RetryDoer) { r.factor = f }
}
// WithJitter sets the jitter fraction (0..1) applied to each backoff.
// Default 0.2.
func WithJitter(j float64) RetryOption {
return func(r *RetryDoer) { r.jitter = j }
}
// NewRetryDoer wraps inner with retry behaviour.
func NewRetryDoer(inner HTTPDoer, opts ...RetryOption) *RetryDoer {
d := &RetryDoer{
inner: inner,
maxAttempts: 4,
base: 500 * time.Millisecond,
max: 30 * time.Second,
factor: 2.0,
jitter: 0.2,
}
for _, o := range opts {
o(d)
}
return d
}
// Do dispatches via the inner HTTPDoer and retries on transient failures.
// The request body is buffered on first attempt so it can be replayed.
func (d *RetryDoer) Do(req *http.Request) (*http.Response, error) {
// Buffer the body so we can replay it across attempts.
var body []byte
if req.Body != nil {
b, err := io.ReadAll(req.Body)
if err != nil {
return nil, &NetworkError{Err: err}
}
_ = req.Body.Close()
body = b
}
var lastResp *http.Response
var lastErr error
for attempt := 1; attempt <= d.maxAttempts; attempt++ {
if body != nil {
req.Body = io.NopCloser(bytes.NewReader(body))
}
resp, err := d.inner.Do(req)
// Network errors: maybe retry.
if err != nil {
// Honour ctx cancellation.
if ctxErr := req.Context().Err(); ctxErr != nil {
return nil, ctxErr
}
lastErr = err
if attempt < d.maxAttempts {
if !d.sleep(req.Context(), d.delay(attempt, 0)) {
return nil, req.Context().Err()
}
continue
}
return nil, err
}
// HTTP 200: Telegram almost always returns 200 even for errors.
// Peek the body to detect retryable Telegram error payloads.
if resp.StatusCode == http.StatusOK {
data, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, &NetworkError{Err: readErr}
}
// Re-attach the buffered body for the caller.
resp.Body = io.NopCloser(bytes.NewReader(data))
if isRetryablePayload(data) && attempt < d.maxAttempts {
lastResp = resp
wait := retryAfterFromPayload(data)
if !d.sleep(req.Context(), d.delay(attempt, wait)) {
return nil, req.Context().Err()
}
continue
}
return resp, nil
}
// Non-200 status (rare with Telegram; usually 200 + ok:false).
// Treat 5xx and 429 as retryable.
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode >= http.StatusInternalServerError) && attempt < d.maxAttempts {
_ = resp.Body.Close()
lastResp = resp
if !d.sleep(req.Context(), d.delay(attempt, 0)) {
return nil, req.Context().Err()
}
continue
}
return resp, nil
}
if lastErr != nil {
return nil, lastErr
}
return lastResp, nil
}
// delay computes the wait duration for the given attempt (1-based).
// override, when non-zero, takes precedence (used to honour Telegram's
// retry_after value).
func (d *RetryDoer) delay(attempt int, override time.Duration) time.Duration {
if override > 0 {
return override
}
delay := float64(d.base) * math.Pow(d.factor, float64(attempt-1))
if d.jitter > 0 {
var b [8]byte
_, _ = crand.Read(b[:])
f := float64(binary.LittleEndian.Uint64(b[:])) / (1 << 64)
delay *= 1 + (f*2-1)*d.jitter
}
if delay > float64(d.max) {
delay = float64(d.max)
}
if delay < 0 {
delay = 0
}
return time.Duration(delay)
}
// sleep waits for dur or ctx cancellation. Returns false if cancelled.
func (d *RetryDoer) sleep(ctx context.Context, dur time.Duration) bool {
if dur <= 0 {
return true
}
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return true
case <-ctx.Done():
return false
}
}
// isRetryablePayload reports whether body is a Telegram error response
// indicating a retryable failure (429 or 5xx error_code).
func isRetryablePayload(body []byte) bool {
var env struct {
OK bool `json:"ok"`
ErrorCode int `json:"error_code"`
}
if err := json.Unmarshal(body, &env); err != nil {
return false
}
if env.OK {
return false
}
return env.ErrorCode == 429 || (env.ErrorCode >= 500 && env.ErrorCode < 600)
}
// retryAfterFromPayload extracts the retry_after value from a Telegram
// error response body and returns it as a duration. Returns 0 if absent.
func retryAfterFromPayload(body []byte) time.Duration {
var env struct {
Parameters struct {
RetryAfter int `json:"retry_after"`
} `json:"parameters"`
}
if err := json.Unmarshal(body, &env); err != nil {
return 0
}
return time.Duration(env.Parameters.RetryAfter) * time.Second
}
+144
View File
@@ -0,0 +1,144 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type retryMockDoer struct{ mock.Mock }
func (m *retryMockDoer) Do(r *http.Request) (*http.Response, error) {
args := m.Called(r)
if v := args.Get(0); v != nil {
return v.(*http.Response), args.Error(1)
}
return nil, args.Error(1)
}
func okResp(body string) *http.Response {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
func TestRetryDoer_HappyPath(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":"hi"}`), nil).Once()
d := NewRetryDoer(m)
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
resp, err := d.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
m.AssertExpectations(t)
}
func TestRetryDoer_RetriesOnNetworkError(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout")).Once()
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":"hi"}`), nil).Once()
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
resp, err := d.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
m.AssertExpectations(t)
}
func TestRetryDoer_HonoursRetryAfter(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(
okResp(`{"ok":false,"error_code":429,"description":"Too Many","parameters":{"retry_after":1}}`), nil).Once()
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":1}`), nil).Once()
// base is 10s — retry_after=1s should override it (much shorter wait).
d := NewRetryDoer(m, WithBaseBackoff(10*time.Second))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
start := time.Now()
resp, err := d.Do(req)
elapsed := time.Since(start)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.GreaterOrEqual(t, elapsed, 900*time.Millisecond, "should honour retry_after=1s")
require.Less(t, elapsed, 3*time.Second, "should NOT use base backoff (10s)")
m.AssertExpectations(t)
}
func TestRetryDoer_Retries5xx(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(
okResp(`{"ok":false,"error_code":500,"description":"Internal Server Error"}`), nil).Once()
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":1}`), nil).Once()
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
resp, err := d.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
m.AssertExpectations(t)
}
func TestRetryDoer_AllAttemptsFail(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
d := NewRetryDoer(m, WithMaxAttempts(3), WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
_, err := d.Do(req)
require.Error(t, err)
require.Contains(t, err.Error(), "dial timeout")
require.Equal(t, 3, len(m.Calls))
}
func TestRetryDoer_ContextCancellationAborts(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(
okResp(`{"ok":false,"error_code":500,"description":"server error"}`), nil).Maybe()
d := NewRetryDoer(m, WithBaseBackoff(100*time.Millisecond))
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "POST", "http://x", strings.NewReader(`{}`))
_, err := d.Do(req)
require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded))
}
func TestRetryDoer_ReplaysBody(t *testing.T) {
m := &retryMockDoer{}
var seen []string
// First call: capture body, return 500 to trigger retry.
m.On("Do", mock.Anything).Return(okResp(`{"ok":false,"error_code":500}`), nil).Once().Run(func(args mock.Arguments) {
r := args.Get(0).(*http.Request)
body, _ := io.ReadAll(r.Body)
seen = append(seen, string(body))
})
// Second call: capture body, return success.
m.On("Do", mock.Anything).Return(okResp(`{"ok":true}`), nil).Once().Run(func(args mock.Arguments) {
r := args.Get(0).(*http.Request)
body, _ := io.ReadAll(r.Body)
seen = append(seen, string(body))
})
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{"chat_id":42}`))
_, err := d.Do(req)
require.NoError(t, err)
require.Len(t, seen, 2)
require.Equal(t, seen[0], seen[1])
require.Equal(t, `{"chat_id":42}`, seen[0])
m.AssertExpectations(t)
}
+378
View File
@@ -0,0 +1,378 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"testing"
"github.com/lukaszraczylo/go-telegram/internal/spec"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// loadIR
// ---------------------------------------------------------------------------
func TestLoadIR_ValidFile(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
data, err := json.Marshal(api)
require.NoError(t, err)
tmp := filepath.Join(t.TempDir(), "api.json")
require.NoError(t, os.WriteFile(tmp, data, 0o600))
loaded, err := loadIR(tmp)
require.NoError(t, err)
require.Len(t, loaded.Methods, 1)
require.Equal(t, "getMe", loaded.Methods[0].Name)
}
func TestLoadIR_MissingFile(t *testing.T) {
_, err := loadIR("/nonexistent/path/api.json")
require.Error(t, err)
require.Contains(t, err.Error(), "open IR")
}
func TestLoadIR_InvalidJSON(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "bad.json")
require.NoError(t, os.WriteFile(tmp, []byte("not json"), 0o600))
_, err := loadIR(tmp)
require.Error(t, err)
require.Contains(t, err.Error(), "decode IR")
}
// ---------------------------------------------------------------------------
// auditBool
// ---------------------------------------------------------------------------
func TestAuditBool_LongDocTruncated(t *testing.T) {
longDoc := make([]byte, 200)
for i := range longDoc {
longDoc[i] = 'a'
}
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "myMethod", Doc: string(longDoc), Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
problems := auditBool(api, &spec.Overrides{})
require.Len(t, problems, 1)
require.Contains(t, problems[0], "…")
}
func TestAuditBool_TrueIsReturnedVariant(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "doThing", Doc: "true is returned on success.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
require.Empty(t, auditBool(api, &spec.Overrides{}))
}
func TestAuditBool_ReturnsBoolean(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "doThing", Doc: "Returns Boolean on success.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
require.Empty(t, auditBool(api, &spec.Overrides{}))
}
// ---------------------------------------------------------------------------
// formatTypeRef
// ---------------------------------------------------------------------------
func TestFormatTypeRef_AllBranches(t *testing.T) {
cases := []struct {
tr spec.TypeRef
want string
}{
{spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}, "bool"},
{spec.TypeRef{Kind: spec.KindNamed, Name: "User"}, "User"},
{spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}, "[]Update"},
{spec.TypeRef{Kind: spec.KindArray}, "[]any"},
{spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}, "(int64 | string)"},
{spec.TypeRef{Kind: spec.Kind(99)}, "?"},
}
for _, c := range cases {
got := formatTypeRef(c.tr)
require.Equal(t, c.want, got, "for kind=%v name=%v", c.tr.Kind, c.tr.Name)
}
}
// ---------------------------------------------------------------------------
// auditDrift
// ---------------------------------------------------------------------------
func TestAuditDrift_InvalidRefReturnsError(t *testing.T) {
cur := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
_, err := auditDrift("internal/spec/api.json", "THIS_REF_DOES_NOT_EXIST", cur)
require.Error(t, err)
}
func TestAuditDrift_SameRefNoDrift(t *testing.T) {
irPath := "../../internal/spec/api.json"
cur, err := loadIR(irPath)
if err != nil {
t.Skip("api.json not available, skipping drift test")
}
changes, err := auditDrift(irPath, "HEAD", cur)
require.NoError(t, err)
require.Empty(t, changes)
}
func TestAuditDrift_InvalidJSONFromGit(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("shell script not supported on Windows")
}
tmp := t.TempDir()
fakeGit := filepath.Join(tmp, "git")
require.NoError(t, os.WriteFile(fakeGit, []byte("#!/bin/sh\necho 'not valid json'\n"), 0o600))
require.NoError(t, os.Chmod(fakeGit, 0o755))
origPATH := os.Getenv("PATH")
t.Cleanup(func() { _ = os.Setenv("PATH", origPATH) })
_ = os.Setenv("PATH", tmp+string(os.PathListSeparator)+origPATH)
_, err := auditDrift("internal/spec/api.json", "HEAD", &spec.API{})
require.Error(t, err)
require.Contains(t, err.Error(), "decode")
}
// ---------------------------------------------------------------------------
// auditAny
// ---------------------------------------------------------------------------
func TestAuditAny_FlagsUnknownMethodReturn(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{
Name: "weirdMethod",
Returns: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B", "C"}},
},
},
}
out := auditAny(api)
require.Len(t, out, 1)
require.Contains(t, out[0], "any return: weirdMethod")
}
func TestAuditAny_FlagsUnknownMethodParam(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{
Name: "weirdMethod",
Params: []spec.Field{
{Name: "Thing", JSONName: "thing", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"X", "Y", "Z"}}},
},
},
},
}
out := auditAny(api)
require.Len(t, out, 1)
require.Contains(t, out[0], "any param: weirdMethod.Thing")
}
// ---------------------------------------------------------------------------
// diffSignatures
// ---------------------------------------------------------------------------
func TestDiffSignatures_UnchangedNoDrift(t *testing.T) {
prev := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
cur := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
require.Empty(t, diffSignatures(prev, cur))
}
// ---------------------------------------------------------------------------
// typeRefEqual
// ---------------------------------------------------------------------------
func TestTypeRefEqual_ArrayNilElemDiffers(t *testing.T) {
a := spec.TypeRef{Kind: spec.KindArray}
b := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}
require.False(t, typeRefEqual(a, b))
require.False(t, typeRefEqual(b, a))
}
// ---------------------------------------------------------------------------
// TestHelperMain: subprocess helper for main() coverage.
//
// When AUDIT_HELPER_MAIN=1 is set, this function:
// 1. Resets flag.CommandLine so main()'s flag.Parse() gets a clean slate.
// 2. Sets os.Args to the args encoded in AUDIT_HELPER_ARGS.
// 3. Calls main() which calls os.Exit — the test process terminates with
// main's exit code, which the parent test captures.
// ---------------------------------------------------------------------------
func TestHelperMain(t *testing.T) {
if os.Getenv("AUDIT_HELPER_MAIN") != "1" {
t.Skip("not running as subprocess helper")
}
// Reset flag.CommandLine so main()'s flag.Parse() gets a clean slate.
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
// Decode args from JSON-encoded env var.
encoded := os.Getenv("AUDIT_HELPER_ARGS")
if encoded != "" {
var args []string
if err := json.Unmarshal([]byte(encoded), &args); err == nil {
os.Args = append([]string{os.Args[0]}, args...)
}
} else {
os.Args = os.Args[:1]
}
main()
}
// runMain runs main() via the test binary subprocess so that coverage counters
// from main() are included in the profile. Args are JSON-encoded in an env var
// to avoid conflicts with the test binary's own flag parsing.
func runMain(t *testing.T, extraEnv []string, args ...string) (string, int) {
t.Helper()
argsJSON, _ := json.Marshal(args)
cmd := exec.Command(os.Args[0], "-test.run=TestHelperMain", "-test.v=false")
cmd.Env = append(os.Environ(), "AUDIT_HELPER_MAIN=1", "AUDIT_HELPER_ARGS="+string(argsJSON))
cmd.Env = append(cmd.Env, extraEnv...)
out, err := cmd.CombinedOutput()
code := 0
if err != nil {
if ee, ok := err.(*exec.ExitError); ok {
code = ee.ExitCode()
}
}
return string(out), code
}
// ---------------------------------------------------------------------------
// main() integration tests — exercise main() code paths via subprocess
// ---------------------------------------------------------------------------
func TestMain_CleanExitsZero(t *testing.T) {
tmp := t.TempDir()
ir := writeIR(t, tmp, &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Doc: "Returns True on success.", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
})
ov := writeOverrides(t, tmp)
out, code := runMain(t, nil, "-ir", ir, "-overrides", ov)
require.Equal(t, exitClean, code, "expected exit 0 (clean)\nout: %s", out)
require.Contains(t, out, "clean")
}
func TestMain_FallbackExitsOne(t *testing.T) {
tmp := t.TempDir()
ir := writeIR(t, tmp, &spec.API{
Methods: []spec.MethodDecl{
{Name: "doSomething", Doc: "Does something.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
})
ov := writeOverrides(t, tmp)
out, code := runMain(t, nil, "-ir", ir, "-overrides", ov)
require.Equal(t, exitFallback, code, "expected exit 1 (fallback)\nout: %s", out)
require.Contains(t, out, "bool fallback")
}
func TestMain_InvalidIRExitsThree(t *testing.T) {
tmp := t.TempDir()
bad := filepath.Join(tmp, "bad.json")
require.NoError(t, os.WriteFile(bad, []byte("not json"), 0o600))
ov := writeOverrides(t, tmp)
out, code := runMain(t, nil, "-ir", bad, "-overrides", ov)
require.Equal(t, exitInvalid, code, "expected exit 3 (invalid IR)\nout: %s", out)
}
func TestMain_InvalidOverridesExitsThree(t *testing.T) {
tmp := t.TempDir()
ir := writeIR(t, tmp, &spec.API{})
bad := filepath.Join(tmp, "bad_ov.json")
require.NoError(t, os.WriteFile(bad, []byte("not json"), 0o600))
out, code := runMain(t, nil, "-ir", ir, "-overrides", bad)
require.Equal(t, exitInvalid, code, "expected exit 3 (invalid overrides)\nout: %s", out)
}
func TestMain_DriftDetectedExitsTwo(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("shell script not supported on Windows")
}
tmp := t.TempDir()
prevAPI := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
curAPI := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}},
},
}
curIR := writeIR(t, tmp, curAPI)
ov := writeOverrides(t, tmp)
prevData, err := json.Marshal(prevAPI)
require.NoError(t, err)
prevFile := filepath.Join(tmp, "prev.json")
require.NoError(t, os.WriteFile(prevFile, prevData, 0o600))
fakeGit := filepath.Join(tmp, "git")
script := fmt.Sprintf("#!/bin/sh\ncat %s\n", prevFile)
require.NoError(t, os.WriteFile(fakeGit, []byte(script), 0o600))
require.NoError(t, os.Chmod(fakeGit, 0o755))
newPATH := tmp + string(os.PathListSeparator) + os.Getenv("PATH")
out, code := runMain(t,
[]string{"PATH=" + newPATH},
"-ir", curIR, "-overrides", ov, "-drift", "-against", "HEAD~1",
)
require.Equal(t, exitDrift, code, "expected exit 2 (drift)\nout: %s", out)
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
func writeIR(t *testing.T, dir string, api *spec.API) string {
t.Helper()
data, err := json.Marshal(api)
require.NoError(t, err)
p := filepath.Join(dir, "api.json")
require.NoError(t, os.WriteFile(p, data, 0o600))
return p
}
func writeOverrides(t *testing.T, dir string) string {
t.Helper()
p := filepath.Join(dir, "overrides.json")
require.NoError(t, os.WriteFile(p, []byte("{}"), 0o600))
return p
}
+291
View File
@@ -0,0 +1,291 @@
// Command audit reports IR-level codegen fallbacks and signature drift.
//
// Usage:
//
// audit -ir <path> (default internal/spec/api.json)
// audit -overrides <path> (default internal/spec/overrides.json)
// audit -drift (compare against -against ref's IR; off by default)
// audit -against <ref> (git ref to diff drift against; default HEAD~1)
//
// Exit codes:
//
// 0 — clean
// 1 — unaccounted bool fallbacks or any-typed fields
// 2 — drift detected (signature changed)
// 3 — invalid IR or overrides
package main
import (
"flag"
"fmt"
"github.com/goccy/go-json"
"os"
"os/exec"
"strings"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
const (
exitClean = 0
exitFallback = 1
exitDrift = 2
exitInvalid = 3
)
func main() {
irPath := flag.String("ir", "internal/spec/api.json", "path to IR JSON")
ovPath := flag.String("overrides", "internal/spec/overrides.json", "path to overrides JSON")
checkDrift := flag.Bool("drift", false, "compare against -against ref's IR for signature changes")
againstRef := flag.String("against", "HEAD~1", "git ref to diff drift against (e.g. origin/main, HEAD~1)")
flag.Parse()
api, err := loadIR(*irPath)
if err != nil {
fmt.Fprintln(os.Stderr, "audit:", err)
os.Exit(exitInvalid)
}
overrides, err := spec.LoadOverrides(*ovPath)
if err != nil {
fmt.Fprintln(os.Stderr, "audit:", err)
os.Exit(exitInvalid)
}
var problems []string
problems = append(problems, auditBool(api, overrides)...)
problems = append(problems, auditAny(api)...)
driftFound := false
if *checkDrift {
if d, err := auditDrift(*irPath, *againstRef, api); err != nil {
fmt.Fprintln(os.Stderr, "audit: drift check skipped:", err)
} else if len(d) > 0 {
fmt.Println("Drift detected (signatures changed since HEAD):")
for _, p := range d {
fmt.Println(" ", p)
}
driftFound = true
}
}
if len(problems) == 0 && !driftFound {
fmt.Println("audit: clean")
os.Exit(exitClean)
}
if len(problems) > 0 {
fmt.Println("Codegen fallbacks requiring action:")
for _, p := range problems {
fmt.Println(" ", p)
}
fmt.Println()
fmt.Println("To resolve: extend cmd/scrape/method.go regex patterns OR")
fmt.Println("add an entry to internal/spec/overrides.json.")
os.Exit(exitFallback)
}
// drift only, no fallbacks
os.Exit(exitDrift)
}
func loadIR(path string) (*spec.API, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open IR: %w", err)
}
defer func() { _ = f.Close() }()
var api spec.API
if err := json.NewDecoder(f).Decode(&api); err != nil {
return nil, fmt.Errorf("decode IR: %w", err)
}
return &api, nil
}
// auditBool returns problems for methods returning bool whose docs don't
// actually say "Returns True" / etc. and which aren't in the approved list.
func auditBool(api *spec.API, ov *spec.Overrides) []string {
var out []string
for _, m := range api.Methods {
if m.Returns.Kind != spec.KindPrimitive || m.Returns.Name != "bool" {
continue
}
if ov.IsBoolApproved(m.Name) {
continue
}
if looksGenuinelyBool(m.Doc) {
continue
}
snippet := m.Doc
if len(snippet) > 120 {
snippet = snippet[:120] + "…"
}
out = append(out, fmt.Sprintf("bool fallback: %s — doc: %q", m.Name, snippet))
}
return out
}
func looksGenuinelyBool(doc string) bool {
for _, p := range []string{
"Returns True", "Returns true",
"True is returned", "true is returned",
"Returns Boolean", "Returns Bool",
} {
if strings.Contains(doc, p) {
return true
}
}
return false
}
// auditAny scans the IR for any KindOneOf TypeRef that would render as
// `any` in generated code (not matched by ChatID/InputFile-or-string/known
// union heuristics). Reports each occurrence with location.
func auditAny(api *spec.API) []string {
var out []string
isKnownUnion := func(variants []string) bool {
if hasVariants(variants, "int64", "string") {
return true // ChatID
}
if hasVariants(variants, "InputFile", "string") {
return true // *InputFile
}
// ReplyMarkup union: all four keyboard types — emitter renders as `any` intentionally
if hasVariants(variants, "InlineKeyboardMarkup", "ReplyKeyboardMarkup", "ReplyKeyboardRemove", "ForceReply") {
return true
}
for _, t := range api.Types {
if len(t.OneOf) > 0 && sameSet(variants, t.OneOf) {
return true
}
}
return false
}
isAny := func(tr spec.TypeRef) bool {
return tr.Kind == spec.KindOneOf && !isKnownUnion(tr.Variants)
}
for _, t := range api.Types {
for _, f := range t.Fields {
if isAny(f.Type) {
out = append(out, fmt.Sprintf("any field: %s.%s (variants=%v)", t.Name, f.Name, f.Type.Variants))
}
}
}
for _, m := range api.Methods {
if isAny(m.Returns) {
out = append(out, fmt.Sprintf("any return: %s (variants=%v)", m.Name, m.Returns.Variants))
}
for _, p := range m.Params {
if isAny(p.Type) {
out = append(out, fmt.Sprintf("any param: %s.%s (variants=%v)", m.Name, p.Name, p.Type.Variants))
}
}
}
return out
}
func hasVariants(got []string, want ...string) bool {
if len(got) != len(want) {
return false
}
seen := map[string]int{}
for _, g := range got {
seen[g]++
}
for _, w := range want {
seen[w]--
}
for _, v := range seen {
if v != 0 {
return false
}
}
return true
}
func sameSet(a, b []string) bool {
if len(a) != len(b) {
return false
}
return hasVariants(a, b...)
}
// auditDrift compares method/return signatures between the given git ref's
// version of irPath and the in-memory current IR. Returns a list of
// human-readable change descriptions.
func auditDrift(irPath, againstRef string, current *spec.API) ([]string, error) {
cmd := exec.Command("git", "show", againstRef+":"+irPath) // #nosec G204 - operator tool, ref controlled by caller
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("git show %s: %w", againstRef, err)
}
var prev spec.API
if err := json.Unmarshal(out, &prev); err != nil {
return nil, fmt.Errorf("decode %s IR: %w", againstRef, err)
}
return diffSignatures(&prev, current), nil
}
func diffSignatures(prev, cur *spec.API) []string {
var changes []string
pmeth := indexByName(prev.Methods, func(m spec.MethodDecl) string { return m.Name })
cmeth := indexByName(cur.Methods, func(m spec.MethodDecl) string { return m.Name })
for name, p := range pmeth {
c, ok := cmeth[name]
if !ok {
changes = append(changes, fmt.Sprintf("removed method: %s", name))
continue
}
if !typeRefEqual(p.Returns, c.Returns) {
changes = append(changes, fmt.Sprintf(
"method %s return changed: %s → %s",
name, formatTypeRef(p.Returns), formatTypeRef(c.Returns)))
}
}
for name := range cmeth {
if _, ok := pmeth[name]; !ok {
changes = append(changes, fmt.Sprintf("added method: %s", name))
}
}
return changes
}
func indexByName[T any](xs []T, f func(T) string) map[string]T {
out := map[string]T{}
for _, x := range xs {
out[f(x)] = x
}
return out
}
func typeRefEqual(a, b spec.TypeRef) bool {
if a.Kind != b.Kind || a.Name != b.Name {
return false
}
if (a.ElemType == nil) != (b.ElemType == nil) {
return false
}
if a.ElemType != nil && !typeRefEqual(*a.ElemType, *b.ElemType) {
return false
}
return sameSet(a.Variants, b.Variants)
}
func formatTypeRef(t spec.TypeRef) string {
switch t.Kind {
case spec.KindPrimitive:
return t.Name
case spec.KindNamed:
return t.Name
case spec.KindArray:
if t.ElemType != nil {
return "[]" + formatTypeRef(*t.ElemType)
}
return "[]any"
case spec.KindOneOf:
return "(" + strings.Join(t.Variants, " | ") + ")"
}
return "?"
}
+216
View File
@@ -0,0 +1,216 @@
package main
import (
"testing"
"github.com/lukaszraczylo/go-telegram/internal/spec"
"github.com/stretchr/testify/require"
)
// ---- auditBool -----------------------------------------------------------
func TestAuditBool_FlagsUnapprovedBoolMethod(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Doc: "A simple method.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
ov := &spec.Overrides{}
problems := auditBool(api, ov)
require.Len(t, problems, 1)
require.Contains(t, problems[0], "bool fallback: getMe")
}
func TestAuditBool_SkipsApprovedBoolMethod(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "setWebhook", Doc: "Use this to set webhook.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
ov := &spec.Overrides{ApprovedBoolMethods: []string{"setWebhook"}}
problems := auditBool(api, ov)
require.Empty(t, problems)
}
func TestAuditBool_SkipsMethodWithReturnsTrueDoc(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "doThing", Doc: "Returns True on success.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
ov := &spec.Overrides{}
problems := auditBool(api, ov)
require.Empty(t, problems)
}
func TestAuditBool_SkipsNonBoolMethods(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Doc: "Gets user.", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
ov := &spec.Overrides{}
require.Empty(t, auditBool(api, ov))
}
// ---- auditAny ------------------------------------------------------------
func TestAuditAny_FlagsUnrecognisedOneOf(t *testing.T) {
api := &spec.API{
Types: []spec.TypeDecl{
{
Name: "Foo",
Fields: []spec.Field{
{Name: "Bar", JSONName: "bar", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B", "C"}}},
},
},
},
}
out := auditAny(api)
require.Len(t, out, 1)
require.Contains(t, out[0], "any field: Foo.Bar")
}
func TestAuditAny_SkipsChatIDShape(t *testing.T) {
api := &spec.API{
Types: []spec.TypeDecl{
{
Name: "SendMessage",
Fields: []spec.Field{
{Name: "ChatID", JSONName: "chat_id", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}},
},
},
},
}
require.Empty(t, auditAny(api))
}
func TestAuditAny_SkipsKnownUnion(t *testing.T) {
api := &spec.API{
Types: []spec.TypeDecl{
{Name: "InputMedia", OneOf: []string{"InputMediaPhoto", "InputMediaVideo"}},
{
Name: "SomeMethod",
Fields: []spec.Field{
{Name: "Media", JSONName: "media", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputMediaPhoto", "InputMediaVideo"}}},
},
},
},
}
require.Empty(t, auditAny(api))
}
func TestAuditAny_SkipsReplyMarkupShape(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{
Name: "sendMessage",
Params: []spec.Field{
{Name: "ReplyMarkup", JSONName: "reply_markup", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InlineKeyboardMarkup", "ReplyKeyboardMarkup", "ReplyKeyboardRemove", "ForceReply"}}},
},
},
},
}
require.Empty(t, auditAny(api))
}
func TestAuditAny_SkipsInputFileShape(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{
Name: "sendPhoto",
Params: []spec.Field{
{Name: "Photo", JSONName: "photo", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}},
},
},
},
}
require.Empty(t, auditAny(api))
}
// ---- diffSignatures ------------------------------------------------------
func TestDiffSignatures_AddedMethod(t *testing.T) {
prev := &spec.API{}
cur := &spec.API{
Methods: []spec.MethodDecl{
{Name: "newMethod", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
changes := diffSignatures(prev, cur)
require.Len(t, changes, 1)
require.Contains(t, changes[0], "added method: newMethod")
}
func TestDiffSignatures_RemovedMethod(t *testing.T) {
prev := &spec.API{
Methods: []spec.MethodDecl{
{Name: "oldMethod", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
cur := &spec.API{}
changes := diffSignatures(prev, cur)
require.Len(t, changes, 1)
require.Contains(t, changes[0], "removed method: oldMethod")
}
func TestDiffSignatures_ChangedReturn(t *testing.T) {
prev := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
},
}
cur := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
changes := diffSignatures(prev, cur)
require.Len(t, changes, 1)
require.Contains(t, changes[0], "getMe")
require.Contains(t, changes[0], "bool")
require.Contains(t, changes[0], "User")
}
func TestDiffSignatures_Clean(t *testing.T) {
api := &spec.API{
Methods: []spec.MethodDecl{
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
},
}
require.Empty(t, diffSignatures(api, api))
}
// ---- typeRefEqual --------------------------------------------------------
func TestTypeRefEqual_Primitive(t *testing.T) {
a := spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
require.True(t, typeRefEqual(a, a))
b := spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}
require.False(t, typeRefEqual(a, b))
}
func TestTypeRefEqual_Array(t *testing.T) {
elem := &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}
a := spec.TypeRef{Kind: spec.KindArray, ElemType: elem}
b := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}
require.True(t, typeRefEqual(a, b))
c := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}}
require.False(t, typeRefEqual(a, c))
}
func TestTypeRefEqual_OneOf(t *testing.T) {
a := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}
b := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"string", "int64"}}
require.True(t, typeRefEqual(a, b))
c := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64"}}
require.False(t, typeRefEqual(a, c))
}
func TestTypeRefEqual_NilVsNonNilElem(t *testing.T) {
a := spec.TypeRef{Kind: spec.KindArray}
b := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}
require.False(t, typeRefEqual(a, b))
}
+749
View File
@@ -0,0 +1,749 @@
package main
import (
"bytes"
_ "embed"
"fmt"
"github.com/goccy/go-json"
"go/format"
"os"
"path/filepath"
"sort"
"text/template"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
//go:embed types.tmpl
var typesTmpl string
//go:embed methods.tmpl
var methodsTmpl string
//go:embed enums.tmpl
var enumsTmpl string
//go:embed tests.tmpl
var testsTmpl string
// runtimeTypes lists types that are intentionally hand-coded and must not be
// emitted by the code generator. Skipping them prevents collisions between
// generated and hand-coded definitions.
var runtimeTypes = map[string]bool{
"InputFile": true,
"ResponseParameters": true,
"ChatID": true,
"MessageOrBool": true,
}
// discriminatorSpec describes how to decode a sealed-interface union by
// peeking at a single JSON field.
type discriminatorSpec struct {
Field string // JSON field name to peek at
Variants map[string]string // discriminator value → concrete Go type name
}
// knownDiscriminators maps parent union name → discriminator spec.
// Used by the template helpers hasDiscriminator / discriminatorField /
// discriminatorMap to emit UnmarshalXxx helpers.
var knownDiscriminators = map[string]discriminatorSpec{
"ChatMember": {
Field: "status",
Variants: map[string]string{
"creator": "ChatMemberOwner",
"administrator": "ChatMemberAdministrator",
"member": "ChatMemberMember",
"restricted": "ChatMemberRestricted",
"left": "ChatMemberLeft",
"kicked": "ChatMemberBanned",
},
},
"MessageOrigin": {
Field: "type",
Variants: map[string]string{
"user": "MessageOriginUser",
"hidden_user": "MessageOriginHiddenUser",
"chat": "MessageOriginChat",
"channel": "MessageOriginChannel",
},
},
"ReactionType": {
Field: "type",
Variants: map[string]string{
"emoji": "ReactionTypeEmoji",
"custom_emoji": "ReactionTypeCustomEmoji",
"paid": "ReactionTypePaid",
},
},
"PaidMedia": {
Field: "type",
Variants: map[string]string{
"preview": "PaidMediaPreview",
"photo": "PaidMediaPhoto",
"video": "PaidMediaVideo",
},
},
"BackgroundType": {
Field: "type",
Variants: map[string]string{
"fill": "BackgroundTypeFill",
"wallpaper": "BackgroundTypeWallpaper",
"pattern": "BackgroundTypePattern",
"chat_theme": "BackgroundTypeChatTheme",
},
},
"BackgroundFill": {
Field: "type",
Variants: map[string]string{
"solid": "BackgroundFillSolid",
"gradient": "BackgroundFillGradient",
"freeform_gradient": "BackgroundFillFreeformGradient",
},
},
"ChatBoostSource": {
Field: "source",
Variants: map[string]string{
"premium": "ChatBoostSourcePremium",
"gift_code": "ChatBoostSourceGiftCode",
"giveaway": "ChatBoostSourceGiveaway",
},
},
"RevenueWithdrawalState": {
Field: "type",
Variants: map[string]string{
"pending": "RevenueWithdrawalStatePending",
"succeeded": "RevenueWithdrawalStateSucceeded",
"failed": "RevenueWithdrawalStateFailed",
},
},
"TransactionPartner": {
Field: "type",
Variants: map[string]string{
"fragment": "TransactionPartnerFragment",
"user": "TransactionPartnerUser",
"telegram_ads": "TransactionPartnerTelegramAds",
"telegram_api": "TransactionPartnerTelegramApi",
"other": "TransactionPartnerOther",
},
},
"MenuButton": {
Field: "type",
Variants: map[string]string{
"commands": "MenuButtonCommands",
"web_app": "MenuButtonWebApp",
"default": "MenuButtonDefault",
},
},
"OwnedGift": {
Field: "type",
Variants: map[string]string{
"regular": "OwnedGiftRegular",
"unique": "OwnedGiftUnique",
},
},
"StoryAreaType": {
Field: "type",
Variants: map[string]string{
"location": "StoryAreaTypeLocation",
"suggested_reaction": "StoryAreaTypeSuggestedReaction",
"link": "StoryAreaTypeLink",
"weather": "StoryAreaTypeWeather",
"unique_gift": "StoryAreaTypeUniqueGift",
},
},
// MaybeInaccessibleMessage uses an integer discriminator (date field).
// Variants is nil — the standard template block is skipped; a
// hand-coded UnmarshalMaybeInaccessibleMessage is emitted instead.
"MaybeInaccessibleMessage": {
Field: "",
Variants: nil,
},
}
// emitter renders Go source from a spec.API IR.
type emitter struct {
api *spec.API
outDir string
}
func newEmitter(api *spec.API, outDir string) *emitter {
return &emitter{api: api, outDir: outDir}
}
// emitTypes renders types.gen.go.
func (e *emitter) emitTypes() error {
t, err := template.New("types").Funcs(funcs()).Parse(typesTmpl)
if err != nil {
return fmt.Errorf("parse types.tmpl: %w", err)
}
filtered := *e.api
filtered.Types = nil
for _, typ := range e.api.Types {
if !runtimeTypes[typ.Name] {
filtered.Types = append(filtered.Types, typ)
}
}
var buf bytes.Buffer
if execErr := t.Execute(&buf, &filtered); execErr != nil {
return fmt.Errorf("execute types.tmpl: %w", execErr)
}
src, err := format.Source(buf.Bytes())
if err != nil {
// Surface the unformatted output so debugging is possible.
return fmt.Errorf("gofmt types.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
}
return os.WriteFile(filepath.Join(e.outDir, "types.gen.go"), src, 0o600)
}
// loadAPI reads and decodes the IR JSON.
func loadAPI(path string) (*spec.API, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var api spec.API
if err := json.Unmarshal(data, &api); err != nil {
return nil, err
}
return &api, nil
}
// funcs is the FuncMap shared across templates.
func funcs() template.FuncMap {
return template.FuncMap{
"goType": goType,
"goField": goField,
"docComment": docComment,
"isOptional": func(f spec.Field) bool { return !f.Required },
"not": func(b bool) bool { return !b },
"title": title,
"isFileField": isFileField,
"fileCheck": fileCheck,
"multipartFieldEntry": multipartFieldEntry,
"multipartFileEntry": multipartFileEntry,
"returnGoType": returnGoType,
// discriminator helpers for types.tmpl
"hasDiscriminator": func(name string) bool { s, ok := knownDiscriminators[name]; return ok && len(s.Variants) > 0 },
"isSealedUnionReturn": func(tr spec.TypeRef) bool {
if tr.Kind != spec.KindNamed {
return false
}
s, ok := knownDiscriminators[tr.Name]
return ok && len(s.Variants) > 0
},
"isMaybeInaccessibleMessage": func(name string) bool { return name == "MaybeInaccessibleMessage" },
"discriminatorField": func(name string) string { return knownDiscriminators[name].Field },
"discriminatorMap": func(name string) map[string]string { return knownDiscriminators[name].Variants },
// union-field helpers for per-struct UnmarshalJSON emission
"unionFields": unionFieldsOf,
"isArrayUnion": func(tr spec.TypeRef) bool { return hasUnionElem(tr) },
"unionTypeName": func(tr spec.TypeRef) string { name, _ := unionTypeFor(tr); return name },
}
}
// title upper-cases the first byte of s (ASCII only — all Telegram method names are ASCII).
func title(s string) string {
if s == "" {
return ""
}
r := s[0]
if r >= 'a' && r <= 'z' {
r = r - 'a' + 'A'
}
return string(r) + s[1:]
}
// isFileField reports whether the field carries an InputFile.
func isFileField(f spec.Field) bool {
return mentionsInputFileTr(f.Type)
}
func mentionsInputFileTr(tr spec.TypeRef) bool {
switch tr.Kind {
case spec.KindNamed:
return tr.Name == "InputFile"
case spec.KindArray:
if tr.ElemType != nil {
return mentionsInputFileTr(*tr.ElemType)
}
case spec.KindOneOf:
for _, v := range tr.Variants {
if v == "InputFile" {
return true
}
}
}
return false
}
// fileCheck returns the HasFile guard line for a file-carrying field.
// Both named InputFile and InputFile-or-String oneOf fields are now *InputFile,
// so no type assertion is needed in either case.
func fileCheck(f spec.Field) string {
return fmt.Sprintf("\tif p.%s != nil && p.%s.IsLocalUpload() { return true }\n", f.Name, f.Name)
}
// multipartFileEntry returns the MultipartFiles append block for a file field.
// Both named InputFile and InputFile-or-String oneOf fields are now *InputFile,
// so the same code works for both cases.
func multipartFileEntry(f spec.Field) string {
jsonName := f.JSONName
return fmt.Sprintf(
"\tif p.%s != nil && p.%s.IsLocalUpload() {\n\t\tname := p.%s.Filename\n\t\tif name == \"\" { name = %q }\n\t\tfiles = append(files, client.MultipartFile{FieldName: %q, Filename: name, Reader: p.%s.Reader})\n\t}\n",
f.Name, f.Name, f.Name, jsonName, jsonName, f.Name)
}
// multipartFieldEntry generates the line that adds f to the multipart map.
// Required scalar fields go in unconditionally; optional ones go in only
// when non-zero/non-empty.
func multipartFieldEntry(f spec.Field) string {
switch f.Type.Kind {
case spec.KindPrimitive:
switch f.Type.Name {
case "int64":
if f.Required {
return fmt.Sprintf("\tout[%q] = strconv.FormatInt(p.%s, 10)\n", f.JSONName, f.Name)
}
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatInt(*p.%s, 10) }\n", f.Name, f.JSONName, f.Name)
case "string":
if f.Required {
return fmt.Sprintf("\tout[%q] = p.%s\n", f.JSONName, f.Name)
}
return fmt.Sprintf("\tif p.%s != \"\" { out[%q] = p.%s }\n", f.Name, f.JSONName, f.Name)
case "bool":
if f.Required {
return fmt.Sprintf("\tout[%q] = strconv.FormatBool(p.%s)\n", f.JSONName, f.Name)
}
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatBool(*p.%s) }\n", f.Name, f.JSONName, f.Name)
case "float64":
if f.Required {
return fmt.Sprintf("\tout[%q] = strconv.FormatFloat(p.%s, 'f', -1, 64)\n", f.JSONName, f.Name)
}
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatFloat(*p.%s, 'f', -1, 64) }\n", f.Name, f.JSONName, f.Name)
}
case spec.KindOneOf:
// Integer-or-String → ChatID: use .String() wire form.
if matchesVariants(f.Type.Variants, "int64", "string") {
if f.Required {
return fmt.Sprintf("\tout[%q] = p.%s.String()\n", f.JSONName, f.Name)
}
return fmt.Sprintf("\tif !p.%s.IsZero() { out[%q] = p.%s.String() }\n", f.Name, f.JSONName, f.Name)
}
// InputFile-or-String → *InputFile: non-upload branch sends PathOrID.
if matchesVariants(f.Type.Variants, "InputFile", "string") {
return fmt.Sprintf("\tif p.%s != nil && !p.%s.IsLocalUpload() && p.%s.PathOrID != \"\" { out[%q] = p.%s.PathOrID }\n",
f.Name, f.Name, f.Name, f.JSONName, f.Name)
}
// Sealed-interface unions — JSON-marshal.
if f.Required {
return fmt.Sprintf("\tif b, _ := json.Marshal(p.%s); len(b) > 0 && string(b) != \"null\" { out[%q] = string(b) }\n", f.Name, f.JSONName)
}
return fmt.Sprintf("\tif p.%s != nil { if b, _ := json.Marshal(p.%s); len(b) > 0 && string(b) != \"null\" { out[%q] = string(b) } }\n", f.Name, f.Name, f.JSONName)
}
// Named or array: fall back to JSON-marshal to JSON string.
if f.Required {
return fmt.Sprintf("\tif b, _ := json.Marshal(p.%s); len(b) > 0 { out[%q] = string(b) }\n", f.Name, f.JSONName)
}
return fmt.Sprintf("\tif p.%s != nil { if b, _ := json.Marshal(p.%s); len(b) > 0 { out[%q] = string(b) } }\n", f.Name, f.Name, f.JSONName)
}
func returnGoType(tr spec.TypeRef) string {
switch tr.Kind {
case spec.KindPrimitive:
return tr.Name
case spec.KindNamed:
// Sealed-interface unions are returned by interface value, not pointer
// (you can't take a pointer to an interface in any useful way; the
// generated UnmarshalXxx returns the interface directly).
if _, ok := knownDiscriminators[tr.Name]; ok {
return tr.Name
}
// MessageOrBool is a hand-coded runtime wrapper — pointer return.
return "*" + tr.Name
case spec.KindArray:
if tr.ElemType == nil {
return "[]any"
}
return "[]" + returnGoElem(*tr.ElemType)
case spec.KindOneOf:
// Integer-or-String return (rare but possible).
if matchesVariants(tr.Variants, "int64", "string") {
return "ChatID"
}
return "any"
}
return "any"
}
func returnGoElem(tr spec.TypeRef) string {
switch tr.Kind {
case spec.KindPrimitive:
return tr.Name
case spec.KindNamed:
return tr.Name
case spec.KindArray:
if tr.ElemType == nil {
return "any"
}
return "[]" + returnGoElem(*tr.ElemType)
}
return "any"
}
// emitMethods renders methods.gen.go.
func (e *emitter) emitMethods() error {
t, err := template.New("methods").Funcs(funcs()).Parse(methodsTmpl)
if err != nil {
return fmt.Errorf("parse methods.tmpl: %w", err)
}
var buf bytes.Buffer
if execErr := t.Execute(&buf, e.api); execErr != nil {
return fmt.Errorf("execute methods.tmpl: %w", execErr)
}
src, err := format.Source(buf.Bytes())
if err != nil {
return fmt.Errorf("gofmt methods.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
}
return os.WriteFile(filepath.Join(e.outDir, "methods.gen.go"), src, 0o600)
}
// emitEnums renders enums.gen.go.
func (e *emitter) emitEnums() error {
t, err := template.New("enums").Funcs(funcs()).Parse(enumsTmpl)
if err != nil {
return fmt.Errorf("parse enums.tmpl: %w", err)
}
var buf bytes.Buffer
if execErr := t.Execute(&buf, e.api); execErr != nil {
return fmt.Errorf("execute enums.tmpl: %w", execErr)
}
src, err := format.Source(buf.Bytes())
if err != nil {
return fmt.Errorf("gofmt enums.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
}
return os.WriteFile(filepath.Join(e.outDir, "enums.gen.go"), src, 0o600)
}
// goType returns the Go type expression for a TypeRef.
// Optional fields use pointer types for primitives and named types,
// or rely on omitempty for slices and maps. parameter `optional` controls
// whether to wrap pointer-style.
func goType(tr spec.TypeRef, optional bool) string {
switch tr.Kind {
case spec.KindPrimitive:
if optional && (tr.Name == "bool" || tr.Name == "int64" || tr.Name == "float64") {
return "*" + tr.Name
}
return tr.Name
case spec.KindNamed:
// Named types are always pointer-optional when optional, except:
// 1. Union (interface) types — they are naturally nil-able; pointer-to-interface is invalid.
// 2. InputFile is always pointer-typed even when required: the
// multipart helpers (fileCheck, multipartFileEntry) call
// f.IsLocalUpload() and dereference Reader, both of which
// expect a pointer receiver.
if _, isUnion := knownDiscriminators[tr.Name]; isUnion {
// Interface type — never add *.
return tr.Name
}
if optional || tr.Name == "InputFile" {
return "*" + tr.Name
}
return tr.Name
case spec.KindArray:
if tr.ElemType == nil {
return "[]any"
}
// Inside slices, the element shape is its own thing — never wrap
// the element in a pointer just because the field is optional.
return "[]" + goType(*tr.ElemType, false)
case spec.KindOneOf:
// Integer-or-String: typed ChatID wrapper.
if matchesVariants(tr.Variants, "int64", "string") {
if optional {
return "*ChatID"
}
return "ChatID"
}
// InputFile-or-String: *InputFile runtime helper handles both.
if matchesVariants(tr.Variants, "InputFile", "string") {
return "*InputFile"
}
// All-named variants sealed interface: fall back to interface.
return "any"
}
return "any"
}
// unionField pairs a struct field with the name of its union type.
type unionField struct {
Field spec.Field
UnionName string // e.g. "ChatMember"
}
// unionFieldsOf returns the subset of t.Fields whose type is a known
// discriminated union (directly or as array element).
func unionFieldsOf(t spec.TypeDecl) []unionField {
var out []unionField
for _, f := range t.Fields {
if u, ok := unionTypeFor(f.Type); ok {
out = append(out, unionField{Field: f, UnionName: u})
}
}
return out
}
// unionTypeFor inspects a TypeRef and reports whether it (or its array
// element) is a known discriminated union. Returns the union name and true.
func unionTypeFor(tr spec.TypeRef) (string, bool) {
switch tr.Kind {
case spec.KindNamed:
if _, ok := knownDiscriminators[tr.Name]; ok {
return tr.Name, true
}
case spec.KindArray:
if tr.ElemType != nil {
return unionTypeFor(*tr.ElemType)
}
case spec.KindOneOf:
if u := unionNameByVariants(tr.Variants); u != "" {
return u, true
}
}
return "", false
}
// unionNameByVariants finds the parent union whose variant type names exactly
// match the given variant set (order-insensitive).
func unionNameByVariants(variants []string) string {
for parentName, ds := range knownDiscriminators {
wanted := make([]string, 0, len(ds.Variants))
for _, vt := range ds.Variants {
wanted = append(wanted, vt)
}
if matchesVariants(variants, wanted...) {
return parentName
}
}
return ""
}
// hasUnionElem reports whether tr is an array whose element type is a known union.
func hasUnionElem(tr spec.TypeRef) bool {
if tr.Kind != spec.KindArray || tr.ElemType == nil {
return false
}
_, ok := unionTypeFor(*tr.ElemType)
return ok
}
// matchesVariants reports whether got equals want as a set (order-insensitive).
func matchesVariants(got []string, want ...string) bool {
if len(got) != len(want) {
return false
}
seen := make(map[string]int, len(got))
for _, g := range got {
seen[g]++
}
for _, w := range want {
seen[w]--
}
for _, v := range seen {
if v != 0 {
return false
}
}
return true
}
// goField returns the Go struct-field declaration for a Field.
func goField(f spec.Field) string {
tag := fmt.Sprintf("`json:%q`", f.JSONName+omitempty(f))
return fmt.Sprintf("%s %s %s", f.Name, goType(f.Type, !f.Required), tag)
}
func omitempty(f spec.Field) string {
if f.Required {
return ""
}
return ",omitempty"
}
// docComment converts a doc string into a Go-style block comment with
// a leading "// " on each line.
func docComment(s string) string {
if s == "" {
return ""
}
var buf bytes.Buffer
for _, line := range splitLines(s) {
buf.WriteString("// ")
buf.WriteString(line)
buf.WriteByte('\n')
}
return buf.String()
}
func splitLines(s string) []string {
var out []string
start := 0
for i := 0; i < len(s); i++ {
if s[i] == '\n' {
out = append(out, s[start:i])
start = i + 1
}
}
if start < len(s) {
out = append(out, s[start:])
}
return out
}
// hasVariants reports whether the variant list contains all of the named strings (order-insensitive).
func hasVariants(variants []string, names ...string) bool {
return matchesVariants(variants, names...)
}
// buildUnionTypeSet returns the set of all type names that generate interface types
// (i.e., types with one_of). This includes knownDiscriminators and marker-interface
// unions not covered by the discriminator map.
func buildUnionTypeSet(api *spec.API) map[string]bool {
s := make(map[string]bool, len(knownDiscriminators)+16)
for name := range knownDiscriminators {
s[name] = true
}
for _, t := range api.Types {
if len(t.OneOf) > 0 {
s[t.Name] = true
}
}
return s
}
// makeSentinelValue returns a sentinelValue func that uses the given union type set.
// It returns a minimal valid Go expression for a spec.Field's type,
// used in generated test param literals.
func makeSentinelValue(unionTypes map[string]bool) func(spec.Field) string {
return func(f spec.Field) string {
return sentinelForField(f, unionTypes)
}
}
func sentinelForField(f spec.Field, unionTypes map[string]bool) string {
tr := f.Type
switch tr.Kind {
case spec.KindPrimitive:
switch tr.Name {
case "int64":
return "42"
case "string":
return `"test_value"`
case "bool":
return "true"
case "float64":
return "1.0"
}
case spec.KindNamed:
switch tr.Name {
case "ChatID":
return "ChatIDFromInt(123)"
case "InputFile":
return `&InputFile{PathOrID: "file_id_test"}`
}
// Interface (union) types are nil-able.
if unionTypes[tr.Name] {
return "nil"
}
// Required named struct types are value types in the generated struct.
if f.Required {
return tr.Name + "{}"
}
return "&" + tr.Name + "{}"
case spec.KindArray:
return "nil"
case spec.KindOneOf:
if hasVariants(tr.Variants, "int64", "string") {
return "ChatIDFromInt(123)"
}
if hasVariants(tr.Variants, "InputFile", "string") {
return `&InputFile{PathOrID: "file_id_test"}`
}
// Sealed named-union interface: use nil (any).
return "nil"
}
return "nil"
}
// successResp returns a backtick Go string literal containing a minimal
// {"ok":true,"result":...} JSON body for the method's return type.
func successResp(m spec.MethodDecl) string {
body := successBody(m.Returns)
return "`{\"ok\":true,\"result\":" + body + "}`"
}
func successBody(tr spec.TypeRef) string {
switch tr.Kind {
case spec.KindPrimitive:
switch tr.Name {
case "bool":
return "true"
case "int64", "float64":
return "0"
case "string":
return `""`
}
case spec.KindNamed:
if tr.Name == "MessageOrBool" {
return "true"
}
// Sealed-interface unions need a discriminator field so UnmarshalXxx can dispatch.
// Pick the lexicographically first variant value for determinism (map
// iteration order in Go is randomized — using `range` directly produces
// non-deterministic regen output).
if disc, ok := knownDiscriminators[tr.Name]; ok && disc.Field != "" {
values := make([]string, 0, len(disc.Variants))
for v := range disc.Variants {
values = append(values, v)
}
sort.Strings(values)
if len(values) > 0 {
return fmt.Sprintf(`{"%s":"%s"}`, disc.Field, values[0])
}
}
// MaybeInaccessibleMessage uses date==0 → InaccessibleMessage variant.
if tr.Name == "MaybeInaccessibleMessage" {
return `{"date":0,"chat":{"id":1,"type":"private"},"message_id":1}`
}
return "{}"
case spec.KindArray:
return "[]"
case spec.KindOneOf:
return "null"
}
return "null"
}
// emitTests renders methods_gen_test.go.
func (e *emitter) emitTests() error {
unionTypes := buildUnionTypeSet(e.api)
// Add test-specific helpers to the shared func map.
fm := funcs()
fm["sentinelValue"] = makeSentinelValue(unionTypes)
fm["successResp"] = successResp
t, err := template.New("tests").Funcs(fm).Parse(testsTmpl)
if err != nil {
return fmt.Errorf("parse tests.tmpl: %w", err)
}
var buf bytes.Buffer
if execErr := t.Execute(&buf, e.api); execErr != nil {
return fmt.Errorf("execute tests.tmpl: %w", execErr)
}
src, err := format.Source(buf.Bytes())
if err != nil {
return fmt.Errorf("gofmt methods_gen_test.go: %w\n--- unformatted ---\n%s", err, buf.String())
}
return os.WriteFile(filepath.Join(e.outDir, "methods_gen_test.go"), src, 0o600)
}
+97
View File
@@ -0,0 +1,97 @@
package main
import (
"flag"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
var updateGolden = flag.Bool("update", false, "update golden files")
func TestEmit_Types_FixtureGolden(t *testing.T) {
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
require.NoError(t, err)
tmp := t.TempDir()
e := newEmitter(api, tmp)
require.NoError(t, e.emitTypes())
got, err := os.ReadFile(filepath.Join(tmp, "types.gen.go"))
require.NoError(t, err)
goldenPath := "../../testdata/golden/types.gen.go"
if *updateGolden {
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
return
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
require.Equal(t, string(expected), string(got))
}
func TestEmit_Enums_FixtureGolden(t *testing.T) {
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
require.NoError(t, err)
tmp := t.TempDir()
e := newEmitter(api, tmp)
require.NoError(t, e.emitEnums())
got, err := os.ReadFile(filepath.Join(tmp, "enums.gen.go"))
require.NoError(t, err)
goldenPath := "../../testdata/golden/enums.gen.go"
if *updateGolden {
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
return
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
require.Equal(t, string(expected), string(got))
}
func TestEmit_Methods_FixtureGolden(t *testing.T) {
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
require.NoError(t, err)
tmp := t.TempDir()
e := newEmitter(api, tmp)
require.NoError(t, e.emitTypes()) // some methods reference types
require.NoError(t, e.emitMethods())
got, err := os.ReadFile(filepath.Join(tmp, "methods.gen.go"))
require.NoError(t, err)
goldenPath := "../../testdata/golden/methods.gen.go"
if *updateGolden {
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
return
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
require.Equal(t, string(expected), string(got))
}
func TestEmit_Tests_FixtureGolden(t *testing.T) {
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
require.NoError(t, err)
tmp := t.TempDir()
e := newEmitter(api, tmp)
require.NoError(t, e.emitTests())
got, err := os.ReadFile(filepath.Join(tmp, "methods_gen_test.go"))
require.NoError(t, err)
goldenPath := "../../testdata/golden/methods_gen_test.go"
if *updateGolden {
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
return
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
require.Equal(t, string(expected), string(got))
}
+60
View File
@@ -0,0 +1,60 @@
// Code generated by cmd/genapi. DO NOT EDIT.
//go:build !ignore_autogenerated
package api
// ParseMode controls how Telegram interprets formatting in message text.
type ParseMode string
const (
ParseModeMarkdown ParseMode = "Markdown" // legacy
ParseModeMarkdownV2 ParseMode = "MarkdownV2"
ParseModeHTML ParseMode = "HTML"
)
// ChatType is the type of a Telegram chat.
type ChatType string
const (
ChatTypePrivate ChatType = "private"
ChatTypeGroup ChatType = "group"
ChatTypeSupergroup ChatType = "supergroup"
ChatTypeChannel ChatType = "channel"
)
// UpdateType identifies an Update payload variant. Used by allowed_updates
// in getUpdates / setWebhook.
type UpdateType string
const (
UpdateMessage UpdateType = "message"
UpdateEditedMessage UpdateType = "edited_message"
UpdateChannelPost UpdateType = "channel_post"
UpdateEditedChannelPost UpdateType = "edited_channel_post"
UpdateCallbackQuery UpdateType = "callback_query"
UpdateInlineQuery UpdateType = "inline_query"
)
// MessageEntityType is the kind of an entity (mention, hashtag, command, ...).
type MessageEntityType string
const (
EntityMention MessageEntityType = "mention"
EntityHashtag MessageEntityType = "hashtag"
EntityCashtag MessageEntityType = "cashtag"
EntityBotCommand MessageEntityType = "bot_command"
EntityURL MessageEntityType = "url"
EntityEmail MessageEntityType = "email"
EntityPhoneNumber MessageEntityType = "phone_number"
EntityBold MessageEntityType = "bold"
EntityItalic MessageEntityType = "italic"
EntityUnderline MessageEntityType = "underline"
EntityStrike MessageEntityType = "strikethrough"
EntitySpoiler MessageEntityType = "spoiler"
EntityCode MessageEntityType = "code"
EntityPre MessageEntityType = "pre"
EntityTextLink MessageEntityType = "text_link"
EntityTextMention MessageEntityType = "text_mention"
EntityCustomEmoji MessageEntityType = "custom_emoji"
)
+645
View File
@@ -0,0 +1,645 @@
package main
import (
"testing"
"github.com/lukaszraczylo/go-telegram/internal/spec"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// goType — all branches
// ---------------------------------------------------------------------------
func TestGoType_Primitive(t *testing.T) {
cases := []struct {
name string
optional bool
want string
}{
{"bool", false, "bool"},
{"bool", true, "*bool"},
{"int64", false, "int64"},
{"int64", true, "*int64"},
{"float64", false, "float64"},
{"float64", true, "*float64"},
{"string", false, "string"},
{"string", true, "string"}, // string is not pointer-wrapped
}
for _, c := range cases {
tr := spec.TypeRef{Kind: spec.KindPrimitive, Name: c.name}
got := goType(tr, c.optional)
require.Equal(t, c.want, got, "goType(%q, optional=%v)", c.name, c.optional)
}
}
func TestGoType_Named_Required(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
require.Equal(t, "Message", goType(tr, false))
require.Equal(t, "*Message", goType(tr, true))
}
func TestGoType_Named_InputFile(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}
// InputFile is always pointer even when required.
require.Equal(t, "*InputFile", goType(tr, false))
require.Equal(t, "*InputFile", goType(tr, true))
}
func TestGoType_Named_UnionInterface(t *testing.T) {
// ChatMember is a known discriminated union — no * even when optional.
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
require.Equal(t, "ChatMember", goType(tr, false))
require.Equal(t, "ChatMember", goType(tr, true))
}
func TestGoType_Array_NilElem(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindArray}
require.Equal(t, "[]any", goType(tr, false))
}
func TestGoType_Array_WithElem(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
require.Equal(t, "[]Update", goType(tr, false))
}
func TestGoType_OneOf_ChatID(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}
require.Equal(t, "ChatID", goType(tr, false))
require.Equal(t, "*ChatID", goType(tr, true))
}
func TestGoType_OneOf_InputFile(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}
require.Equal(t, "*InputFile", goType(tr, false))
}
func TestGoType_OneOf_SealedInterface(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}
require.Equal(t, "any", goType(tr, false))
}
func TestGoType_Unknown(t *testing.T) {
tr := spec.TypeRef{Kind: spec.Kind(99)}
require.Equal(t, "any", goType(tr, false))
}
// ---------------------------------------------------------------------------
// returnGoType — all branches
// ---------------------------------------------------------------------------
func TestReturnGoType_Primitive(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
require.Equal(t, "bool", returnGoType(tr))
}
func TestReturnGoType_Named(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
require.Equal(t, "*Message", returnGoType(tr))
}
func TestReturnGoType_Array_NilElem(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindArray}
require.Equal(t, "[]any", returnGoType(tr))
}
func TestReturnGoType_Array_WithElem(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
require.Equal(t, "[]Update", returnGoType(tr))
}
func TestReturnGoType_OneOf_ChatID(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}
require.Equal(t, "ChatID", returnGoType(tr))
}
func TestReturnGoType_OneOf_Other(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}
require.Equal(t, "any", returnGoType(tr))
}
func TestReturnGoType_Unknown(t *testing.T) {
tr := spec.TypeRef{Kind: spec.Kind(99)}
require.Equal(t, "any", returnGoType(tr))
}
// ---------------------------------------------------------------------------
// returnGoElem — all branches
// ---------------------------------------------------------------------------
func TestReturnGoElem_Primitive(t *testing.T) {
require.Equal(t, "int64", returnGoElem(spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}))
}
func TestReturnGoElem_Named(t *testing.T) {
require.Equal(t, "Message", returnGoElem(spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}))
}
func TestReturnGoElem_Array_NilElem(t *testing.T) {
require.Equal(t, "any", returnGoElem(spec.TypeRef{Kind: spec.KindArray}))
}
func TestReturnGoElem_Array_WithElem(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "PhotoSize"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
require.Equal(t, "[]PhotoSize", returnGoElem(tr))
}
func TestReturnGoElem_Unknown(t *testing.T) {
require.Equal(t, "any", returnGoElem(spec.TypeRef{Kind: spec.Kind(99)}))
}
// ---------------------------------------------------------------------------
// multipartFieldEntry — all branches
// ---------------------------------------------------------------------------
func makeField(name, jname, typName string, kind spec.Kind, required bool) spec.Field {
return spec.Field{
Name: name,
JSONName: jname,
Type: spec.TypeRef{Kind: kind, Name: typName},
Required: required,
}
}
func makeFieldVariants(name, jname string, kind spec.Kind, variants []string, required bool) spec.Field {
return spec.Field{
Name: name,
JSONName: jname,
Type: spec.TypeRef{Kind: kind, Variants: variants},
Required: required,
}
}
func TestMultipartFieldEntry_Int64Required(t *testing.T) {
f := makeField("ChatID", "chat_id", "int64", spec.KindPrimitive, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `FormatInt`)
require.NotContains(t, got, "if p.")
}
func TestMultipartFieldEntry_Int64Optional(t *testing.T) {
f := makeField("MessageThreadID", "message_thread_id", "int64", spec.KindPrimitive, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `FormatInt`)
require.Contains(t, got, "if p.")
}
func TestMultipartFieldEntry_StringRequired(t *testing.T) {
f := makeField("Text", "text", "string", spec.KindPrimitive, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `out["text"]`)
require.NotContains(t, got, "if p.Text")
}
func TestMultipartFieldEntry_StringOptional(t *testing.T) {
f := makeField("ParseMode", "parse_mode", "string", spec.KindPrimitive, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `if p.ParseMode`)
}
func TestMultipartFieldEntry_BoolRequired(t *testing.T) {
f := makeField("DisableNotification", "disable_notification", "bool", spec.KindPrimitive, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `FormatBool`)
require.NotContains(t, got, "if p.")
}
func TestMultipartFieldEntry_BoolOptional(t *testing.T) {
f := makeField("Protected", "protect_content", "bool", spec.KindPrimitive, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `FormatBool`)
require.Contains(t, got, "if p.")
}
func TestMultipartFieldEntry_Float64Required(t *testing.T) {
f := makeField("Latitude", "latitude", "float64", spec.KindPrimitive, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `FormatFloat`)
require.NotContains(t, got, "if p.")
}
func TestMultipartFieldEntry_Float64Optional(t *testing.T) {
f := makeField("Longitude", "longitude", "float64", spec.KindPrimitive, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `FormatFloat`)
require.Contains(t, got, "if p.")
}
func TestMultipartFieldEntry_OneOf_ChatIDRequired(t *testing.T) {
f := makeFieldVariants("ChatID", "chat_id", spec.KindOneOf, []string{"int64", "string"}, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `.String()`)
require.NotContains(t, got, "IsZero")
}
func TestMultipartFieldEntry_OneOf_ChatIDOptional(t *testing.T) {
f := makeFieldVariants("ChatID", "chat_id", spec.KindOneOf, []string{"int64", "string"}, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `IsZero`)
}
func TestMultipartFieldEntry_OneOf_InputFileOrString(t *testing.T) {
f := makeFieldVariants("Photo", "photo", spec.KindOneOf, []string{"InputFile", "string"}, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `PathOrID`)
}
func TestMultipartFieldEntry_OneOf_SealedRequired(t *testing.T) {
f := makeFieldVariants("Markup", "reply_markup", spec.KindOneOf, []string{"A", "B"}, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `json.Marshal`)
}
func TestMultipartFieldEntry_OneOf_SealedOptional(t *testing.T) {
f := makeFieldVariants("Markup", "reply_markup", spec.KindOneOf, []string{"A", "B"}, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `json.Marshal`)
require.Contains(t, got, "if p.Markup")
}
func TestMultipartFieldEntry_Named_Required(t *testing.T) {
f := makeField("Entities", "entities", "MessageEntity", spec.KindNamed, true)
got := multipartFieldEntry(f)
require.Contains(t, got, `json.Marshal`)
require.NotContains(t, got, "if p.")
}
func TestMultipartFieldEntry_Named_Optional(t *testing.T) {
f := makeField("Entities", "entities", "MessageEntity", spec.KindNamed, false)
got := multipartFieldEntry(f)
require.Contains(t, got, `json.Marshal`)
require.Contains(t, got, "if p.")
}
// ---------------------------------------------------------------------------
// unionTypeFor — all branches
// ---------------------------------------------------------------------------
func TestUnionTypeFor_DirectNamed(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
name, ok := unionTypeFor(tr)
require.True(t, ok)
require.Equal(t, "ChatMember", name)
}
func TestUnionTypeFor_Array(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
name, ok := unionTypeFor(tr)
require.True(t, ok)
require.Equal(t, "ChatMember", name)
}
func TestUnionTypeFor_ArrayNilElem(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindArray}
_, ok := unionTypeFor(tr)
require.False(t, ok)
}
func TestUnionTypeFor_NotUnion(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
_, ok := unionTypeFor(tr)
require.False(t, ok)
}
func TestUnionTypeFor_Unknown(t *testing.T) {
tr := spec.TypeRef{Kind: spec.Kind(99)}
_, ok := unionTypeFor(tr)
require.False(t, ok)
}
// ---------------------------------------------------------------------------
// unionNameByVariants
// ---------------------------------------------------------------------------
func TestUnionNameByVariants_ChatMember(t *testing.T) {
// Use the actual variants from knownDiscriminators["ChatMember"].
variants := []string{
"ChatMemberOwner", "ChatMemberAdministrator", "ChatMemberMember",
"ChatMemberRestricted", "ChatMemberLeft", "ChatMemberBanned",
}
name := unionNameByVariants(variants)
require.Equal(t, "ChatMember", name)
}
func TestUnionNameByVariants_Unknown(t *testing.T) {
name := unionNameByVariants([]string{"X", "Y", "Z"})
require.Equal(t, "", name)
}
// ---------------------------------------------------------------------------
// hasUnionElem
// ---------------------------------------------------------------------------
func TestHasUnionElem_NonArray(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
require.False(t, hasUnionElem(tr))
}
func TestHasUnionElem_ArrayNilElem(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindArray}
require.False(t, hasUnionElem(tr))
}
func TestHasUnionElem_ArrayUnionElem(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
require.True(t, hasUnionElem(tr))
}
func TestHasUnionElem_ArrayNonUnionElem(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
require.False(t, hasUnionElem(tr))
}
// ---------------------------------------------------------------------------
// unionFieldsOf
// ---------------------------------------------------------------------------
func TestUnionFieldsOf_WithUnionField(t *testing.T) {
td := spec.TypeDecl{
Name: "ChatMemberUpdated",
Fields: []spec.Field{
{Name: "NewChatMember", JSONName: "new_chat_member", Type: spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}},
{Name: "OldChatMember", JSONName: "old_chat_member", Type: spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}},
{Name: "Date", JSONName: "date", Type: spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}},
},
}
uf := unionFieldsOf(td)
require.Len(t, uf, 2)
require.Equal(t, "ChatMember", uf[0].UnionName)
}
// ---------------------------------------------------------------------------
// splitLines — edge cases
// ---------------------------------------------------------------------------
func TestSplitLines_Empty(t *testing.T) {
require.Empty(t, splitLines(""))
}
func TestSplitLines_NoNewline(t *testing.T) {
got := splitLines("hello world")
require.Equal(t, []string{"hello world"}, got)
}
func TestSplitLines_TrailingNewline(t *testing.T) {
got := splitLines("line1\nline2\n")
require.Equal(t, []string{"line1", "line2"}, got)
}
func TestSplitLines_MultiLine(t *testing.T) {
got := splitLines("a\nb\nc")
require.Equal(t, []string{"a", "b", "c"}, got)
}
// ---------------------------------------------------------------------------
// docComment
// ---------------------------------------------------------------------------
func TestDocComment_Empty(t *testing.T) {
require.Equal(t, "", docComment(""))
}
func TestDocComment_SingleLine(t *testing.T) {
got := docComment("Hello world.")
require.Equal(t, "// Hello world.\n", got)
}
func TestDocComment_MultiLine(t *testing.T) {
got := docComment("Line 1\nLine 2")
require.Contains(t, got, "// Line 1\n")
require.Contains(t, got, "// Line 2\n")
}
// ---------------------------------------------------------------------------
// title
// ---------------------------------------------------------------------------
func TestTitle_Empty(t *testing.T) {
require.Equal(t, "", title(""))
}
func TestTitle_Lowercase(t *testing.T) {
require.Equal(t, "SendMessage", title("sendMessage"))
}
func TestTitle_AlreadyUpper(t *testing.T) {
require.Equal(t, "GetMe", title("GetMe"))
}
// ---------------------------------------------------------------------------
// mentionsInputFileTr
// ---------------------------------------------------------------------------
func TestMentionsInputFileTr_Named(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}
require.True(t, mentionsInputFileTr(tr))
}
func TestMentionsInputFileTr_NotInputFile(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
require.False(t, mentionsInputFileTr(tr))
}
func TestMentionsInputFileTr_Array(t *testing.T) {
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
require.True(t, mentionsInputFileTr(tr))
}
func TestMentionsInputFileTr_ArrayNilElem(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindArray}
require.False(t, mentionsInputFileTr(tr))
}
func TestMentionsInputFileTr_OneOf_WithInputFile(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}
require.True(t, mentionsInputFileTr(tr))
}
func TestMentionsInputFileTr_OneOf_WithoutInputFile(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}
require.False(t, mentionsInputFileTr(tr))
}
// ---------------------------------------------------------------------------
// loadAPI — error paths
// ---------------------------------------------------------------------------
func TestLoadAPI_MissingFile(t *testing.T) {
_, err := loadAPI("/nonexistent/path/api.json")
require.Error(t, err)
}
// ---------------------------------------------------------------------------
// runtimeTypes filter in emitTypes
// ---------------------------------------------------------------------------
func TestRuntimeTypes_NeverEmitted(t *testing.T) {
for name := range runtimeTypes {
require.True(t, runtimeTypes[name], "runtimeType %q should be true", name)
}
require.True(t, runtimeTypes["InputFile"])
require.True(t, runtimeTypes["ChatID"])
require.True(t, runtimeTypes["MessageOrBool"])
require.True(t, runtimeTypes["ResponseParameters"])
}
// ---------------------------------------------------------------------------
// sentinelForField — all branches
// ---------------------------------------------------------------------------
func TestSentinelForField(t *testing.T) {
unionTypes := map[string]bool{"ChatMember": true}
cases := []struct {
name string
field spec.Field
contains string
}{
{
name: "int64 primitive",
field: makeField("Count", "count", "int64", spec.KindPrimitive, true),
contains: "42",
},
{
name: "string primitive",
field: makeField("Text", "text", "string", spec.KindPrimitive, true),
contains: "test_value",
},
{
name: "bool primitive",
field: makeField("Flag", "flag", "bool", spec.KindPrimitive, true),
contains: "true",
},
{
name: "float64 primitive",
field: makeField("Lat", "lat", "float64", spec.KindPrimitive, true),
contains: "1.0",
},
{
name: "named ChatID",
field: makeField("ChatID", "chat_id", "ChatID", spec.KindNamed, true),
contains: "ChatIDFromInt",
},
{
name: "named InputFile",
field: makeField("Photo", "photo", "InputFile", spec.KindNamed, true),
contains: "InputFile",
},
{
name: "named union (nil-able)",
field: makeField("Member", "member", "ChatMember", spec.KindNamed, true),
contains: "nil",
},
{
name: "named required struct",
field: makeField("Chat", "chat", "Chat", spec.KindNamed, true),
contains: "Chat{}",
},
{
name: "named optional struct",
field: makeField("Chat", "chat", "Chat", spec.KindNamed, false),
contains: "&Chat{}",
},
{
name: "array",
field: spec.Field{Name: "Items", JSONName: "items", Type: spec.TypeRef{Kind: spec.KindArray}},
contains: "nil",
},
{
name: "oneOf ChatID variants",
field: makeFieldVariants("ChatID", "chat_id", spec.KindOneOf, []string{"int64", "string"}, true),
contains: "ChatIDFromInt",
},
{
name: "oneOf InputFile variants",
field: makeFieldVariants("Photo", "photo", spec.KindOneOf, []string{"InputFile", "string"}, true),
contains: "InputFile",
},
{
name: "oneOf sealed",
field: makeFieldVariants("Markup", "markup", spec.KindOneOf, []string{"A", "B"}, true),
contains: "nil",
},
{
name: "unknown kind",
field: spec.Field{Name: "X", JSONName: "x", Type: spec.TypeRef{Kind: spec.Kind(99)}},
contains: "nil",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := sentinelForField(c.field, unionTypes)
require.Contains(t, got, c.contains, "sentinelForField for %q", c.name)
})
}
}
// ---------------------------------------------------------------------------
// successBody — all branches
// ---------------------------------------------------------------------------
func TestSuccessBody(t *testing.T) {
cases := []struct {
name string
tr spec.TypeRef
want string
}{
{"bool", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}, "true"},
{"int64", spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}, "0"},
{"float64", spec.TypeRef{Kind: spec.KindPrimitive, Name: "float64"}, "0"},
{"string", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}, `""`},
{"MessageOrBool", spec.TypeRef{Kind: spec.KindNamed, Name: "MessageOrBool"}, "true"},
{"named", spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}, "{}"},
{"array", spec.TypeRef{Kind: spec.KindArray}, "[]"},
{"oneOf", spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}, "null"},
{"unknown", spec.TypeRef{Kind: spec.Kind(99)}, "null"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := successBody(c.tr)
require.Equal(t, c.want, got)
})
}
}
// ---------------------------------------------------------------------------
// unionTypeFor — KindOneOf branch (variant-set match)
// ---------------------------------------------------------------------------
func TestUnionTypeFor_OneOfVariants(t *testing.T) {
// These variant *type names* match the ChatMember discriminator.
tr := spec.TypeRef{
Kind: spec.KindOneOf,
Variants: []string{
"ChatMemberOwner", "ChatMemberAdministrator", "ChatMemberMember",
"ChatMemberRestricted", "ChatMemberLeft", "ChatMemberBanned",
},
}
name, ok := unionTypeFor(tr)
require.True(t, ok)
require.Equal(t, "ChatMember", name)
}
func TestUnionTypeFor_OneOfNoMatch(t *testing.T) {
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"Foo", "Bar"}}
_, ok := unionTypeFor(tr)
require.False(t, ok)
}
// ---------------------------------------------------------------------------
// funcs() returns a non-nil FuncMap with expected keys
// ---------------------------------------------------------------------------
func TestFuncs_HasExpectedKeys(t *testing.T) {
fm := funcs()
require.NotNil(t, fm)
for _, key := range []string{"goType", "docComment", "returnGoType", "unionFields"} {
require.NotNil(t, fm[key], "funcs() missing key %q", key)
}
}
+49
View File
@@ -0,0 +1,49 @@
// Command genapi reads internal/spec/api.json and emits api/*.gen.go.
//
// Usage:
//
// genapi -input <file> (default: internal/spec/api.json)
// genapi -outdir <dir> (default: api)
package main
import (
"flag"
"fmt"
"os"
)
func main() {
input := flag.String("input", "internal/spec/api.json", "IR JSON path")
outdir := flag.String("outdir", "api", "output directory")
flag.Parse()
if err := run(*input, *outdir); err != nil {
fmt.Fprintln(os.Stderr, "genapi:", err)
os.Exit(1)
}
}
// run is filled in by P2.T8/T9/T10.
func run(input, outdir string) error {
api, err := loadAPI(input)
if err != nil {
return fmt.Errorf("load api: %w", err)
}
if err := os.MkdirAll(outdir, 0o750); err != nil {
return err
}
e := newEmitter(api, outdir)
if err := e.emitTypes(); err != nil {
return err
}
if err := e.emitMethods(); err != nil {
return err
}
if err := e.emitEnums(); err != nil {
return err
}
if err := e.emitTests(); err != nil {
return err
}
return nil
}
+58
View File
@@ -0,0 +1,58 @@
// Code generated by cmd/genapi. DO NOT EDIT.
//go:build !ignore_autogenerated
package api
import (
"context"
"github.com/goccy/go-json"
"strconv"
"github.com/lukaszraczylo/go-telegram/client"
)
var _ = strconv.Itoa // keep import for multipart helpers
var _ = json.Marshal // keep import for complex multipart fields
{{range .Methods}}
// {{title .Name}}Params is the parameter set for {{title .Name}}.
//
{{docComment .Doc -}}
type {{title .Name}}Params struct {
{{range .Params}}{{docComment .Doc}} {{goField .}}
{{end}}}
{{if .HasFiles}}
// HasFile reports whether a multipart upload is required.
func (p *{{title .Name}}Params) HasFile() bool {
{{range .Params}}{{if isFileField .}}{{fileCheck .}}{{end}}{{end}} return false
}
// MultipartFields returns the non-file fields used in the multipart body.
func (p *{{title .Name}}Params) MultipartFields() map[string]string {
out := map[string]string{}
{{range .Params}}{{if not (isFileField .)}}{{multipartFieldEntry .}}{{end}}{{end}} return out
}
// MultipartFiles returns the file parts.
func (p *{{title .Name}}Params) MultipartFiles() []client.MultipartFile {
var files []client.MultipartFile
{{range .Params}}{{if isFileField .}}{{multipartFileEntry .}}{{end}}{{end}} return files
}
{{end}}
// {{title .Name}} calls the {{.Name}} Telegram Bot API method.
//
{{docComment .Doc -}}
func {{title .Name}}(ctx context.Context, b *client.Bot, p *{{title .Name}}Params) ({{returnGoType .Returns}}, error) {
{{if isSealedUnionReturn .Returns -}}
raw, err := client.CallRaw[*{{title .Name}}Params](ctx, b, "{{.Name}}", p)
if err != nil {
return nil, err
}
return Unmarshal{{.Returns.Name}}(raw)
{{else -}}
return client.Call[*{{title .Name}}Params, {{returnGoType .Returns}}](ctx, b, "{{.Name}}", p)
{{end -}}
}
{{end}}
+220
View File
@@ -0,0 +1,220 @@
// Code generated by cmd/genapi. DO NOT EDIT.
//go:build !ignore_autogenerated
package api
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// genTestMockDoer is a testify-mock HTTPDoer used by generated tests only.
type genTestMockDoer struct{ mock.Mock }
func (m *genTestMockDoer) Do(r *http.Request) (*http.Response, error) {
args := m.Called(r)
if v := args.Get(0); v != nil {
return v.(*http.Response), args.Error(1)
}
return nil, args.Error(1)
}
func genTestResp(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
{{range .Methods}}{{$m := .}}{{$mName := title .Name}}{{$mWire := .Name}}
func Test_{{$mName}}_Success(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
return strings.HasSuffix(r.URL.Path, "/{{$mWire}}")
})).Return(genTestResp(200, {{successResp $m}}), nil)
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(context.Background(), bot, params)
{{- else}}
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
{{- end}}
require.NoError(t, err)
}
func Test_{{$mName}}_APIError(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(
genTestResp(200, `{"ok":false,"error_code":429,"description":"Too Many Requests","parameters":{"retry_after":1}}`), nil)
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(context.Background(), bot, params)
{{- else}}
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
{{- end}}
require.Error(t, err)
var ae *client.APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 429, ae.Code)
require.True(t, ae.IsRetryable())
}
func Test_{{$mName}}_NetworkError(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial tcp: timeout"))
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(context.Background(), bot, params)
{{- else}}
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
{{- end}}
require.Error(t, err)
var ne *client.NetworkError
require.ErrorAs(t, err, &ne)
}
func Test_{{$mName}}_ParseError(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(genTestResp(200, `not json`), nil)
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(context.Background(), bot, params)
{{- else}}
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
{{- end}}
require.Error(t, err)
var pe *client.ParseError
require.ErrorAs(t, err, &pe)
}
func Test_{{$mName}}_ContextCanceled(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(nil, context.Canceled).Maybe()
ctx, cancel := context.WithCancel(context.Background())
cancel()
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(ctx, bot, params)
{{- else}}
_, err := {{$mName}}(ctx, bot, &{{$mName}}Params{})
{{- end}}
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}
// Test_{{$mName}}_MissingRequiredFields exercises Telegram's server-side
// validation: when a required field is omitted, Telegram returns 400 with
// a description like "Bad Request: <field> is empty". The library must
// surface this as *APIError with the ErrBadRequest sentinel.
func Test_{{$mName}}_MissingRequiredFields(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(
genTestResp(200, `{"ok":false,"error_code":400,"description":"Bad Request: chat_id is empty"}`), nil)
bot := client.New("test:token", client.WithHTTPClient(m))
// Send a Params with all required fields zeroed — simulates a caller
// that forgot to populate them. The bot library marshals as-is and
// surfaces Telegram's 400 reply.
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
require.Error(t, err)
var ae *client.APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 400, ae.Code)
require.True(t, errors.Is(err, client.ErrBadRequest))
require.False(t, ae.IsRetryable())
}
// Test_{{$mName}}_Forbidden exercises the 403 path (bot blocked by user,
// removed from chat, etc.). The library must surface the ErrForbidden
// sentinel.
func Test_{{$mName}}_Forbidden(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(
genTestResp(200, `{"ok":false,"error_code":403,"description":"Forbidden: bot was blocked by the user"}`), nil)
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(context.Background(), bot, params)
{{- else}}
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
{{- end}}
require.Error(t, err)
var ae *client.APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 403, ae.Code)
require.True(t, errors.Is(err, client.ErrForbidden))
require.False(t, ae.IsRetryable())
}
// Test_{{$mName}}_ServerError exercises the 5xx path. The library must
// classify these as retryable so RetryDoer / user retry logic kicks in.
func Test_{{$mName}}_ServerError(t *testing.T) {
m := &genTestMockDoer{}
m.On("Do", mock.Anything).Return(
genTestResp(200, `{"ok":false,"error_code":500,"description":"Internal server error"}`), nil)
bot := client.New("test:token", client.WithHTTPClient(m))
{{- if .Params}}
params := &{{$mName}}Params{
{{- range .Params}}{{if .Required}}
{{.Name}}: {{sentinelValue .}},{{end}}
{{- end}}
}
_, err := {{$mName}}(context.Background(), bot, params)
{{- else}}
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
{{- end}}
require.Error(t, err)
var ae *client.APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 500, ae.Code)
require.True(t, ae.IsRetryable(), "5xx must be retryable")
}
{{end}}
+126
View File
@@ -0,0 +1,126 @@
// Code generated by cmd/genapi. DO NOT EDIT.
//go:build !ignore_autogenerated
// Package api contains the Telegram Bot API object types and method
// wrappers, generated from the live documentation by cmd/genapi.
package api
import (
"github.com/goccy/go-json"
"fmt"
"io"
)
var _ = io.Discard // keep import even if no fields use io
var _ = json.Marshal // keep import for UnmarshalXxx helpers
var _ = fmt.Errorf // keep import for UnmarshalXxx helpers
{{range .Types}}
{{- $td := . -}}
{{if .OneOf}}
// {{.Name}} is a union type. The following concrete variants implement
// it:
{{range .OneOf}}// - {{.}}
{{end}}//
{{docComment .Doc -}}
type {{.Name}} interface{ is{{.Name}}() }
{{range .OneOf}}
// is{{$td.Name}} is the marker method that makes {{.}} implement {{$td.Name}}.
func (*{{.}}) is{{$td.Name}}() {}
{{end}}
{{if hasDiscriminator .Name}}
// Unmarshal{{.Name}} decodes a {{.Name}} from JSON by inspecting the
// "{{discriminatorField .Name}}" field and dispatching to the correct concrete type.
func Unmarshal{{.Name}}(data []byte) ({{.Name}}, error) {
var probe struct {
V string `json:"{{discriminatorField .Name}}"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, err
}
var v {{.Name}}
switch probe.V {
{{range $val, $typ := discriminatorMap .Name}} case {{printf "%q" $val}}:
v = &{{$typ}}{}
{{end}} default:
return nil, fmt.Errorf("{{.Name}}: unknown {{discriminatorField .Name}} %q", probe.V)
}
if err := json.Unmarshal(data, v); err != nil {
return nil, err
}
return v, nil
}
{{end}}
{{if isMaybeInaccessibleMessage .Name}}
// UnmarshalMaybeInaccessibleMessage decodes a JSON object into the correct
// MaybeInaccessibleMessage variant. Telegram uses the date field as a
// discriminator: date == 0 indicates InaccessibleMessage; any other value
// indicates a real Message.
func UnmarshalMaybeInaccessibleMessage(data []byte) (MaybeInaccessibleMessage, error) {
var probe struct {
Date int64 `json:"date"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("MaybeInaccessibleMessage: %w", err)
}
if probe.Date == 0 {
v := &InaccessibleMessage{}
if err := json.Unmarshal(data, v); err != nil {
return nil, fmt.Errorf("InaccessibleMessage: %w", err)
}
return v, nil
}
v := &Message{}
if err := json.Unmarshal(data, v); err != nil {
return nil, fmt.Errorf("Message: %w", err)
}
return v, nil
}
{{end}}
{{else}}
{{docComment .Doc -}}
type {{.Name}} struct {
{{range .Fields}}{{docComment .Doc}}{{goField .}}
{{end}}}
{{$unionFields := unionFields .}}{{if $unionFields}}
// UnmarshalJSON decodes {{.Name}} by dispatching union-typed fields
// ({{range $i, $u := $unionFields}}{{if $i}}, {{end}}{{$u.Field.Name}}{{end}}) through their concrete UnmarshalXxx helpers.
func (m *{{.Name}}) UnmarshalJSON(data []byte) error {
type Alias {{.Name}}
aux := &struct {
{{range $unionFields}}{{.Field.Name}} json.RawMessage `json:"{{.Field.JSONName}},omitempty"`
{{end}}*Alias
}{Alias: (*Alias)(m)}
if err := json.Unmarshal(data, aux); err != nil {
return err
}
{{range $unionFields}}{{$f := .Field}}{{$u := .UnionName}}
if len(aux.{{$f.Name}}) > 0 && string(aux.{{$f.Name}}) != "null" {
{{if isArrayUnion $f.Type}}var raws []json.RawMessage
if err := json.Unmarshal(aux.{{$f.Name}}, &raws); err != nil {
return fmt.Errorf("decoding {{$f.JSONName}}: %w", err)
}
decoded := make([]{{$u}}, 0, len(raws))
for i, r := range raws {
v, err := Unmarshal{{$u}}(r)
if err != nil {
return fmt.Errorf("decoding {{$f.JSONName}}[%d]: %w", i, err)
}
decoded = append(decoded, v)
}
m.{{$f.Name}} = decoded
{{else}}v, err := Unmarshal{{$u}}(aux.{{$f.Name}})
if err != nil {
return fmt.Errorf("decoding {{$f.JSONName}}: %w", err)
}
m.{{$f.Name}} = v
{{end}}
}
{{end}}
return nil
}
{{end}}
{{end}}
{{end}}
+211
View File
@@ -0,0 +1,211 @@
package main
import (
"testing"
"github.com/lukaszraczylo/go-telegram/internal/spec"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// parseTypeRef — edge cases
// ---------------------------------------------------------------------------
func TestParseTypeRef_Empty(t *testing.T) {
// Empty string → named with empty name (fallback).
got := parseTypeRef("")
require.Equal(t, spec.KindNamed, got.Kind)
require.Equal(t, "", got.Name)
}
func TestParseTypeRef_Whitespace(t *testing.T) {
got := parseTypeRef(" Integer ")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "int64", got.Name)
}
func TestParseTypeRef_True(t *testing.T) {
got := parseTypeRef("True")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "bool", got.Name)
}
func TestParseTypeRef_False(t *testing.T) {
got := parseTypeRef("False")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "bool", got.Name)
}
func TestParseTypeRef_FloatNumber(t *testing.T) {
got := parseTypeRef("Float number")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "float64", got.Name)
}
func TestParseTypeRef_Int(t *testing.T) {
got := parseTypeRef("Int")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "int64", got.Name)
}
func TestParseTypeRef_Bool(t *testing.T) {
got := parseTypeRef("Bool")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "bool", got.Name)
}
func TestParseTypeRef_CommaAndUnion(t *testing.T) {
// "Foo, Bar and Baz" → oneOf{Foo, Bar, Baz}
got := parseTypeRef("InputMediaPhoto, InputMediaVideo and InputMediaDocument")
require.Equal(t, spec.KindOneOf, got.Kind)
require.Len(t, got.Variants, 3)
require.Contains(t, got.Variants, "InputMediaPhoto")
require.Contains(t, got.Variants, "InputMediaVideo")
require.Contains(t, got.Variants, "InputMediaDocument")
}
func TestParseTypeRef_ArrayOfNothing(t *testing.T) {
// "Array of " with trailing space — TrimSpace removes the trailing space
// leaving "Array of" which does NOT match the "Array of " prefix, so it
// falls through to primitiveOrNamed and returns KindNamed (not KindArray).
got := parseTypeRef("Array of ")
require.Equal(t, spec.KindNamed, got.Kind)
}
// ---------------------------------------------------------------------------
// splitCommaAnd
// ---------------------------------------------------------------------------
func TestSplitCommaAnd_ThreeVariants(t *testing.T) {
got := splitCommaAnd("A, B and C")
require.Equal(t, []string{"A", "B", "C"}, got)
}
func TestSplitCommaAnd_FourVariants(t *testing.T) {
got := splitCommaAnd("A, B, C and D")
require.Equal(t, []string{"A", "B", "C", "D"}, got)
}
func TestSplitCommaAnd_ExtraSpaces(t *testing.T) {
got := splitCommaAnd(" Foo , Bar and Baz ")
require.Len(t, got, 3)
}
// ---------------------------------------------------------------------------
// goName — edge cases
// ---------------------------------------------------------------------------
func TestGoName_Empty(t *testing.T) {
require.Equal(t, "", goName(""))
}
func TestGoName_SingleWord(t *testing.T) {
require.Equal(t, "Photo", goName("photo"))
}
func TestGoName_JSON(t *testing.T) {
require.Equal(t, "JSON", goName("json"))
}
func TestGoName_HTML(t *testing.T) {
require.Equal(t, "HTML", goName("html"))
}
func TestGoName_HTTPS(t *testing.T) {
require.Equal(t, "HTTPS", goName("https"))
}
func TestGoName_AlreadyUpperSegment(t *testing.T) {
// Segment that starts with uppercase letter should be passed through.
require.Equal(t, "MediaGroupID", goName("media_group_id"))
}
// ---------------------------------------------------------------------------
// extractReturn — additional patterns
// ---------------------------------------------------------------------------
func TestExtractReturn_ArrayPattern(t *testing.T) {
desc := "Returns an Array of Update objects."
got := extractReturn(desc)
require.Equal(t, spec.KindArray, got.Kind)
require.Equal(t, "Update", got.ElemType.Name)
}
func TestExtractReturn_BoolPattern(t *testing.T) {
desc := "Returns True on success."
got := extractReturn(desc)
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "bool", got.Name)
}
func TestExtractReturn_OnSuccessTrueIsReturned(t *testing.T) {
desc := "On success, true is returned."
got := extractReturn(desc)
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "bool", got.Name)
}
func TestExtractReturn_NamedObject(t *testing.T) {
desc := "On success, returns a Message object."
got := extractReturn(desc)
require.Equal(t, spec.KindNamed, got.Kind)
require.Equal(t, "Message", got.Name)
}
func TestExtractReturn_MessageOrBool(t *testing.T) {
desc := "On success, the edited Message is returned, otherwise True is returned."
got := extractReturn(desc)
require.Equal(t, spec.KindNamed, got.Kind)
require.Equal(t, "MessageOrBool", got.Name)
}
func TestExtractReturn_InFormOf(t *testing.T) {
desc := "The answer is provided in form of a ChatInviteLink object."
got := extractReturn(desc)
require.Equal(t, spec.KindNamed, got.Kind)
require.Equal(t, "ChatInviteLink", got.Name)
}
func TestExtractReturn_Fallback(t *testing.T) {
// No recognized pattern → bool fallback.
got := extractReturn("This method does something interesting.")
require.Equal(t, spec.KindPrimitive, got.Kind)
require.Equal(t, "bool", got.Name)
}
func TestExtractReturn_MultipleReturnsFirstWins(t *testing.T) {
// Doc with multiple "Returns" phrases — first matching pattern should win.
// The indefinite-article pattern ("Returns a X object") appears earlier in
// the priority list than "Returns True", so it matches "Returns a Message"
// before the bool pattern can fire.
desc := "Returns True on success. You can also Returns a Message object later."
got := extractReturn(desc)
// The indefinite-article pattern fires first → returns Message (KindNamed).
require.Equal(t, spec.KindNamed, got.Kind)
require.Equal(t, "Message", got.Name)
}
// ---------------------------------------------------------------------------
// extractVersion
// ---------------------------------------------------------------------------
func TestExtractVersion_InTitle(t *testing.T) {
sections := []section{
{Title: "Bot API 7.3", Description: ""},
}
require.Equal(t, "7.3", extractVersion(sections))
}
func TestExtractVersion_InDescription(t *testing.T) {
sections := []section{
{Title: "April 2024", Description: "Released Bot API 7.2."},
}
require.Equal(t, "7.2", extractVersion(sections))
}
func TestExtractVersion_NotFound(t *testing.T) {
sections := []section{
{Title: "Introduction", Description: "Welcome to the API."},
}
require.Equal(t, "", extractVersion(sections))
}
+77
View File
@@ -0,0 +1,77 @@
// Command scrape parses the Telegram Bot API HTML page into the IR
// (internal/spec.API) and writes it to internal/spec/api.json.
//
// Usage:
//
// scrape -input <file> (read HTML from local file)
// scrape -url <url> (fetch HTML from URL; default: live docs)
// scrape -output <file> (output path; default: internal/spec/api.json)
package main
import (
"errors"
"flag"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
const defaultURL = "https://core.telegram.org/bots/api"
func main() {
input := flag.String("input", "", "local HTML file (overrides -url)")
url := flag.String("url", defaultURL, "URL to fetch HTML from")
output := flag.String("output", "internal/spec/api.json", "output path")
overridesPath := flag.String("overrides", "internal/spec/overrides.json", "path to overrides JSON")
flag.Parse()
if err := run(*input, *url, *output, *overridesPath); err != nil {
fmt.Fprintln(os.Stderr, "scrape:", err)
os.Exit(1)
}
}
func run(input, url, output, overridesPath string) error {
htmlBytes, err := readHTML(input, url)
if err != nil {
return fmt.Errorf("read html: %w", err)
}
api, err := scrape(htmlBytes)
if err != nil {
return fmt.Errorf("scrape: %w", err)
}
overrides, err := spec.LoadOverrides(overridesPath)
if err != nil {
return fmt.Errorf("load overrides: %w", err)
}
overrides.Apply(api)
return writeJSON(output, api)
}
func readHTML(input, url string) ([]byte, error) {
if input != "" {
return os.ReadFile(input)
}
c := &http.Client{Timeout: 30 * time.Second}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", "go-telegram codegen scraper")
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, errors.New(resp.Status)
}
return io.ReadAll(resp.Body)
}
+149
View File
@@ -0,0 +1,149 @@
package main
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
// extractReturn pulls the return type from a method's description prose.
//
// Patterns we handle (in priority order):
//
// "Returns an Array of X" / "On success, an Array of X is returned" → array of named X
// "an array of X of the sent messages is returned" → array of named X
// "the edited X is returned, otherwise True is returned" → XOrBool
// "Returns ... as a X object" / "Returns ... as X object" → named X
// "Returns ... as String on success" → string
// "On success, returns a X object" / "Returns a X object" → named X (indefinite article)
// "On success, an? X is returned" / "On success, the X is returned" → named X
// "Returns True" / "On success, true is returned" → bool
// "Returns the verb-ed X" → named X
// "On success, X is returned" → named X
// "Returns X on success" (no article) → named X
// "in form of a X" → named X
// fallback: bool
func extractReturn(desc string) spec.TypeRef {
// Normalise; strip *bold* markers because Telegram uses italics.
d := strings.ReplaceAll(desc, "*", "")
patterns := []struct {
re *regexp.Regexp
fn func([]string) spec.TypeRef
}{
// Array patterns first — most specific.
{regexp.MustCompile(`Returns an? [Aa]rray of ([A-Z][A-Za-z0-9]+)`), func(m []string) spec.TypeRef {
elem := primitiveOrNamed(m[1])
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
}},
{regexp.MustCompile(`On success(?:,)?\s+(?:an?\s+)?[Aa]rray of ([A-Z][A-Za-z0-9]+)(?:\s+objects?)?\s+(?:is|are|that\s+\S+\s+\S+\s+)?(?:is |are )?returned`), func(m []string) spec.TypeRef {
elem := primitiveOrNamed(m[1])
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
}},
// "an array of X of the sent messages is returned" (ForwardMessages/CopyMessages shape).
{regexp.MustCompile(`(?:[Oo]n success[,.]?\s+)?an? array of ([A-Z][A-Za-z0-9]+)(?:\s+of [^.]+?)?\s+(?:objects\s+)?(?:is|are) returned`), func(m []string) spec.TypeRef {
elem := primitiveOrNamed(m[1])
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
}},
// "Message or True" conditional return → XOrBool sentinel.
{regexp.MustCompile(`the (?:edited|sent|stopped)?\s*([A-Z][A-Za-z0-9]+)\s+is returned, otherwise (?:True|true) is returned`), func(m []string) spec.TypeRef {
return spec.TypeRef{Kind: spec.KindNamed, Name: m[1] + "OrBool"}
}},
// "Returns ... as a X object" / "Returns ... as X object" (with or without article).
{regexp.MustCompile(`[Rr]eturns? (?:.+? )?as (?:an? )?([A-Z][A-Za-z0-9]+) object`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// "Returns ... as String on success" / "Returns ... as X on success" (named type after "as").
{regexp.MustCompile(`[Rr]eturns? (?:.+? )?as ([A-Z][A-Za-z0-9]+) on success`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// Indefinite article: "On success, returns a X object" / "Returns a X object".
{regexp.MustCompile(`(?:[Oo]n success[,.]?\s+)?[Rr]eturns? an? ([A-Z][A-Za-z0-9]+)(?:\s+object)?`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// "On success, an? X is returned" / "On success, the stopped X is returned".
{regexp.MustCompile(`On success,\s+(?:an?|the)?\s*(?:[a-z]+\s+)?([A-Z][A-Za-z0-9]+)(?:\s+object)?\s+is returned`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// Explicit True — must come before the broad "Returns X" pattern.
{regexp.MustCompile(`Returns True`), func(m []string) spec.TypeRef {
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
}},
{regexp.MustCompile(`(?i)on success, true is returned`), func(m []string) spec.TypeRef {
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
}},
// "Returns the verb-ed X" — accepts any verb prefix (uploaded, revoked, …).
{regexp.MustCompile(`Returns (?:the|an?)\s+(?:[a-z]+ )?([A-Z][A-Za-z0-9]+)`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// "On success, X is returned" (no article).
{regexp.MustCompile(`On success(?:,)?\s+(?:the\s+)?(?:newly\s+)?(?:edited\s+|sent\s+|created\s+|updated\s+)?([A-Z][A-Za-z0-9]+)\s+is returned`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// "Returns X on success" (no article, e.g. "Returns OwnedGifts on success").
{regexp.MustCompile(`[Rr]eturns ([A-Z][A-Za-z0-9]+) on success`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
// "in form of a X".
{regexp.MustCompile(`in (?:the )?form of (?:a )?([A-Z][A-Za-z0-9]+)`), func(m []string) spec.TypeRef {
return primitiveOrNamed(m[1])
}},
}
for _, p := range patterns {
if m := p.re.FindStringSubmatch(d); m != nil {
return p.fn(m)
}
}
// Fallback: bool. Better than panic; method-by-method tests would
// catch any regression.
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
}
// hasFilesParams returns true if any param mentions InputFile (the
// scraper convention triggering multipart/form-data).
func hasFilesParams(params []spec.Field) bool {
for _, p := range params {
if mentionsInputFile(p.Type) {
return true
}
}
return false
}
func mentionsInputFile(tr spec.TypeRef) bool {
switch tr.Kind {
case spec.KindNamed:
return tr.Name == "InputFile" || strings.HasPrefix(tr.Name, "InputMedia") || strings.HasPrefix(tr.Name, "InputPaidMedia")
case spec.KindArray:
if tr.ElemType != nil {
return mentionsInputFile(*tr.ElemType)
}
case spec.KindOneOf:
for _, v := range tr.Variants {
if v == "InputFile" || strings.HasPrefix(v, "InputMedia") || strings.HasPrefix(v, "InputPaidMedia") {
return true
}
}
}
return false
}
// extractVersion finds the API version string in a "Bot API X.Y[.Z]" heading.
var versionRE = regexp.MustCompile(`Bot API (\d+\.\d+(?:\.\d+)?)`)
// extractVersion finds the API version string. The live docs page emits
// the version as "<strong>Bot API X.Y</strong>" inside a paragraph below
// a date heading; the small fixture uses an h4 "Bot API X.Y" instead.
// Both shapes are handled here by also scanning section descriptions.
func extractVersion(sections []section) string {
for _, s := range sections {
if m := versionRE.FindStringSubmatch(s.Title); m != nil {
return m[1]
}
if m := versionRE.FindStringSubmatch(s.Description); m != nil {
return m[1]
}
}
return ""
}
+76
View File
@@ -0,0 +1,76 @@
package main
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
func TestExtractReturn(t *testing.T) {
cases := []struct {
in string
want spec.TypeRef
}{
{"Returns basic information about the bot in form of a User object.", spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
{"On success, the sent Message is returned.", spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}},
{"Returns an Array of Update objects.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}},
{"Returns True on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
{"On success, True is returned.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
// Issue 5: "Message or True" conditional return → MessageOrBool sentinel.
{"On success, if the edited message is not an inline message, the edited Message is returned, otherwise True is returned.", spec.TypeRef{Kind: spec.KindNamed, Name: "MessageOrBool"}},
// Issue 1: new phrasings.
{"On success, returns a WebhookInfo object.", spec.TypeRef{Kind: spec.KindNamed, Name: "WebhookInfo"}},
{"Returns a UserProfilePhotos object.", spec.TypeRef{Kind: spec.KindNamed, Name: "UserProfilePhotos"}},
{"Returns the uploaded File.", spec.TypeRef{Kind: spec.KindNamed, Name: "File"}},
{"On success, the stopped Poll is returned.", spec.TypeRef{Kind: spec.KindNamed, Name: "Poll"}},
{"On success, an Array of MessageId is returned.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "MessageId"}}},
{"On success, an array of Message objects that were sent is returned.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}}},
// ForwardMessages/CopyMessages shape: "an array of X of the sent messages is returned".
{"On success, an array of MessageId of the sent messages is returned.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "MessageId"}}},
// "Returns X on success" (no article) — OwnedGifts, StarAmount, Story, MenuButton, etc.
{"Returns the gifts received and owned by a managed business account. Returns OwnedGifts on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "OwnedGifts"}},
{"Returns StarAmount on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "StarAmount"}},
{"Posts a story on behalf of a managed business account. Returns Story on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "Story"}},
{"Returns MenuButton on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "MenuButton"}},
// "Returns ... as X object" (no article before type) — ChatInviteLink variants.
{"Returns the new invite link as ChatInviteLink object.", spec.TypeRef{Kind: spec.KindNamed, Name: "ChatInviteLink"}},
{"Returns the revoked invite link as ChatInviteLink object.", spec.TypeRef{Kind: spec.KindNamed, Name: "ChatInviteLink"}},
// "Returns ... as a X object" (with article) — createForumTopic.
{"Returns information about the created topic as a ForumTopic object.", spec.TypeRef{Kind: spec.KindNamed, Name: "ForumTopic"}},
// "Returns ... as String on success" — exportChatInviteLink / createInvoiceLink.
{"Returns the new invite link as String on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
{"Returns the created invoice link as String on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
// "Returns Int on success" — getChatMemberCount.
{"Returns Int on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}},
}
for _, c := range cases {
require.Equal(t, c.want, extractReturn(c.in), c.in)
}
}
func TestHasFilesParams(t *testing.T) {
require.True(t, hasFilesParams([]spec.Field{
{Type: spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}},
}))
require.True(t, hasFilesParams([]spec.Field{
{Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}},
}))
require.False(t, hasFilesParams([]spec.Field{
{Type: spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
}))
// Issue 2: Array of InputMedia* union triggers HasFiles.
require.True(t, hasFilesParams([]spec.Field{
{Type: spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputMediaPhoto", "InputMediaVideo"}}}},
}))
}
func TestExtractVersion(t *testing.T) {
sections := []section{{Title: "Recent changes"}, {Title: "Bot API 7.10"}, {Title: "Available types"}}
require.Equal(t, "7.10", extractVersion(sections))
// Issue 4: 3-part version must not be truncated.
sections3 := []section{{Description: "Bot API 8.0.1"}}
require.Equal(t, "8.0.1", extractVersion(sections3))
}
+84
View File
@@ -0,0 +1,84 @@
package main
import (
"bytes"
"fmt"
"github.com/goccy/go-json"
"os"
"golang.org/x/net/html"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
// scrape (the package-level implementation overriding the stub in main.go;
// remove the stub from main.go in this task) parses the docs HTML into IR.
func scrape(htmlBytes []byte) (*spec.API, error) {
doc, err := html.Parse(bytes.NewReader(htmlBytes))
if err != nil {
return nil, fmt.Errorf("html parse: %w", err)
}
sections := walk(doc)
api := &spec.API{Version: extractVersion(sections)}
for _, s := range sections {
switch {
case isMethodTitle(s.Title):
api.Methods = append(api.Methods, methodFromSection(s))
case isTypeTitle(s.Title):
api.Types = append(api.Types, typeFromSection(s))
}
}
return api, nil
}
func typeFromSection(s section) spec.TypeDecl {
td := spec.TypeDecl{Name: s.Title, Doc: s.Description}
if len(s.Tables) > 0 {
td.Fields = parseFieldsTable(s.Tables[0])
} else if len(s.Lists) > 0 {
// Union: extract variant names from <li><a>...</a></li>.
td.OneOf = extractListLinks(s.Lists[0])
}
return td
}
func methodFromSection(s section) spec.MethodDecl {
md := spec.MethodDecl{Name: s.Title, Doc: s.Description, Returns: extractReturn(s.Description)}
if len(s.Tables) > 0 {
md.Params = parseParamsTable(s.Tables[0])
}
md.HasFiles = hasFilesParams(md.Params)
return md
}
// extractListLinks pulls anchor texts out of a <ul>: each <li><a>X</a></li>
// contributes "X" to the result. Used for union variant lists.
func extractListLinks(ul *html.Node) []string {
var names []string
var visit func(*html.Node)
visit = func(n *html.Node) {
if n.Type == html.ElementNode && n.Data == "a" {
names = append(names, textOf(n))
return
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
visit(c)
}
}
visit(ul)
return names
}
// writeJSON marshals the IR with stable, human-readable formatting and
// writes it to path. Marshalling is deterministic: types and methods
// preserve scrape order; struct fields use IR-defined order.
func writeJSON(path string, api *spec.API) error {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetIndent("", " ")
enc.SetEscapeHTML(false)
if err := enc.Encode(api); err != nil {
return err
}
return os.WriteFile(path, buf.Bytes(), 0o644)
}
+36
View File
@@ -0,0 +1,36 @@
package main
import (
"bytes"
"encoding/json"
"flag"
"os"
"testing"
"github.com/stretchr/testify/require"
)
var update = flag.Bool("update", false, "update golden files")
func TestScrape_Golden_SmallFixture(t *testing.T) {
htmlBytes, err := os.ReadFile("../../testdata/html/small_fixture.html")
require.NoError(t, err)
api, err := scrape(htmlBytes)
require.NoError(t, err)
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetIndent("", " ")
enc.SetEscapeHTML(false)
require.NoError(t, enc.Encode(api))
goldenPath := "../../testdata/golden/api_small_fixture.json"
if *update {
require.NoError(t, os.WriteFile(goldenPath, buf.Bytes(), 0o644))
return
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "missing golden; run `go test -update ./cmd/scrape/...` to create")
require.Equal(t, string(expected), buf.String())
}
+224
View File
@@ -0,0 +1,224 @@
package main
import (
"strings"
"golang.org/x/net/html"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
// parseFieldsTable walks a <table> for an object-type definition.
// Columns: Field, Type, Description (optional column orders are not
// supported; Telegram's docs use a stable layout).
//
// Optional fields are detected via the "Optional." prefix in the
// description text, which is the documented convention.
func parseFieldsTable(t *html.Node) []spec.Field {
rows := tableRows(t)
if len(rows) == 0 {
return nil
}
var fields []spec.Field
for _, row := range rows[1:] { // skip header
cells := rowCells(row)
if len(cells) < 3 {
continue
}
jname := strings.TrimSpace(textOf(cells[0]))
typeText := strings.TrimSpace(textOf(cells[1]))
desc := strings.TrimSpace(textOf(cells[2]))
required := !strings.HasPrefix(desc, "Optional.")
fields = append(fields, spec.Field{
Name: goName(jname),
JSONName: jname,
Type: parseTypeRef(typeText),
Required: required,
Doc: desc,
})
}
return fields
}
// parseParamsTable walks a <table> for a method definition.
// Columns: Parameter, Type, Required, Description.
func parseParamsTable(t *html.Node) []spec.Field {
rows := tableRows(t)
if len(rows) == 0 {
return nil
}
var params []spec.Field
for _, row := range rows[1:] {
cells := rowCells(row)
if len(cells) < 4 {
continue
}
jname := strings.TrimSpace(textOf(cells[0]))
typeText := strings.TrimSpace(textOf(cells[1]))
req := strings.EqualFold(strings.TrimSpace(textOf(cells[2])), "Yes")
desc := strings.TrimSpace(textOf(cells[3]))
params = append(params, spec.Field{
Name: goName(jname),
JSONName: jname,
Type: parseTypeRef(typeText),
Required: req,
Doc: desc,
})
}
return params
}
// tableRows returns the <tr> children of a <table>, skipping over
// any <thead>/<tbody> wrappers.
func tableRows(t *html.Node) []*html.Node {
var rows []*html.Node
var visit func(*html.Node)
visit = func(n *html.Node) {
if n.Type == html.ElementNode && n.Data == "tr" {
rows = append(rows, n)
return
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
visit(c)
}
}
visit(t)
return rows
}
// rowCells returns the <td> (or <th>) children of a <tr>.
func rowCells(tr *html.Node) []*html.Node {
var cells []*html.Node
for c := tr.FirstChild; c != nil; c = c.NextSibling {
if c.Type == html.ElementNode && (c.Data == "td" || c.Data == "th") {
cells = append(cells, c)
}
}
return cells
}
// goName converts a snake_case JSON identifier to PascalCase.
// Special-cases common acronyms used in the Telegram docs.
func goName(s string) string {
if s == "" {
return ""
}
parts := strings.Split(s, "_")
var b strings.Builder
for _, p := range parts {
if p == "" {
continue
}
switch p {
case "id":
b.WriteString("ID")
case "url":
b.WriteString("URL")
case "ip":
b.WriteString("IP")
case "https":
b.WriteString("HTTPS")
case "json":
b.WriteString("JSON")
case "html":
b.WriteString("HTML")
default:
if p[0] >= 'a' && p[0] <= 'z' {
b.WriteByte(p[0] - 'a' + 'A')
b.WriteString(p[1:])
} else {
b.WriteString(p)
}
}
}
return b.String()
}
// parseTypeRef decodes the type-cell text into a spec.TypeRef.
//
// Recognised shapes:
//
// "Integer" → primitive int64
// "String" → primitive string
// "Boolean" / "True" → primitive bool
// "Float" / "Float number"→ primitive float64
// "Array of X" → array of (parseTypeRef of X)
// "Array of Array of X" → array of array of X
// "Foo" → named Foo
// "Foo or Bar" → oneOf {Foo, Bar}
// "InputFile or String" → oneOf (caller may translate to InputFile)
//
// parseTypeRef decodes the type-cell text into a spec.TypeRef.
//
// Recognised shapes:
//
// "Integer" → primitive int64
// "String" → primitive string
// "Boolean" / "True" → primitive bool
// "Float" / "Float number"→ primitive float64
// "Array of X" → array of (parseTypeRef of X)
// "Array of Array of X" → array of array of X
// "Foo" → named Foo
// "Foo or Bar" → oneOf {Foo, Bar}
// "Foo, Bar and Baz" → oneOf {Foo, Bar, Baz} (Telegram's comma+and union form)
// "InputFile or String" → oneOf (caller may translate to InputFile)
func parseTypeRef(s string) spec.TypeRef {
s = strings.TrimSpace(s)
// Array prefix.
if rest, ok := strings.CutPrefix(s, "Array of "); ok {
elem := parseTypeRef(rest)
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
}
// Comma-and union ("X, Y, Z and W") — used by Telegram for ≥3-variant unions.
if strings.Contains(s, ", ") && strings.Contains(s, " and ") {
parts := splitCommaAnd(s)
variants := make([]string, 0, len(parts))
for _, p := range parts {
variants = append(variants, primitiveOrNamed(strings.TrimSpace(p)).Name)
}
return spec.TypeRef{Kind: spec.KindOneOf, Variants: variants}
}
// "X or Y" union (the 2-variant form).
if strings.Contains(s, " or ") {
parts := strings.Split(s, " or ")
variants := make([]string, 0, len(parts))
for _, p := range parts {
variants = append(variants, primitiveOrNamed(strings.TrimSpace(p)).Name)
}
return spec.TypeRef{Kind: spec.KindOneOf, Variants: variants}
}
return primitiveOrNamed(s)
}
// splitCommaAnd splits "A, B, C and D" → ["A", "B", "C", "D"].
func splitCommaAnd(s string) []string {
// Replace " and " with ", " then split on ", ".
s = strings.ReplaceAll(s, " and ", ", ")
parts := strings.Split(s, ", ")
out := make([]string, 0, len(parts))
for _, p := range parts {
if p = strings.TrimSpace(p); p != "" {
out = append(out, p)
}
}
return out
}
// primitiveOrNamed maps a single-word type cell to either a primitive
// or a named TypeRef. Unrecognised words are treated as named types.
func primitiveOrNamed(s string) spec.TypeRef {
switch s {
case "Integer", "Int":
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}
case "String":
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}
case "Boolean", "Bool", "True", "False":
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
case "Float", "Float number":
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "float64"}
default:
return spec.TypeRef{Kind: spec.KindNamed, Name: s}
}
}
+92
View File
@@ -0,0 +1,92 @@
package main
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/lukaszraczylo/go-telegram/internal/spec"
)
func TestGoName(t *testing.T) {
cases := []struct{ in, want string }{
{"chat_id", "ChatID"},
{"first_name", "FirstName"},
{"is_bot", "IsBot"},
{"url", "URL"},
{"ip_address", "IPAddress"},
{"language_code", "LanguageCode"},
{"webhook_URL", "WebhookURL"}, // Issue 3: already-uppercase segment must not be corrupted.
}
for _, c := range cases {
require.Equal(t, c.want, goName(c.in), c.in)
}
}
func TestParseTypeRef(t *testing.T) {
cases := []struct {
in string
want spec.TypeRef
}{
{"Integer", spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}},
{"String", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
{"Boolean", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
{"Float", spec.TypeRef{Kind: spec.KindPrimitive, Name: "float64"}},
{"Message", spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}},
{"Array of Update", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}},
{"Array of Array of PhotoSize", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "PhotoSize"}}}},
{"Integer or String", spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}},
{"InputFile or String", spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}},
}
for _, c := range cases {
require.Equal(t, c.want, parseTypeRef(c.in), c.in)
}
}
func TestParseFieldsTable_FromFixture(t *testing.T) {
doc := parse(t, "../../testdata/html/small_fixture.html")
sections := walk(doc)
var user *section
for i := range sections {
if sections[i].Title == "User" {
user = &sections[i]
break
}
}
require.NotNil(t, user)
require.Len(t, user.Tables, 1)
fields := parseFieldsTable(user.Tables[0])
require.Len(t, fields, 4)
require.Equal(t, "ID", fields[0].Name)
require.Equal(t, "id", fields[0].JSONName)
require.Equal(t, spec.KindPrimitive, fields[0].Type.Kind)
require.True(t, fields[0].Required)
require.Equal(t, "LastName", fields[3].Name)
require.False(t, fields[3].Required) // "Optional." prefix
}
func TestParseParamsTable_FromFixture(t *testing.T) {
doc := parse(t, "../../testdata/html/small_fixture.html")
sections := walk(doc)
var sm *section
for i := range sections {
if sections[i].Title == "sendMessage" {
sm = &sections[i]
break
}
}
require.NotNil(t, sm)
require.Len(t, sm.Tables, 1)
params := parseParamsTable(sm.Tables[0])
require.Len(t, params, 3)
require.Equal(t, "ChatID", params[0].Name)
require.True(t, params[0].Required)
require.Equal(t, spec.KindOneOf, params[0].Type.Kind)
require.Equal(t, []string{"int64", "string"}, params[0].Type.Variants)
require.Equal(t, "ParseMode", params[2].Name)
require.False(t, params[2].Required) // "Optional"
}
+137
View File
@@ -0,0 +1,137 @@
package main
import (
"strings"
"golang.org/x/net/html"
)
// section is an h4-anchored block of the docs page. Title is the
// heading text (e.g. "User" or "sendMessage"). Description is the
// concatenation of immediately-following <p> paragraphs (until the
// next h4 / h3 / table / list). Tables and Lists hold raw nodes for
// later parsing by the table/oneof extractors.
type section struct {
Title string
Description string
Tables []*html.Node // <table> nodes
Lists []*html.Node // <ul> nodes (used for oneof variant lists)
}
// walk parses the page and returns sections in document order.
// Sections whose title contains a space (e.g. "Bot API 7.10") are
// included; later passes ignore them or treat them specially.
func walk(doc *html.Node) []section {
var (
sections []section
current *section
)
var visit func(n *html.Node)
visit = func(n *html.Node) {
if n.Type == html.ElementNode {
switch n.Data {
case "h4":
if current != nil {
sections = append(sections, *current)
}
current = &section{Title: textOf(n)}
// Don't recurse into the heading; we already have its text.
return
case "h3":
// h3 (e.g. "Available methods") delimits a section;
// flush the current h4 section but do not start a new one.
if current != nil {
sections = append(sections, *current)
current = nil
}
return
case "p":
if current != nil {
if current.Description != "" {
current.Description += "\n"
}
current.Description += strings.TrimSpace(textOf(n))
}
return
case "table":
if current != nil {
current.Tables = append(current.Tables, n)
}
return
case "ul":
if current != nil {
current.Lists = append(current.Lists, n)
}
return
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
visit(c)
}
}
visit(doc)
if current != nil {
sections = append(sections, *current)
}
return sections
}
// textOf returns the concatenated text content of n and descendants,
// with adjacent whitespace collapsed to single spaces.
func textOf(n *html.Node) string {
var sb strings.Builder
var w func(*html.Node)
w = func(n *html.Node) {
if n.Type == html.TextNode {
sb.WriteString(n.Data)
return
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
w(c)
}
}
w(n)
return collapseWS(sb.String())
}
func collapseWS(s string) string {
var b strings.Builder
prevSpace := false
for _, r := range s {
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
if !prevSpace {
b.WriteByte(' ')
}
prevSpace = true
continue
}
prevSpace = false
b.WriteRune(r)
}
return strings.TrimSpace(b.String())
}
// isMethodTitle returns true for headings that look like method names
// (camelCase starting with a lowercase letter; e.g. "sendMessage").
func isMethodTitle(s string) bool {
if s == "" {
return false
}
r := s[0]
return r >= 'a' && r <= 'z'
}
// isTypeTitle returns true for headings that look like type names
// (PascalCase; e.g. "Message"). Allows a leading-uppercase only;
// excludes spaces (which would denote a header like "Bot API 7.10").
func isTypeTitle(s string) bool {
if s == "" {
return false
}
r := s[0]
if r < 'A' || r > 'Z' {
return false
}
return !strings.Contains(s, " ")
}
+69
View File
@@ -0,0 +1,69 @@
package main
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/net/html"
)
func parse(t *testing.T, path string) *html.Node {
t.Helper()
f, err := os.Open(path)
require.NoError(t, err)
defer f.Close()
doc, err := html.Parse(f)
require.NoError(t, err)
return doc
}
func TestWalk_FixtureSections(t *testing.T) {
doc := parse(t, "../../testdata/html/small_fixture.html")
sections := walk(doc)
titles := make([]string, 0, len(sections))
for _, s := range sections {
titles = append(titles, s.Title)
}
require.Contains(t, titles, "User")
require.Contains(t, titles, "ChatMember")
require.Contains(t, titles, "getMe")
require.Contains(t, titles, "sendMessage")
require.Contains(t, titles, "sendDocument")
require.Contains(t, titles, "getUpdates")
require.Contains(t, titles, "Bot API 7.10")
}
func TestIsMethodTitle(t *testing.T) {
require.True(t, isMethodTitle("sendMessage"))
require.True(t, isMethodTitle("getMe"))
require.False(t, isMethodTitle("Message"))
require.False(t, isMethodTitle(""))
require.False(t, isMethodTitle("Bot API 7.10"))
}
func TestIsTypeTitle(t *testing.T) {
require.True(t, isTypeTitle("Message"))
require.True(t, isTypeTitle("ChatMember"))
require.False(t, isTypeTitle("sendMessage"))
require.False(t, isTypeTitle("Bot API 7.10"))
require.False(t, isTypeTitle(""))
}
func TestSection_DescriptionAndTables(t *testing.T) {
doc := parse(t, "../../testdata/html/small_fixture.html")
sections := walk(doc)
var sm *section
for i, s := range sections {
if s.Title == "sendMessage" {
sm = &sections[i]
break
}
}
require.NotNil(t, sm)
require.True(t, strings.Contains(sm.Description, "Use this method to send text messages"))
require.Len(t, sm.Tables, 1)
}
+40
View File
@@ -0,0 +1,40 @@
// Package dispatch provides a typed router for Telegram updates. It
// consumes any transport.Updater and dispatches updates to handlers
// registered by command, regex, or update-payload kind.
package dispatch
import (
"context"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
)
// Context bundles the per-update state every handler receives.
//
// Ctx is the request context propagated from Router.Run; cancelling the
// run cancels every handler.
//
// Bot is the API client. Handlers reply by calling api.SendMessage(c.Ctx,
// c.Bot, ...) etc.
//
// Update is the raw update; payload-typed handlers also receive a
// narrowed pointer to one of its sub-fields.
//
// Values is a per-update bag matchers populate. Conventional keys:
//
// "command": string, the matched bot command (e.g. "/start")
// "command_args": string, everything after the command
// "regex_match": []string, regex sub-matches when OnText matches
type Context struct {
Ctx context.Context
Bot *client.Bot
Update *api.Update
Values map[string]any
}
// NewContext constructs a Context. Used by Router internally; exposed for
// custom test harnesses.
func NewContext(ctx context.Context, b *client.Bot, u *api.Update) *Context {
return &Context{Ctx: ctx, Bot: b, Update: u, Values: map[string]any{}}
}
+494
View File
@@ -0,0 +1,494 @@
package conversation_test
import (
"context"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/lukaszraczylo/go-telegram/dispatch"
"github.com/lukaszraczylo/go-telegram/dispatch/conversation"
"github.com/stretchr/testify/require"
)
// ---- helpers ---------------------------------------------------------------
func msgUpd(userID, chatID int64, text string) api.Update {
return api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1,
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
Text: text,
},
}
}
func makeCtx(u *api.Update) *dispatch.Context {
return dispatch.NewContext(context.Background(), client.New("t"), u)
}
// anyMsg matches any update that has a Message.
var anyMsg = func(u *api.Update) bool { return u.Message != nil }
// hasPrefix returns a filter matching updates whose Message.Text has prefix p.
func hasPrefix(p string) dispatch.Filter[*api.Update] {
return func(u *api.Update) bool {
return u.Message != nil && strings.HasPrefix(u.Message.Text, p)
}
}
// fakeUpdater feeds a fixed set of updates then closes (mirrors router_test.go).
type fakeUpdater struct{ ch chan api.Update }
func newFake(ups ...api.Update) *fakeUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeUpdater{ch: ch}
}
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
// ---- Storage tests ---------------------------------------------------------
func TestStorage_ErrKeyNotFound(t *testing.T) {
s := conversation.NewMemoryStorage()
_, err := s.Get(context.Background(), "missing")
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestStorage_SetAndGet(t *testing.T) {
ctx := context.Background()
s := conversation.NewMemoryStorage()
require.NoError(t, s.Set(ctx, "k", "state_a"))
v, err := s.Get(ctx, "k")
require.NoError(t, err)
require.Equal(t, conversation.State("state_a"), v)
}
func TestStorage_Delete(t *testing.T) {
ctx := context.Background()
s := conversation.NewMemoryStorage()
require.NoError(t, s.Set(ctx, "k", "state_a"))
require.NoError(t, s.Delete(ctx, "k"))
_, err := s.Get(ctx, "k")
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestStorage_DeleteNonExistentIsNoop(t *testing.T) {
require.NoError(t, conversation.NewMemoryStorage().Delete(context.Background(), "gone"))
}
// ---- Key strategy tests ----------------------------------------------------
func TestKeyByUser_Variants(t *testing.T) {
t.Run("message", func(t *testing.T) {
u := msgUpd(42, 100, "hi")
require.Equal(t, "u:42", conversation.KeyByUser(&u))
})
t.Run("edited_message", func(t *testing.T) {
u := api.Update{EditedMessage: &api.Message{From: &api.User{ID: 7}, Chat: api.Chat{ID: 1}}}
require.Equal(t, "u:7", conversation.KeyByUser(&u))
})
t.Run("callback_query", func(t *testing.T) {
u := api.Update{CallbackQuery: &api.CallbackQuery{From: api.User{ID: 99}}}
require.Equal(t, "u:99", conversation.KeyByUser(&u))
})
t.Run("inline_query", func(t *testing.T) {
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.Equal(t, "u:5", conversation.KeyByUser(&u))
})
t.Run("empty", func(t *testing.T) {
require.Equal(t, "", conversation.KeyByUser(&api.Update{}))
})
}
func TestKeyByChat_Variants(t *testing.T) {
t.Run("message", func(t *testing.T) {
u := msgUpd(1, 200, "")
require.Equal(t, "c:200", conversation.KeyByChat(&u))
})
t.Run("inline_has_no_chat", func(t *testing.T) {
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.Equal(t, "", conversation.KeyByChat(&u))
})
}
func TestKeyByUserAndChat(t *testing.T) {
u := msgUpd(42, 100, "")
require.Equal(t, "uc:100:42", conversation.KeyByUserAndChat(&u))
}
// ---- Handler / state machine tests -----------------------------------------
func buildConv() *conversation.Conversation {
return &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("await_age") },
}},
"await_age": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
},
Exits: []conversation.Step{{
Filter: hasPrefix("/cancel"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
}
}
func TestConversation_FullFlow(t *testing.T) {
conv := buildConv()
var downstream int
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstream++
return nil
})
mw := conv.Dispatch(noop)
key := "uc:1:42"
// 1. /start → enters, state = await_name
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
v, err := conv.Storage.Get(context.Background(), key)
require.NoError(t, err)
require.Equal(t, conversation.State("await_name"), v)
require.Equal(t, 0, downstream, "entry claimed update")
// 2. name → state = await_age
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, err = conv.Storage.Get(context.Background(), key)
require.NoError(t, err)
require.Equal(t, conversation.State("await_age"), v)
// 3. age → End, key deleted
u3 := msgUpd(42, 1, "30")
require.NoError(t, mw(makeCtx(&u3), &u3))
_, err = conv.Storage.Get(context.Background(), key)
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestConversation_ExitsCancelMidFlow(t *testing.T) {
conv := buildConv()
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start conversation.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
require.NoError(t, err)
// Cancel mid-flow.
u2 := msgUpd(42, 1, "/cancel")
require.NoError(t, mw(makeCtx(&u2), &u2))
_, err = conv.Storage.Get(context.Background(), "uc:1:42")
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "exit should clear state")
}
func TestConversation_FallbackFiresWhenNoStateStepMatches(t *testing.T) {
fallbackHit := false
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
}},
States: map[conversation.State][]conversation.Step{
// No steps for "waiting" that match a callback query.
"waiting": {},
},
Fallbacks: []conversation.Step{{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
fallbackHit = true
return nil
},
}},
}
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
u2 := msgUpd(42, 1, "unexpected text")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.True(t, fallbackHit, "fallback should have fired")
}
func TestConversation_NoActiveConv_PassesToDownstream(t *testing.T) {
conv := buildConv()
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
// Random message that doesn't match /start
u := msgUpd(42, 1, "hello")
require.NoError(t, mw(makeCtx(&u), &u))
require.True(t, downstreamHit, "unmatched update should reach downstream")
}
func TestConversation_EmptyKey_PassesThrough(t *testing.T) {
// InlineQuery has no chatID → KeyByUserAndChat returns "" → pass through.
conv := buildConv()
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.NoError(t, mw(makeCtx(&u), &u))
require.True(t, downstreamHit)
}
func TestConversation_AllowReEntry(t *testing.T) {
conv := buildConv()
conv.AllowReEntry = true
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_name"), v)
// Advance once.
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_age"), v)
// Re-enter with /start — should restart to await_name even though mid-flow.
u3 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u3), &u3))
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_name"), v, "AllowReEntry should restart")
}
func TestConversation_NoReEntry_EntryIgnoredWhenActive(t *testing.T) {
conv := buildConv()
conv.AllowReEntry = false
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
// Advance to await_age.
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_age"), v)
// /start again — should NOT restart; state should stay await_age since
// /start matches the state step filter (anyMsg) and advances.
// Actually /start is handled by "await_age" anyMsg step → End().
u3 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u3), &u3))
// State ended (End() called by await_age step).
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "state step should have consumed /start when AllowReEntry=false")
}
func TestConversation_StayInState_NilReturn(t *testing.T) {
// Handler returning nil keeps state unchanged.
stored := false
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.Next("waiting")
},
}},
States: map[conversation.State][]conversation.Step{
"waiting": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
stored = true
return nil // stay in current state
},
}},
},
}
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
u2 := msgUpd(42, 1, "something")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.True(t, stored)
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("waiting"), v, "nil return should leave state unchanged")
}
func TestConversation_ActiveNoMatch_Swallows(t *testing.T) {
// Active conversation with no matching state step and no fallback:
// update is swallowed (not passed downstream).
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
}},
States: map[conversation.State][]conversation.Step{
"waiting": {{
// Only matches /done specifically.
Filter: hasPrefix("/done"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
},
}
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
// Random text doesn't match /done and there's no fallback → swallowed.
u2 := msgUpd(42, 1, "random")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.False(t, downstreamHit, "active conv with no matching step should swallow")
}
// ---- Via Router.Run --------------------------------------------------------
func TestConversation_ViaRouter(t *testing.T) {
var steps atomic.Int32
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.Next("await_age")
},
}},
"await_age": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.End()
},
}},
},
}
router := dispatch.New(client.New("t"), dispatch.WithMaxConcurrency(0)) // serial
router.Use(conv.Dispatch)
ups := []api.Update{
msgUpd(42, 1, "/start"),
msgUpd(42, 1, "Alice"),
msgUpd(42, 1, "30"),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- router.Run(ctx, newFake(ups...)) }()
// Wait for updater channel to drain (Run returns when closed).
err := <-errCh
if err != nil && err != context.Canceled {
t.Fatalf("Run error: %v", err)
}
require.Equal(t, int32(3), steps.Load(), "all three steps should have fired")
}
// ---- Concurrent storage safety ---------------------------------------------
func TestConversation_ConcurrentStorageAccess(t *testing.T) {
// 15 goroutines each running a full /start → name → age flow against the
// same shared storage but DIFFERENT keys (one per goroutine). Validates
// no data races.
const numUsers = 15
conv := buildConv()
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
var wg sync.WaitGroup
wg.Add(numUsers)
for i := 0; i < numUsers; i++ {
go func(uid int64) {
defer wg.Done()
u1 := msgUpd(uid, uid, "/start")
_ = mw(makeCtx(&u1), &u1)
u2 := msgUpd(uid, uid, "Alice")
_ = mw(makeCtx(&u2), &u2)
u3 := msgUpd(uid, uid, "30")
_ = mw(makeCtx(&u3), &u3)
}(int64(i + 1))
}
wg.Wait()
// Race detector catches bugs; no assertion needed beyond clean finish.
}
func TestConversation_ConcurrentSameKey(t *testing.T) {
// 12 goroutines hammer the same key concurrently. Storage must not panic
// or corrupt state. Race detector validates lock discipline.
const goroutines = 12
s := conversation.NewMemoryStorage()
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
_ = s.Set(ctx, "shared", conversation.State("step"))
_, _ = s.Get(ctx, "shared")
if i%4 == 0 {
_ = s.Delete(ctx, "shared")
}
}(i)
}
wg.Wait()
}
+176
View File
@@ -0,0 +1,176 @@
package conversation
import (
"context"
"errors"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// stateTransition is a sentinel error type carrying a state transition
// or end signal. Conversation handlers return one of these (via Next or
// End helpers below) to drive the state machine.
type stateTransition struct {
next State
end bool
}
func (e *stateTransition) Error() string {
if e.end {
return "conversation: end"
}
return "conversation: → " + string(e.next)
}
// Next signals the conversation should advance to the given state.
// Conversation handlers return Next("state_name") to transition.
func Next(s State) error {
return &stateTransition{next: s}
}
// End signals the conversation has finished and state should be cleared.
// Conversation handlers return End() to terminate.
func End() error {
return &stateTransition{end: true}
}
// Handler defines a step in the conversation. Receives the dispatch context
// and the raw update. Returns:
// - nil to stay in the current state
// - Next("state") to transition to a different state
// - End() to end the conversation
// - any other non-nil error to surface to the dispatcher (state unchanged)
type Handler func(ctx *dispatch.Context, u *api.Update) error
// Step pairs a filter with a handler for one conversation step.
type Step struct {
Filter dispatch.Filter[*api.Update]
Handler Handler
}
// Conversation is a stateful handler with entry, per-state, exit and
// fallback steps. A conversation is keyed by KeyStrategy (default
// KeyByUserAndChat) and persisted by Storage (default in-memory).
type Conversation struct {
// EntryPoints starts a new conversation when a matching filter fires
// and no conversation is already active for the key.
EntryPoints []Step
// States maps each state to the steps that handle it.
States map[State][]Step
// Exits, if any match, end the active conversation early. Useful for
// /cancel-style commands.
Exits []Step
// Fallbacks run when no state step matches the current update.
Fallbacks []Step
// Storage persists conversation state. Defaults to NewMemoryStorage.
Storage Storage
// KeyStrategy derives the persistence key. Defaults to KeyByUserAndChat.
KeyStrategy KeyStrategy
// AllowReEntry, when true, lets entry-point steps fire even while a
// conversation is already active for the key (effectively restarting it).
AllowReEntry bool
}
// Dispatch is a global middleware-shaped Handler that consumes updates
// and routes them through the conversation graph. Register via
// router.Use(conv.Dispatch).
//
// If the conversation claims an update, downstream handlers are skipped.
// If the conversation does not claim it, downstream handlers run as normal.
func (c *Conversation) Dispatch(next dispatch.Handler[*api.Update]) dispatch.Handler[*api.Update] {
if c.Storage == nil {
c.Storage = NewMemoryStorage()
}
if c.KeyStrategy == nil {
c.KeyStrategy = KeyByUserAndChat
}
return func(dctx *dispatch.Context, u *api.Update) error {
key := c.KeyStrategy(u)
if key == "" {
return next(dctx, u)
}
ctx := dctx.Ctx
current, err := c.Storage.Get(ctx, key)
if err != nil && !errors.Is(err, ErrKeyNotFound) {
return err
}
active := !errors.Is(err, ErrKeyNotFound)
// Try exits first (always allowed if conversation is active).
if active {
for _, step := range c.Exits {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
}
// Try entry points (only if no active conversation, or AllowReEntry).
if !active || c.AllowReEntry {
for _, step := range c.EntryPoints {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
}
if !active {
return next(dctx, u)
}
// Active conversation: try state steps.
for _, step := range c.States[current] {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
// Fallbacks if no state step matched.
for _, step := range c.Fallbacks {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
// Active conversation but no step matched and no fallback: swallow the
// update (do NOT pass to downstream handlers, since the user is
// mid-conversation and an unrelated handler would surprise them).
return nil
}
}
// runStep invokes the handler and applies its return-value state transition.
func (c *Conversation) runStep(ctx context.Context, dctx *dispatch.Context, u *api.Update, key string, h Handler) error {
err := h(dctx, u)
if err == nil {
return nil
}
var trans *stateTransition
if errors.As(err, &trans) {
if trans.end {
return c.Storage.Delete(ctx, key)
}
return c.Storage.Set(ctx, key, trans.next)
}
return err
}
+43
View File
@@ -0,0 +1,43 @@
package conversation
import (
"context"
"sync"
)
// MemoryStorage is the default in-process Storage. It is safe for
// concurrent use. Conversation state is lost on process restart; use
// a custom Storage backed by a database for persistent flows.
type MemoryStorage struct {
mu sync.RWMutex
state map[string]State
}
// NewMemoryStorage constructs an empty in-memory storage.
func NewMemoryStorage() *MemoryStorage {
return &MemoryStorage{state: map[string]State{}}
}
func (s *MemoryStorage) Get(_ context.Context, key string) (State, error) {
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.state[key]
if !ok {
return "", ErrKeyNotFound
}
return v, nil
}
func (s *MemoryStorage) Set(_ context.Context, key string, state State) error {
s.mu.Lock()
defer s.mu.Unlock()
s.state[key] = state
return nil
}
func (s *MemoryStorage) Delete(_ context.Context, key string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.state, key)
return nil
}
+87
View File
@@ -0,0 +1,87 @@
package conversation
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
func TestMemoryStorage_GetSetDelete(t *testing.T) {
ctx := context.Background()
s := NewMemoryStorage()
// Get on empty key returns ErrKeyNotFound.
_, err := s.Get(ctx, "k1")
require.ErrorIs(t, err, ErrKeyNotFound)
// Set then Get returns the stored state.
require.NoError(t, s.Set(ctx, "k1", "step_a"))
v, err := s.Get(ctx, "k1")
require.NoError(t, err)
require.Equal(t, State("step_a"), v)
// Overwrite works.
require.NoError(t, s.Set(ctx, "k1", "step_b"))
v, err = s.Get(ctx, "k1")
require.NoError(t, err)
require.Equal(t, State("step_b"), v)
// Delete removes the key.
require.NoError(t, s.Delete(ctx, "k1"))
_, err = s.Get(ctx, "k1")
require.ErrorIs(t, err, ErrKeyNotFound)
// Delete of non-existent key is a no-op (no error).
require.NoError(t, s.Delete(ctx, "nonexistent"))
}
func TestMemoryStorage_MultipleKeys(t *testing.T) {
ctx := context.Background()
s := NewMemoryStorage()
require.NoError(t, s.Set(ctx, "a", "stateA"))
require.NoError(t, s.Set(ctx, "b", "stateB"))
va, err := s.Get(ctx, "a")
require.NoError(t, err)
require.Equal(t, State("stateA"), va)
vb, err := s.Get(ctx, "b")
require.NoError(t, err)
require.Equal(t, State("stateB"), vb)
// Delete one key; the other remains.
require.NoError(t, s.Delete(ctx, "a"))
_, err = s.Get(ctx, "a")
require.ErrorIs(t, err, ErrKeyNotFound)
vb, err = s.Get(ctx, "b")
require.NoError(t, err)
require.Equal(t, State("stateB"), vb)
}
func TestMemoryStorage_Concurrent(t *testing.T) {
// 20 goroutines hammering the same key concurrently — no data race.
ctx := context.Background()
s := NewMemoryStorage()
const goroutines = 20
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
key := "shared"
_ = s.Set(ctx, key, State("step"))
_, _ = s.Get(ctx, key)
if i%3 == 0 {
_ = s.Delete(ctx, key)
}
}(i)
}
wg.Wait()
// No assertion needed — race detector catches the bug if present.
}
+79
View File
@@ -0,0 +1,79 @@
package conversation
import (
"fmt"
"github.com/lukaszraczylo/go-telegram/api"
)
// KeyStrategy derives a persistence key from an update. Strategies
// determine how conversation scope works — per-user, per-chat, or
// per-user-and-chat. Implementations must return a stable string for
// the same logical scope across updates.
//
// Returns the empty string if the update doesn't have enough context
// to derive a key (in which case the conversation handler skips it).
type KeyStrategy func(u *api.Update) string
// KeyByUser derives a key from the sending user's ID. Useful for DM
// conversations and any flow that should follow the user across chats.
var KeyByUser KeyStrategy = func(u *api.Update) string {
if uid := userID(u); uid != 0 {
return fmt.Sprintf("u:%d", uid)
}
return ""
}
// KeyByChat derives a key from the chat ID. Useful for group flows where
// any user in the chat can drive the conversation.
var KeyByChat KeyStrategy = func(u *api.Update) string {
if cid := chatID(u); cid != 0 {
return fmt.Sprintf("c:%d", cid)
}
return ""
}
// KeyByUserAndChat derives a key from both user and chat IDs. The most
// common strategy: each user has their own conversation per chat.
var KeyByUserAndChat KeyStrategy = func(u *api.Update) string {
uid := userID(u)
cid := chatID(u)
if uid == 0 || cid == 0 {
return ""
}
return fmt.Sprintf("uc:%d:%d", cid, uid)
}
// userID extracts the sending user's ID from any update payload.
func userID(u *api.Update) int64 {
switch {
case u.Message != nil && u.Message.From != nil:
return u.Message.From.ID
case u.EditedMessage != nil && u.EditedMessage.From != nil:
return u.EditedMessage.From.ID
case u.CallbackQuery != nil:
return u.CallbackQuery.From.ID
case u.InlineQuery != nil:
return u.InlineQuery.From.ID
}
return 0
}
// chatID extracts the relevant chat ID.
func chatID(u *api.Update) int64 {
switch {
case u.Message != nil:
return u.Message.Chat.ID
case u.EditedMessage != nil:
return u.EditedMessage.Chat.ID
case u.ChannelPost != nil:
return u.ChannelPost.Chat.ID
case u.EditedChannelPost != nil:
return u.EditedChannelPost.Chat.ID
case u.CallbackQuery != nil && u.CallbackQuery.Message != nil:
if msg, ok := u.CallbackQuery.Message.(*api.Message); ok {
return msg.Chat.ID
}
}
return 0
}
+115
View File
@@ -0,0 +1,115 @@
package conversation
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/stretchr/testify/require"
)
// helpers to build api.Update variants.
func msgUpdate(userID, chatID int64) *api.Update {
return &api.Update{
Message: &api.Message{
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
},
}
}
func editedMsgUpdate(userID, chatID int64) *api.Update {
return &api.Update{
EditedMessage: &api.Message{
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
},
}
}
func callbackUpdate(userID, chatID int64) *api.Update {
return &api.Update{
CallbackQuery: &api.CallbackQuery{
From: api.User{ID: userID},
Message: &api.Message{Chat: api.Chat{ID: chatID}},
},
}
}
func inlineUpdate(userID int64) *api.Update {
return &api.Update{
InlineQuery: &api.InlineQuery{
From: api.User{ID: userID},
},
}
}
func emptyUpdate() *api.Update { return &api.Update{} }
func TestKeyByUser(t *testing.T) {
t.Run("message update", func(t *testing.T) {
require.Equal(t, "u:42", KeyByUser(msgUpdate(42, 100)))
})
t.Run("edited message", func(t *testing.T) {
require.Equal(t, "u:7", KeyByUser(editedMsgUpdate(7, 100)))
})
t.Run("callback query", func(t *testing.T) {
require.Equal(t, "u:99", KeyByUser(callbackUpdate(99, 100)))
})
t.Run("inline query", func(t *testing.T) {
require.Equal(t, "u:5", KeyByUser(inlineUpdate(5)))
})
t.Run("empty update returns empty string", func(t *testing.T) {
require.Equal(t, "", KeyByUser(emptyUpdate()))
})
}
func TestKeyByChat(t *testing.T) {
t.Run("message update", func(t *testing.T) {
require.Equal(t, "c:100", KeyByChat(msgUpdate(42, 100)))
})
t.Run("edited message", func(t *testing.T) {
require.Equal(t, "c:200", KeyByChat(editedMsgUpdate(7, 200)))
})
t.Run("callback with accessible message", func(t *testing.T) {
require.Equal(t, "c:300", KeyByChat(callbackUpdate(99, 300)))
})
t.Run("inline query has no chat → empty", func(t *testing.T) {
require.Equal(t, "", KeyByChat(inlineUpdate(5)))
})
t.Run("empty update returns empty string", func(t *testing.T) {
require.Equal(t, "", KeyByChat(emptyUpdate()))
})
}
func TestKeyByUserAndChat(t *testing.T) {
t.Run("message update", func(t *testing.T) {
require.Equal(t, "uc:100:42", KeyByUserAndChat(msgUpdate(42, 100)))
})
t.Run("edited message", func(t *testing.T) {
require.Equal(t, "uc:200:7", KeyByUserAndChat(editedMsgUpdate(7, 200)))
})
t.Run("callback query", func(t *testing.T) {
require.Equal(t, "uc:300:99", KeyByUserAndChat(callbackUpdate(99, 300)))
})
t.Run("inline query has no chat → empty", func(t *testing.T) {
require.Equal(t, "", KeyByUserAndChat(inlineUpdate(5)))
})
t.Run("empty update returns empty string", func(t *testing.T) {
require.Equal(t, "", KeyByUserAndChat(emptyUpdate()))
})
}
func TestKeyByUserAndChat_CallbackInaccessibleMessage(t *testing.T) {
// CallbackQuery.Message is InaccessibleMessage (not *Message) — chatID returns 0.
u := &api.Update{
CallbackQuery: &api.CallbackQuery{
From: api.User{ID: 10},
Message: &api.InaccessibleMessage{}, // implements MaybeInaccessibleMessage, not *api.Message
},
}
// userID picks up From.ID=10 but chatID fails type assertion → 0
require.Equal(t, "", KeyByUserAndChat(u), "no key when message inaccessible")
// KeyByUser still works since From is set.
require.Equal(t, "u:10", KeyByUser(u))
}
+9
View File
@@ -0,0 +1,9 @@
// Package conversation implements a stateful conversation handler for the
// go-telegram dispatch router. It provides a state-machine abstraction over
// multi-step Telegram bot interactions, with pluggable storage and flexible
// key strategies.
package conversation
// State is a label identifying a node in the conversation graph.
// The empty string is the implicit "no active conversation" state.
type State string
+20
View File
@@ -0,0 +1,20 @@
package conversation
import (
"context"
"errors"
)
// ErrKeyNotFound is returned by Storage.Get when no conversation is active
// for the given key.
var ErrKeyNotFound = errors.New("conversation: key not found")
// Storage persists per-user (or per-chat, per-message — depending on the
// KeyStrategy in use) conversation state across update deliveries.
//
// Implementations must be safe for concurrent use.
type Storage interface {
Get(ctx context.Context, key string) (State, error)
Set(ctx context.Context, key string, state State) error
Delete(ctx context.Context, key string) error
}
+70
View File
@@ -0,0 +1,70 @@
package dispatch
// Filter is a predicate over a typed payload (e.g. *api.Message). Filters
// compose via And/Or/Not for multi-condition matching.
//
// Example:
//
// f := message.HasPhoto().And(message.InChat(-100123456789))
type Filter[T any] func(payload T) bool
// And returns a Filter that matches iff f and every one of others matches.
func (f Filter[T]) And(others ...Filter[T]) Filter[T] {
return func(payload T) bool {
if !f(payload) {
return false
}
for _, o := range others {
if !o(payload) {
return false
}
}
return true
}
}
// Or returns a Filter that matches iff f matches OR any of others matches.
func (f Filter[T]) Or(others ...Filter[T]) Filter[T] {
return func(payload T) bool {
if f(payload) {
return true
}
for _, o := range others {
if o(payload) {
return true
}
}
return false
}
}
// Not returns a Filter that inverts f.
func (f Filter[T]) Not() Filter[T] {
return func(payload T) bool { return !f(payload) }
}
// All combines filters with AND. Returns a Filter that matches when all match.
// Returns a filter that always matches when filters is empty.
func All[T any](filters ...Filter[T]) Filter[T] {
return func(payload T) bool {
for _, f := range filters {
if !f(payload) {
return false
}
}
return true
}
}
// Any combines filters with OR. Returns a Filter that matches when at least
// one matches. Returns a filter that never matches when filters is empty.
func Any[T any](filters ...Filter[T]) Filter[T] {
return func(payload T) bool {
for _, f := range filters {
if f(payload) {
return true
}
}
return false
}
}
+87
View File
@@ -0,0 +1,87 @@
package dispatch
import (
"testing"
"github.com/stretchr/testify/require"
)
func alwaysTrue[T any]() Filter[T] { return func(_ T) bool { return true } }
func alwaysFalse[T any]() Filter[T] { return func(_ T) bool { return false } }
func TestFilter_And(t *testing.T) {
t.Run("all true", func(t *testing.T) {
f := alwaysTrue[int]().And(alwaysTrue[int](), alwaysTrue[int]())
require.True(t, f(0))
})
t.Run("first false", func(t *testing.T) {
f := alwaysFalse[int]().And(alwaysTrue[int]())
require.False(t, f(0))
})
t.Run("other false", func(t *testing.T) {
f := alwaysTrue[int]().And(alwaysFalse[int]())
require.False(t, f(0))
})
t.Run("no others — acts as identity", func(t *testing.T) {
require.True(t, alwaysTrue[int]().And()(0))
require.False(t, alwaysFalse[int]().And()(0))
})
}
func TestFilter_Or(t *testing.T) {
t.Run("first true", func(t *testing.T) {
f := alwaysTrue[int]().Or(alwaysFalse[int]())
require.True(t, f(0))
})
t.Run("other true", func(t *testing.T) {
f := alwaysFalse[int]().Or(alwaysTrue[int]())
require.True(t, f(0))
})
t.Run("all false", func(t *testing.T) {
f := alwaysFalse[int]().Or(alwaysFalse[int]())
require.False(t, f(0))
})
t.Run("no others", func(t *testing.T) {
require.True(t, alwaysTrue[int]().Or()(0))
require.False(t, alwaysFalse[int]().Or()(0))
})
}
func TestFilter_Not(t *testing.T) {
require.False(t, alwaysTrue[int]().Not()(0))
require.True(t, alwaysFalse[int]().Not()(0))
}
func TestAll(t *testing.T) {
t.Run("all true", func(t *testing.T) {
require.True(t, All(alwaysTrue[int](), alwaysTrue[int]())(0))
})
t.Run("one false", func(t *testing.T) {
require.False(t, All(alwaysTrue[int](), alwaysFalse[int]())(0))
})
t.Run("empty — always true", func(t *testing.T) {
require.True(t, All[int]()(0))
})
}
func TestAny(t *testing.T) {
t.Run("one true", func(t *testing.T) {
require.True(t, Any(alwaysFalse[int](), alwaysTrue[int]())(0))
})
t.Run("all false", func(t *testing.T) {
require.False(t, Any(alwaysFalse[int](), alwaysFalse[int]())(0))
})
t.Run("empty — always false", func(t *testing.T) {
require.False(t, Any[int]()(0))
})
}
func TestFilter_Composition(t *testing.T) {
// (true AND false) OR true == true
f := alwaysTrue[int]().And(alwaysFalse[int]()).Or(alwaysTrue[int]())
require.True(t, f(0))
// NOT (true OR false) == false
g := alwaysTrue[int]().Or(alwaysFalse[int]()).Not()
require.False(t, g(0))
}
+43
View File
@@ -0,0 +1,43 @@
// Package callback provides Filter helpers for *api.CallbackQuery payloads.
package callback
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Data returns a Filter that matches callback queries whose Data matches
// pattern (regex). Panics at registration time on an invalid pattern.
func Data(pattern string) dispatch.Filter[*api.CallbackQuery] {
re := regexp.MustCompile(pattern)
return func(q *api.CallbackQuery) bool {
return q != nil && re.MatchString(q.Data)
}
}
// DataEquals returns a Filter that matches callback queries whose Data equals
// s exactly.
func DataEquals(s string) dispatch.Filter[*api.CallbackQuery] {
return func(q *api.CallbackQuery) bool {
return q != nil && q.Data == s
}
}
// DataPrefix returns a Filter that matches callback queries whose Data starts
// with prefix.
func DataPrefix(prefix string) dispatch.Filter[*api.CallbackQuery] {
return func(q *api.CallbackQuery) bool {
return q != nil && strings.HasPrefix(q.Data, prefix)
}
}
// FromUser returns a Filter that matches callback queries whose From.ID equals
// userID.
func FromUser(userID int64) dispatch.Filter[*api.CallbackQuery] {
return func(q *api.CallbackQuery) bool {
return q != nil && q.From.ID == userID
}
}
@@ -0,0 +1,56 @@
package callback_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
cbfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/callback"
"github.com/stretchr/testify/require"
)
func cq(data string, userID int64) *api.CallbackQuery {
return &api.CallbackQuery{
ID: "q",
From: api.User{ID: userID},
Data: data,
}
}
func TestData(t *testing.T) {
f := cbfilter.Data(`^like:\d+$`)
require.True(t, f(cq("like:42", 1)))
require.False(t, f(cq("dislike:42", 1)))
require.False(t, f(nil))
}
func TestData_PanicsOnBadPattern(t *testing.T) {
require.Panics(t, func() { cbfilter.Data(`[bad`) })
}
func TestDataEquals(t *testing.T) {
f := cbfilter.DataEquals("yes")
require.True(t, f(cq("yes", 1)))
require.False(t, f(cq("yes please", 1)))
require.False(t, f(nil))
}
func TestDataPrefix(t *testing.T) {
f := cbfilter.DataPrefix("vote:")
require.True(t, f(cq("vote:up", 1)))
require.False(t, f(cq("novote:up", 1)))
require.False(t, f(nil))
}
func TestFromUser(t *testing.T) {
f := cbfilter.FromUser(7)
require.True(t, f(cq("data", 7)))
require.False(t, f(cq("data", 8)))
require.False(t, f(nil))
}
func TestComposedCallbackFilters(t *testing.T) {
f := cbfilter.DataPrefix("vote:").And(cbfilter.FromUser(7))
require.True(t, f(cq("vote:up", 7)))
require.False(t, f(cq("vote:up", 8)))
require.False(t, f(cq("other", 7)))
}
@@ -0,0 +1,23 @@
// Package chatjoinrequest provides Filter helpers for *api.ChatJoinRequest payloads.
package chatjoinrequest
import (
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// FromUser returns a Filter that matches join requests where the requesting
// user's ID equals uid.
func FromUser(uid int64) dispatch.Filter[*api.ChatJoinRequest] {
return func(r *api.ChatJoinRequest) bool {
return r != nil && r.From.ID == uid
}
}
// InChat returns a Filter that matches join requests directed at the chat
// with the given chat ID.
func InChat(cid int64) dispatch.Filter[*api.ChatJoinRequest] {
return func(r *api.ChatJoinRequest) bool {
return r != nil && r.Chat.ID == cid
}
}
@@ -0,0 +1,37 @@
package chatjoinrequest_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
cjrfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatjoinrequest"
"github.com/stretchr/testify/require"
)
func joinRequest(fromID, chatID int64) *api.ChatJoinRequest {
return &api.ChatJoinRequest{
Chat: api.Chat{ID: chatID},
From: api.User{ID: fromID},
}
}
func TestFromUser_Matches(t *testing.T) {
f := cjrfilter.FromUser(10)
require.True(t, f(joinRequest(10, 100)))
require.False(t, f(joinRequest(99, 100)))
require.False(t, f(nil))
}
func TestInChat_Matches(t *testing.T) {
f := cjrfilter.InChat(100)
require.True(t, f(joinRequest(10, 100)))
require.False(t, f(joinRequest(10, 200)))
require.False(t, f(nil))
}
func TestComposedFilters(t *testing.T) {
f := cjrfilter.FromUser(10).And(cjrfilter.InChat(100))
require.True(t, f(joinRequest(10, 100)))
require.False(t, f(joinRequest(10, 200)))
require.False(t, f(joinRequest(99, 100)))
}
+41
View File
@@ -0,0 +1,41 @@
// Package chatmember provides Filter helpers for *api.ChatMemberUpdated payloads.
package chatmember
import (
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// NewStatus returns a Filter that matches updates where the new chat member
// status equals s (e.g. "member", "administrator", "kicked", "left").
func NewStatus(s string) dispatch.Filter[*api.ChatMemberUpdated] {
return func(u *api.ChatMemberUpdated) bool {
if u == nil {
return false
}
switch m := u.NewChatMember.(type) {
case *api.ChatMemberOwner:
return m.Status == s
case *api.ChatMemberAdministrator:
return m.Status == s
case *api.ChatMemberMember:
return m.Status == s
case *api.ChatMemberRestricted:
return m.Status == s
case *api.ChatMemberLeft:
return m.Status == s
case *api.ChatMemberBanned:
return m.Status == s
default:
return false
}
}
}
// FromUser returns a Filter that matches updates where the acting user
// (From.ID) equals uid.
func FromUser(uid int64) dispatch.Filter[*api.ChatMemberUpdated] {
return func(u *api.ChatMemberUpdated) bool {
return u != nil && u.From.ID == uid
}
}
@@ -0,0 +1,95 @@
package chatmember_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
cmfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatmember"
"github.com/stretchr/testify/require"
)
func memberUpdate(status string, fromID int64) *api.ChatMemberUpdated {
var newMember api.ChatMember
switch status {
case "member":
newMember = &api.ChatMemberMember{Status: status}
case "administrator":
newMember = &api.ChatMemberAdministrator{Status: status}
case "kicked":
newMember = &api.ChatMemberBanned{Status: status}
case "left":
newMember = &api.ChatMemberLeft{Status: status}
default:
newMember = &api.ChatMemberMember{Status: status}
}
return &api.ChatMemberUpdated{
From: api.User{ID: fromID},
NewChatMember: newMember,
}
}
func TestNewStatus_Matches(t *testing.T) {
f := cmfilter.NewStatus("member")
require.True(t, f(memberUpdate("member", 1)))
require.False(t, f(memberUpdate("kicked", 1)))
require.False(t, f(nil))
}
func TestNewStatus_Administrator(t *testing.T) {
f := cmfilter.NewStatus("administrator")
require.True(t, f(memberUpdate("administrator", 1)))
require.False(t, f(memberUpdate("member", 1)))
}
func TestNewStatus_Kicked(t *testing.T) {
f := cmfilter.NewStatus("kicked")
require.True(t, f(memberUpdate("kicked", 1)))
require.False(t, f(memberUpdate("left", 1)))
}
func TestNewStatus_Left(t *testing.T) {
f := cmfilter.NewStatus("left")
require.True(t, f(memberUpdate("left", 1)))
require.False(t, f(memberUpdate("member", 1)))
}
func TestFromUser_Matches(t *testing.T) {
f := cmfilter.FromUser(42)
require.True(t, f(memberUpdate("member", 42)))
require.False(t, f(memberUpdate("member", 99)))
require.False(t, f(nil))
}
func TestComposedFilters(t *testing.T) {
f := cmfilter.NewStatus("member").And(cmfilter.FromUser(7))
require.True(t, f(memberUpdate("member", 7)))
require.False(t, f(memberUpdate("member", 8)))
require.False(t, f(memberUpdate("kicked", 7)))
}
func TestNewStatus_Owner(t *testing.T) {
u := &api.ChatMemberUpdated{
From: api.User{ID: 1},
NewChatMember: &api.ChatMemberOwner{Status: "creator"},
}
require.True(t, cmfilter.NewStatus("creator")(u))
require.False(t, cmfilter.NewStatus("member")(u))
}
func TestNewStatus_Restricted(t *testing.T) {
u := &api.ChatMemberUpdated{
From: api.User{ID: 1},
NewChatMember: &api.ChatMemberRestricted{Status: "restricted"},
}
require.True(t, cmfilter.NewStatus("restricted")(u))
require.False(t, cmfilter.NewStatus("member")(u))
}
func TestNewStatus_UnknownType(t *testing.T) {
// nil NewChatMember → default branch → false
u := &api.ChatMemberUpdated{
From: api.User{ID: 1},
NewChatMember: nil,
}
require.False(t, cmfilter.NewStatus("member")(u))
}
+35
View File
@@ -0,0 +1,35 @@
// Package inline provides Filter helpers for *api.InlineQuery payloads.
package inline
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Query returns a Filter that matches inline queries whose Query field matches
// pattern (regex). Panics at registration time on an invalid pattern.
func Query(pattern string) dispatch.Filter[*api.InlineQuery] {
re := regexp.MustCompile(pattern)
return func(q *api.InlineQuery) bool {
return q != nil && re.MatchString(q.Query)
}
}
// QueryEquals returns a Filter that matches inline queries whose Query equals
// s exactly.
func QueryEquals(s string) dispatch.Filter[*api.InlineQuery] {
return func(q *api.InlineQuery) bool {
return q != nil && q.Query == s
}
}
// QueryPrefix returns a Filter that matches inline queries whose Query starts
// with prefix.
func QueryPrefix(prefix string) dispatch.Filter[*api.InlineQuery] {
return func(q *api.InlineQuery) bool {
return q != nil && strings.HasPrefix(q.Query, prefix)
}
}
+45
View File
@@ -0,0 +1,45 @@
package inline_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
ilfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/inline"
"github.com/stretchr/testify/require"
)
func iq(query string) *api.InlineQuery {
return &api.InlineQuery{ID: "i", From: api.User{ID: 1}, Query: query}
}
func TestQuery(t *testing.T) {
f := ilfilter.Query(`^find`)
require.True(t, f(iq("find me")))
require.False(t, f(iq("search me")))
require.False(t, f(nil))
}
func TestQuery_PanicsOnBadPattern(t *testing.T) {
require.Panics(t, func() { ilfilter.Query(`[bad`) })
}
func TestQueryEquals(t *testing.T) {
f := ilfilter.QueryEquals("exact")
require.True(t, f(iq("exact")))
require.False(t, f(iq("exact match")))
require.False(t, f(nil))
}
func TestQueryPrefix(t *testing.T) {
f := ilfilter.QueryPrefix("@user")
require.True(t, f(iq("@username")))
require.False(t, f(iq("no prefix")))
require.False(t, f(nil))
}
func TestComposedInlineFilters(t *testing.T) {
f := ilfilter.QueryPrefix("find").Or(ilfilter.QueryEquals("help"))
require.True(t, f(iq("find me")))
require.True(t, f(iq("help")))
require.False(t, f(iq("other")))
}
+142
View File
@@ -0,0 +1,142 @@
// Package message provides Filter helpers for *api.Message payloads.
package message
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Text returns a Filter that matches messages whose Text matches pattern (regex).
// Panics at registration time on an invalid pattern.
func Text(pattern string) dispatch.Filter[*api.Message] {
re := regexp.MustCompile(pattern)
return func(m *api.Message) bool {
return m != nil && re.MatchString(m.Text)
}
}
// TextEquals returns a Filter that matches messages whose Text equals s exactly.
func TextEquals(s string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Text == s
}
}
// TextPrefix returns a Filter that matches messages whose Text starts with prefix.
func TextPrefix(prefix string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && strings.HasPrefix(m.Text, prefix)
}
}
// TextContains returns a Filter that matches messages whose Text contains sub.
func TextContains(sub string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && strings.Contains(m.Text, sub)
}
}
// Command returns a Filter that matches messages whose first entity is a
// bot_command equal to "/<name>" (with or without "@BotName" suffix).
func Command(name string) dispatch.Filter[*api.Message] {
want := "/" + strings.TrimPrefix(name, "/")
return func(m *api.Message) bool {
if m == nil || len(m.Entities) == 0 || m.Text == "" {
return false
}
first := m.Entities[0]
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
return false
}
end := int(first.Length)
runes := []rune(m.Text)
if end > len(runes) {
return false
}
cmd := string(runes[:end])
if i := strings.Index(cmd, "@"); i >= 0 {
cmd = cmd[:i]
}
return cmd == want
}
}
// AnyCommand returns a Filter that matches any message starting with a
// bot_command entity at offset 0.
func AnyCommand() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
if m == nil || len(m.Entities) == 0 {
return false
}
first := m.Entities[0]
return first.Type == string(api.EntityBotCommand) && first.Offset == 0
}
}
// IsReply returns a Filter that matches messages that have ReplyToMessage set.
func IsReply() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.ReplyToMessage != nil
}
}
// IsForward returns a Filter that matches messages that have ForwardOrigin set.
func IsForward() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.ForwardOrigin != nil
}
}
// HasPhoto returns a Filter that matches messages with a Photo attachment.
func HasPhoto() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && len(m.Photo) > 0
}
}
// HasDocument returns a Filter that matches messages with a Document attachment.
func HasDocument() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Document != nil
}
}
// HasEntity returns a Filter that matches messages whose Entities contain at
// least one entity of type t (e.g. string(api.EntityBotCommand)).
func HasEntity(t string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
if m == nil {
return false
}
for _, e := range m.Entities {
if e.Type == t {
return true
}
}
return false
}
}
// ChatType returns a Filter that matches messages whose Chat.Type equals t.
func ChatType(t api.ChatType) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Chat.Type == string(t)
}
}
// FromUser returns a Filter that matches messages whose From.ID equals userID.
func FromUser(userID int64) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.From != nil && m.From.ID == userID
}
}
// InChat returns a Filter that matches messages whose Chat.ID equals chatID.
func InChat(chatID int64) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Chat.ID == chatID
}
}
+188
View File
@@ -0,0 +1,188 @@
package message_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
msgfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/message"
"github.com/stretchr/testify/require"
)
func msg(text string) *api.Message {
return &api.Message{
MessageID: 1,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
}
}
func cmdMsg(cmd string) *api.Message {
text := cmd
return &api.Message{
MessageID: 1,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
Entities: []api.MessageEntity{
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len([]rune(text)))},
},
}
}
func TestText(t *testing.T) {
f := msgfilter.Text(`^hello`)
require.True(t, f(msg("hello world")))
require.False(t, f(msg("world hello")))
require.False(t, f(nil))
}
func TestText_PanicsOnBadPattern(t *testing.T) {
require.Panics(t, func() { msgfilter.Text(`[invalid`) })
}
func TestTextEquals(t *testing.T) {
f := msgfilter.TextEquals("hi")
require.True(t, f(msg("hi")))
require.False(t, f(msg("hi there")))
require.False(t, f(nil))
}
func TestTextPrefix(t *testing.T) {
f := msgfilter.TextPrefix("/start")
require.True(t, f(msg("/start now")))
require.False(t, f(msg("no prefix")))
require.False(t, f(nil))
}
func TestTextContains(t *testing.T) {
f := msgfilter.TextContains("bot")
require.True(t, f(msg("my bot is cool")))
require.False(t, f(msg("nothing here")))
require.False(t, f(nil))
}
func TestCommand(t *testing.T) {
t.Run("matches exact command", func(t *testing.T) {
f := msgfilter.Command("/start")
require.True(t, f(cmdMsg("/start")))
})
t.Run("matches without leading slash", func(t *testing.T) {
f := msgfilter.Command("start")
require.True(t, f(cmdMsg("/start")))
})
t.Run("strips BotName suffix", func(t *testing.T) {
m := &api.Message{
Text: "/start@MyBot",
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 12}},
}
f := msgfilter.Command("/start")
require.True(t, f(m))
})
t.Run("no match different command", func(t *testing.T) {
f := msgfilter.Command("/stop")
require.False(t, f(cmdMsg("/start")))
})
t.Run("nil message", func(t *testing.T) {
require.False(t, msgfilter.Command("/start")(nil))
})
t.Run("no entities", func(t *testing.T) {
require.False(t, msgfilter.Command("/start")(msg("/start")))
})
}
func TestAnyCommand(t *testing.T) {
f := msgfilter.AnyCommand()
require.True(t, f(cmdMsg("/anything")))
require.False(t, f(msg("plain text")))
require.False(t, f(nil))
}
func TestIsReply(t *testing.T) {
f := msgfilter.IsReply()
m := msg("reply")
m.ReplyToMessage = &api.Message{MessageID: 2}
require.True(t, f(m))
require.False(t, f(msg("no reply")))
require.False(t, f(nil))
}
func TestIsForward(t *testing.T) {
// ForwardOrigin is a MessageOrigin interface; set via a concrete type.
f := msgfilter.IsForward()
m := msg("fwd")
m.ForwardOrigin = &api.MessageOriginUser{Type: "user"}
require.True(t, f(m))
require.False(t, f(msg("no fwd")))
require.False(t, f(nil))
}
func TestHasPhoto(t *testing.T) {
f := msgfilter.HasPhoto()
m := msg("")
m.Photo = []api.PhotoSize{{FileID: "x", Width: 100, Height: 100}}
require.True(t, f(m))
require.False(t, f(msg("no photo")))
require.False(t, f(nil))
}
func TestHasDocument(t *testing.T) {
f := msgfilter.HasDocument()
m := msg("")
m.Document = &api.Document{FileID: "doc1"}
require.True(t, f(m))
require.False(t, f(msg("no doc")))
require.False(t, f(nil))
}
func TestHasEntity(t *testing.T) {
f := msgfilter.HasEntity(string(api.EntityURL))
m := msg("check https://example.com")
m.Entities = []api.MessageEntity{{Type: string(api.EntityURL), Offset: 6, Length: 19}}
require.True(t, f(m))
require.False(t, f(msg("plain")))
require.False(t, f(nil))
}
func TestChatType(t *testing.T) {
f := msgfilter.ChatType(api.ChatTypePrivate)
private := msg("hi")
require.True(t, f(private))
group := msg("hi")
group.Chat.Type = string(api.ChatTypeGroup)
require.False(t, f(group))
require.False(t, f(nil))
}
func TestFromUser(t *testing.T) {
f := msgfilter.FromUser(42)
m := msg("hi")
m.From = &api.User{ID: 42}
require.True(t, f(m))
m2 := msg("hi")
m2.From = &api.User{ID: 99}
require.False(t, f(m2))
require.False(t, f(msg("no from")))
require.False(t, f(nil))
}
func TestInChat(t *testing.T) {
f := msgfilter.InChat(1)
require.True(t, f(msg("hi")))
m2 := msg("hi")
m2.Chat.ID = 2
require.False(t, f(m2))
require.False(t, f(nil))
}
func TestComposedMessageFilters(t *testing.T) {
// private chat AND contains "hello"
f := msgfilter.ChatType(api.ChatTypePrivate).And(msgfilter.TextContains("hello"))
m := msg("say hello")
require.True(t, f(m))
m2 := msg("say hello")
m2.Chat.Type = string(api.ChatTypeGroup)
require.False(t, f(m2))
}
@@ -0,0 +1,23 @@
// Package precheckoutquery provides Filter helpers for *api.PreCheckoutQuery payloads.
package precheckoutquery
import (
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Currency returns a Filter that matches pre-checkout queries with the given
// ISO 4217 currency code (e.g. "USD", "EUR", "XTR").
func Currency(c string) dispatch.Filter[*api.PreCheckoutQuery] {
return func(q *api.PreCheckoutQuery) bool {
return q != nil && q.Currency == c
}
}
// FromUser returns a Filter that matches pre-checkout queries sent by the
// user with the given ID.
func FromUser(uid int64) dispatch.Filter[*api.PreCheckoutQuery] {
return func(q *api.PreCheckoutQuery) bool {
return q != nil && q.From.ID == uid
}
}
@@ -0,0 +1,38 @@
package precheckoutquery_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
pcqfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/precheckoutquery"
"github.com/stretchr/testify/require"
)
func pcq(currency string, fromID int64) *api.PreCheckoutQuery {
return &api.PreCheckoutQuery{
ID: "q",
Currency: currency,
From: api.User{ID: fromID},
}
}
func TestCurrency_Matches(t *testing.T) {
f := pcqfilter.Currency("USD")
require.True(t, f(pcq("USD", 1)))
require.False(t, f(pcq("EUR", 1)))
require.False(t, f(nil))
}
func TestFromUser_Matches(t *testing.T) {
f := pcqfilter.FromUser(5)
require.True(t, f(pcq("USD", 5)))
require.False(t, f(pcq("USD", 9)))
require.False(t, f(nil))
}
func TestComposedFilters(t *testing.T) {
f := pcqfilter.Currency("XTR").And(pcqfilter.FromUser(42))
require.True(t, f(pcq("XTR", 42)))
require.False(t, f(pcq("XTR", 99)))
require.False(t, f(pcq("USD", 42)))
}
+186
View File
@@ -0,0 +1,186 @@
package dispatch
import (
"errors"
"regexp"
"sort"
"github.com/lukaszraczylo/go-telegram/api"
)
// ErrEndGroups stops dispatch from running any further handlers in any
// group for this update when returned by a handler. Use it to indicate
// the update has been definitively handled.
//
// errors.Is(err, ErrEndGroups) is the canonical check, though dispatch
// itself recognises it by exact identity.
var ErrEndGroups = errors.New("dispatch: end groups")
// ErrContinueGroups signals that this group's handler should be treated
// as not-matching when returned by a handler: dispatch moves on to the
// next handler in the same group, then to subsequent groups.
//
// Without ErrContinueGroups, a non-error return from a matched handler
// stops dispatch (default first-match-wins semantics).
var ErrContinueGroups = errors.New("dispatch: continue groups")
// RouterScope registers handlers into a specific priority group on its parent
// Router. Group 0 runs first, then group 1, etc. Within a group, handlers run
// in registration order; the first non-skipped match terminates dispatch
// unless the handler returns ErrContinueGroups.
type RouterScope struct {
router *Router
group int
}
// Group returns a RouterScope that registers handlers in the given group.
// Group 0 (the default) runs first, then group 1, etc. Within a group,
// handlers run in registration order; the first non-skipped match
// terminates dispatch unless the handler returns ErrContinueGroups.
func (r *Router) Group(group int) *RouterScope {
return &RouterScope{router: r, group: group}
}
// OnCommand registers a command handler in this group.
func (s *RouterScope) OnCommand(cmd string, h Handler[*api.Message]) {
s.router.groupCommands = append(s.router.groupCommands, groupCommandRoute{
cmd: cmd, group: s.group, handler: h,
})
}
// OnText registers a regex text handler in this group.
// Panics at registration time if pattern is not a valid regular expression.
func (s *RouterScope) OnText(pattern string, h Handler[*api.Message]) {
s.router.groupTexts = append(s.router.groupTexts, groupTextRoute{
re: regexp.MustCompile(pattern), group: s.group, handler: h,
})
}
// OnMessageFilter registers a filter-based message handler in this group.
func (s *RouterScope) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
s.router.groupMessageFilters = append(s.router.groupMessageFilters, groupMessageFilterRoute{
filter: f, group: s.group, handler: h,
})
}
// group-aware route types
type groupCommandRoute struct {
cmd string
group int
handler Handler[*api.Message]
}
type groupTextRoute struct {
re *regexp.Regexp
group int
handler Handler[*api.Message]
}
type groupMessageFilterRoute struct {
filter Filter[*api.Message]
group int
handler Handler[*api.Message]
}
// dispatchGroups runs message handlers registered via RouterScope.Group().
// It collects all matching groups, sorts by group number, and applies
// first-match-wins semantics within each group. Handlers may return
// ErrContinueGroups (skip to next handler/group) or ErrEndGroups (stop all groups).
// A non-sentinel error stops dispatch and is returned to the caller.
func (r *Router) dispatchGroups(c *Context, m *api.Message) error {
// Collect group numbers present.
groupSet := map[int]struct{}{}
for _, gr := range r.groupCommands {
groupSet[gr.group] = struct{}{}
}
for _, gr := range r.groupTexts {
groupSet[gr.group] = struct{}{}
}
for _, gr := range r.groupMessageFilters {
groupSet[gr.group] = struct{}{}
}
if len(groupSet) == 0 {
return nil
}
groups := make([]int, 0, len(groupSet))
for g := range groupSet {
groups = append(groups, g)
}
sort.Ints(groups)
for _, g := range groups {
matched, err := r.runGroupHandlers(c, m, g)
if err != nil {
if errors.Is(err, ErrEndGroups) {
return nil
}
return err
}
if matched {
// First-match-wins: stop further groups.
return nil
}
// No match or ErrContinueGroups from all handlers: try next group.
}
return nil
}
// runGroupHandlers runs all handlers in group g against m, in registration
// order. Returns (true, nil) when a handler matched (returned nil). Returns
// (false, nil) when all handlers returned ErrContinueGroups. Returns
// (false, err) for ErrEndGroups or any non-sentinel error.
func (r *Router) runGroupHandlers(c *Context, m *api.Message, g int) (matched bool, err error) {
// Commands.
if cmd, args, ok := extractCommand(m); ok {
for _, route := range r.groupCommands {
if route.group != g || route.cmd != cmd {
continue
}
c.Values["command"] = cmd
c.Values["command_args"] = args
if err := route.handler(c, m); err != nil {
if errors.Is(err, ErrContinueGroups) {
continue
}
return false, err
}
return true, nil
}
}
// Text regex.
if m.Text != "" {
for _, route := range r.groupTexts {
if route.group != g {
continue
}
subs := route.re.FindStringSubmatch(m.Text)
if subs == nil {
continue
}
c.Values["regex_match"] = subs
if err := route.handler(c, m); err != nil {
if errors.Is(err, ErrContinueGroups) {
continue
}
return false, err
}
return true, nil
}
}
// Filter-based.
for _, route := range r.groupMessageFilters {
if route.group != g || !route.filter(m) {
continue
}
if err := route.handler(c, m); err != nil {
if errors.Is(err, ErrContinueGroups) {
continue
}
return false, err
}
return true, nil
}
return false, nil
}
+209
View File
@@ -0,0 +1,209 @@
package dispatch
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/require"
)
// msgUpdate builds a simple private message update.
func msgUpdate(id int64, text string) api.Update {
return api.Update{
UpdateID: id,
Message: &api.Message{
MessageID: id,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
},
}
}
// cmdUpdate builds a command message update.
func cmdUpdate(id int64, cmd string) api.Update {
return api.Update{
UpdateID: id,
Message: &api.Message{
MessageID: id,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: cmd,
Entities: []api.MessageEntity{
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len(cmd))},
},
},
}
}
// runSingle fires one update through the router and waits for it to complete.
func runSingle(t *testing.T, r *Router, up api.Update) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(up))
}
// TestGroup_Order verifies group 0 fires before group 1.
func TestGroup_Order(t *testing.T) {
r := New(client.New("t"))
var order []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
order = append(order, 0)
return ErrContinueGroups // let group 1 also run
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
order = append(order, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
require.Equal(t, []int{0, 1}, order)
}
// TestGroup_FirstMatchWins verifies group 0 match stops group 1 by default.
func TestGroup_FirstMatchWins(t *testing.T) {
r := New(client.New("t"))
var fired []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 0)
return nil // matched — group 1 must NOT run
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
require.Equal(t, []int{0}, fired)
}
// TestGroup_ErrContinueGroups lets group 1 run when group 0 returns ErrContinueGroups.
func TestGroup_ErrContinueGroups(t *testing.T) {
r := New(client.New("t"))
g1Hit := make(chan struct{}, 1)
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
return ErrContinueGroups
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
g1Hit <- struct{}{}
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ping"))) }()
select {
case <-g1Hit:
case <-ctx.Done():
t.Fatal("group 1 handler never fired")
}
}
// TestGroup_ErrEndGroups stops all further groups.
func TestGroup_ErrEndGroups(t *testing.T) {
r := New(client.New("t"))
var fired []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 0)
return ErrEndGroups
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
require.Equal(t, []int{0}, fired)
}
// TestGroup_NonSentinelError propagates error and stops further groups.
func TestGroup_NonSentinelError(t *testing.T) {
r := New(client.New("t"), WithMaxConcurrency(0))
var fired []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 0)
return context.DeadlineExceeded // non-sentinel real error
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
// group 1 must not fire
require.Equal(t, []int{0}, fired)
}
// TestGroup_Command verifies OnCommand in a group works.
func TestGroup_Command(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.Group(0).OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- "g0-start"
return nil
})
r.Group(1).OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- "g1-start"
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/start"))) }()
got := <-hit
require.Equal(t, "g0-start", got)
}
// TestGroup_MessageFilter verifies OnMessageFilter in a group works.
func TestGroup_MessageFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan bool, 1)
r.Group(0).OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ok" }),
func(c *Context, m *api.Message) error {
hit <- true
return nil
},
)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ok"))) }()
require.True(t, <-hit)
}
// TestGroup_ErrContinueGroups_WithCommand verifies ErrContinueGroups works for commands across groups.
func TestGroup_ErrContinueGroups_WithCommand(t *testing.T) {
r := New(client.New("t"))
var count atomic.Int32
r.Group(0).OnCommand("/ping", func(c *Context, m *api.Message) error {
count.Add(1)
return ErrContinueGroups
})
r.Group(1).OnCommand("/ping", func(c *Context, m *api.Message) error {
count.Add(10)
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/ping"))) }()
time.Sleep(100 * time.Millisecond)
cancel()
require.Equal(t, int32(11), count.Load())
}
+21
View File
@@ -0,0 +1,21 @@
package dispatch
// Handler is a generic handler over update payload type T. T is typically
// *api.Message, *api.CallbackQuery, *api.InlineQuery, or *api.Update for
// global middleware.
type Handler[T any] func(ctx *Context, payload T) error
// Middleware wraps a Handler[T] with cross-cutting behaviour (logging,
// recovery, auth). Middleware composition is left-to-right: Use(a,b,c)
// runs as a(b(c(handler))).
type Middleware[T any] func(Handler[T]) Handler[T]
// Chain composes a slice of middleware into a single Middleware[T].
func Chain[T any](mws ...Middleware[T]) Middleware[T] {
return func(h Handler[T]) Handler[T] {
for i := len(mws) - 1; i >= 0; i-- {
h = mws[i](h)
}
return h
}
}
+27
View File
@@ -0,0 +1,27 @@
package dispatch
import (
"fmt"
"runtime/debug"
"github.com/lukaszraczylo/go-telegram/api"
)
// Recovery returns middleware that recovers from panics in downstream
// handlers, converting them into a returned error and logging via the
// bot's configured logger. Registered automatically by NewRouter.
func Recovery() Middleware[*api.Update] {
return func(next Handler[*api.Update]) Handler[*api.Update] {
return func(c *Context, u *api.Update) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in handler: %v\n%s", r, debug.Stack())
if c.Bot != nil {
c.Bot.Logger().Error("dispatch recovered panic", "err", err)
}
}
}()
return next(c, u)
}
}
}
+101
View File
@@ -0,0 +1,101 @@
package dispatch
import (
"fmt"
"sync"
)
// NamedHandlers manages handlers by string name, allowing runtime
// registration, replacement, and removal. This complements the Router's
// registration methods: each registration via Named*() also gets a name
// for later lookup.
//
// Use case: a plugin system that loads/unloads command handlers without
// restarting the bot.
type NamedHandlers[T any] struct {
mu sync.RWMutex
handlers map[string]Handler[T]
order []string // preserves registration order
}
// NewNamedHandlers returns a new, empty NamedHandlers[T].
func NewNamedHandlers[T any]() *NamedHandlers[T] {
return &NamedHandlers[T]{handlers: map[string]Handler[T]{}}
}
// Set registers or replaces the handler under name. If name is new, it is
// appended to the end of the registration order.
func (n *NamedHandlers[T]) Set(name string, h Handler[T]) {
n.mu.Lock()
defer n.mu.Unlock()
if _, exists := n.handlers[name]; !exists {
n.order = append(n.order, name)
}
n.handlers[name] = h
}
// Remove unregisters the handler under name. Returns true if it existed.
func (n *NamedHandlers[T]) Remove(name string) bool {
n.mu.Lock()
defer n.mu.Unlock()
if _, ok := n.handlers[name]; !ok {
return false
}
delete(n.handlers, name)
for i, k := range n.order {
if k == name {
n.order = append(n.order[:i], n.order[i+1:]...)
break
}
}
return true
}
// Has reports whether name is registered.
func (n *NamedHandlers[T]) Has(name string) bool {
n.mu.RLock()
defer n.mu.RUnlock()
_, ok := n.handlers[name]
return ok
}
// Names returns the registered names in registration order.
func (n *NamedHandlers[T]) Names() []string {
n.mu.RLock()
defer n.mu.RUnlock()
out := make([]string, len(n.order))
copy(out, n.order)
return out
}
// Handler returns a single Handler[T] that runs each registered handler
// in registration order, first non-nil error stops the chain. Use this
// to wire NamedHandlers into a Router.OnXxx call:
//
// names := dispatch.NewNamedHandlers[*api.Message]()
// names.Set("logger", loggingHandler)
// names.Set("audit", auditHandler)
// router.OnCommand("/admin", names.Handler())
//
// Subsequent Set/Remove calls take effect on the next dispatch.
func (n *NamedHandlers[T]) Handler() Handler[T] {
return func(c *Context, payload T) error {
n.mu.RLock()
names := make([]string, len(n.order))
copy(names, n.order)
n.mu.RUnlock()
for _, name := range names {
n.mu.RLock()
h, ok := n.handlers[name]
n.mu.RUnlock()
if !ok {
continue
}
if err := h(c, payload); err != nil {
return fmt.Errorf("named handler %q: %w", name, err)
}
}
return nil
}
}
+153
View File
@@ -0,0 +1,153 @@
package dispatch
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/stretchr/testify/require"
)
// makeMsg returns a minimal *api.Message for use in handler tests.
func makeMsg() *api.Message {
return &api.Message{MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}}
}
// makeCtx returns a minimal *Context (nil bot is fine for unit tests).
func makeCtx() *Context {
return NewContext(context.Background(), nil, &api.Update{})
}
func TestNamedHandlers_SetAndHas(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
require.False(t, n.Has("a"))
n.Set("a", func(c *Context, m *api.Message) error { return nil })
require.True(t, n.Has("a"))
}
func TestNamedHandlers_Names_RegistrationOrder(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("first", func(c *Context, m *api.Message) error { return nil })
n.Set("second", func(c *Context, m *api.Message) error { return nil })
n.Set("third", func(c *Context, m *api.Message) error { return nil })
require.Equal(t, []string{"first", "second", "third"}, n.Names())
}
func TestNamedHandlers_Remove(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("a", func(c *Context, m *api.Message) error { return nil })
n.Set("b", func(c *Context, m *api.Message) error { return nil })
removed := n.Remove("a")
require.True(t, removed)
require.False(t, n.Has("a"))
require.Equal(t, []string{"b"}, n.Names())
// Remove non-existent returns false.
require.False(t, n.Remove("nonexistent"))
}
func TestNamedHandlers_Replacement_SameOrderSlot(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("a", func(c *Context, m *api.Message) error { return nil })
n.Set("b", func(c *Context, m *api.Message) error { return nil })
var called string
n.Set("a", func(c *Context, m *api.Message) error {
called = "replaced-a"
return nil
})
// Order must not change; "a" stays first.
require.Equal(t, []string{"a", "b"}, n.Names())
h := n.Handler()
_ = h(makeCtx(), makeMsg())
require.Equal(t, "replaced-a", called)
}
func TestNamedHandlers_Handler_RunsInOrder(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
var calls []string
n.Set("first", func(c *Context, m *api.Message) error {
calls = append(calls, "first")
return nil
})
n.Set("second", func(c *Context, m *api.Message) error {
calls = append(calls, "second")
return nil
})
h := n.Handler()
require.NoError(t, h(makeCtx(), makeMsg()))
require.Equal(t, []string{"first", "second"}, calls)
}
func TestNamedHandlers_Handler_ErrorWrappedAndStops(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
sentinel := errors.New("boom")
n.Set("ok", func(c *Context, m *api.Message) error { return nil })
n.Set("fail", func(c *Context, m *api.Message) error { return sentinel })
n.Set("never", func(c *Context, m *api.Message) error {
t.Fatal("should not be called after an error")
return nil
})
h := n.Handler()
err := h(makeCtx(), makeMsg())
require.Error(t, err)
require.True(t, errors.Is(err, sentinel))
require.Contains(t, err.Error(), `named handler "fail"`)
}
func TestNamedHandlers_Concurrent_SetRemove(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
// Pre-populate so Handler() has something to iterate.
for i := range 5 {
name := fmt.Sprintf("h%d", i)
n.Set(name, func(c *Context, m *api.Message) error { return nil })
}
h := n.Handler()
var wg sync.WaitGroup
// Concurrent readers (invoke handler).
for range 20 {
wg.Add(1)
go func() {
defer wg.Done()
_ = h(makeCtx(), makeMsg())
}()
}
// Concurrent writers.
for i := range 5 {
wg.Add(1)
go func(i int) {
defer wg.Done()
name := fmt.Sprintf("new%d", i)
n.Set(name, func(c *Context, m *api.Message) error { return nil })
n.Remove(fmt.Sprintf("h%d", i))
}(i)
}
wg.Wait()
}
func TestNamedHandlers_RemoveAndReinstate(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("a", func(c *Context, m *api.Message) error { return nil })
n.Remove("a")
require.False(t, n.Has("a"))
// Re-register after removal; should be added at end.
n.Set("b", func(c *Context, m *api.Message) error { return nil })
n.Set("a", func(c *Context, m *api.Message) error { return nil })
require.Equal(t, []string{"b", "a"}, n.Names())
}
+582
View File
@@ -0,0 +1,582 @@
package dispatch
import (
"context"
"regexp"
"strings"
"sync"
"unicode/utf8"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/lukaszraczylo/go-telegram/transport"
)
// Router dispatches updates from any Updater to typed handlers.
//
// Matchers run in registration order; first match wins. A panic-recovery
// middleware is attached automatically and runs around every dispatch.
type Router struct {
bot *client.Bot
commands []commandRoute
texts []textRoute
callbacks []callbackRoute
inlines []Handler[*api.InlineQuery]
editedMsg []Handler[*api.Message]
channelPosts []Handler[*api.Message]
editedChannelPosts []Handler[*api.Message]
messageFilters []messageFilterRoute
callbackFilters []callbackFilterRoute
inlineFilters []inlineFilterRoute
// typed update handlers
myChatMember []Handler[*api.ChatMemberUpdated]
chatMember []Handler[*api.ChatMemberUpdated]
chatJoinRequest []Handler[*api.ChatJoinRequest]
preCheckoutQuery []Handler[*api.PreCheckoutQuery]
shippingQuery []Handler[*api.ShippingQuery]
poll []Handler[*api.Poll]
pollAnswer []Handler[*api.PollAnswer]
chosenInlineResult []Handler[*api.ChosenInlineResult]
messageReaction []Handler[*api.MessageReactionUpdated]
messageReactionCnt []Handler[*api.MessageReactionCountUpdated]
chatBoost []Handler[*api.ChatBoostUpdated]
removedChatBoost []Handler[*api.ChatBoostRemoved]
businessConn []Handler[*api.BusinessConnection]
purchasedPaidMedia []Handler[*api.PaidMediaPurchased]
myChatMemberFilters []chatMemberFilterRoute
chatMemberFilters []chatMemberFilterRoute
chatJoinRequestFilters []chatJoinRequestFilterRoute
preCheckoutFilters []preCheckoutFilterRoute
// group-priority routes (registered via Router.Group())
groupCommands []groupCommandRoute
groupTexts []groupTextRoute
groupMessageFilters []groupMessageFilterRoute
globalMW []Middleware[*api.Update]
maxConcurrency int // default 50; 0 = serial (legacy)
sem chan struct{}
}
type messageFilterRoute struct {
filter Filter[*api.Message]
handler Handler[*api.Message]
}
type callbackFilterRoute struct {
filter Filter[*api.CallbackQuery]
handler Handler[*api.CallbackQuery]
}
type inlineFilterRoute struct {
filter Filter[*api.InlineQuery]
handler Handler[*api.InlineQuery]
}
type chatMemberFilterRoute struct {
filter Filter[*api.ChatMemberUpdated]
handler Handler[*api.ChatMemberUpdated]
}
type chatJoinRequestFilterRoute struct {
filter Filter[*api.ChatJoinRequest]
handler Handler[*api.ChatJoinRequest]
}
type preCheckoutFilterRoute struct {
filter Filter[*api.PreCheckoutQuery]
handler Handler[*api.PreCheckoutQuery]
}
// RouterOption configures a Router at construction time.
type RouterOption func(*Router)
// WithMaxConcurrency sets the maximum number of updates processed in parallel.
// Default is 50. Pass 0 to dispatch serially (one update at a time, in the
// calling goroutine — the legacy behaviour before v1.1.0).
//
// Note: concurrent dispatch means handlers for different updates may run
// simultaneously. Handlers that mutate shared state must be safe for concurrent
// access.
func WithMaxConcurrency(n int) RouterOption {
return func(r *Router) { r.maxConcurrency = n }
}
type commandRoute struct {
cmd string
handler Handler[*api.Message]
}
type textRoute struct {
re *regexp.Regexp
handler Handler[*api.Message]
}
type callbackRoute struct {
re *regexp.Regexp
handler Handler[*api.CallbackQuery]
}
// New constructs a Router. Recovery middleware is added by default; users
// can disable it by passing WithoutRecovery (not implemented here, but
// the hook is in place via Use).
func New(b *client.Bot, opts ...RouterOption) *Router {
r := &Router{bot: b, maxConcurrency: 50}
for _, o := range opts {
o(r)
}
if r.maxConcurrency > 0 {
r.sem = make(chan struct{}, r.maxConcurrency)
}
r.Use(Recovery())
return r
}
// Use registers a global middleware applied to every Update dispatch.
func (r *Router) Use(mw Middleware[*api.Update]) { r.globalMW = append(r.globalMW, mw) }
// OnCommand registers a handler for a slash command. The command string
// includes the leading slash (e.g. "/start"). Matching strips an optional
// "@BotName" suffix.
func (r *Router) OnCommand(cmd string, h Handler[*api.Message]) {
r.commands = append(r.commands, commandRoute{cmd: cmd, handler: h})
}
// OnText registers a handler for messages whose Text matches the regex.
//
// Panics at registration time if pattern is not a valid regular expression.
func (r *Router) OnText(pattern string, h Handler[*api.Message]) {
r.texts = append(r.texts, textRoute{re: regexp.MustCompile(pattern), handler: h})
}
// OnCallback registers a handler for callback queries whose Data matches
// the regex.
//
// Panics at registration time if pattern is not a valid regular expression.
func (r *Router) OnCallback(pattern string, h Handler[*api.CallbackQuery]) {
r.callbacks = append(r.callbacks, callbackRoute{re: regexp.MustCompile(pattern), handler: h})
}
// OnInlineQuery registers a handler for inline queries (one matcher only;
// inline queries are not partitioned by content here).
func (r *Router) OnInlineQuery(h Handler[*api.InlineQuery]) {
r.inlines = append(r.inlines, h)
}
// OnEditedMessage registers a handler for edited message updates.
func (r *Router) OnEditedMessage(h Handler[*api.Message]) {
r.editedMsg = append(r.editedMsg, h)
}
// OnChannelPost registers a handler for channel post updates.
func (r *Router) OnChannelPost(h Handler[*api.Message]) {
r.channelPosts = append(r.channelPosts, h)
}
// OnEditedChannelPost registers a handler for edited channel post updates.
func (r *Router) OnEditedChannelPost(h Handler[*api.Message]) {
r.editedChannelPosts = append(r.editedChannelPosts, h)
}
// OnMessageFilter registers a typed message handler gated by filter f.
// Filter routes are checked after command and text routes; first match wins.
func (r *Router) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
r.messageFilters = append(r.messageFilters, messageFilterRoute{filter: f, handler: h})
}
// OnCallbackFilter registers a typed callback-query handler gated by filter f.
// Filter routes are checked after pattern-based OnCallback routes; first match wins.
func (r *Router) OnCallbackFilter(f Filter[*api.CallbackQuery], h Handler[*api.CallbackQuery]) {
r.callbackFilters = append(r.callbackFilters, callbackFilterRoute{filter: f, handler: h})
}
// OnInlineQueryFilter registers an inline-query handler gated by filter f.
// Filter routes are checked after bare OnInlineQuery handlers; first match wins.
func (r *Router) OnInlineQueryFilter(f Filter[*api.InlineQuery], h Handler[*api.InlineQuery]) {
r.inlineFilters = append(r.inlineFilters, inlineFilterRoute{filter: f, handler: h})
}
// OnMyChatMember registers a handler for bot's own chat member status changes.
func (r *Router) OnMyChatMember(h Handler[*api.ChatMemberUpdated]) {
r.myChatMember = append(r.myChatMember, h)
}
// OnMyChatMemberFilter registers a filtered handler for bot's own chat member status changes.
func (r *Router) OnMyChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
r.myChatMemberFilters = append(r.myChatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
}
// OnChatMember registers a handler for chat member status changes.
func (r *Router) OnChatMember(h Handler[*api.ChatMemberUpdated]) {
r.chatMember = append(r.chatMember, h)
}
// OnChatMemberFilter registers a filtered handler for chat member status changes.
func (r *Router) OnChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
r.chatMemberFilters = append(r.chatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
}
// OnChatJoinRequest registers a handler for chat join requests.
func (r *Router) OnChatJoinRequest(h Handler[*api.ChatJoinRequest]) {
r.chatJoinRequest = append(r.chatJoinRequest, h)
}
// OnChatJoinRequestFilter registers a filtered handler for chat join requests.
func (r *Router) OnChatJoinRequestFilter(f Filter[*api.ChatJoinRequest], h Handler[*api.ChatJoinRequest]) {
r.chatJoinRequestFilters = append(r.chatJoinRequestFilters, chatJoinRequestFilterRoute{filter: f, handler: h})
}
// OnPreCheckoutQuery registers a handler for pre-checkout queries.
func (r *Router) OnPreCheckoutQuery(h Handler[*api.PreCheckoutQuery]) {
r.preCheckoutQuery = append(r.preCheckoutQuery, h)
}
// OnPreCheckoutQueryFilter registers a filtered handler for pre-checkout queries.
func (r *Router) OnPreCheckoutQueryFilter(f Filter[*api.PreCheckoutQuery], h Handler[*api.PreCheckoutQuery]) {
r.preCheckoutFilters = append(r.preCheckoutFilters, preCheckoutFilterRoute{filter: f, handler: h})
}
// OnShippingQuery registers a handler for shipping queries.
func (r *Router) OnShippingQuery(h Handler[*api.ShippingQuery]) {
r.shippingQuery = append(r.shippingQuery, h)
}
// OnPoll registers a handler for poll state updates.
func (r *Router) OnPoll(h Handler[*api.Poll]) {
r.poll = append(r.poll, h)
}
// OnPollAnswer registers a handler for poll answer updates.
func (r *Router) OnPollAnswer(h Handler[*api.PollAnswer]) {
r.pollAnswer = append(r.pollAnswer, h)
}
// OnChosenInlineResult registers a handler for chosen inline results.
func (r *Router) OnChosenInlineResult(h Handler[*api.ChosenInlineResult]) {
r.chosenInlineResult = append(r.chosenInlineResult, h)
}
// OnMessageReaction registers a handler for message reaction updates.
func (r *Router) OnMessageReaction(h Handler[*api.MessageReactionUpdated]) {
r.messageReaction = append(r.messageReaction, h)
}
// OnMessageReactionCount registers a handler for anonymous message reaction count updates.
func (r *Router) OnMessageReactionCount(h Handler[*api.MessageReactionCountUpdated]) {
r.messageReactionCnt = append(r.messageReactionCnt, h)
}
// OnChatBoost registers a handler for chat boost updates.
func (r *Router) OnChatBoost(h Handler[*api.ChatBoostUpdated]) {
r.chatBoost = append(r.chatBoost, h)
}
// OnRemovedChatBoost registers a handler for removed chat boost updates.
func (r *Router) OnRemovedChatBoost(h Handler[*api.ChatBoostRemoved]) {
r.removedChatBoost = append(r.removedChatBoost, h)
}
// OnBusinessConnection registers a handler for business connection updates.
func (r *Router) OnBusinessConnection(h Handler[*api.BusinessConnection]) {
r.businessConn = append(r.businessConn, h)
}
// OnPurchasedPaidMedia registers a handler for purchased paid media updates.
func (r *Router) OnPurchasedPaidMedia(h Handler[*api.PaidMediaPurchased]) {
r.purchasedPaidMedia = append(r.purchasedPaidMedia, h)
}
// Run consumes the Updater and dispatches each update. It blocks until
// the Updater's channel is closed or ctx is cancelled.
//
// By default updates are processed concurrently (up to WithMaxConcurrency(50)
// goroutines). Handlers for different updates may therefore run simultaneously;
// shared state must be protected. Pass WithMaxConcurrency(0) to New to restore
// serial (legacy) behaviour.
//
// Run waits for all in-flight handlers to finish before returning.
func (r *Router) Run(ctx context.Context, u transport.Updater) error {
runErr := make(chan error, 1)
go func() { runErr <- u.Run(ctx) }()
root := r.dispatch
for i := len(r.globalMW) - 1; i >= 0; i-- {
root = r.globalMW[i](root)
}
var wg sync.WaitGroup
defer wg.Wait()
dispatch := func(up api.Update) {
c := NewContext(ctx, r.bot, &up)
if err := root(c, &up); err != nil {
if r.bot != nil {
r.bot.Logger().Error("dispatch handler error", "err", err, "update_id", up.UpdateID)
}
}
}
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-runErr:
return err
case up, ok := <-u.Updates():
if !ok {
// Channel closed; consume the run error if pending.
select {
case err := <-runErr:
return err
default:
}
return nil
}
if r.sem == nil {
// Serial mode (legacy / WithMaxConcurrency(0)).
dispatch(up)
continue
}
// Concurrent mode: acquire semaphore slot then launch goroutine.
select {
case r.sem <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
wg.Add(1)
go func(up api.Update) {
defer func() {
<-r.sem
wg.Done()
}()
dispatch(up)
}(up)
}
}
}
func (r *Router) dispatch(c *Context, u *api.Update) error {
switch {
case u.Message != nil:
return r.handleMessage(c, u.Message)
case u.EditedMessage != nil:
return runHandlers(r.editedMsg, c, u.EditedMessage)
case u.ChannelPost != nil:
return runHandlers(r.channelPosts, c, u.ChannelPost)
case u.EditedChannelPost != nil:
return runHandlers(r.editedChannelPosts, c, u.EditedChannelPost)
case u.CallbackQuery != nil:
return r.handleCallback(c, u.CallbackQuery)
case u.InlineQuery != nil:
if err := runHandlers(r.inlines, c, u.InlineQuery); err != nil {
return err
}
for _, route := range r.inlineFilters {
if route.filter(u.InlineQuery) {
return route.handler(c, u.InlineQuery)
}
}
return nil
case u.MyChatMember != nil:
return r.handleChatMemberUpdate(c, u.MyChatMember, r.myChatMember, r.myChatMemberFilters)
case u.ChatMember != nil:
return r.handleChatMemberUpdate(c, u.ChatMember, r.chatMember, r.chatMemberFilters)
case u.ChatJoinRequest != nil:
return r.handleChatJoinRequest(c, u.ChatJoinRequest)
case u.PreCheckoutQuery != nil:
return r.handlePreCheckoutQuery(c, u.PreCheckoutQuery)
case u.ShippingQuery != nil:
return runHandlers(r.shippingQuery, c, u.ShippingQuery)
case u.Poll != nil:
return runHandlers(r.poll, c, u.Poll)
case u.PollAnswer != nil:
return runHandlers(r.pollAnswer, c, u.PollAnswer)
case u.ChosenInlineResult != nil:
return runHandlers(r.chosenInlineResult, c, u.ChosenInlineResult)
case u.MessageReaction != nil:
return runHandlers(r.messageReaction, c, u.MessageReaction)
case u.MessageReactionCount != nil:
return runHandlers(r.messageReactionCnt, c, u.MessageReactionCount)
case u.ChatBoost != nil:
return runHandlers(r.chatBoost, c, u.ChatBoost)
case u.RemovedChatBoost != nil:
return runHandlers(r.removedChatBoost, c, u.RemovedChatBoost)
case u.BusinessConnection != nil:
return runHandlers(r.businessConn, c, u.BusinessConnection)
case u.PurchasedPaidMedia != nil:
return runHandlers(r.purchasedPaidMedia, c, u.PurchasedPaidMedia)
}
return nil
}
func (r *Router) handleChatMemberUpdate(c *Context, payload *api.ChatMemberUpdated, handlers []Handler[*api.ChatMemberUpdated], filters []chatMemberFilterRoute) error {
if err := runHandlers(handlers, c, payload); err != nil {
return err
}
for _, route := range filters {
if route.filter(payload) {
return route.handler(c, payload)
}
}
return nil
}
func (r *Router) handleChatJoinRequest(c *Context, payload *api.ChatJoinRequest) error {
if err := runHandlers(r.chatJoinRequest, c, payload); err != nil {
return err
}
for _, route := range r.chatJoinRequestFilters {
if route.filter(payload) {
return route.handler(c, payload)
}
}
return nil
}
func (r *Router) handlePreCheckoutQuery(c *Context, payload *api.PreCheckoutQuery) error {
if err := runHandlers(r.preCheckoutQuery, c, payload); err != nil {
return err
}
for _, route := range r.preCheckoutFilters {
if route.filter(payload) {
return route.handler(c, payload)
}
}
return nil
}
// runHandlers invokes each handler in order; returns the first non-nil error.
func runHandlers[T any](handlers []Handler[T], c *Context, payload T) error {
for _, h := range handlers {
if err := h(c, payload); err != nil {
return err
}
}
return nil
}
func (r *Router) handleMessage(c *Context, m *api.Message) error {
// Try command first (entity-aware).
if cmd, args, ok := extractCommand(m); ok {
for _, route := range r.commands {
if route.cmd == cmd {
c.Values["command"] = cmd
c.Values["command_args"] = args
return route.handler(c, m)
}
}
}
// Then text regex matchers.
if m.Text != "" {
for _, route := range r.texts {
if subs := route.re.FindStringSubmatch(m.Text); subs != nil {
c.Values["regex_match"] = subs
return route.handler(c, m)
}
}
}
// Filter-based routes.
for _, route := range r.messageFilters {
if route.filter(m) {
return route.handler(c, m)
}
}
// Group-priority routes (registered via RouterScope.Group()).
return r.dispatchGroups(c, m)
}
func (r *Router) handleCallback(c *Context, q *api.CallbackQuery) error {
for _, route := range r.callbacks {
if subs := route.re.FindStringSubmatch(q.Data); subs != nil {
c.Values["regex_match"] = subs
return route.handler(c, q)
}
}
// Filter-based routes checked after pattern routes.
for _, route := range r.callbackFilters {
if route.filter(q) {
return route.handler(c, q)
}
}
return nil
}
// extractCommand returns the command (e.g. "/start") and the remaining
// argument string, when m carries a leading bot_command entity. It strips
// optional "@BotName" suffix on the command itself.
func extractCommand(m *api.Message) (cmd, args string, ok bool) {
if len(m.Entities) == 0 || m.Text == "" {
return "", "", false
}
first := m.Entities[0]
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
return "", "", false
}
cmd, sliceOk := utf16Slice(m.Text, int(first.Offset), int(first.Length))
if !sliceOk {
return "", "", false
}
if i := strings.Index(cmd, "@"); i >= 0 {
cmd = cmd[:i]
}
end := int(first.Offset) + int(first.Length)
rest, _ := utf16Slice(m.Text, end, utf16Len(m.Text)-end)
args = strings.TrimSpace(rest)
return cmd, args, true
}
// utf16Slice returns the substring of s identified by a UTF-16 offset/length
// pair, as Telegram's MessageEntity uses. ok is false if the indices fall
// outside s's UTF-16 length.
func utf16Slice(s string, offset, length int) (string, bool) {
runes := []rune(s)
var startBytes, endBytes int
var u16 int
found := false
for i, r := range runes {
if u16 == offset {
startBytes = byteIndex(runes, i)
found = true
}
if u16 == offset+length {
endBytes = byteIndex(runes, i)
return s[startBytes:endBytes], true
}
if r > 0xFFFF {
u16 += 2
} else {
u16++
}
}
if found && u16 == offset+length {
return s[startBytes:], true
}
return "", false
}
func byteIndex(runes []rune, runeIdx int) int {
n := 0
for i := 0; i < runeIdx; i++ {
n += utf8.RuneLen(runes[i])
}
return n
}
func utf16Len(s string) int {
n := 0
for _, r := range s {
if r > 0xFFFF {
n += 2
} else {
n++
}
}
return n
}
+940
View File
@@ -0,0 +1,940 @@
package dispatch
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/require"
)
// fakeUpdater feeds a fixed slice of updates then closes.
type fakeUpdater struct{ ch chan api.Update }
func newFake(ups ...api.Update) *fakeUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeUpdater{ch: ch}
}
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
func cmdMessage(text string) api.Update {
return api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1, Date: 0, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(indexEnd(text))}},
},
}
}
func indexEnd(text string) int {
for i, r := range text {
if r == ' ' {
return i
}
}
return len(text)
}
func TestRouter_OnCommandMatches(t *testing.T) {
b := client.New("t")
r := New(b)
hit := make(chan string, 1)
r.OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- c.Values["command"].(string)
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start"))) }()
require.Equal(t, "/start", <-hit)
}
func TestRouter_OnCommandStripsBotName(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- "matched"
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start@MyBot hello"))) }()
require.Equal(t, "matched", <-hit)
}
func TestRouter_OnText(t *testing.T) {
r := New(client.New("t"))
hit := make(chan []string, 1)
r.OnText(`^hello (\w+)$`, func(c *Context, m *api.Message) error {
hit <- c.Values["regex_match"].([]string)
return nil
})
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "hello world",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
subs := <-hit
require.Equal(t, "world", subs[1])
}
func TestRouter_OnCallback(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnCallback(`^like:(\d+)$`, func(c *Context, q *api.CallbackQuery) error {
hit <- q.Data
return nil
})
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "like:42",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "like:42", <-hit)
}
func TestRouter_NoMatch(t *testing.T) {
r := New(client.New("t"))
called := false
r.OnCommand("/start", func(c *Context, m *api.Message) error {
called = true
return nil
})
u := api.Update{UpdateID: 1, Message: &api.Message{Text: "no command"}}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.False(t, called)
}
func TestRouter_PanicRecovery(t *testing.T) {
r := New(client.New("t"))
r.OnCommand("/boom", func(c *Context, m *api.Message) error {
panic("kaboom")
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
// Should not propagate panic to Run.
require.NotPanics(t, func() { _ = r.Run(ctx, newFake(cmdMessage("/boom"))) })
}
// TestRouter_NonASCIICommand verifies that UTF-16 entity offsets are used
// correctly when the command contains non-ASCII runes. "/старт" is 6 runes,
// each a BMP code point, so UTF-16 length == 6.
func TestRouter_NonASCIICommand(t *testing.T) {
const text = "/старт аргумент"
// "/старт" = 1 + 5 runes, all BMP → UTF-16 length 6
const cmdU16Len = int64(6)
u := api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
Entities: []api.MessageEntity{
{Type: string(api.EntityBotCommand), Offset: 0, Length: cmdU16Len},
},
},
}
r := New(client.New("t"))
hit := make(chan [2]string, 1)
r.OnCommand("/старт", func(c *Context, m *api.Message) error {
hit <- [2]string{
c.Values["command"].(string),
c.Values["command_args"].(string),
}
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
got := <-hit
require.Equal(t, "/старт", got[0])
require.Equal(t, "аргумент", got[1])
}
// TestRouter_CommandValuesNotLeakedOnNoMatch verifies that c.Values["command"]
// is not set when a command entity is present but no route matches, so a
// subsequent text handler doesn't see stale values.
func TestRouter_CommandValuesNotLeakedOnNoMatch(t *testing.T) {
r := New(client.New("t"))
// Register a text handler that should fire as fallback.
leaked := make(chan bool, 1)
r.OnText(`.*`, func(c *Context, m *api.Message) error {
_, hasCmd := c.Values["command"]
leaked <- hasCmd
return nil
})
// No OnCommand registered, so the command entity won't match any route.
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"},
Text: "/unknown",
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 8}},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.False(t, <-leaked, "command value must not leak into text handler")
}
func TestRouter_MiddlewareOrder(t *testing.T) {
r := New(client.New("t"))
var order []string
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
return func(c *Context, u *api.Update) error {
order = append(order, "before-1")
err := next(c, u)
order = append(order, "after-1")
return err
}
})
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
return func(c *Context, u *api.Update) error {
order = append(order, "before-2")
err := next(c, u)
order = append(order, "after-2")
return err
}
})
r.OnCommand("/x", func(c *Context, m *api.Message) error {
order = append(order, "handler")
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(cmdMessage("/x")))
require.Equal(t,
[]string{"before-1", "before-2", "handler", "after-2", "after-1"},
order)
}
func TestRouter_OnChannelPost(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChannelPost(func(c *Context, m *api.Message) error {
hit <- m.MessageID
return nil
})
u := api.Update{UpdateID: 1, ChannelPost: &api.Message{
MessageID: 99, Chat: api.Chat{ID: -100, Type: string(api.ChatTypeChannel)},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, int64(99), <-hit)
}
func TestRouter_RunsAllHandlersForEditedMessage(t *testing.T) {
r := New(client.New("t"))
var hits []string
r.OnEditedMessage(func(c *Context, m *api.Message) error {
hits = append(hits, "first")
return nil
})
r.OnEditedMessage(func(c *Context, m *api.Message) error {
hits = append(hits, "second")
return nil
})
u := api.Update{UpdateID: 1, EditedMessage: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "edited",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.Equal(t, []string{"first", "second"}, hits)
}
// ---------------------------------------------------------------------------
// Filter-route tests
// ---------------------------------------------------------------------------
func TestRouter_OnMessageFilter_Matches(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ping" }),
func(c *Context, m *api.Message) error { hit <- m.Text; return nil },
)
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "ping",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "ping", <-hit)
}
func TestRouter_OnMessageFilter_NoMatch(t *testing.T) {
r := New(client.New("t"))
called := false
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return false }),
func(c *Context, m *api.Message) error { called = true; return nil },
)
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "any",
}}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.False(t, called)
}
// Command routes must take priority over filter routes.
func TestRouter_OnMessageFilter_CommandWins(t *testing.T) {
r := New(client.New("t"))
var winner string
r.OnCommand("/start", func(c *Context, m *api.Message) error { winner = "command"; return nil })
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error { winner = "filter"; return nil },
)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(cmdMessage("/start")))
require.Equal(t, "command", winner)
}
func TestRouter_OnCallbackFilter_Matches(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnCallbackFilter(
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return q != nil && q.Data == "yes" }),
func(c *Context, q *api.CallbackQuery) error { hit <- q.Data; return nil },
)
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "yes", <-hit)
}
// Pattern-based OnCallback wins over OnCallbackFilter when both match.
func TestRouter_OnCallbackFilter_PatternWins(t *testing.T) {
r := New(client.New("t"))
var winner string
r.OnCallback(`^yes$`, func(c *Context, q *api.CallbackQuery) error { winner = "pattern"; return nil })
r.OnCallbackFilter(
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return true }),
func(c *Context, q *api.CallbackQuery) error { winner = "filter"; return nil },
)
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.Equal(t, "pattern", winner)
}
func TestRouter_OnInlineQueryFilter_Matches(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnInlineQueryFilter(
Filter[*api.InlineQuery](func(q *api.InlineQuery) bool { return q != nil && q.Query == "find" }),
func(c *Context, q *api.InlineQuery) error { hit <- q.Query; return nil },
)
u := api.Update{UpdateID: 1, InlineQuery: &api.InlineQuery{
ID: "i", From: api.User{ID: 1}, Query: "find",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "find", <-hit)
}
func TestRouter_FilterChain_Composition(t *testing.T) {
// Filter: private chat AND text contains "hello"
privateChat := Filter[*api.Message](func(m *api.Message) bool {
return m != nil && m.Chat.Type == string(api.ChatTypePrivate)
})
hasHello := Filter[*api.Message](func(m *api.Message) bool {
return m != nil && len(m.Text) > 0 && containsStr(m.Text, "hello")
})
combined := privateChat.And(hasHello)
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnMessageFilter(combined, func(c *Context, m *api.Message) error { hit <- m.Text; return nil })
match := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "say hello",
}}
noMatch := api.Update{UpdateID: 2, Message: &api.Message{
MessageID: 2, Chat: api.Chat{ID: 2, Type: string(api.ChatTypeGroup)}, Text: "say hello",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(match, noMatch)) }()
require.Equal(t, "say hello", <-hit)
}
// containsStr is a helper to avoid importing strings in test file unnecessarily.
func containsStr(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub))
}
func containsSubstr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
// ---------------------------------------------------------------------------
// Concurrent dispatch tests
// ---------------------------------------------------------------------------
// fakeSlowUpdater feeds n updates then blocks until ctx cancel.
type fakeSlowUpdater struct {
ch chan api.Update
}
func newSlowFake(ups ...api.Update) *fakeSlowUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeSlowUpdater{ch: ch}
}
func (f *fakeSlowUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeSlowUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeSlowUpdater) Stop(ctx context.Context) error { return nil }
func TestRouter_ConcurrentDispatch_AllHandlersFire(t *testing.T) {
const n = 100
var fired atomic.Int64
ups := make([]api.Update, n)
for i := range ups {
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}}
}
r := New(client.New("t"), WithMaxConcurrency(20))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error { fired.Add(1); return nil },
)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = r.Run(ctx, newSlowFake(ups...))
require.Equal(t, int64(n), fired.Load())
}
func TestRouter_ConcurrentDispatch_SemaphoreBoundsConcurrency(t *testing.T) {
const limit = 5
const n = 30
var inFlight atomic.Int64
var maxSeen atomic.Int64
ready := make(chan struct{}) // signals handler to proceed
started := make(chan struct{}) // first handler signals it's running
ups := make([]api.Update, n)
for i := range ups {
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}}
}
once := atomic.Bool{}
r := New(client.New("t"), WithMaxConcurrency(limit))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error {
cur := inFlight.Add(1)
for {
old := maxSeen.Load()
if cur <= old || maxSeen.CompareAndSwap(old, cur) {
break
}
}
if once.CompareAndSwap(false, true) {
close(started)
}
<-ready
inFlight.Add(-1)
return nil
},
)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go func() { _ = r.Run(ctx, newSlowFake(ups...)) }()
select {
case <-started:
case <-ctx.Done():
t.Fatal("timed out waiting for first handler")
}
// Give the pool a moment to fill up.
time.Sleep(50 * time.Millisecond)
close(ready)
// Wait for Run to drain by cancelling context after a short wait.
time.Sleep(200 * time.Millisecond)
cancel()
require.LessOrEqual(t, maxSeen.Load(), int64(limit),
"in-flight goroutines exceeded semaphore limit")
}
func TestRouter_ConcurrentDispatch_WaitsForInFlight(t *testing.T) {
unblock := make(chan struct{})
done := make(chan struct{})
r := New(client.New("t"), WithMaxConcurrency(10))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error {
<-unblock
return nil
},
)
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "hi",
}}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go func() {
_ = r.Run(ctx, newSlowFake(u))
close(done)
}()
// Give Run time to pick up the update and launch the goroutine.
time.Sleep(30 * time.Millisecond)
cancel() // trigger Run to exit its loop
// Run should not return until handler unblocks.
select {
case <-done:
t.Fatal("Run returned before in-flight handler finished")
case <-time.After(50 * time.Millisecond):
}
close(unblock)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run did not return after handler finished")
}
}
func TestRouter_SerialMode_NoRace(t *testing.T) {
// WithMaxConcurrency(0) — serial; shared slice is safe without a mutex.
var order []int64
const n = 20
ups := make([]api.Update, n)
for i := range ups {
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}}
}
r := New(client.New("t"), WithMaxConcurrency(0))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error {
order = append(order, m.MessageID)
return nil
},
)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = r.Run(ctx, newSlowFake(ups...))
require.Len(t, order, n)
for i, v := range order {
require.Equal(t, int64(i+1), v)
}
}
// liveUpdater is an updater whose channel stays open until stopCh is closed.
type liveUpdater struct {
ch chan api.Update
stopCh chan struct{}
}
func newLiveUpdater() *liveUpdater {
return &liveUpdater{ch: make(chan api.Update, 8), stopCh: make(chan struct{})}
}
func (l *liveUpdater) Send(u api.Update) { l.ch <- u }
func (l *liveUpdater) Close() { close(l.stopCh) }
func (l *liveUpdater) Updates() <-chan api.Update { return l.ch }
func (l *liveUpdater) Stop(ctx context.Context) error { return nil }
func (l *liveUpdater) Run(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-l.stopCh:
return nil
}
}
// ---------------------------------------------------------------------------
// Typed handler tests (Feature 1)
// ---------------------------------------------------------------------------
func TestRouter_OnMyChatMember(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnMyChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
upd := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{
From: api.User{ID: 42},
Chat: api.Chat{ID: 1},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(42), <-hit)
}
func TestRouter_OnMyChatMemberFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.From.ID == 99 })
r.OnMyChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
match := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 99}}}
noMatch := api.Update{UpdateID: 2, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(noMatch, match)) }()
require.Equal(t, int64(99), <-hit)
}
func TestRouter_OnChatMember(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{
From: api.User{ID: 1},
Chat: api.Chat{ID: 77},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(77), <-hit)
}
func TestRouter_OnChatMemberFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.Chat.ID == 55 })
r.OnChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{Chat: api.Chat{ID: 55}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(55), <-hit)
}
func TestRouter_OnChatJoinRequest(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChatJoinRequest(func(c *Context, req *api.ChatJoinRequest) error { hit <- req.From.ID; return nil })
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
From: api.User{ID: 11},
Chat: api.Chat{ID: 1},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(11), <-hit)
}
func TestRouter_OnChatJoinRequestFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
f := Filter[*api.ChatJoinRequest](func(req *api.ChatJoinRequest) bool { return req.Chat.ID == 22 })
r.OnChatJoinRequestFilter(f, func(c *Context, req *api.ChatJoinRequest) error { hit <- req.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
From: api.User{ID: 1},
Chat: api.Chat{ID: 22},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(22), <-hit)
}
func TestRouter_OnPreCheckoutQuery(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPreCheckoutQuery(func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
ID: "q1", From: api.User{ID: 1}, Currency: "USD",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "USD", <-hit)
}
func TestRouter_OnPreCheckoutQueryFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
f := Filter[*api.PreCheckoutQuery](func(q *api.PreCheckoutQuery) bool { return q.Currency == "EUR" })
r.OnPreCheckoutQueryFilter(f, func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
ID: "q1", From: api.User{ID: 1}, Currency: "EUR",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "EUR", <-hit)
}
func TestRouter_OnShippingQuery(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnShippingQuery(func(c *Context, q *api.ShippingQuery) error { hit <- q.ID; return nil })
upd := api.Update{UpdateID: 1, ShippingQuery: &api.ShippingQuery{ID: "sq1", From: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "sq1", <-hit)
}
func TestRouter_OnPoll(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPoll(func(c *Context, p *api.Poll) error { hit <- p.ID; return nil })
upd := api.Update{UpdateID: 1, Poll: &api.Poll{ID: "poll1"}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "poll1", <-hit)
}
func TestRouter_OnPollAnswer(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPollAnswer(func(c *Context, a *api.PollAnswer) error { hit <- a.PollID; return nil })
upd := api.Update{UpdateID: 1, PollAnswer: &api.PollAnswer{PollID: "p1", OptionIds: []int64{0}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "p1", <-hit)
}
func TestRouter_OnChosenInlineResult(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnChosenInlineResult(func(c *Context, res *api.ChosenInlineResult) error { hit <- res.ResultID; return nil })
upd := api.Update{UpdateID: 1, ChosenInlineResult: &api.ChosenInlineResult{ResultID: "r1", From: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "r1", <-hit)
}
func TestRouter_OnMessageReaction(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnMessageReaction(func(c *Context, u *api.MessageReactionUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, MessageReaction: &api.MessageReactionUpdated{Chat: api.Chat{ID: 33}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(33), <-hit)
}
func TestRouter_OnMessageReactionCount(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnMessageReactionCount(func(c *Context, u *api.MessageReactionCountUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, MessageReactionCount: &api.MessageReactionCountUpdated{Chat: api.Chat{ID: 44}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(44), <-hit)
}
func TestRouter_OnChatBoost(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChatBoost(func(c *Context, u *api.ChatBoostUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatBoost: &api.ChatBoostUpdated{Chat: api.Chat{ID: 55}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(55), <-hit)
}
func TestRouter_OnRemovedChatBoost(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnRemovedChatBoost(func(c *Context, u *api.ChatBoostRemoved) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, RemovedChatBoost: &api.ChatBoostRemoved{Chat: api.Chat{ID: 66}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(66), <-hit)
}
func TestRouter_OnBusinessConnection(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnBusinessConnection(func(c *Context, bc *api.BusinessConnection) error { hit <- bc.ID; return nil })
upd := api.Update{UpdateID: 1, BusinessConnection: &api.BusinessConnection{ID: "bc1", UserChatID: 1, User: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "bc1", <-hit)
}
func TestRouter_OnPurchasedPaidMedia(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPurchasedPaidMedia(func(c *Context, p *api.PaidMediaPurchased) error { hit <- p.PaidMediaPayload; return nil })
upd := api.Update{UpdateID: 1, PurchasedPaidMedia: &api.PaidMediaPurchased{From: api.User{ID: 1}, PaidMediaPayload: "payload1"}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "payload1", <-hit)
}
func TestRouter_ContextCancel_UnblocksWaitingAcquire(t *testing.T) {
// Fill the semaphore with slow handlers, send one more update, then cancel
// ctx. Run must unblock from the semaphore-acquire select and return.
const limit = 2
unblock := make(chan struct{})
slowHandler := func(c *Context, m *api.Message) error {
<-unblock
return nil
}
lu := newLiveUpdater()
r := New(client.New("t"), WithMaxConcurrency(limit))
r.OnMessageFilter(Filter[*api.Message](func(m *api.Message) bool { return true }), slowHandler)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runDone := make(chan error, 1)
go func() { runDone <- r.Run(ctx, lu) }()
// Send enough updates to fill semaphore.
for i := range limit {
lu.Send(api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}})
}
// Give goroutines time to acquire all semaphore slots.
time.Sleep(50 * time.Millisecond)
// Send one more update — Run will block trying to acquire the full semaphore.
lu.Send(api.Update{UpdateID: int64(limit + 1), Message: &api.Message{
MessageID: int64(limit + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "extra",
}})
// Give Run a moment to reach the semaphore-acquire select.
time.Sleep(30 * time.Millisecond)
cancel()
// Unblock handlers so wg.Wait() inside Run can complete, allowing Run to
// return (and write to runDone).
close(unblock)
select {
case err := <-runDone:
require.Error(t, err)
case <-time.After(2 * time.Second):
t.Fatal("Run did not unblock after context cancel")
}
}
+5
View File
@@ -0,0 +1,5 @@
// Package gotelegram is the module root.
//
// The public API lives in the api, client, transport, and dispatch packages.
// See https://github.com/lukaszraczylo/go-telegram for documentation.
package gotelegram
+39
View File
@@ -0,0 +1,39 @@
# Examples
Each subdirectory contains a self-contained sample bot demonstrating one feature area.
| Example | What it shows |
|---|---|
| [echo](./echo) | Long-poll bot that echoes text back to the sender |
| [webhook](./webhook) | Webhook delivery with secret-token verification |
| [callback](./callback) | Inline keyboard with callback queries and counter state |
| [conversation](./conversation) | Multi-step conversation flow with `dispatch/conversation` |
| [files](./files) | Upload and download files via `api.DownloadFile` |
| [inline](./inline) | Inline-mode bot returning search-style results |
| [middleware](./middleware) | Custom middleware chains via `Router.Use` |
| [stateful](./stateful) | Per-user state managed via closures |
| [welcome](./welcome) | Greet new chat members; detect and log departures |
| [moderation](./moderation) | `/kick`, `/ban`, `/mute`, `/warn` with admin permission checks |
| [polls](./polls) | Create polls and tally answers via `OnPollAnswer` |
| [payments](./payments) | Telegram Payments: sendInvoice → pre_checkout_query → successful_payment |
| [pagination](./pagination) | Multi-page inline keyboard with stateless prev/next navigation |
| [admin](./admin) | Auth middleware allowlisting specific user IDs via `Router.Use` |
## Running
All examples follow the same pattern:
```bash
export TELEGRAM_BOT_TOKEN=123456:ABC...
go run ./examples/<name>
```
Webhook examples need a public HTTPS endpoint (use Cloudflare Tunnel, ngrok, or similar).
## Common patterns
**Retry-safe HTTP** — every example wraps the HTTP client with `client.NewRetryDoer`, which automatically honours Telegram's `retry_after` field on 429 responses.
**Graceful shutdown** — all examples use `signal.NotifyContext` so the bot drains cleanly on `SIGINT`/`SIGTERM`.
**Structured logging** — for production, wire a logger via `client.WithLogger` and wrap the process in supervision (systemd unit, k8s liveness probe, etc.).
+40
View File
@@ -0,0 +1,40 @@
# admin
Authentication middleware that restricts the bot to an allowlist of Telegram user IDs.
## What it shows
- `router.Use(...)` to install a global `Middleware[*api.Update]`
- Parsing `ALLOWED_USERS` env var into a `map[int64]bool` lookup set
- Extracting sender ID from multiple update types in one helper
- Silent drop pattern for unauthorized updates (no error, no reply)
## Environment variables
| Variable | Required | Description |
|---|---|---|
| `TELEGRAM_BOT_TOKEN` | Yes | Bot token from @BotFather |
| `ALLOWED_USERS` | No | Comma-separated numeric user IDs, e.g. `123456,789012`. If unset, all users are permitted. |
## Finding your user ID
Send `/whoami` to the bot — it replies with your numeric Telegram user ID. Add that ID to `ALLOWED_USERS` to restrict the bot to you.
## Extending
Combine with `examples/moderation` to ensure only group admins can invoke moderation commands:
```go
router.Use(allowlistMiddleware(adminIDs))
router.OnCommand("/ban", banHandler)
```
For group-context admin checks (verify the sender is an admin of *that specific group*), use `api.GetChatAdministrators` and check the result dynamically rather than a static ID list.
## Running
```bash
export TELEGRAM_BOT_TOKEN=123456:ABC...
export ALLOWED_USERS=111111,222222
go run ./examples/admin
```

Some files were not shown because too many files have changed in this diff Show More