mirror of
https://github.com/lukaszraczylo/go-telegram.git
synced 2026-06-05 22:43:59 +00:00
Initial release of go-telegram
A fully-generated, strongly-typed Go client for the Telegram Bot API. * 176 methods + 301 types generated from Bot API v10.0 * 1408 auto-generated tests (8 scenarios per method) * Typed unions throughout — no 'any' in the public surface * Pluggable HTTP transport and JSON codec (default goccy/go-json) * Built-in retry middleware honouring Telegram's retry_after * Generic dispatcher with filters and conversation handlers * Self-verifying codegen pipeline (regen → audit → emit → run tests) * 14 example bots covering common patterns
This commit is contained in:
@@ -0,0 +1,274 @@
|
|||||||
|
name: ci
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
dry-run-release:
|
||||||
|
description: "Compute release version, do not tag or release"
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
# Cancel any in-flight CI runs for the same branch when a new push lands.
|
||||||
|
concurrency:
|
||||||
|
group: ci-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
vet:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- run: go vet ./...
|
||||||
|
|
||||||
|
staticcheck:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- run: go install honnef.co/go/tools/cmd/staticcheck@v0.7.0
|
||||||
|
- run: staticcheck ./...
|
||||||
|
|
||||||
|
govulncheck:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- run: go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||||
|
- run: govulncheck ./...
|
||||||
|
|
||||||
|
gosec:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- run: go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||||
|
# G404: math/rand/v2 jitter in transport/backoff.go — intentional (not crypto)
|
||||||
|
# G304: os.ReadFile from CLI flag variable — intentional (tool, not server)
|
||||||
|
# G306: 0o644 on generated doc artifacts in cmd/scrape — intentional
|
||||||
|
# G204: git subprocess in cmd/audit uses CLI flag path — intentional (operator tool)
|
||||||
|
# G706: log.Printf with values from Telegram/env in examples — illustrative,
|
||||||
|
# library users are expected to sanitise before logging in production
|
||||||
|
- run: gosec -quiet -exclude=G404,G304,G306,G204,G706 -exclude-dir=testdata ./...
|
||||||
|
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- run: go test -race -coverprofile=coverage.out ./...
|
||||||
|
- name: Build all examples
|
||||||
|
run: go build ./examples/...
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: coverage
|
||||||
|
path: coverage.out
|
||||||
|
|
||||||
|
codegen-clean:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- name: Regenerate against pinned snapshot
|
||||||
|
run: make regen-from-fixture
|
||||||
|
- name: Assert clean diff
|
||||||
|
run: git diff --exit-code internal/spec/api.json api/
|
||||||
|
|
||||||
|
audit:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # need history for drift comparison
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/go/pkg/mod
|
||||||
|
~/.cache/go-build
|
||||||
|
key: ${{ runner.os }}-go-1.25-${{ hashFiles('**/go.sum') }}
|
||||||
|
- name: Audit fallbacks
|
||||||
|
run: make audit
|
||||||
|
- name: Audit drift vs base
|
||||||
|
# On PRs: compare against the merge base (origin/<base>).
|
||||||
|
# On push to main: compare against the parent commit.
|
||||||
|
# Drift is informational; doesn't fail CI.
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
BASE="origin/${{ github.base_ref }}"
|
||||||
|
else
|
||||||
|
BASE="HEAD~1"
|
||||||
|
fi
|
||||||
|
echo "Drift base: $BASE"
|
||||||
|
go run ./cmd/audit -ir internal/spec/api.json -drift -against "$BASE" || true
|
||||||
|
|
||||||
|
# Aggregate gate — depends on every check above. Used as a single
|
||||||
|
# required status check in branch protection AND as a dependency for the
|
||||||
|
# release job below.
|
||||||
|
ci-success:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [vet, staticcheck, govulncheck, gosec, test, codegen-clean, audit]
|
||||||
|
if: always()
|
||||||
|
steps:
|
||||||
|
- name: All checks passed
|
||||||
|
if: ${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}
|
||||||
|
run: echo "ci-success"
|
||||||
|
- name: At least one check failed
|
||||||
|
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') }}
|
||||||
|
run: |
|
||||||
|
echo "Failed/cancelled jobs:"
|
||||||
|
echo '${{ toJSON(needs) }}'
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
# Auto-release fires on every clean push to main (and on manual
|
||||||
|
# workflow_dispatch for testing). Computes next SemVer from commit
|
||||||
|
# history via lukaszraczylo/semver-generator, dual-tags
|
||||||
|
# (v<X.Y.Z> + bot-api-v<A.B>), runs GoReleaser.
|
||||||
|
release:
|
||||||
|
needs: ci-success
|
||||||
|
if: |
|
||||||
|
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
|
||||||
|
github.event_name == 'workflow_dispatch'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
|
||||||
|
# Action interface from
|
||||||
|
# https://github.com/lukaszraczylo/semver-generator/blob/main/action.yml
|
||||||
|
# Inputs: repository_local: true (use already-cloned repo)
|
||||||
|
# existing: true (respect existing tags as base)
|
||||||
|
# Output: semantic_version (bare version string, no "v" prefix)
|
||||||
|
- name: Compute next SemVer
|
||||||
|
id: semver
|
||||||
|
uses: lukaszraczylo/semver-generator@v1
|
||||||
|
with:
|
||||||
|
repository_local: true
|
||||||
|
existing: true
|
||||||
|
config_file: .semver.yaml
|
||||||
|
|
||||||
|
- name: Read Bot API version
|
||||||
|
id: api_version
|
||||||
|
run: |
|
||||||
|
VERSION=$(python3 -c 'import json; print(json.load(open("internal/spec/api.json"))["version"])')
|
||||||
|
if [ -z "$VERSION" ] || [ "$VERSION" = "null" ]; then
|
||||||
|
echo "tag=" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "tag=bot-api-v${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Dry-run summary
|
||||||
|
if: github.event_name == 'workflow_dispatch' && inputs.dry-run-release == true
|
||||||
|
run: |
|
||||||
|
echo "Would release: v${{ steps.semver.outputs.semantic_version }}"
|
||||||
|
if [ -n "${{ steps.api_version.outputs.tag }}" ]; then
|
||||||
|
echo "Would also tag: ${{ steps.api_version.outputs.tag }} (Bot API ${{ steps.api_version.outputs.version }})"
|
||||||
|
fi
|
||||||
|
echo "Skipping tag + release (dry-run)."
|
||||||
|
|
||||||
|
- name: Tag library + bot-api versions
|
||||||
|
if: github.event_name != 'workflow_dispatch' || inputs.dry-run-release == false
|
||||||
|
env:
|
||||||
|
LIB_TAG: v${{ steps.semver.outputs.semantic_version }}
|
||||||
|
API_TAG: ${{ steps.api_version.outputs.tag }}
|
||||||
|
API_VER: ${{ steps.api_version.outputs.version }}
|
||||||
|
run: |
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git tag -a "$LIB_TAG" -m "Release $LIB_TAG"
|
||||||
|
git push origin "$LIB_TAG"
|
||||||
|
|
||||||
|
if [ -n "$API_TAG" ]; then
|
||||||
|
# Force-update the bot-api tag so it always points at the latest
|
||||||
|
# release that supports that API version.
|
||||||
|
if git rev-parse "$API_TAG" >/dev/null 2>&1; then
|
||||||
|
git tag -f -a "$API_TAG" -m "go-telegram release $LIB_TAG (Bot API $API_VER)"
|
||||||
|
git push -f origin "$API_TAG"
|
||||||
|
else
|
||||||
|
git tag -a "$API_TAG" -m "go-telegram release $LIB_TAG (Bot API $API_VER)"
|
||||||
|
git push origin "$API_TAG"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Run GoReleaser
|
||||||
|
if: github.event_name != 'workflow_dispatch' || inputs.dry-run-release == false
|
||||||
|
uses: goreleaser/goreleaser-action@v6
|
||||||
|
with:
|
||||||
|
distribution: goreleaser
|
||||||
|
version: '~> v2'
|
||||||
|
args: release --clean
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
BOT_API_VERSION: ${{ steps.api_version.outputs.version }}
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
name: regen
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 6 * * 1" # Monday 06:00 UTC
|
||||||
|
workflow_dispatch: {}
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
regen:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # full history so audit -drift can compare against main
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.25.x'
|
||||||
|
check-latest: true
|
||||||
|
|
||||||
|
- name: Capture latest snapshot
|
||||||
|
id: snapshot
|
||||||
|
run: |
|
||||||
|
DATE=$(date +%Y-%m-%d)
|
||||||
|
DEST="testdata/html/snapshot_${DATE}.html"
|
||||||
|
curl -fsSL --user-agent "go-telegram codegen scraper" \
|
||||||
|
https://core.telegram.org/bots/api > "$DEST"
|
||||||
|
ln -sf "snapshot_${DATE}.html" testdata/html/latest.html
|
||||||
|
echo "date=$DATE" >> $GITHUB_OUTPUT
|
||||||
|
echo "dest=$DEST" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Regenerate (scrape + emit, with clean-generated)
|
||||||
|
# `make regen` depends on `clean-generated`, which sweeps any orphan
|
||||||
|
# api/*.gen.go files left behind by removed/renamed methods.
|
||||||
|
run: make regen
|
||||||
|
|
||||||
|
- name: Audit fallbacks
|
||||||
|
id: audit
|
||||||
|
run: |
|
||||||
|
set +e
|
||||||
|
OUT=$(make audit 2>&1)
|
||||||
|
STATUS=$?
|
||||||
|
set -e
|
||||||
|
echo "$OUT"
|
||||||
|
{
|
||||||
|
echo 'output<<EOF'
|
||||||
|
echo "$OUT"
|
||||||
|
echo 'EOF'
|
||||||
|
} >> $GITHUB_OUTPUT
|
||||||
|
echo "status=$STATUS" >> $GITHUB_OUTPUT
|
||||||
|
# Don't fail the workflow on fallbacks — surface them in the PR body so
|
||||||
|
# the reviewer can extend overrides.json or fix scraper patterns.
|
||||||
|
|
||||||
|
- name: Audit drift vs main
|
||||||
|
id: drift
|
||||||
|
run: |
|
||||||
|
set +e
|
||||||
|
DRIFT=$(go run ./cmd/audit -ir internal/spec/api.json -drift -against origin/main 2>&1)
|
||||||
|
set -e
|
||||||
|
echo "$DRIFT"
|
||||||
|
{
|
||||||
|
echo 'output<<EOF'
|
||||||
|
echo "$DRIFT"
|
||||||
|
echo 'EOF'
|
||||||
|
} >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: go test -race ./...
|
||||||
|
|
||||||
|
- name: Detect changes
|
||||||
|
id: diff
|
||||||
|
run: |
|
||||||
|
git status --porcelain
|
||||||
|
if git diff --quiet internal/spec/api.json api/ testdata/html/; then
|
||||||
|
echo "no_changes=true" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Read API version
|
||||||
|
if: steps.diff.outputs.no_changes != 'true'
|
||||||
|
id: meta
|
||||||
|
run: |
|
||||||
|
VERSION=$(python3 -c 'import json; print(json.load(open("internal/spec/api.json")).get("version", "unknown"))')
|
||||||
|
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Open PR
|
||||||
|
if: steps.diff.outputs.no_changes != 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
commit-message: |
|
||||||
|
chore(api): regenerate from Telegram Bot API v${{ steps.meta.outputs.version }}
|
||||||
|
branch: regen/api-v${{ steps.meta.outputs.version }}
|
||||||
|
title: "chore(api): regenerate from Telegram Bot API v${{ steps.meta.outputs.version }}"
|
||||||
|
labels: automated, api-update
|
||||||
|
body: |
|
||||||
|
Automated regeneration from `https://core.telegram.org/bots/api`.
|
||||||
|
|
||||||
|
**API version:** v${{ steps.meta.outputs.version }}
|
||||||
|
**Snapshot date:** ${{ steps.snapshot.outputs.date }}
|
||||||
|
|
||||||
|
## Audit (fallbacks)
|
||||||
|
```
|
||||||
|
${{ steps.audit.outputs.output }}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Drift vs `main`
|
||||||
|
```
|
||||||
|
${{ steps.drift.outputs.output }}
|
||||||
|
```
|
||||||
|
|
||||||
|
Inspect the IR diff (`internal/spec/api.json`) for added/changed/removed methods.
|
||||||
|
|
||||||
|
CI must pass before merge. Auto-merge: enable when satisfied with the diff.
|
||||||
+45
@@ -0,0 +1,45 @@
|
|||||||
|
# Binaries
|
||||||
|
/bin/
|
||||||
|
*.exe
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
|
||||||
|
# Test artifacts
|
||||||
|
*.test
|
||||||
|
*.out
|
||||||
|
coverage.out
|
||||||
|
coverage.html
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Local env
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Stray binaries at repo root from `go build ./cmd/...` or `go run` artefacts.
|
||||||
|
# Listed explicitly (not via /* glob) so source dirs are never accidentally ignored.
|
||||||
|
/echo
|
||||||
|
/webhook
|
||||||
|
/genapi
|
||||||
|
/scrape
|
||||||
|
/audit
|
||||||
|
/callback
|
||||||
|
/conversation
|
||||||
|
/files
|
||||||
|
/inline
|
||||||
|
/middleware
|
||||||
|
/stateful
|
||||||
|
/admin
|
||||||
|
/moderation
|
||||||
|
/pagination
|
||||||
|
/payments
|
||||||
|
/polls
|
||||||
|
/welcome
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
linters:
|
||||||
|
default: none
|
||||||
|
enable:
|
||||||
|
- govet
|
||||||
|
- staticcheck
|
||||||
|
- unused
|
||||||
|
- errcheck
|
||||||
|
- gosec
|
||||||
|
|
||||||
|
settings:
|
||||||
|
unused:
|
||||||
|
field-writes-are-uses: true
|
||||||
|
post-statements-are-reads: true
|
||||||
|
exported-is-used: true
|
||||||
|
exported-fields-are-used: true
|
||||||
|
|
||||||
|
govet:
|
||||||
|
enable-all: true
|
||||||
|
disable:
|
||||||
|
- fieldalignment # Field order is intentional per project spec; alignment not enforced.
|
||||||
|
|
||||||
|
staticcheck:
|
||||||
|
checks: ["all"]
|
||||||
|
|
||||||
|
exclusions:
|
||||||
|
presets:
|
||||||
|
- common-false-positives
|
||||||
|
rules:
|
||||||
|
- path: _test\.go
|
||||||
|
linters:
|
||||||
|
- unused
|
||||||
|
- errcheck
|
||||||
|
- path: \.pb\.go$
|
||||||
|
linters:
|
||||||
|
- all
|
||||||
|
- path: transport/backoff\.go
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
text: "G404" # math/rand/v2 acceptable for jitter
|
||||||
|
- path: cmd/scrape/
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
text: "G306" # 0o644 intentional: output files are world-readable docs artifacts
|
||||||
|
- path: cmd/audit/
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
text: "G204" # git subprocess with CLI flag path — intentional operator tool
|
||||||
|
- path: cmd/audit/.*_test\.go
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
text: "G302" # test helper sets 0o755 on fake git binary — executable permission intentional
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- gofmt
|
||||||
|
settings:
|
||||||
|
gofmt:
|
||||||
|
simplify: true
|
||||||
|
|
||||||
|
run:
|
||||||
|
timeout: 5m
|
||||||
|
tests: true
|
||||||
|
modules-download-mode: readonly
|
||||||
|
|
||||||
|
output:
|
||||||
|
formats:
|
||||||
|
text:
|
||||||
|
path: stdout
|
||||||
|
colors: true
|
||||||
|
sort-results: true
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
# Documentation: https://goreleaser.com
|
||||||
|
version: 2
|
||||||
|
|
||||||
|
project_name: go-telegram
|
||||||
|
|
||||||
|
before:
|
||||||
|
hooks:
|
||||||
|
- go mod tidy
|
||||||
|
|
||||||
|
builds:
|
||||||
|
- id: scrape
|
||||||
|
main: ./cmd/scrape
|
||||||
|
binary: tg-scrape
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=0
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
- darwin
|
||||||
|
- windows
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}}
|
||||||
|
|
||||||
|
- id: genapi
|
||||||
|
main: ./cmd/genapi
|
||||||
|
binary: tg-genapi
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=0
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
- darwin
|
||||||
|
- windows
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}}
|
||||||
|
|
||||||
|
- id: audit
|
||||||
|
main: ./cmd/audit
|
||||||
|
binary: tg-audit
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=0
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
- darwin
|
||||||
|
- windows
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}}
|
||||||
|
|
||||||
|
archives:
|
||||||
|
- id: default
|
||||||
|
formats: [tar.gz]
|
||||||
|
name_template: >-
|
||||||
|
{{ .ProjectName }}_{{ .Version }}_
|
||||||
|
{{- title .Os }}_
|
||||||
|
{{- if eq .Arch "amd64" }}x86_64
|
||||||
|
{{- else if eq .Arch "386" }}i386
|
||||||
|
{{- else }}{{ .Arch }}{{ end }}
|
||||||
|
format_overrides:
|
||||||
|
- goos: windows
|
||||||
|
formats: [zip]
|
||||||
|
files:
|
||||||
|
- LICENSE
|
||||||
|
- README.md
|
||||||
|
- CHANGELOG.md
|
||||||
|
|
||||||
|
checksum:
|
||||||
|
name_template: 'checksums.txt'
|
||||||
|
|
||||||
|
snapshot:
|
||||||
|
version_template: '{{ incpatch .Version }}-next'
|
||||||
|
|
||||||
|
changelog:
|
||||||
|
sort: asc
|
||||||
|
use: github
|
||||||
|
groups:
|
||||||
|
- title: Features
|
||||||
|
regexp: '^.*?feat(\(.+\))?:.+$'
|
||||||
|
order: 0
|
||||||
|
- title: Bug fixes
|
||||||
|
regexp: '^.*?fix(\(.+\))?:.+$'
|
||||||
|
order: 1
|
||||||
|
- title: Documentation
|
||||||
|
regexp: '^.*?docs(\(.+\))?:.+$'
|
||||||
|
order: 2
|
||||||
|
- title: Other
|
||||||
|
order: 999
|
||||||
|
filters:
|
||||||
|
exclude:
|
||||||
|
- '^chore:'
|
||||||
|
- '^test:'
|
||||||
|
- '^style:'
|
||||||
|
- 'merge conflict'
|
||||||
|
- Merge pull request
|
||||||
|
- Merge remote-tracking branch
|
||||||
|
- Merge branch
|
||||||
|
- go mod tidy
|
||||||
|
|
||||||
|
release:
|
||||||
|
github:
|
||||||
|
owner: lukaszraczylo
|
||||||
|
name: go-telegram
|
||||||
|
prerelease: auto
|
||||||
|
mode: replace
|
||||||
|
header: |
|
||||||
|
## go-telegram {{ .Tag }}
|
||||||
|
|
||||||
|
Released on {{ .Date }} for Telegram Bot API regeneration tooling.
|
||||||
|
Built against **Bot API {{ envOrDefault "BOT_API_VERSION" "(unspecified)" }}**.
|
||||||
|
|
||||||
|
The library itself ships via `go get github.com/lukaszraczylo/go-telegram@{{ .Tag }}`.
|
||||||
|
Binaries below are the codegen tools (`tg-scrape`, `tg-genapi`, `tg-audit`)
|
||||||
|
for users who want to vendor regen tooling.
|
||||||
|
footer: |
|
||||||
|
**Full changelog**: https://github.com/lukaszraczylo/go-telegram/compare/{{ .PreviousTag }}...{{ .Tag }}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
# Configuration for lukaszraczylo/semver-generator.
|
||||||
|
# Reference: https://github.com/lukaszraczylo/semver-generator
|
||||||
|
#
|
||||||
|
# Word matching is fuzzy + case-insensitive. The keywords below mirror
|
||||||
|
# Conventional Commits prefixes used in this repo's git history.
|
||||||
|
|
||||||
|
version: 1
|
||||||
|
|
||||||
|
# Respect existing v* tags as the version baseline. semver-generator finds
|
||||||
|
# the highest existing tag and bumps from there.
|
||||||
|
force:
|
||||||
|
existing: true
|
||||||
|
|
||||||
|
# Skip merge commits and machine-generated traffic that would otherwise
|
||||||
|
# spuriously bump the version.
|
||||||
|
blacklist:
|
||||||
|
- "Merge branch"
|
||||||
|
- "Merge pull request"
|
||||||
|
- "Merge remote-tracking branch"
|
||||||
|
- "go mod tidy"
|
||||||
|
|
||||||
|
# Strip the auto-generated bot-api-vX.Y tag prefix when scanning existing
|
||||||
|
# tags — those are markers that point at library releases, not version
|
||||||
|
# sources themselves.
|
||||||
|
tag_prefixes:
|
||||||
|
- "bot-api-"
|
||||||
|
|
||||||
|
wording:
|
||||||
|
patch:
|
||||||
|
- "fix"
|
||||||
|
- "chore"
|
||||||
|
- "docs"
|
||||||
|
- "test"
|
||||||
|
- "style"
|
||||||
|
- "refactor"
|
||||||
|
- "build"
|
||||||
|
- "ci"
|
||||||
|
- "perf"
|
||||||
|
minor:
|
||||||
|
- "feat"
|
||||||
|
major:
|
||||||
|
- "breaking"
|
||||||
|
- "BREAKING CHANGE"
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file. The format follows [Keep a Changelog](https://keepachangelog.com/) and this project adheres to [Semantic Versioning](https://semver.org/).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.0.0] - 2026-05-09
|
||||||
|
|
||||||
|
Initial public release. Built against Telegram Bot API v10.0 (176 methods, 301 types).
|
||||||
|
|
||||||
|
### Library surface
|
||||||
|
- `client.Bot` with pluggable `HTTPDoer`, `Codec` (default `github.com/goccy/go-json`), and `Logger` interfaces.
|
||||||
|
- Generic `client.Call[Req, Resp]` and `client.CallRaw[Req]` helpers — every API method funnels through one of these.
|
||||||
|
- `client.RetryDoer` middleware — exponential backoff with jitter; honours Telegram's `retry_after`; replays request bodies across attempts. Configurable via `RetryOption` functions (`WithMaxAttempts`, `WithBase`, `WithMax`, `WithFactor`, `WithJitter`).
|
||||||
|
- Typed errors: `*APIError` (sentinel-mapped: `ErrUnauthorized`, `ErrForbidden`, `ErrTooManyRequests`, `ErrChatNotFound`, `ErrMessageNotModified`, `ErrBadRequest`, `ErrUserNotFound`, `ErrMessageNotFound`), `*NetworkError`, `*ParseError`. Bot tokens redacted from error messages.
|
||||||
|
|
||||||
|
### Update delivery
|
||||||
|
- `transport.LongPoller` — getUpdates loop with `ExponentialBackoff`, retry-after honouring, graceful shutdown via `Stop`.
|
||||||
|
- `transport.WebhookServer` — `http.Handler` + `ListenAndServe`; secret-token verification (constant-time); 1 MiB body cap; in-flight handler tracking via `WaitGroup`.
|
||||||
|
|
||||||
|
### Dispatcher
|
||||||
|
- Generic `dispatch.Handler[T]` and `dispatch.Middleware[T]`.
|
||||||
|
- 21 typed handler registrations: `OnCommand`, `OnText`, `OnCallback`, `OnInlineQuery`, `OnEditedMessage`, `OnChannelPost`, `OnEditedChannelPost`, `OnMyChatMember`, `OnChatMember`, `OnChatJoinRequest`, `OnPreCheckoutQuery`, `OnShippingQuery`, `OnPoll`, `OnPollAnswer`, `OnChosenInlineResult`, `OnMessageReaction`, `OnMessageReactionCount`, `OnChatBoost`, `OnRemovedChatBoost`, `OnBusinessConnection`, `OnPurchasedPaidMedia`. Filter variants available for message, callback, inline-query, chat-member, chat-join-request, and pre-checkout-query types.
|
||||||
|
- Composable filters in `dispatch/filters/{message,callback,inline,chatmember,chatjoinrequest,precheckoutquery}` packages with `And`/`Or`/`Not`/`All`/`Any` combinators and 20+ filter helpers.
|
||||||
|
- `dispatch/conversation` — multi-step state machines with `Storage` interface (`MemoryStorage` default), pluggable `KeyStrategy` (`KeyByUser`, `KeyByChat`, `KeyByUserAndChat`), entry/state/exit/fallback steps, `AllowReEntry`. `Next(state)` and `End()` sentinel errors drive transitions. `Conversation.Dispatch` integrates as middleware via `router.Use`.
|
||||||
|
- Per-update goroutine pool (default 50; configurable via `WithMaxConcurrency`). Pass `0` for serial dispatch.
|
||||||
|
- Handler groups with `EndGroups`/`ContinueGroups` flow control.
|
||||||
|
- `NamedHandlers[T]` for runtime registration and replacement.
|
||||||
|
- Panic-recovery middleware + automatic handler-error logging registered by default.
|
||||||
|
|
||||||
|
### Generated API
|
||||||
|
- Full Bot API v10.0 surface in `api/*.gen.go` — 176 methods, 301 types, regenerated from a committed HTML snapshot of `core.telegram.org/bots/api`.
|
||||||
|
- Strongly typed unions: `ChatID` (Integer-or-String), `*InputFile` (InputFile-or-String), `MessageOrBool` (Message-or-True returns).
|
||||||
|
- 13 discriminated unions with auto-decode: `ChatMember`, `MessageOrigin`, `ReactionType`, `PaidMedia`, `BackgroundType`, `BackgroundFill`, `ChatBoostSource`, `RevenueWithdrawalState`, `TransactionPartner`, `MenuButton`, `OwnedGift`, `StoryAreaType`, `MaybeInaccessibleMessage`. Parent structs containing union-typed fields auto-decode via generated `UnmarshalJSON`.
|
||||||
|
- Sealed-interface return types use `client.CallRaw` + `Unmarshal<Name>` dispatch (e.g., `GetChatMember` returns `ChatMember` interface, decoded into the correct concrete variant).
|
||||||
|
|
||||||
|
### Runtime helpers
|
||||||
|
- `api.MeCache` — concurrent-safe `getMe` cache; zero-value safe.
|
||||||
|
- `api.DownloadFile` / `api.DownloadFileByPath` — fetch file contents from the Telegram CDN.
|
||||||
|
- `(*Message).GetSender()` — unifies `From`, `SenderChat`, and anonymous-admin fields into a `*Sender`.
|
||||||
|
|
||||||
|
### Codegen pipeline (`cmd/scrape`, `cmd/genapi`, `cmd/audit`)
|
||||||
|
- Two-stage codegen: HTML → IR (`internal/spec/api.json`) → Go.
|
||||||
|
- `internal/spec/overrides.json` pins specific method returns/field types when scraper regex does not match a particular doc phrasing.
|
||||||
|
- `cmd/audit` reports any-typed fields, fallback `bool` returns, and signature drift vs HEAD's IR. Exit codes: 0 clean, 1 fallback, 2 drift, 3 invalid.
|
||||||
|
- `make regen` is self-verifying: clean → scrape → audit → emit (code + 1428 tests) → run tests.
|
||||||
|
- Auto-generated tests cover 8 scenarios per method (1428 total): Success, APIError 429, NetworkError, ParseError, ContextCanceled, MissingRequiredFields (400 + ErrBadRequest), Forbidden (403 + ErrForbidden), ServerError (500 + IsRetryable).
|
||||||
|
|
||||||
|
### Tooling
|
||||||
|
- `.github/workflows/ci.yml` — Go matrix 1.23/1.24, vet, staticcheck, govulncheck, gosec, race tests, codegen-clean check, audit + drift detection.
|
||||||
|
- `.github/workflows/regen.yml` — weekly cron + workflow_dispatch; scrapes live API, regenerates, runs tests, opens auto-PR with audit summary in body.
|
||||||
|
- `.github/workflows/release.yml` — semver-generator-driven version bump; dual tagging (library SemVer + `bot-api-v<X.Y>` marker); GoReleaser.
|
||||||
|
- `.goreleaser.yaml` — ships `tg-scrape`, `tg-genapi`, `tg-audit` binaries for Linux/macOS/Windows × amd64/arm64.
|
||||||
|
|
||||||
|
### Examples (14)
|
||||||
|
echo, webhook, callback, conversation, files, inline, middleware, stateful, welcome, moderation, polls, payments, pagination, admin.
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
# Contributing
|
||||||
|
|
||||||
|
Thanks for your interest in go-telegram. The library mixes hand-written and generated code; this guide explains how to update each.
|
||||||
|
|
||||||
|
## Project layout
|
||||||
|
|
||||||
|
- **`client/`** — hand-written Bot client, generic Call helper, error taxonomy, retry middleware. Stable; rarely changes.
|
||||||
|
- **`transport/`** — long-poll and webhook updaters. Hand-written.
|
||||||
|
- **`dispatch/`** — typed router with command/text/callback matchers. Hand-written.
|
||||||
|
- **`api/`** — generated types and method wrappers (`*.gen.go`) plus runtime helpers (`runtime.go`, `download.go`, `me.go`).
|
||||||
|
- **`internal/spec/`** — IR types + the committed `api.json` snapshot of the Telegram Bot API.
|
||||||
|
- **`cmd/scrape/`** — HTML scraper that produces `internal/spec/api.json`.
|
||||||
|
- **`cmd/genapi/`** — emitter that consumes `api.json` and renders `api/*.gen.go`.
|
||||||
|
|
||||||
|
## Workflows
|
||||||
|
|
||||||
|
### Updating to a newer Telegram Bot API version
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make snapshot # fetch latest HTML from core.telegram.org
|
||||||
|
make regen # scrape + emit
|
||||||
|
go test -race ./... # verify
|
||||||
|
```
|
||||||
|
|
||||||
|
If the live page introduces a phrasing the scraper doesn't recognise, you'll see methods falling back to `bool` returns or struct fields typed `any`. Check the audit script in `cmd/scrape/method_test.go` and add new patterns to `cmd/scrape/method.go` and / or `cmd/scrape/table.go`. Then `go test -update ./cmd/scrape/...` to refresh the small-fixture golden, and re-run `make regen`.
|
||||||
|
|
||||||
|
### Adding a new union for auto-decode
|
||||||
|
|
||||||
|
If Telegram introduces a new discriminated union type (similar to `ChatMember`):
|
||||||
|
|
||||||
|
1. Add an entry to `knownDiscriminators` in `cmd/genapi/emitter.go`.
|
||||||
|
2. Run `make regen`.
|
||||||
|
3. The emitter will produce `UnmarshalXxx` for the union and per-struct `UnmarshalJSON` for any field referencing it.
|
||||||
|
|
||||||
|
### Updating runtime helpers
|
||||||
|
|
||||||
|
Edit `api/runtime.go`, `api/download.go`, `api/me.go`, or any of `client/*.go`, `transport/*.go`, `dispatch/*.go`. Add tests for new functionality. CI runs `go test -race ./...`, `go vet`, `staticcheck`, and the codegen-clean check (which asserts the committed `api/` matches what the pipeline produces from the committed snapshot).
|
||||||
|
|
||||||
|
### Conventions
|
||||||
|
|
||||||
|
- Doc comments on every exported symbol — generated types carry verbatim Telegram prose.
|
||||||
|
- No `//nolint` directives anywhere; if the linter complains, fix the code or update `.golangci.yml`.
|
||||||
|
- No reordering of struct fields for `fieldalignment` — JSON field order tracks the spec for diff readability.
|
||||||
|
- TDD where practical: failing test, then implementation, then commit.
|
||||||
|
- Conventional Commits style for messages: `feat(...):`, `fix(...):`, `docs(...):`, `chore(...):`, `refactor(...):`, `test(...):`.
|
||||||
|
|
||||||
|
## Running locally
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test # unit tests
|
||||||
|
make test-race # with race detector (CI default)
|
||||||
|
make lint # vet + staticcheck
|
||||||
|
make integration # live API smoke tests (requires TELEGRAM_BOT_TOKEN)
|
||||||
|
```
|
||||||
|
|
||||||
|
The codegen tooling in `cmd/scrape` pulls `golang.org/x/net/html`. The runtime library packages depend only on the standard library plus `stretchr/testify` (test-only).
|
||||||
|
|
||||||
|
## Releasing
|
||||||
|
|
||||||
|
1. Bump version in `CHANGELOG.md`.
|
||||||
|
2. Tag with `git tag -a v0.x.0 -m "summary"` (no leading 'v' alone — use the full SemVer triple).
|
||||||
|
3. `git push --tags`.
|
||||||
|
|
||||||
|
(There is no GoReleaser config yet; releases are tag-only and `go install` works against tags.)
|
||||||
|
|
||||||
|
## Reporting issues
|
||||||
|
|
||||||
|
File issues on the GitHub repository with:
|
||||||
|
- The Telegram Bot API method involved (if applicable).
|
||||||
|
- A minimal reproduction (mocked HTTP transport is fine).
|
||||||
|
- The library version (`go list -m github.com/lukaszraczylo/go-telegram`).
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 Lukasz Raczylo
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
.PHONY: test test-race lint vet integration regen snapshot regen-from-fixture test-update-golden clean clean-generated audit audit-drift help
|
||||||
|
|
||||||
|
GO ?= go
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo "Targets:"
|
||||||
|
@echo " test - run unit tests"
|
||||||
|
@echo " test-race - run unit tests with race detector"
|
||||||
|
@echo " lint - go vet + staticcheck"
|
||||||
|
@echo " integration - run integration suite (requires TELEGRAM_BOT_TOKEN)"
|
||||||
|
@echo " snapshot - capture HTML snapshot from live API (Plan 2)"
|
||||||
|
@echo " regen - regenerate api/ from latest snapshot (Plan 2)"
|
||||||
|
@echo " regen-from-fixture - deterministic regen from pinned fixture (Plan 2)"
|
||||||
|
@echo " test-update-golden - refresh golden test fixtures (Plan 2)"
|
||||||
|
@echo " audit - report any-typed/bool fallbacks in current IR"
|
||||||
|
@echo " audit-drift - audit + compare against HEAD's IR for signature changes"
|
||||||
|
@echo " clean-generated - delete generated api/*.gen.go and internal/spec/api.json"
|
||||||
|
@echo " clean - clean-generated + transient artefacts (binaries, coverage)"
|
||||||
|
|
||||||
|
test:
|
||||||
|
$(GO) test ./...
|
||||||
|
|
||||||
|
test-race:
|
||||||
|
$(GO) test -race ./...
|
||||||
|
|
||||||
|
vet:
|
||||||
|
$(GO) vet ./...
|
||||||
|
|
||||||
|
lint: vet
|
||||||
|
@which staticcheck > /dev/null || (echo "install staticcheck: go install honnef.co/go/tools/cmd/staticcheck@latest" && exit 1)
|
||||||
|
staticcheck ./...
|
||||||
|
|
||||||
|
integration:
|
||||||
|
$(GO) test -tags=integration -v ./test/integration/...
|
||||||
|
|
||||||
|
SCRAPE_INPUT ?= testdata/html/snapshot_2026-05-08.html
|
||||||
|
SCRAPE_OUTPUT ?= internal/spec/api.json
|
||||||
|
|
||||||
|
snapshot:
|
||||||
|
./scripts/snapshot.sh
|
||||||
|
|
||||||
|
regen: clean-generated
|
||||||
|
$(GO) run ./cmd/scrape -input testdata/html/latest.html -output $(SCRAPE_OUTPUT)
|
||||||
|
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT)
|
||||||
|
$(GO) run ./cmd/genapi -input $(SCRAPE_OUTPUT) -outdir api
|
||||||
|
$(GO) test ./api/...
|
||||||
|
|
||||||
|
regen-from-fixture: clean-generated
|
||||||
|
$(GO) run ./cmd/scrape -input $(SCRAPE_INPUT) -output $(SCRAPE_OUTPUT)
|
||||||
|
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT)
|
||||||
|
$(GO) run ./cmd/genapi -input $(SCRAPE_OUTPUT) -outdir api
|
||||||
|
$(GO) test ./api/...
|
||||||
|
|
||||||
|
audit:
|
||||||
|
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT)
|
||||||
|
|
||||||
|
audit-drift:
|
||||||
|
$(GO) run ./cmd/audit -ir $(SCRAPE_OUTPUT) -drift
|
||||||
|
|
||||||
|
test-update-golden:
|
||||||
|
$(GO) test -run TestEmit -update ./cmd/genapi/...
|
||||||
|
$(GO) test -run TestScrape -update ./cmd/scrape/...
|
||||||
|
|
||||||
|
# clean-generated removes ONLY codegen output. Source code (cmd/scrape,
|
||||||
|
# cmd/genapi, runtime helpers) is untouched. Run before regen to avoid
|
||||||
|
# orphan files lingering when the IR shrinks (renamed/removed methods).
|
||||||
|
clean-generated:
|
||||||
|
rm -f api/*.gen.go api/*_gen_test.go
|
||||||
|
rm -f internal/spec/api.json
|
||||||
|
|
||||||
|
# clean removes generated output AND transient artefacts (binaries
|
||||||
|
# accidentally left at repo root, coverage reports). Source code is
|
||||||
|
# never touched.
|
||||||
|
clean: clean-generated
|
||||||
|
rm -f coverage.out coverage.html
|
||||||
|
rm -f echo webhook genapi scrape callback files inline conversation middleware stateful
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
# go-telegram
|
||||||
|
|
||||||
|
> A fully-generated, strongly-typed Go client for the Telegram Bot API — no `any`, no guessing.
|
||||||
|
|
||||||
|
[](https://github.com/lukaszraczylo/go-telegram/actions/workflows/ci.yml)
|
||||||
|
[](https://pkg.go.dev/github.com/lukaszraczylo/go-telegram)
|
||||||
|
[](go.mod)
|
||||||
|
[](LICENSE)
|
||||||
|
|
||||||
|
> Bot API **v10.0** · 176 methods · 301 types · 1428 auto-generated tests
|
||||||
|
|
||||||
|
Most Telegram bot libraries expose Telegram's "Integer or String" fields as `interface{}` or `any`. Every union type in go-telegram is a real Go type with compile-time safety and auto-decoding. The entire API surface is code-generated from a committed HTML snapshot of the live Telegram docs — regenerating picks up new Bot API versions in one command, with a self-verifying pipeline that catches regressions before they ship.
|
||||||
|
|
||||||
|
```go
|
||||||
|
bot := client.New(os.Getenv("TELEGRAM_BOT_TOKEN"),
|
||||||
|
client.WithHTTPClient(client.NewRetryDoer(client.NewDefaultHTTPDoer())),
|
||||||
|
)
|
||||||
|
|
||||||
|
router := dispatch.New(bot)
|
||||||
|
router.OnCommand("/start", func(c *dispatch.Context, m *api.Message) error {
|
||||||
|
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
|
||||||
|
ChatID: api.ChatIDFromInt(m.Chat.ID),
|
||||||
|
Text: "Hello! Send me anything to echo.",
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
router.OnText(`.+`, func(c *dispatch.Context, m *api.Message) error {
|
||||||
|
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
|
||||||
|
ChatID: api.ChatIDFromInt(m.Chat.ID),
|
||||||
|
Text: m.Text,
|
||||||
|
ReplyParameters: &api.ReplyParameters{MessageID: m.MessageID},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
router.Run(ctx, transport.NewLongPoller(bot))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Why go-telegram
|
||||||
|
|
||||||
|
| Feature | What it means for you |
|
||||||
|
|---|---|
|
||||||
|
| **Typed unions** | `ChatID`, `MessageOrBool`, `InputFile`, and 13 discriminated-union interfaces give you `switch v.(type)` instead of runtime panics |
|
||||||
|
| **Full Bot API v10.0** | 176 methods and 301 types — all generated, none hand-written, nothing missing |
|
||||||
|
| **Self-verifying codegen** | `make snapshot && make regen` regenerates everything and runs 1428 tests; any regression fails the pipeline |
|
||||||
|
| **Pluggable transport + codec** | `HTTPDoer` and `Codec` are one-method interfaces — swap in fasthttp, sonic, or your test fake without forking |
|
||||||
|
| **Retry middleware** | `RetryDoer` honours Telegram's `retry_after`, backs off on 5xx, replays request bodies |
|
||||||
|
| **Composable dispatcher** | Per-update goroutine pool (default 50), filter combinators (`And`/`Or`/`Not`), conversation state machines, named handlers |
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get github.com/lukaszraczylo/go-telegram
|
||||||
|
```
|
||||||
|
|
||||||
|
Full echo bot — long-poll, graceful shutdown, retry on 429:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/transport"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
token := os.Getenv("TELEGRAM_BOT_TOKEN")
|
||||||
|
if token == "" {
|
||||||
|
log.Fatal("TELEGRAM_BOT_TOKEN required")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
bot := client.New(token,
|
||||||
|
client.WithHTTPClient(client.NewRetryDoer(client.NewDefaultHTTPDoer())),
|
||||||
|
)
|
||||||
|
|
||||||
|
router := dispatch.New(bot)
|
||||||
|
router.OnCommand("/start", func(c *dispatch.Context, m *api.Message) error {
|
||||||
|
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
|
||||||
|
ChatID: api.ChatIDFromInt(m.Chat.ID),
|
||||||
|
Text: fmt.Sprintf("Hello %s! Send me anything.", m.From.FirstName),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
router.OnText(`.+`, func(c *dispatch.Context, m *api.Message) error {
|
||||||
|
_, err := api.SendMessage(c.Ctx, c.Bot, &api.SendMessageParams{
|
||||||
|
ChatID: api.ChatIDFromInt(m.Chat.ID),
|
||||||
|
Text: m.Text,
|
||||||
|
ReplyParameters: &api.ReplyParameters{MessageID: m.MessageID},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := router.Run(ctx, transport.NewLongPoller(bot)); err != nil && err != context.Canceled {
|
||||||
|
log.Printf("router exited: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Run any example: `TELEGRAM_BOT_TOKEN=xxx go run ./examples/<name>`
|
||||||
|
|
||||||
|
| Category | Example | What it shows |
|
||||||
|
|---|---|---|
|
||||||
|
| **Basics** | [`echo`](examples/echo) | Long-poll echo bot |
|
||||||
|
| | [`webhook`](examples/webhook) | Webhook server with secret-token verification |
|
||||||
|
| | [`files`](examples/files) | Upload and download cycle |
|
||||||
|
| | [`inline`](examples/inline) | Inline-mode results |
|
||||||
|
| **Conversations & state** | [`conversation`](examples/conversation) | Multi-step state machine with `/cancel` exit |
|
||||||
|
| | [`stateful`](examples/stateful) | Per-user state via closures |
|
||||||
|
| | [`callback`](examples/callback) | Inline keyboards and callback query handling |
|
||||||
|
| | [`pagination`](examples/pagination) | Multi-page inline keyboard |
|
||||||
|
| **Group management** | [`welcome`](examples/welcome) | Greet new chat members |
|
||||||
|
| | [`moderation`](examples/moderation) | Kick/ban/mute/warn with permission checks |
|
||||||
|
| | [`admin`](examples/admin) | Auth middleware allowlist |
|
||||||
|
| **Advanced** | [`middleware`](examples/middleware) | `Use` chains |
|
||||||
|
| | [`polls`](examples/polls) | `sendPoll` and answer tally |
|
||||||
|
| | [`payments`](examples/payments) | Invoice → pre-checkout → success |
|
||||||
|
|
||||||
|
## Concepts
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Bot client and pluggable transport</summary>
|
||||||
|
|
||||||
|
`client.New` accepts functional options:
|
||||||
|
|
||||||
|
```go
|
||||||
|
bot := client.New(token,
|
||||||
|
client.WithHTTPClient(doer), // any HTTPDoer (one-method interface)
|
||||||
|
client.WithCodec(myCodec), // any Codec (Marshal + Unmarshal)
|
||||||
|
client.WithLogger(myLogger),
|
||||||
|
client.WithBaseURL("https://..."), // proxy or local Bot API server
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`HTTPDoer` is `Do(*http.Request) (*http.Response, error)` — a plain `*http.Client` satisfies it.
|
||||||
|
`Codec` is `Marshal(any) ([]byte, error)` + `Unmarshal([]byte, any) error` — the default wraps `goccy/go-json`.
|
||||||
|
|
||||||
|
Every API call goes through `client.Call[Req, Resp]`; per-method generated functions are thin wrappers.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Typed unions — no any</summary>
|
||||||
|
|
||||||
|
Telegram's docs describe many fields as "Integer or String" or "one of N types". go-telegram turns every one of these into a concrete Go type.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ChatID: construct from int64 or @username
|
||||||
|
chatID := api.ChatIDFromInt(123456789)
|
||||||
|
chatID := api.ChatIDFromString("@mychannel")
|
||||||
|
|
||||||
|
// Discriminated unions — 13 interfaces with auto-decode via generated UnmarshalJSON
|
||||||
|
for _, u := range updates {
|
||||||
|
if u.MyChatMember == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch v := u.MyChatMember.OldChatMember.(type) {
|
||||||
|
case *api.ChatMemberOwner:
|
||||||
|
log.Println("was owner")
|
||||||
|
case *api.ChatMemberAdministrator:
|
||||||
|
log.Printf("was admin: can_post=%v", v.CanPostMessages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Full union list: `ChatMember`, `MessageOrigin`, `ReactionType`, `PaidMedia`, `BackgroundType`, `BackgroundFill`, `ChatBoostSource`, `RevenueWithdrawalState`, `TransactionPartner`, `MenuButton`, `OwnedGift`, `StoryAreaType`, `MaybeInaccessibleMessage`, plus `ChatID`, `MessageOrBool`, and `InputFile`.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Dispatcher, filters, and conversations</summary>
|
||||||
|
|
||||||
|
The router dispatches each update in its own goroutine (semaphore-bounded, default 50):
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := dispatch.New(bot, dispatch.WithMaxConcurrency(50))
|
||||||
|
|
||||||
|
r.OnCommand("/start", handler)
|
||||||
|
r.OnText(`^hi (\w+)`, handler)
|
||||||
|
r.OnCallback(`^like:\d+`, handler)
|
||||||
|
r.OnInlineQuery(handler)
|
||||||
|
r.OnMyChatMember(handler)
|
||||||
|
// + 20 more typed On* methods
|
||||||
|
```
|
||||||
|
|
||||||
|
**Composable filters** — each update type has its own filter package:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/lukaszraczylo/go-telegram/dispatch/filters/message"
|
||||||
|
|
||||||
|
r.OnMessageFilter(
|
||||||
|
message.Command("/admin").And(message.IsReply()),
|
||||||
|
handler,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Filter packages: `message`, `callback`, `inline`, `chatmember`, `chatjoinrequest`, `precheckoutquery`. Combinators: `And`, `Or`, `Not`, `All`, `Any`.
|
||||||
|
|
||||||
|
**Conversation state machines** — multi-step flows with pluggable storage:
|
||||||
|
|
||||||
|
```go
|
||||||
|
conv := &conversation.Conversation{
|
||||||
|
EntryPoints: []conversation.Step{{
|
||||||
|
Filter: dispatch.FilterFunc(func(c *dispatch.Context, u *api.Update) bool {
|
||||||
|
return u.Message != nil && u.Message.Text == "/start"
|
||||||
|
}),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
// send prompt, advance state
|
||||||
|
return conversation.Next("await_name")
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
States: map[conversation.State][]conversation.Step{
|
||||||
|
"await_name": {{
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
return conversation.End()
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router.Use(conv.Dispatch)
|
||||||
|
```
|
||||||
|
|
||||||
|
Key strategies: `KeyByUser`, `KeyByChat`, `KeyByUserAndChat` (default). Default storage: `MemoryStorage` (in-process, concurrency-safe). Implement the `Storage` interface for Redis or any other backend.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Errors and retry middleware</summary>
|
||||||
|
|
||||||
|
Wrap the default HTTP doer with `RetryDoer` for production:
|
||||||
|
|
||||||
|
```go
|
||||||
|
bot := client.New(token,
|
||||||
|
client.WithHTTPClient(
|
||||||
|
client.NewRetryDoer(
|
||||||
|
client.NewDefaultHTTPDoer(),
|
||||||
|
client.WithMaxAttempts(5),
|
||||||
|
client.WithBaseBackoff(500*time.Millisecond),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`RetryDoer` retries on 429, 5xx, and transient network errors. On a 429 it reads `retry_after` from Telegram's response body and waits exactly that long — overriding any backoff calculation. Request bodies are buffered and replayed across attempts.
|
||||||
|
|
||||||
|
Sentinel errors for `errors.Is` checks: `client.ErrForbidden`, `client.ErrNotFound`, `client.ErrUnauthorized`.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Handler groups and named handlers</summary>
|
||||||
|
|
||||||
|
Priority-ordered groups with flow control signals:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Group 0 runs first — return EndGroups to stop, ContinueGroups to continue
|
||||||
|
r.Group(0).OnText(`.*`, authMiddleware)
|
||||||
|
r.Group(1).OnText(`.*`, businessHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
Named handlers — register and replace at runtime:
|
||||||
|
|
||||||
|
```go
|
||||||
|
named := dispatch.NewNamedHandlers[*api.Message]()
|
||||||
|
named.Set("main", myHandler)
|
||||||
|
r.OnCommand("/cmd", named.Handler())
|
||||||
|
// later: named.Set("main", updatedHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Codegen pipeline
|
||||||
|
|
||||||
|
The full API surface in `api/*.gen.go` is generated from a committed HTML snapshot of `core.telegram.org/bots/api`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make snapshot # fetch and commit latest HTML from core.telegram.org
|
||||||
|
make regen # scrape → audit → emit Go code → run generated tests
|
||||||
|
go test -race ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
`make regen` is self-verifying. The audit tool (`cmd/audit`) checks:
|
||||||
|
|
||||||
|
- `any`-typed fields or returns that escaped the union machinery
|
||||||
|
- Methods returning `bool` not on the approved list (`internal/spec/overrides.json`)
|
||||||
|
- Signature drift vs HEAD's IR (added/removed/changed return types)
|
||||||
|
|
||||||
|
Exit codes: 0 clean · 1 fallback · 2 drift · 3 invalid. CI runs the audit on every PR. A weekly `regen.yml` workflow opens a PR with regenerated code and the audit summary in the body.
|
||||||
|
|
||||||
|
To track a new Bot API release: run `make snapshot && make regen`, review the audit output, update `internal/spec/overrides.json` for any newly unparseable methods, and submit a PR.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Mock the one-method `HTTPDoer` interface to test handlers in isolation — no test server needed:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type fakeDoer struct{ body string }
|
||||||
|
func (f fakeDoer) Do(*http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: io.NopCloser(strings.NewReader(f.body)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bot := client.New("token", client.WithHTTPClient(fakeDoer{
|
||||||
|
body: `{"ok":true,"result":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"}}}`,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
The library's own generated test suite (`api/methods_gen_test.go`) covers 176 methods × 8 scenarios each: Success, APIError, NetworkError, ParseError, ContextCanceled, MissingRequiredFields, Forbidden, ServerError.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
See [CONTRIBUTING.md](CONTRIBUTING.md).
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DownloadFile fetches the contents of a Telegram-hosted file given a
|
||||||
|
// previously-uploaded file_id. It calls GetFile to resolve the file's
|
||||||
|
// download path, then issues an HTTP GET to the file CDN endpoint.
|
||||||
|
//
|
||||||
|
// The returned io.ReadCloser must be closed by the caller. The size of
|
||||||
|
// the file is reported via *File.FileSize when known.
|
||||||
|
//
|
||||||
|
// For files larger than 20 MB, Telegram requires a self-hosted Bot API
|
||||||
|
// server (default api.telegram.org has a 20 MB limit on getFile).
|
||||||
|
func DownloadFile(ctx context.Context, b *client.Bot, fileID string) (io.ReadCloser, *File, error) {
|
||||||
|
f, err := GetFile(ctx, b, &GetFileParams{FileID: fileID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("getFile: %w", err)
|
||||||
|
}
|
||||||
|
if f == nil || f.FilePath == "" {
|
||||||
|
return nil, f, fmt.Errorf("telegram: file %q has no download path", fileID)
|
||||||
|
}
|
||||||
|
rc, err := DownloadFileByPath(ctx, b, f.FilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, f, err
|
||||||
|
}
|
||||||
|
return rc, f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadFileByPath fetches a file by its file_path (typically obtained
|
||||||
|
// from a prior File response). Useful when the caller already has a
|
||||||
|
// *File and wants to skip the GetFile round-trip.
|
||||||
|
func DownloadFileByPath(ctx context.Context, b *client.Bot, filePath string) (io.ReadCloser, error) {
|
||||||
|
url := fmt.Sprintf("%s/file/bot%s/%s", b.BaseURL(), b.Token(), filePath)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := b.HTTP().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return nil, ctxErr
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("download: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("download: HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return resp.Body, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadFile_HappyPath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(r.URL.Path, "/getFile"):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true,"result":{"file_id":"abc","file_unique_id":"u","file_size":11,"file_path":"documents/hello.txt"}}`))
|
||||||
|
case strings.HasPrefix(r.URL.Path, "/file/bot"):
|
||||||
|
_, _ = w.Write([]byte("hello world"))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
bot := client.New("123:abc", client.WithBaseURL(srv.URL))
|
||||||
|
rc, file, err := DownloadFile(context.Background(), bot, "abc")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rc.Close()
|
||||||
|
require.Equal(t, "documents/hello.txt", file.FilePath)
|
||||||
|
body, err := io.ReadAll(rc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "hello world", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_GetFileFailure(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"ok":false,"error_code":400,"description":"Bad Request: invalid file_id"}`))
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
bot := client.New("t", client.WithBaseURL(srv.URL))
|
||||||
|
_, _, err := DownloadFile(context.Background(), bot, "bad")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "getFile")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFile_NoFilePath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// result without file_path
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true,"result":{"file_id":"abc","file_unique_id":"u"}}`))
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
bot := client.New("t", client.WithBaseURL(srv.URL))
|
||||||
|
_, _, err := DownloadFile(context.Background(), bot, "abc")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "no download path")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadFileByPath_HTTPError(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.HasPrefix(r.URL.Path, "/file/bot") {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
bot := client.New("t", client.WithBaseURL(srv.URL))
|
||||||
|
_, err := DownloadFileByPath(context.Background(), bot, "secret/file")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "403")
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
// Code generated by cmd/genapi. DO NOT EDIT.
|
||||||
|
|
||||||
|
//go:build !ignore_autogenerated
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
// ParseMode controls how Telegram interprets formatting in message text.
|
||||||
|
type ParseMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ParseModeMarkdown ParseMode = "Markdown" // legacy
|
||||||
|
ParseModeMarkdownV2 ParseMode = "MarkdownV2"
|
||||||
|
ParseModeHTML ParseMode = "HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatType is the type of a Telegram chat.
|
||||||
|
type ChatType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatTypePrivate ChatType = "private"
|
||||||
|
ChatTypeGroup ChatType = "group"
|
||||||
|
ChatTypeSupergroup ChatType = "supergroup"
|
||||||
|
ChatTypeChannel ChatType = "channel"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateType identifies an Update payload variant. Used by allowed_updates
|
||||||
|
// in getUpdates / setWebhook.
|
||||||
|
type UpdateType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UpdateMessage UpdateType = "message"
|
||||||
|
UpdateEditedMessage UpdateType = "edited_message"
|
||||||
|
UpdateChannelPost UpdateType = "channel_post"
|
||||||
|
UpdateEditedChannelPost UpdateType = "edited_channel_post"
|
||||||
|
UpdateCallbackQuery UpdateType = "callback_query"
|
||||||
|
UpdateInlineQuery UpdateType = "inline_query"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageEntityType is the kind of an entity (mention, hashtag, command, ...).
|
||||||
|
type MessageEntityType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EntityMention MessageEntityType = "mention"
|
||||||
|
EntityHashtag MessageEntityType = "hashtag"
|
||||||
|
EntityCashtag MessageEntityType = "cashtag"
|
||||||
|
EntityBotCommand MessageEntityType = "bot_command"
|
||||||
|
EntityURL MessageEntityType = "url"
|
||||||
|
EntityEmail MessageEntityType = "email"
|
||||||
|
EntityPhoneNumber MessageEntityType = "phone_number"
|
||||||
|
EntityBold MessageEntityType = "bold"
|
||||||
|
EntityItalic MessageEntityType = "italic"
|
||||||
|
EntityUnderline MessageEntityType = "underline"
|
||||||
|
EntityStrike MessageEntityType = "strikethrough"
|
||||||
|
EntitySpoiler MessageEntityType = "spoiler"
|
||||||
|
EntityCode MessageEntityType = "code"
|
||||||
|
EntityPre MessageEntityType = "pre"
|
||||||
|
EntityTextLink MessageEntityType = "text_link"
|
||||||
|
EntityTextMention MessageEntityType = "text_mention"
|
||||||
|
EntityCustomEmoji MessageEntityType = "custom_emoji"
|
||||||
|
)
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MeCache caches the result of GetMe across calls. Construct one per
|
||||||
|
// Bot and call Get to retrieve the cached User on subsequent invocations.
|
||||||
|
//
|
||||||
|
// var meCache api.MeCache
|
||||||
|
// me, err := meCache.Get(ctx, bot)
|
||||||
|
//
|
||||||
|
// MeCache is safe for concurrent use.
|
||||||
|
type MeCache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cached *User
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the User from a cached GetMe call. If the cache is empty,
|
||||||
|
// it calls GetMe and populates the cache on success.
|
||||||
|
func (c *MeCache) Get(ctx context.Context, b *client.Bot) (*User, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.cached != nil {
|
||||||
|
u := c.cached
|
||||||
|
c.mu.Unlock()
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
u, err := GetMe(ctx, b, &GetMeParams{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.cached = u
|
||||||
|
c.mu.Unlock()
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the cache. Useful in tests or after the bot's identity
|
||||||
|
// is known to have changed (very rare).
|
||||||
|
func (c *MeCache) Reset() {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.cached = nil
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMeCache_FetchesOnce(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
var calls atomic.Int32
|
||||||
|
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
|
||||||
|
if strings.HasSuffix(r.URL.Path, "/getMe") {
|
||||||
|
calls.Add(1)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})).Return(newJSONResp(200, `{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"echo","username":"echo_bot"}}`), nil)
|
||||||
|
|
||||||
|
bot := client.New("t", client.WithHTTPClient(m))
|
||||||
|
var cache MeCache
|
||||||
|
|
||||||
|
me1, err := cache.Get(context.Background(), bot)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "echo_bot", me1.Username)
|
||||||
|
|
||||||
|
me2, err := cache.Get(context.Background(), bot)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Same(t, me1, me2)
|
||||||
|
require.Equal(t, int32(1), calls.Load(), "should fetch only once")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMeCache_Reset(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
calls.Add(1)
|
||||||
|
}).Return(newJSONResp(200, `{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"echo","username":"echo_bot"}}`), nil).Once()
|
||||||
|
m.On("Do", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
calls.Add(1)
|
||||||
|
}).Return(newJSONResp(200, `{"ok":true,"result":{"id":1,"is_bot":true,"first_name":"echo","username":"echo_bot"}}`), nil).Once()
|
||||||
|
|
||||||
|
bot := client.New("t", client.WithHTTPClient(m))
|
||||||
|
var cache MeCache
|
||||||
|
|
||||||
|
_, err := cache.Get(context.Background(), bot)
|
||||||
|
require.NoError(t, err)
|
||||||
|
cache.Reset()
|
||||||
|
_, err = cache.Get(context.Background(), bot)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int32(2), calls.Load())
|
||||||
|
}
|
||||||
+5145
File diff suppressed because it is too large
Load Diff
+24703
File diff suppressed because it is too large
Load Diff
+143
@@ -0,0 +1,143 @@
|
|||||||
|
// Package api contains the Telegram Bot API types and method wrappers.
|
||||||
|
// Most of the package is generated by cmd/genapi from internal/spec/api.json;
|
||||||
|
// this file holds the runtime types that are intentionally hand-coded.
|
||||||
|
//
|
||||||
|
// InputFile carries either a local upload (Reader+Filename) or a reference
|
||||||
|
// to a previously-uploaded file (file_id) / URL Telegram can fetch. It is
|
||||||
|
// not a pure JSON type, so the codegen skips it (see runtimeTypes in
|
||||||
|
// cmd/genapi/emitter.go).
|
||||||
|
//
|
||||||
|
// ResponseParameters mirrors client.ResponseParameters so callers importing
|
||||||
|
// only `api` can access retry_after and migrate_to_chat_id without pulling
|
||||||
|
// in the client package.
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InputFile carries either a file path (for upload) or a Telegram file_id
|
||||||
|
// / URL string (for reuse). When PathOrID names a local file, the request
|
||||||
|
// is sent as multipart/form-data; otherwise the value is sent inline.
|
||||||
|
type InputFile struct {
|
||||||
|
// PathOrID is one of: an absolute or relative filesystem path, a
|
||||||
|
// previously-uploaded Telegram file_id, or an HTTPS URL Telegram
|
||||||
|
// can fetch.
|
||||||
|
PathOrID string
|
||||||
|
// Reader, when non-nil, is used as the file content (Filename names it).
|
||||||
|
Reader io.Reader
|
||||||
|
// Filename is the upload filename used when Reader is set.
|
||||||
|
Filename string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocalUpload reports whether this InputFile triggers a multipart upload.
|
||||||
|
func (f *InputFile) IsLocalUpload() bool {
|
||||||
|
if f == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return f.Reader != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseParameters is the optional metadata Telegram includes on certain
|
||||||
|
// failures. The most common is RetryAfter (seconds) on 429 responses.
|
||||||
|
//
|
||||||
|
// https://core.telegram.org/bots/api#responseparameters
|
||||||
|
type ResponseParameters struct {
|
||||||
|
MigrateToChatID int64 `json:"migrate_to_chat_id,omitempty"`
|
||||||
|
RetryAfter int `json:"retry_after,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatID identifies a chat by either numeric id or "@username". The Telegram
|
||||||
|
// Bot API spells the same field as either an integer or a string; ChatID
|
||||||
|
// preserves both forms with explicit constructors and a custom MarshalJSON
|
||||||
|
// so callers never see `any` at the source level.
|
||||||
|
type ChatID struct {
|
||||||
|
int64Set bool
|
||||||
|
intID int64
|
||||||
|
username string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatIDFromInt builds a ChatID for a numeric chat identifier (e.g. -1001234567890).
|
||||||
|
func ChatIDFromInt(id int64) ChatID { return ChatID{int64Set: true, intID: id} }
|
||||||
|
|
||||||
|
// ChatIDFromUsername builds a ChatID for a public chat (e.g. "@channel").
|
||||||
|
// The leading "@" is required by Telegram for usernames.
|
||||||
|
func ChatIDFromUsername(name string) ChatID { return ChatID{username: name} }
|
||||||
|
|
||||||
|
// IsZero reports whether c carries no value.
|
||||||
|
func (c ChatID) IsZero() bool { return !c.int64Set && c.username == "" }
|
||||||
|
|
||||||
|
// String returns the wire form (decimal integer or "@name") for use in
|
||||||
|
// multipart bodies.
|
||||||
|
func (c ChatID) String() string {
|
||||||
|
if c.int64Set {
|
||||||
|
return strconv.FormatInt(c.intID, 10)
|
||||||
|
}
|
||||||
|
return c.username
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON emits either a JSON number (integer form) or a JSON string
|
||||||
|
// (@username form). Empty values marshal as "null".
|
||||||
|
func (c ChatID) MarshalJSON() ([]byte, error) {
|
||||||
|
if c.int64Set {
|
||||||
|
return []byte(strconv.FormatInt(c.intID, 10)), nil
|
||||||
|
}
|
||||||
|
if c.username != "" {
|
||||||
|
return json.Marshal(c.username)
|
||||||
|
}
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON accepts either a JSON number or a JSON string.
|
||||||
|
func (c *ChatID) UnmarshalJSON(data []byte) error {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 || bytes.Equal(data, []byte("null")) {
|
||||||
|
*c = ChatID{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] == '"' {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*c = ChatIDFromUsername(s)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var n int64
|
||||||
|
if err := json.Unmarshal(data, &n); err != nil {
|
||||||
|
return fmt.Errorf("ChatID: %w", err)
|
||||||
|
}
|
||||||
|
*c = ChatIDFromInt(n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageOrBool wraps the "Message or True" return shape Telegram uses on
|
||||||
|
// edit methods (editMessageText, editMessageCaption, etc.). When the bot
|
||||||
|
// edits a regular chat message, Message is non-nil; when it edits an
|
||||||
|
// inline message, OK is true.
|
||||||
|
type MessageOrBool struct {
|
||||||
|
Message *Message
|
||||||
|
OK bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON decodes either {...} into Message or `true`/`false` into OK.
|
||||||
|
func (m *MessageOrBool) UnmarshalJSON(data []byte) error {
|
||||||
|
data = bytes.TrimSpace(data)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] == '{' {
|
||||||
|
m.Message = new(Message)
|
||||||
|
return json.Unmarshal(data, m.Message)
|
||||||
|
}
|
||||||
|
var b bool
|
||||||
|
if err := json.Unmarshal(data, &b); err != nil {
|
||||||
|
return fmt.Errorf("MessageOrBool: %w", err)
|
||||||
|
}
|
||||||
|
m.OK = b
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChatID_IntForm(t *testing.T) {
|
||||||
|
c := ChatIDFromInt(-1001234567890)
|
||||||
|
require.False(t, c.IsZero())
|
||||||
|
require.Equal(t, "-1001234567890", c.String())
|
||||||
|
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "-1001234567890", string(data))
|
||||||
|
|
||||||
|
var c2 ChatID
|
||||||
|
require.NoError(t, json.Unmarshal(data, &c2))
|
||||||
|
require.Equal(t, c, c2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatID_UsernameForm(t *testing.T) {
|
||||||
|
c := ChatIDFromUsername("@channel")
|
||||||
|
require.False(t, c.IsZero())
|
||||||
|
require.Equal(t, "@channel", c.String())
|
||||||
|
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, `"@channel"`, string(data))
|
||||||
|
|
||||||
|
var c2 ChatID
|
||||||
|
require.NoError(t, json.Unmarshal(data, &c2))
|
||||||
|
require.Equal(t, c, c2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatID_Zero(t *testing.T) {
|
||||||
|
var c ChatID
|
||||||
|
require.True(t, c.IsZero())
|
||||||
|
require.Equal(t, "", c.String())
|
||||||
|
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "null", string(data))
|
||||||
|
|
||||||
|
var c2 ChatID
|
||||||
|
require.NoError(t, json.Unmarshal([]byte("null"), &c2))
|
||||||
|
require.True(t, c2.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatID_UnmarshalInvalid(t *testing.T) {
|
||||||
|
var c ChatID
|
||||||
|
err := json.Unmarshal([]byte(`"not-a-number"`), &c)
|
||||||
|
require.NoError(t, err) // string always succeeds as username
|
||||||
|
require.Equal(t, "not-a-number", c.username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageOrBool_TrueForm(t *testing.T) {
|
||||||
|
var m MessageOrBool
|
||||||
|
require.NoError(t, json.Unmarshal([]byte("true"), &m))
|
||||||
|
require.True(t, m.OK)
|
||||||
|
require.Nil(t, m.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageOrBool_FalseForm(t *testing.T) {
|
||||||
|
var m MessageOrBool
|
||||||
|
require.NoError(t, json.Unmarshal([]byte("false"), &m))
|
||||||
|
require.False(t, m.OK)
|
||||||
|
require.Nil(t, m.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageOrBool_MessageForm(t *testing.T) {
|
||||||
|
// Message is a generated type; we can only test that it unmarshals without
|
||||||
|
// error into the struct — the generated api/*.gen.go is not available in
|
||||||
|
// the test build unless built. Use build tag !ignore_autogenerated default.
|
||||||
|
// Skip if Message type is not yet present (bootstrap phase).
|
||||||
|
data := []byte(`{"message_id":42,"date":0,"chat":{"id":1,"type":"private"}}`)
|
||||||
|
var m MessageOrBool
|
||||||
|
require.NoError(t, json.Unmarshal(data, &m))
|
||||||
|
require.NotNil(t, m.Message)
|
||||||
|
require.False(t, m.OK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInputFile_IsLocalUpload(t *testing.T) {
|
||||||
|
require.False(t, (*InputFile)(nil).IsLocalUpload())
|
||||||
|
require.False(t, (&InputFile{PathOrID: "AgADAgADu7gxG..."}).IsLocalUpload())
|
||||||
|
require.True(t, (&InputFile{Reader: nopReader{}}).IsLocalUpload())
|
||||||
|
}
|
||||||
|
|
||||||
|
type nopReader struct{}
|
||||||
|
|
||||||
|
func (nopReader) Read(p []byte) (int, error) { return 0, nil }
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
// Sender condenses the various ways a Telegram update can identify the
|
||||||
|
// originator of a message or reaction into a single shape. Use the
|
||||||
|
// GetSender methods on supported types to construct one.
|
||||||
|
type Sender struct {
|
||||||
|
// User is the human user who sent the update, when applicable.
|
||||||
|
User *User
|
||||||
|
// Chat is the chat that sent the update (channel forwards,
|
||||||
|
// anonymous group admins, anonymous channel posts).
|
||||||
|
Chat *Chat
|
||||||
|
// IsAutomaticForward is true when the update originated as an
|
||||||
|
// automatic forward from a linked channel.
|
||||||
|
IsAutomaticForward bool
|
||||||
|
// ChatID is the chat the update was delivered into. Used to
|
||||||
|
// distinguish "this user" from "this anonymous admin posting
|
||||||
|
// in <chat>" when User is nil.
|
||||||
|
ChatID int64
|
||||||
|
// AuthorSignature is the custom title of an anonymous group
|
||||||
|
// administrator. Only meaningful when Chat == this chat.
|
||||||
|
AuthorSignature string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the most-specific identifier available: prefers Chat.ID
|
||||||
|
// over User.ID. Returns 0 if neither is set.
|
||||||
|
func (s *Sender) ID() int64 {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if s.Chat != nil {
|
||||||
|
return s.Chat.ID
|
||||||
|
}
|
||||||
|
if s.User != nil {
|
||||||
|
return s.User.ID
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAnonymousAdmin reports whether the sender is a group admin posting
|
||||||
|
// anonymously (Chat equals the message's own chat).
|
||||||
|
func (s *Sender) IsAnonymousAdmin() bool {
|
||||||
|
return s != nil && s.Chat != nil && s.Chat.ID == s.ChatID
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAnonymousChannel reports whether the sender is an anonymous
|
||||||
|
// channel post (Chat differs from the message's own chat).
|
||||||
|
func (s *Sender) IsAnonymousChannel() bool {
|
||||||
|
return s != nil && s.Chat != nil && s.Chat.ID != s.ChatID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSender constructs a Sender for a Message. The result is never nil.
|
||||||
|
func (m *Message) GetSender() *Sender {
|
||||||
|
if m == nil {
|
||||||
|
return &Sender{}
|
||||||
|
}
|
||||||
|
isAuto := false
|
||||||
|
if m.IsAutomaticForward != nil {
|
||||||
|
isAuto = *m.IsAutomaticForward
|
||||||
|
}
|
||||||
|
return &Sender{
|
||||||
|
User: m.From,
|
||||||
|
Chat: m.SenderChat,
|
||||||
|
IsAutomaticForward: isAuto,
|
||||||
|
ChatID: m.Chat.ID,
|
||||||
|
AuthorSignature: m.AuthorSignature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSender constructs a Sender for a MessageReactionUpdated.
|
||||||
|
func (mru *MessageReactionUpdated) GetSender() *Sender {
|
||||||
|
if mru == nil {
|
||||||
|
return &Sender{}
|
||||||
|
}
|
||||||
|
return &Sender{
|
||||||
|
User: mru.User,
|
||||||
|
Chat: mru.ActorChat,
|
||||||
|
ChatID: mru.Chat.ID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSender constructs a Sender for a PollAnswer.
|
||||||
|
func (pa *PollAnswer) GetSender() *Sender {
|
||||||
|
if pa == nil {
|
||||||
|
return &Sender{}
|
||||||
|
}
|
||||||
|
return &Sender{
|
||||||
|
User: pa.User,
|
||||||
|
Chat: pa.VoterChat,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,366 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSenderID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sender *Sender
|
||||||
|
want int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil sender",
|
||||||
|
sender: nil,
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty sender",
|
||||||
|
sender: &Sender{},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user only",
|
||||||
|
sender: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
},
|
||||||
|
want: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat only",
|
||||||
|
sender: &Sender{
|
||||||
|
Chat: &Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: 456,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat prefers over user",
|
||||||
|
sender: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
Chat: &Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: 456,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.sender.ID()
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ID() = %d, want %d", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatEqual(a, b *Chat) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.ID == b.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func userEqual(a, b *User) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return a.ID == b.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSenderIsAnonymousAdmin(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sender *Sender
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil sender",
|
||||||
|
sender: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no chat",
|
||||||
|
sender: &Sender{User: &User{ID: 123}, ChatID: 456},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat id matches (anonymous admin)",
|
||||||
|
sender: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 789,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat id differs (not anonymous admin)",
|
||||||
|
sender: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.sender.IsAnonymousAdmin()
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsAnonymousAdmin() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSenderIsAnonymousChannel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sender *Sender
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil sender",
|
||||||
|
sender: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no chat",
|
||||||
|
sender: &Sender{User: &User{ID: 123}, ChatID: 456},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat id differs (anonymous channel)",
|
||||||
|
sender: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat id matches (not anonymous channel)",
|
||||||
|
sender: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 789,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.sender.IsAnonymousChannel()
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsAnonymousChannel() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageGetSender(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msg *Message
|
||||||
|
want *Sender
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil message",
|
||||||
|
msg: nil,
|
||||||
|
want: &Sender{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular user message",
|
||||||
|
msg: &Message{
|
||||||
|
From: &User{ID: 123},
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "channel forward",
|
||||||
|
msg: &Message{
|
||||||
|
From: &User{ID: 123},
|
||||||
|
SenderChat: &Chat{ID: 789},
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anonymous admin",
|
||||||
|
msg: &Message{
|
||||||
|
SenderChat: &Chat{ID: 456},
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
AuthorSignature: "Admin Signature",
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
Chat: &Chat{ID: 456},
|
||||||
|
ChatID: 456,
|
||||||
|
AuthorSignature: "Admin Signature",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anonymous channel post",
|
||||||
|
msg: &Message{
|
||||||
|
SenderChat: &Chat{ID: 789},
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "automatic forward",
|
||||||
|
msg: &Message{
|
||||||
|
From: &User{ID: 123},
|
||||||
|
IsAutomaticForward: func() *bool {
|
||||||
|
b := true
|
||||||
|
return &b
|
||||||
|
}(),
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
IsAutomaticForward: true,
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.msg.GetSender()
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetSender() returned nil")
|
||||||
|
}
|
||||||
|
if !userEqual(got.User, tt.want.User) {
|
||||||
|
t.Errorf("User: got %v, want %v", got.User, tt.want.User)
|
||||||
|
}
|
||||||
|
if !chatEqual(got.Chat, tt.want.Chat) {
|
||||||
|
t.Errorf("Chat: got %v, want %v", got.Chat, tt.want.Chat)
|
||||||
|
}
|
||||||
|
if got.IsAutomaticForward != tt.want.IsAutomaticForward {
|
||||||
|
t.Errorf("IsAutomaticForward: got %v, want %v", got.IsAutomaticForward, tt.want.IsAutomaticForward)
|
||||||
|
}
|
||||||
|
if got.ChatID != tt.want.ChatID {
|
||||||
|
t.Errorf("ChatID: got %d, want %d", got.ChatID, tt.want.ChatID)
|
||||||
|
}
|
||||||
|
if got.AuthorSignature != tt.want.AuthorSignature {
|
||||||
|
t.Errorf("AuthorSignature: got %q, want %q", got.AuthorSignature, tt.want.AuthorSignature)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageReactionUpdatedGetSender(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mru *MessageReactionUpdated
|
||||||
|
want *Sender
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil reaction",
|
||||||
|
mru: nil,
|
||||||
|
want: &Sender{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user reaction",
|
||||||
|
mru: &MessageReactionUpdated{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anonymous reaction",
|
||||||
|
mru: &MessageReactionUpdated{
|
||||||
|
ActorChat: &Chat{ID: 789},
|
||||||
|
Chat: Chat{ID: 456},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
ChatID: 456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.mru.GetSender()
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetSender() returned nil")
|
||||||
|
}
|
||||||
|
if !userEqual(got.User, tt.want.User) {
|
||||||
|
t.Errorf("User: got %v, want %v", got.User, tt.want.User)
|
||||||
|
}
|
||||||
|
if !chatEqual(got.Chat, tt.want.Chat) {
|
||||||
|
t.Errorf("Chat: got %v, want %v", got.Chat, tt.want.Chat)
|
||||||
|
}
|
||||||
|
if got.ChatID != tt.want.ChatID {
|
||||||
|
t.Errorf("ChatID: got %d, want %d", got.ChatID, tt.want.ChatID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPollAnswerGetSender(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pa *PollAnswer
|
||||||
|
want *Sender
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil poll answer",
|
||||||
|
pa: nil,
|
||||||
|
want: &Sender{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user vote",
|
||||||
|
pa: &PollAnswer{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
User: &User{ID: 123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anonymous vote",
|
||||||
|
pa: &PollAnswer{
|
||||||
|
VoterChat: &Chat{ID: 789},
|
||||||
|
},
|
||||||
|
want: &Sender{
|
||||||
|
Chat: &Chat{ID: 789},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.pa.GetSender()
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("GetSender() returned nil")
|
||||||
|
}
|
||||||
|
if !userEqual(got.User, tt.want.User) {
|
||||||
|
t.Errorf("User: got %v, want %v", got.User, tt.want.User)
|
||||||
|
}
|
||||||
|
if !chatEqual(got.Chat, tt.want.Chat) {
|
||||||
|
t.Errorf("Chat: got %v, want %v", got.Chat, tt.want.Chat)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockDoer is a testify-mock HTTPDoer shared by hand-written tests.
|
||||||
|
type mockDoer struct{ mock.Mock }
|
||||||
|
|
||||||
|
func (m *mockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||||
|
args := m.Called(r)
|
||||||
|
if v := args.Get(0); v != nil {
|
||||||
|
return v.(*http.Response), args.Error(1)
|
||||||
|
}
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newJSONResp constructs an *http.Response with a JSON body.
|
||||||
|
func newJSONResp(status int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
+5871
File diff suppressed because it is too large
Load Diff
+171
@@ -0,0 +1,171 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Call is the single point through which every Telegram Bot API method
|
||||||
|
// invocation flows. It marshals the request, signs the URL with the bot
|
||||||
|
// token, dispatches via HTTPDoer, decodes the Result envelope, and
|
||||||
|
// translates non-OK responses into typed errors.
|
||||||
|
//
|
||||||
|
// It is generic over both request and response types. Methods with no
|
||||||
|
// parameters may pass a nil Req; the helper sends "{}" in that case so
|
||||||
|
// Telegram receives a syntactically valid empty object.
|
||||||
|
//
|
||||||
|
// Call is exported because the api package (which lives outside this one)
|
||||||
|
// invokes it from generated method wrappers. User code should not normally
|
||||||
|
// call it directly — use the typed wrappers in package api instead.
|
||||||
|
func Call[Req any, Resp any](ctx context.Context, b *Bot, method string, req Req) (Resp, error) {
|
||||||
|
var zero Resp
|
||||||
|
|
||||||
|
if mp, ok := any(req).(multipartRequest); ok {
|
||||||
|
if mp == nil {
|
||||||
|
return zero, &ParseError{Err: errors.New("client: nil multipart request")}
|
||||||
|
}
|
||||||
|
if mp.HasFile() {
|
||||||
|
return callMultipart[Resp](ctx, b, method, mp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := encodeJSONBody(b.codec, req)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := b.base + "/bot" + b.token + "/" + method
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return zero, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := b.http.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
// Surface ctx errors faithfully so callers can errors.Is(err, ctx.Err()).
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return zero, ctxErr
|
||||||
|
}
|
||||||
|
return zero, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return zero, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return decodeResult[Resp](b.codec, raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallRaw is like Call but returns the raw JSON of the result field
|
||||||
|
// instead of decoding it into a typed value. Generated method wrappers
|
||||||
|
// for sealed-interface return types (ChatMember, MenuButton, etc.) use
|
||||||
|
// this helper, then dispatch through the union's UnmarshalXxx function.
|
||||||
|
//
|
||||||
|
// CallRaw still translates non-OK responses into *APIError just like Call.
|
||||||
|
func CallRaw[Req any](ctx context.Context, b *Bot, method string, req Req) (json.RawMessage, error) {
|
||||||
|
if mp, ok := any(req).(multipartRequest); ok {
|
||||||
|
if mp == nil {
|
||||||
|
return nil, &ParseError{Err: errors.New("client: nil multipart request")}
|
||||||
|
}
|
||||||
|
if mp.HasFile() {
|
||||||
|
return callMultipartRaw(ctx, b, method, mp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := encodeJSONBody(b.codec, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := b.base + "/bot" + b.token + "/" + method
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := b.http.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return nil, ctxErr
|
||||||
|
}
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return decodeResultRaw(b.codec, raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeResultRaw is decodeResult's sibling that returns the raw result
|
||||||
|
// field instead of typing it.
|
||||||
|
func decodeResultRaw(codec Codec, raw []byte) (json.RawMessage, error) {
|
||||||
|
var env Result[json.RawMessage]
|
||||||
|
if err := codec.Unmarshal(raw, &env); err != nil {
|
||||||
|
return nil, &ParseError{Err: err, Body: copyBody(raw)}
|
||||||
|
}
|
||||||
|
if !env.OK {
|
||||||
|
return nil, mapAPIError(env.ErrorCode, env.Description, env.Parameters)
|
||||||
|
}
|
||||||
|
return env.Result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeJSONBody marshals req to a JSON body. A nil interface or nil
|
||||||
|
// pointer req yields "{}" so Telegram receives a valid empty object.
|
||||||
|
func encodeJSONBody(codec Codec, req any) (io.Reader, error) {
|
||||||
|
if req == nil || isNilPointer(req) {
|
||||||
|
return bytes.NewBufferString("{}"), nil
|
||||||
|
}
|
||||||
|
data, err := codec.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ParseError{Err: err}
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeResult unmarshals raw into Result[Resp] and translates non-OK
|
||||||
|
// responses into *APIError.
|
||||||
|
func decodeResult[Resp any](codec Codec, raw []byte) (Resp, error) {
|
||||||
|
var zero Resp
|
||||||
|
var env Result[Resp]
|
||||||
|
if err := codec.Unmarshal(raw, &env); err != nil {
|
||||||
|
return zero, &ParseError{Err: err, Body: copyBody(raw)}
|
||||||
|
}
|
||||||
|
if !env.OK {
|
||||||
|
return zero, mapAPIError(env.ErrorCode, env.Description, env.Parameters)
|
||||||
|
}
|
||||||
|
return env.Result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNilPointer returns true when v is a typed nil pointer (the interface
|
||||||
|
// itself is non-nil because it carries a type, but the underlying value
|
||||||
|
// is nil). One reflect call per request; not on a hot path that demands
|
||||||
|
// allocation-freedom.
|
||||||
|
func isNilPointer(v any) bool {
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
return rv.Kind() == reflect.Ptr && rv.IsNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyBody(b []byte) []byte {
|
||||||
|
const max = 4096
|
||||||
|
if len(b) > max {
|
||||||
|
b = b[:max]
|
||||||
|
}
|
||||||
|
out := make([]byte, len(b))
|
||||||
|
copy(out, b)
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockDoer struct{ mock.Mock }
|
||||||
|
|
||||||
|
func (m *mockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||||
|
args := m.Called(r)
|
||||||
|
if v := args.Get(0); v != nil {
|
||||||
|
return v.(*http.Response), args.Error(1)
|
||||||
|
}
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResp(status int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type echoReq struct {
|
||||||
|
ChatID int64 `json:"chat_id"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
type echoResp struct {
|
||||||
|
MessageID int64 `json:"message_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall_Success(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
|
||||||
|
if !strings.HasSuffix(r.URL.Path, "/bot123:abc/sendEcho") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
_, _ = buf.ReadFrom(r.Body)
|
||||||
|
return strings.Contains(buf.String(), `"chat_id":42`)
|
||||||
|
})).Return(newResp(200, `{"ok":true,"result":{"message_id":7}}`), nil)
|
||||||
|
|
||||||
|
b := New("123:abc", WithHTTPClient(m))
|
||||||
|
out, err := Call[*echoReq, *echoResp](context.Background(), b, "sendEcho", &echoReq{ChatID: 42, Text: "hi"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(7), out.MessageID)
|
||||||
|
m.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall_APIError(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
newResp(200, `{"ok":false,"error_code":429,"description":"Too Many Requests: retry after 3","parameters":{"retry_after":3}}`), nil)
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
|
||||||
|
require.Error(t, err)
|
||||||
|
var ae *APIError
|
||||||
|
require.ErrorAs(t, err, &ae)
|
||||||
|
require.Equal(t, 429, ae.Code)
|
||||||
|
require.True(t, ae.IsRetryable())
|
||||||
|
require.True(t, errors.Is(err, ErrTooManyRequests))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall_NetworkError(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
|
||||||
|
require.Error(t, err)
|
||||||
|
var ne *NetworkError
|
||||||
|
require.ErrorAs(t, err, &ne)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall_ParseError(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(newResp(200, `not json`), nil)
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
|
||||||
|
require.Error(t, err)
|
||||||
|
var pe *ParseError
|
||||||
|
require.ErrorAs(t, err, &pe)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall_ContextCanceled(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, context.Canceled).Maybe()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
_, err := Call[*echoReq, *echoResp](ctx, b, "x", &echoReq{})
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCall_NilRequest(t *testing.T) {
|
||||||
|
// Methods with no params (e.g. getMe) may pass a nil Req value.
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
_, _ = buf.ReadFrom(r.Body)
|
||||||
|
return buf.String() == "{}"
|
||||||
|
})).Return(newResp(200, `{"ok":true,"result":{"message_id":0}}`), nil)
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
const defaultBaseURL = "https://api.telegram.org"
|
||||||
|
|
||||||
|
// Bot is the Telegram Bot API client. Construct via New. All API methods
|
||||||
|
// (declared in package api) hang off *Bot via thin wrappers around call.
|
||||||
|
type Bot struct {
|
||||||
|
token string
|
||||||
|
base string
|
||||||
|
http HTTPDoer
|
||||||
|
codec Codec
|
||||||
|
logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token returns the bot token. Exposed for advanced use cases (custom
|
||||||
|
// transports, manual URL building); ordinary code does not need it.
|
||||||
|
func (b *Bot) Token() string { return b.token }
|
||||||
|
|
||||||
|
// BaseURL returns the configured Telegram API base URL.
|
||||||
|
func (b *Bot) BaseURL() string { return b.base }
|
||||||
|
|
||||||
|
// HTTP returns the underlying HTTPDoer. Exposed for adapters that need
|
||||||
|
// to share connection pools or for diagnostic checks.
|
||||||
|
func (b *Bot) HTTP() HTTPDoer { return b.http }
|
||||||
|
|
||||||
|
// Codec returns the configured Codec.
|
||||||
|
func (b *Bot) Codec() Codec { return b.codec }
|
||||||
|
|
||||||
|
// Logger returns the configured Logger.
|
||||||
|
func (b *Bot) Logger() Logger { return b.logger }
|
||||||
|
|
||||||
|
// New constructs a Bot with the given token and optional configuration.
|
||||||
|
// The default HTTP client is tuned for long-poll workloads (see
|
||||||
|
// NewDefaultHTTPDoer); the default codec wraps encoding/json; the default
|
||||||
|
// logger discards records.
|
||||||
|
func New(token string, opts ...Option) *Bot {
|
||||||
|
b := &Bot{
|
||||||
|
token: token,
|
||||||
|
base: defaultBaseURL,
|
||||||
|
http: NewDefaultHTTPDoer(),
|
||||||
|
codec: DefaultCodec{},
|
||||||
|
logger: NoopLogger{},
|
||||||
|
}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(b)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_Defaults(t *testing.T) {
|
||||||
|
b := New("123:abc")
|
||||||
|
require.Equal(t, "123:abc", b.token)
|
||||||
|
require.Equal(t, defaultBaseURL, b.base)
|
||||||
|
require.NotNil(t, b.http)
|
||||||
|
require.NotNil(t, b.codec)
|
||||||
|
require.NotNil(t, b.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_OptionsApplied(t *testing.T) {
|
||||||
|
custom := &http.Client{}
|
||||||
|
type fakeCodec struct{ DefaultCodec }
|
||||||
|
c := fakeCodec{}
|
||||||
|
|
||||||
|
b := New("t",
|
||||||
|
WithHTTPClient(custom),
|
||||||
|
WithCodec(c),
|
||||||
|
WithBaseURL("https://example.test"),
|
||||||
|
WithLogger(NoopLogger{}),
|
||||||
|
)
|
||||||
|
require.Same(t, custom, b.http)
|
||||||
|
require.Equal(t, c, b.codec)
|
||||||
|
require.Equal(t, "https://example.test", b.base)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResultRoundTrip(t *testing.T) {
|
||||||
|
in := Result[int64]{OK: true, Result: 42}
|
||||||
|
data, err := DefaultCodec{}.Marshal(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var out Result[int64]
|
||||||
|
require.NoError(t, DefaultCodec{}.Unmarshal(data, &out))
|
||||||
|
require.Equal(t, in, out)
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
// Package client provides HTTP client primitives for the Telegram Bot API.
|
||||||
|
package client
|
||||||
|
|
||||||
|
import "github.com/goccy/go-json"
|
||||||
|
|
||||||
|
// Codec encodes/decodes JSON payloads exchanged with the Telegram Bot API.
|
||||||
|
// The default implementation wraps goccy/go-json. Users may plug in
|
||||||
|
// bytedance/sonic or any compatible encoder by passing
|
||||||
|
// WithCodec to New.
|
||||||
|
type Codec interface {
|
||||||
|
Marshal(v any) ([]byte, error)
|
||||||
|
Unmarshal(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCodec wraps goccy/go-json. It is the zero-value safe default.
|
||||||
|
type DefaultCodec struct{}
|
||||||
|
|
||||||
|
// Marshal calls json.Marshal.
|
||||||
|
func (DefaultCodec) Marshal(v any) ([]byte, error) { return json.Marshal(v) }
|
||||||
|
|
||||||
|
// Unmarshal calls json.Unmarshal.
|
||||||
|
func (DefaultCodec) Unmarshal(data []byte, v any) error { return json.Unmarshal(data, v) }
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultCodec_RoundTrip(t *testing.T) {
|
||||||
|
c := DefaultCodec{}
|
||||||
|
type payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
N int `json:"n"`
|
||||||
|
}
|
||||||
|
in := payload{Name: "x", N: 7}
|
||||||
|
data, err := c.Marshal(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.JSONEq(t, `{"name":"x","n":7}`, string(data))
|
||||||
|
|
||||||
|
var out payload
|
||||||
|
require.NoError(t, c.Unmarshal(data, &out))
|
||||||
|
require.Equal(t, in, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultCodec_UnmarshalError(t *testing.T) {
|
||||||
|
var v map[string]any
|
||||||
|
err := DefaultCodec{}.Unmarshal([]byte(`not json`), &v)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// APIError represents a non-OK Telegram Bot API response.
|
||||||
|
// It satisfies error and unwraps to a sentinel (ErrUnauthorized, etc.)
|
||||||
|
// where the description matches a known prefix, enabling errors.Is checks.
|
||||||
|
type APIError struct {
|
||||||
|
Code int
|
||||||
|
Description string
|
||||||
|
Parameters *ResponseParameters
|
||||||
|
|
||||||
|
// sentinel, if non-nil, is the wrapped sentinel error returned by
|
||||||
|
// Unwrap. It is set by mapAPIError based on Code+Description.
|
||||||
|
sentinel error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements error.
|
||||||
|
func (e *APIError) Error() string {
|
||||||
|
return fmt.Sprintf("telegram: %d %s", e.Code, e.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the matched sentinel error, if any.
|
||||||
|
func (e *APIError) Unwrap() error { return e.sentinel }
|
||||||
|
|
||||||
|
// IsRetryable returns true for transient HTTP statuses (429, 5xx).
|
||||||
|
func (e *APIError) IsRetryable() bool {
|
||||||
|
return e.Code == 429 || (e.Code >= 500 && e.Code < 600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryAfter returns the recommended back-off duration. It honours the
|
||||||
|
// Telegram-supplied retry_after parameter; if absent, returns 0.
|
||||||
|
func (e *APIError) RetryAfter() time.Duration {
|
||||||
|
if e.Parameters == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return time.Duration(e.Parameters.RetryAfter) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkError wraps a transport-level failure (DNS, TCP, TLS, timeout
|
||||||
|
// short of an HTTP response).
|
||||||
|
type NetworkError struct{ Err error }
|
||||||
|
|
||||||
|
func (e *NetworkError) Error() string { return "telegram: network: " + redactToken(e.Err.Error()) }
|
||||||
|
|
||||||
|
func (e *NetworkError) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
// ParseError wraps a JSON decode failure on a response body. Body is
|
||||||
|
// retained (truncated to 4 KiB); Error() displays up to 256 bytes for diagnostics.
|
||||||
|
type ParseError struct {
|
||||||
|
Err error
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ParseError) Error() string {
|
||||||
|
body := e.Body
|
||||||
|
if len(body) > 256 {
|
||||||
|
body = body[:256]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("telegram: parse: %s (body=%q)", redactToken(e.Err.Error()), body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ParseError) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
// Sentinel errors returned via APIError.Unwrap when the description matches.
|
||||||
|
// Compare with errors.Is.
|
||||||
|
var (
|
||||||
|
ErrUnauthorized = errors.New("telegram: unauthorized")
|
||||||
|
ErrChatNotFound = errors.New("telegram: chat not found")
|
||||||
|
ErrMessageNotModified = errors.New("telegram: message is not modified")
|
||||||
|
ErrTooManyRequests = errors.New("telegram: too many requests")
|
||||||
|
ErrBadRequest = errors.New("telegram: bad request")
|
||||||
|
ErrForbidden = errors.New("telegram: forbidden")
|
||||||
|
ErrUserNotFound = errors.New("telegram: user not found")
|
||||||
|
ErrMessageNotFound = errors.New("telegram: message not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// mapAPIError builds an *APIError and attaches the appropriate sentinel
|
||||||
|
// based on Code+Description. It is the single point where wire-level
|
||||||
|
// failures are translated into the Go error taxonomy.
|
||||||
|
func mapAPIError(code int, description string, params *ResponseParameters) *APIError {
|
||||||
|
e := &APIError{Code: code, Description: description, Parameters: params}
|
||||||
|
switch {
|
||||||
|
case code == 401:
|
||||||
|
e.sentinel = ErrUnauthorized
|
||||||
|
case code == 403:
|
||||||
|
e.sentinel = ErrForbidden
|
||||||
|
case code == 429:
|
||||||
|
e.sentinel = ErrTooManyRequests
|
||||||
|
case code == 400 && strings.Contains(description, "user not found"):
|
||||||
|
e.sentinel = ErrUserNotFound
|
||||||
|
case code == 400 && strings.Contains(description, "message to") && strings.Contains(description, "not found"):
|
||||||
|
e.sentinel = ErrMessageNotFound
|
||||||
|
case code == 400 && strings.Contains(description, "chat not found"):
|
||||||
|
e.sentinel = ErrChatNotFound
|
||||||
|
case code == 400 && strings.Contains(description, "message is not modified"):
|
||||||
|
e.sentinel = ErrMessageNotModified
|
||||||
|
case code == 400:
|
||||||
|
e.sentinel = ErrBadRequest
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPIError_FieldsAndMethods(t *testing.T) {
|
||||||
|
e := &APIError{
|
||||||
|
Code: 429,
|
||||||
|
Description: "Too Many Requests: retry after 5",
|
||||||
|
Parameters: &ResponseParameters{RetryAfter: 5},
|
||||||
|
}
|
||||||
|
require.Equal(t, "telegram: 429 Too Many Requests: retry after 5", e.Error())
|
||||||
|
require.True(t, e.IsRetryable())
|
||||||
|
require.Equal(t, 5*time.Second, e.RetryAfter())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIError_Sentinels(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
code int
|
||||||
|
desc string
|
||||||
|
sentinel error
|
||||||
|
}{
|
||||||
|
{401, "Unauthorized", ErrUnauthorized},
|
||||||
|
{400, "Bad Request: chat not found", ErrChatNotFound},
|
||||||
|
{400, "Bad Request: message is not modified", ErrMessageNotModified},
|
||||||
|
{429, "Too Many Requests: retry after 1", ErrTooManyRequests},
|
||||||
|
{400, "Bad Request: user not found", ErrUserNotFound},
|
||||||
|
{400, "Bad Request: message to delete not found", ErrMessageNotFound},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.desc, func(t *testing.T) {
|
||||||
|
e := mapAPIError(c.code, c.desc, nil)
|
||||||
|
require.True(t, errors.Is(e, c.sentinel), "expected %v to wrap %v", e, c.sentinel)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIError_IsRetryable(t *testing.T) {
|
||||||
|
require.True(t, (&APIError{Code: 500}).IsRetryable())
|
||||||
|
require.True(t, (&APIError{Code: 502}).IsRetryable())
|
||||||
|
require.True(t, (&APIError{Code: 429}).IsRetryable())
|
||||||
|
require.False(t, (&APIError{Code: 400}).IsRetryable())
|
||||||
|
require.False(t, (&APIError{Code: 401}).IsRetryable())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkAndParseErrorWrapping(t *testing.T) {
|
||||||
|
inner := errors.New("dial tcp: timeout")
|
||||||
|
ne := &NetworkError{Err: inner}
|
||||||
|
require.ErrorIs(t, ne, inner)
|
||||||
|
|
||||||
|
pe := &ParseError{Err: errors.New("unexpected EOF"), Body: []byte("garbage")}
|
||||||
|
require.Contains(t, pe.Error(), "garbage")
|
||||||
|
}
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// client.go option getters
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestBot_Getters(t *testing.T) {
|
||||||
|
b := New("mytoken",
|
||||||
|
WithBaseURL("http://localhost:9999"),
|
||||||
|
WithCodec(DefaultCodec{}),
|
||||||
|
WithLogger(NoopLogger{}),
|
||||||
|
)
|
||||||
|
require.Equal(t, "mytoken", b.Token())
|
||||||
|
require.Equal(t, "http://localhost:9999", b.BaseURL())
|
||||||
|
require.NotNil(t, b.HTTP())
|
||||||
|
require.NotNil(t, b.Codec())
|
||||||
|
require.NotNil(t, b.Logger())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithLogger_NilBecomesNoop(t *testing.T) {
|
||||||
|
b := New("t", WithLogger(nil))
|
||||||
|
require.IsType(t, NoopLogger{}, b.Logger())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoopLogger_AllMethods(t *testing.T) {
|
||||||
|
l := NoopLogger{}
|
||||||
|
// None of these should panic.
|
||||||
|
l.Debug("msg")
|
||||||
|
l.Info("msg", "k", "v")
|
||||||
|
l.Warn("msg")
|
||||||
|
l.Error("msg", "err", "oops")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// RetryOption setters
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRetryOptions_Applied(t *testing.T) {
|
||||||
|
d := NewRetryDoer(nil,
|
||||||
|
WithMaxAttempts(7),
|
||||||
|
WithBaseBackoff(1*time.Second),
|
||||||
|
WithMaxBackoff(60*time.Second),
|
||||||
|
WithBackoffFactor(3.0),
|
||||||
|
WithJitter(0.5),
|
||||||
|
)
|
||||||
|
require.Equal(t, 7, d.maxAttempts)
|
||||||
|
require.Equal(t, 1*time.Second, d.base)
|
||||||
|
require.Equal(t, 60*time.Second, d.max)
|
||||||
|
require.Equal(t, 3.0, d.factor)
|
||||||
|
require.Equal(t, 0.5, d.jitter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// RetryDoer.delay — override path
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRetryDoer_DelayOverride(t *testing.T) {
|
||||||
|
d := NewRetryDoer(nil)
|
||||||
|
got := d.delay(1, 5*time.Second)
|
||||||
|
require.Equal(t, 5*time.Second, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_DelayExponential(t *testing.T) {
|
||||||
|
d := NewRetryDoer(nil,
|
||||||
|
WithBaseBackoff(100*time.Millisecond),
|
||||||
|
WithMaxBackoff(10*time.Second),
|
||||||
|
WithJitter(0), // no jitter for deterministic test
|
||||||
|
WithBackoffFactor(2.0),
|
||||||
|
)
|
||||||
|
d1 := d.delay(1, 0)
|
||||||
|
d2 := d.delay(2, 0)
|
||||||
|
require.Greater(t, int64(d2), int64(d1), "backoff should grow")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_DelayMaxCap(t *testing.T) {
|
||||||
|
d := NewRetryDoer(nil,
|
||||||
|
WithBaseBackoff(1*time.Second),
|
||||||
|
WithMaxBackoff(2*time.Second),
|
||||||
|
WithJitter(0),
|
||||||
|
WithBackoffFactor(100.0),
|
||||||
|
)
|
||||||
|
delay := d.delay(10, 0)
|
||||||
|
require.LessOrEqual(t, delay, 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// errors.go — RetryAfter nil parameters + ParseError.Unwrap
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAPIError_RetryAfterNilParams(t *testing.T) {
|
||||||
|
e := &APIError{Code: 429, Description: "Too Many Requests", Parameters: nil}
|
||||||
|
require.Equal(t, time.Duration(0), e.RetryAfter())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseError_Unwrap(t *testing.T) {
|
||||||
|
inner := errors.New("decode error")
|
||||||
|
pe := &ParseError{Err: inner, Body: []byte("body")}
|
||||||
|
require.ErrorIs(t, pe, inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseError_LongBodyTruncated(t *testing.T) {
|
||||||
|
body := bytes.Repeat([]byte("x"), 1000)
|
||||||
|
pe := &ParseError{Err: errors.New("e"), Body: body}
|
||||||
|
msg := pe.Error()
|
||||||
|
// Error() truncates body to 256 for display — should not include all 1000 chars
|
||||||
|
require.Less(t, len(msg), 800, "should truncate body in Error()")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkError_Unwrap(t *testing.T) {
|
||||||
|
inner := errors.New("tcp error")
|
||||||
|
ne := &NetworkError{Err: inner}
|
||||||
|
require.ErrorIs(t, ne, inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// mapAPIError — missing sentinel branches (generic 400, unmapped 500)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMapAPIError_Generic400(t *testing.T) {
|
||||||
|
e := mapAPIError(400, "Bad Request: some unknown thing", nil)
|
||||||
|
require.True(t, errors.Is(e, ErrBadRequest))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapAPIError_Unmapped500(t *testing.T) {
|
||||||
|
e := mapAPIError(500, "Internal Server Error", nil)
|
||||||
|
require.Nil(t, e.sentinel)
|
||||||
|
require.Equal(t, 500, e.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapAPIError_403(t *testing.T) {
|
||||||
|
e := mapAPIError(403, "Forbidden: bot was blocked", nil)
|
||||||
|
require.True(t, errors.Is(e, ErrForbidden))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// callMultipart — ctx cancelled
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCallMultipart_ContextCancelled(t *testing.T) {
|
||||||
|
// A doer that blocks then returns context error.
|
||||||
|
blocker := &extraBlockingDoer{done: make(chan struct{})}
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(blocker))
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
mp := &extraFakeMultipartReq{
|
||||||
|
fields: map[string]string{"chat_id": "1"},
|
||||||
|
files: []MultipartFile{
|
||||||
|
{FieldName: "document", Filename: "f.txt", Reader: bytes.NewReader([]byte("data"))},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
close(blocker.done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := callMultipart[*struct{}](ctx, b, "sendDocument", mp)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type extraBlockingDoer struct{ done chan struct{} }
|
||||||
|
|
||||||
|
func (b *extraBlockingDoer) Do(r *http.Request) (*http.Response, error) {
|
||||||
|
<-b.done
|
||||||
|
return nil, r.Context().Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
type extraFakeMultipartReq struct {
|
||||||
|
fields map[string]string
|
||||||
|
files []MultipartFile
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *extraFakeMultipartReq) HasFile() bool { return len(f.files) > 0 }
|
||||||
|
func (f *extraFakeMultipartReq) MultipartFiles() []MultipartFile { return f.files }
|
||||||
|
func (f *extraFakeMultipartReq) MultipartFields() map[string]string { return f.fields }
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// copyBody size cap
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCopyBody_LargeBodyCapped(t *testing.T) {
|
||||||
|
big := bytes.Repeat([]byte("a"), 8000)
|
||||||
|
out := copyBody(big)
|
||||||
|
require.Len(t, out, 4096)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyBody_SmallBody(t *testing.T) {
|
||||||
|
small := []byte("hello")
|
||||||
|
out := copyBody(small)
|
||||||
|
require.Equal(t, small, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Call — 5xx non-200 HTTP status (transport level)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCall_5xxHTTPStatus(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(&http.Response{
|
||||||
|
StatusCode: 500,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString(`{"ok":false,"error_code":500,"description":"Internal"}`)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
|
||||||
|
require.Error(t, err)
|
||||||
|
var ae *APIError
|
||||||
|
require.ErrorAs(t, err, &ae)
|
||||||
|
require.Equal(t, 500, ae.Code)
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPDoer abstracts the HTTP transport. The default is a net/http client
|
||||||
|
// tuned for Telegram's long-poll usage. Users may plug in valyala/fasthttp
|
||||||
|
// (via an adapter), or any custom retry/circuit-breaker client by passing
|
||||||
|
// WithHTTPClient to New.
|
||||||
|
type HTTPDoer interface {
|
||||||
|
Do(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultHTTPDoer returns an *http.Client with sensible defaults for
|
||||||
|
// Telegram Bot API usage:
|
||||||
|
// - 60s overall timeout (longer than typical long-poll Timeout=30s).
|
||||||
|
// - Connection pooling sized for a small number of long-lived hosts.
|
||||||
|
// - HTTP/2 enabled (default in net/http).
|
||||||
|
func NewDefaultHTTPDoer() *http.Client {
|
||||||
|
t := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: 16,
|
||||||
|
MaxIdleConnsPerHost: 8,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Transport: t,
|
||||||
|
Timeout: 60 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultHTTPClient_Do(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusTeapot)
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
doer := NewDefaultHTTPDoer()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp, err := doer.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusTeapot, resp.StatusCode)
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
// Logger is a slog-shaped logging interface. Users pass any compatible
|
||||||
|
// implementation via WithLogger. The default is NoopLogger, which discards
|
||||||
|
// everything.
|
||||||
|
type Logger interface {
|
||||||
|
Debug(msg string, attrs ...any)
|
||||||
|
Info(msg string, attrs ...any)
|
||||||
|
Warn(msg string, attrs ...any)
|
||||||
|
Error(msg string, attrs ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopLogger discards all log records. It is the zero-value safe default.
|
||||||
|
type NoopLogger struct{}
|
||||||
|
|
||||||
|
func (NoopLogger) Debug(string, ...any) {}
|
||||||
|
func (NoopLogger) Info(string, ...any) {}
|
||||||
|
func (NoopLogger) Warn(string, ...any) {}
|
||||||
|
func (NoopLogger) Error(string, ...any) {}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNoopLogger_DoesNotPanic(t *testing.T) {
|
||||||
|
var l Logger = NoopLogger{}
|
||||||
|
l.Debug("d", "k", "v")
|
||||||
|
l.Info("i")
|
||||||
|
l.Warn("w")
|
||||||
|
l.Error("e")
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// multipartRequest is implemented by request structs that may carry an
|
||||||
|
// InputFile. The codegen emits this interface for any method whose IR
|
||||||
|
// MethodDecl.HasFiles is true.
|
||||||
|
//
|
||||||
|
// HasFile returns true if at least one file field is set; if false, the
|
||||||
|
// request is sent as plain JSON via the regular Call path.
|
||||||
|
//
|
||||||
|
// MultipartFiles returns one entry per file field that should be uploaded.
|
||||||
|
// The accompanying scalar/object fields are returned by MultipartFields.
|
||||||
|
type multipartRequest interface {
|
||||||
|
HasFile() bool
|
||||||
|
MultipartFiles() []MultipartFile
|
||||||
|
MultipartFields() map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultipartFile describes a single file part in a multipart upload.
|
||||||
|
type MultipartFile struct {
|
||||||
|
FieldName string
|
||||||
|
Filename string
|
||||||
|
Reader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// callMultipart performs a multipart/form-data POST. It is invoked by Call
|
||||||
|
// when the request implements multipartRequest and HasFile() is true.
|
||||||
|
func callMultipart[Resp any](ctx context.Context, b *Bot, method string, mp multipartRequest) (Resp, error) {
|
||||||
|
var zero Resp
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
mw := multipart.NewWriter(pw)
|
||||||
|
|
||||||
|
// Stream-write the multipart body in a goroutine so we don't buffer
|
||||||
|
// large files in memory.
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
defer func() { _ = mw.Close() }()
|
||||||
|
for k, v := range mp.MultipartFields() {
|
||||||
|
if err := mw.WriteField(k, v); err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, f := range mp.MultipartFiles() {
|
||||||
|
part, err := mw.CreateFormFile(f.FieldName, f.Filename)
|
||||||
|
if err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(part, f.Reader); err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
url := b.base + "/bot" + b.token + "/" + method
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err)
|
||||||
|
return zero, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := b.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err)
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return zero, ctxErr
|
||||||
|
}
|
||||||
|
return zero, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err)
|
||||||
|
return zero, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
return decodeResult[Resp](b.codec, raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// callMultipartRaw is callMultipart's sibling that returns the raw result
|
||||||
|
// JSON instead of decoding into a typed value. Used by generated method
|
||||||
|
// wrappers whose return type is a sealed-interface union.
|
||||||
|
func callMultipartRaw(ctx context.Context, b *Bot, method string, mp multipartRequest) (json.RawMessage, error) {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
mw := multipart.NewWriter(pw)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
defer func() { _ = mw.Close() }()
|
||||||
|
for k, v := range mp.MultipartFields() {
|
||||||
|
if err := mw.WriteField(k, v); err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, f := range mp.MultipartFiles() {
|
||||||
|
part, err := mw.CreateFormFile(f.FieldName, f.Filename)
|
||||||
|
if err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(part, f.Reader); err != nil {
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
url := b.base + "/bot" + b.token + "/" + method
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err)
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", mw.FormDataContentType())
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := b.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err)
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return nil, ctxErr
|
||||||
|
}
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
_ = pr.CloseWithError(err)
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
return decodeResultRaw(b.codec, raw)
|
||||||
|
}
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeMultipartReq struct {
|
||||||
|
chatID int64
|
||||||
|
body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeMultipartReq) HasFile() bool { return true }
|
||||||
|
func (f *fakeMultipartReq) MultipartFields() map[string]string {
|
||||||
|
return map[string]string{"chat_id": "42"}
|
||||||
|
}
|
||||||
|
func (f *fakeMultipartReq) MultipartFiles() []MultipartFile {
|
||||||
|
return []MultipartFile{{
|
||||||
|
FieldName: "document",
|
||||||
|
Filename: "hello.txt",
|
||||||
|
Reader: strings.NewReader(f.body),
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileResp struct {
|
||||||
|
MessageID int64 `json:"message_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallMultipart_Success(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
|
||||||
|
ct := r.Header.Get("Content-Type")
|
||||||
|
if !strings.HasPrefix(ct, "multipart/form-data") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, params, err := mime.ParseMediaType(ct)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||||
|
seenChat := false
|
||||||
|
seenFile := false
|
||||||
|
for {
|
||||||
|
p, err := mr.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch p.FormName() {
|
||||||
|
case "chat_id":
|
||||||
|
body, _ := io.ReadAll(p)
|
||||||
|
seenChat = string(body) == "42"
|
||||||
|
case "document":
|
||||||
|
body, _ := io.ReadAll(p)
|
||||||
|
seenFile = string(body) == "hello world"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return seenChat && seenFile
|
||||||
|
})).Return(newResp(200, `{"ok":true,"result":{"message_id":99}}`), nil)
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
out, err := Call[*fakeMultipartReq, *fileResp](context.Background(), b, "sendDocument", &fakeMultipartReq{chatID: 42, body: "hello world"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(99), out.MessageID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallMultipart_NoGoroutineLeakOnError(t *testing.T) {
|
||||||
|
m := &mockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
|
||||||
|
|
||||||
|
b := New("t", WithHTTPClient(m))
|
||||||
|
before := runtime.NumGoroutine()
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
_, _ = Call[*fakeMultipartReq, *fileResp](
|
||||||
|
context.Background(), b, "sendDocument",
|
||||||
|
&fakeMultipartReq{chatID: 42, body: strings.Repeat("x", 1<<14)},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow goroutines to finish exiting after Close propagates.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
runtime.GC()
|
||||||
|
after := runtime.NumGoroutine()
|
||||||
|
|
||||||
|
// A small drift is normal (timers, finalizers); 5 is generous.
|
||||||
|
if after-before > 5 {
|
||||||
|
t.Fatalf("goroutine leak: before=%d after=%d", before, after)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
// Option configures a Bot at construction time. Per-call configuration is
|
||||||
|
// expressed via typed parameter structs (e.g. SendMessageParams), not options.
|
||||||
|
type Option func(*Bot)
|
||||||
|
|
||||||
|
// WithHTTPClient overrides the HTTP transport. Pass any HTTPDoer
|
||||||
|
// implementation (e.g. an *http.Client wrapping a custom RoundTripper, or
|
||||||
|
// a fasthttp adapter).
|
||||||
|
func WithHTTPClient(c HTTPDoer) Option { return func(b *Bot) { b.http = c } }
|
||||||
|
|
||||||
|
// WithCodec overrides the JSON codec. Pass goccy/go-json, sonic, or any
|
||||||
|
// type implementing Codec to swap out encoding/json.
|
||||||
|
func WithCodec(c Codec) Option { return func(b *Bot) { b.codec = c } }
|
||||||
|
|
||||||
|
// WithBaseURL overrides the API base URL. Useful for testing against a
|
||||||
|
// local httptest.Server, or for self-hosted Bot API servers.
|
||||||
|
func WithBaseURL(url string) Option { return func(b *Bot) { b.base = url } }
|
||||||
|
|
||||||
|
// WithLogger sets the logger used for diagnostic events. Passing nil
|
||||||
|
// silently disables logging.
|
||||||
|
func WithLogger(l Logger) Option {
|
||||||
|
return func(b *Bot) {
|
||||||
|
if l == nil {
|
||||||
|
l = NoopLogger{}
|
||||||
|
}
|
||||||
|
b.logger = l
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import "regexp"
|
||||||
|
|
||||||
|
// tokenInURL matches a Telegram bot token segment in a URL path. Tokens
|
||||||
|
// have the form <bot_id>:<api_key>, where bot_id is digits and api_key
|
||||||
|
// is 35 base64-url characters. The pattern is conservative: matches
|
||||||
|
// /bot<id>:<key>/ to avoid false positives.
|
||||||
|
var tokenInURL = regexp.MustCompile(`/bot(\d{5,15}):([A-Za-z0-9_-]{30,40})/`)
|
||||||
|
|
||||||
|
// redactToken replaces any bot token in s with /bot<REDACTED>/. Used by
|
||||||
|
// error formatters so logs don't leak credentials.
|
||||||
|
func redactToken(s string) string {
|
||||||
|
return tokenInURL.ReplaceAllString(s, "/bot<REDACTED>/")
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRedactToken(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"plain bot URL", "https://api.telegram.org/bot123456789:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe",
|
||||||
|
"https://api.telegram.org/bot<REDACTED>/getMe"},
|
||||||
|
{"in net/http error", `Post "https://api.telegram.org/bot987654321:Z9YxWvUtSrQpOnMlKjIhGfEdCbA9876543210/sendMessage": dial tcp: lookup api.telegram.org: no such host`,
|
||||||
|
`Post "https://api.telegram.org/bot<REDACTED>/sendMessage": dial tcp: lookup api.telegram.org: no such host`},
|
||||||
|
{"no token", "regular error message", "regular error message"},
|
||||||
|
{"underscore + dash in token", "/bot123456789:abc-def_ghi-jkl_mno-pqr_stu-vwx_yz/sendDocument",
|
||||||
|
"/bot<REDACTED>/sendDocument"},
|
||||||
|
{"too short id (no match)", "/bot123:abc/getMe", "/bot123:abc/getMe"},
|
||||||
|
{"too short key (no match)", "/bot123456789:short/getMe", "/bot123456789:short/getMe"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, c.want, redactToken(c.in))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkError_RedactsToken(t *testing.T) {
|
||||||
|
inner := errors.New(`Post "https://api.telegram.org/bot1234567890:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe": dial tcp: timeout`)
|
||||||
|
e := &NetworkError{Err: inner}
|
||||||
|
require.NotContains(t, e.Error(), "ABCdefGHIjklMNOpqrSTUvwxYZ")
|
||||||
|
require.Contains(t, e.Error(), "<REDACTED>")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseError_RedactsToken(t *testing.T) {
|
||||||
|
inner := fmt.Errorf(`unexpected response from /bot1234567890:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe`)
|
||||||
|
e := &ParseError{Err: inner, Body: []byte("garbage")}
|
||||||
|
require.NotContains(t, e.Error(), "ABCdefGHI")
|
||||||
|
require.Contains(t, e.Error(), "<REDACTED>")
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
// Result is the universal Telegram API response envelope. Every successful
|
||||||
|
// response is shaped {"ok":true,"result":T,...}; failure responses set ok
|
||||||
|
// to false and populate ErrorCode / Description / Parameters.
|
||||||
|
//
|
||||||
|
// Result is generic over T so generated method wrappers can decode the
|
||||||
|
// strongly-typed payload directly. Users do not normally construct or
|
||||||
|
// inspect Result values; method wrappers unwrap them and return either
|
||||||
|
// the typed payload or a *APIError.
|
||||||
|
type Result[T any] struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Result T `json:"result,omitempty"`
|
||||||
|
ErrorCode int `json:"error_code,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Parameters *ResponseParameters `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseParameters is the optional metadata Telegram includes on certain
|
||||||
|
// failures. The most common is RetryAfter (seconds) on 429 responses.
|
||||||
|
//
|
||||||
|
// This type is duplicated in package api for users; keeping a copy here
|
||||||
|
// avoids an import cycle (api imports client, not vice versa).
|
||||||
|
type ResponseParameters struct {
|
||||||
|
MigrateToChatID int64 `json:"migrate_to_chat_id,omitempty"`
|
||||||
|
RetryAfter int `json:"retry_after,omitempty"`
|
||||||
|
}
|
||||||
+225
@@ -0,0 +1,225 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RetryDoer is an HTTPDoer that retries transient failures (429, 5xx,
|
||||||
|
// and network errors) with exponential backoff. It honours the
|
||||||
|
// retry_after value Telegram supplies on rate-limit responses.
|
||||||
|
//
|
||||||
|
// Wrap any HTTPDoer to add retry behaviour:
|
||||||
|
//
|
||||||
|
// bot := client.New(token, client.WithHTTPClient(
|
||||||
|
// client.NewRetryDoer(client.NewDefaultHTTPDoer())))
|
||||||
|
type RetryDoer struct {
|
||||||
|
inner HTTPDoer
|
||||||
|
maxAttempts int
|
||||||
|
base time.Duration
|
||||||
|
max time.Duration
|
||||||
|
factor float64
|
||||||
|
jitter float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryOption configures a RetryDoer.
|
||||||
|
type RetryOption func(*RetryDoer)
|
||||||
|
|
||||||
|
// WithMaxAttempts sets the maximum number of attempts (including the
|
||||||
|
// initial one). Default 4 (one initial + three retries).
|
||||||
|
func WithMaxAttempts(n int) RetryOption {
|
||||||
|
return func(d *RetryDoer) { d.maxAttempts = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBaseBackoff sets the initial backoff duration. Default 500ms.
|
||||||
|
func WithBaseBackoff(d time.Duration) RetryOption {
|
||||||
|
return func(r *RetryDoer) { r.base = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxBackoff caps the backoff at max. Default 30s.
|
||||||
|
func WithMaxBackoff(d time.Duration) RetryOption {
|
||||||
|
return func(r *RetryDoer) { r.max = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBackoffFactor sets the exponential growth factor. Default 2.0.
|
||||||
|
func WithBackoffFactor(f float64) RetryOption {
|
||||||
|
return func(r *RetryDoer) { r.factor = f }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithJitter sets the jitter fraction (0..1) applied to each backoff.
|
||||||
|
// Default 0.2.
|
||||||
|
func WithJitter(j float64) RetryOption {
|
||||||
|
return func(r *RetryDoer) { r.jitter = j }
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRetryDoer wraps inner with retry behaviour.
|
||||||
|
func NewRetryDoer(inner HTTPDoer, opts ...RetryOption) *RetryDoer {
|
||||||
|
d := &RetryDoer{
|
||||||
|
inner: inner,
|
||||||
|
maxAttempts: 4,
|
||||||
|
base: 500 * time.Millisecond,
|
||||||
|
max: 30 * time.Second,
|
||||||
|
factor: 2.0,
|
||||||
|
jitter: 0.2,
|
||||||
|
}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(d)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do dispatches via the inner HTTPDoer and retries on transient failures.
|
||||||
|
// The request body is buffered on first attempt so it can be replayed.
|
||||||
|
func (d *RetryDoer) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
// Buffer the body so we can replay it across attempts.
|
||||||
|
var body []byte
|
||||||
|
if req.Body != nil {
|
||||||
|
b, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &NetworkError{Err: err}
|
||||||
|
}
|
||||||
|
_ = req.Body.Close()
|
||||||
|
body = b
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastResp *http.Response
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= d.maxAttempts; attempt++ {
|
||||||
|
if body != nil {
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
}
|
||||||
|
resp, err := d.inner.Do(req)
|
||||||
|
|
||||||
|
// Network errors: maybe retry.
|
||||||
|
if err != nil {
|
||||||
|
// Honour ctx cancellation.
|
||||||
|
if ctxErr := req.Context().Err(); ctxErr != nil {
|
||||||
|
return nil, ctxErr
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
if attempt < d.maxAttempts {
|
||||||
|
if !d.sleep(req.Context(), d.delay(attempt, 0)) {
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP 200: Telegram almost always returns 200 even for errors.
|
||||||
|
// Peek the body to detect retryable Telegram error payloads.
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
data, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, &NetworkError{Err: readErr}
|
||||||
|
}
|
||||||
|
// Re-attach the buffered body for the caller.
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(data))
|
||||||
|
|
||||||
|
if isRetryablePayload(data) && attempt < d.maxAttempts {
|
||||||
|
lastResp = resp
|
||||||
|
wait := retryAfterFromPayload(data)
|
||||||
|
if !d.sleep(req.Context(), d.delay(attempt, wait)) {
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-200 status (rare with Telegram; usually 200 + ok:false).
|
||||||
|
// Treat 5xx and 429 as retryable.
|
||||||
|
if (resp.StatusCode == http.StatusTooManyRequests ||
|
||||||
|
resp.StatusCode >= http.StatusInternalServerError) && attempt < d.maxAttempts {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
lastResp = resp
|
||||||
|
if !d.sleep(req.Context(), d.delay(attempt, 0)) {
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
return lastResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delay computes the wait duration for the given attempt (1-based).
|
||||||
|
// override, when non-zero, takes precedence (used to honour Telegram's
|
||||||
|
// retry_after value).
|
||||||
|
func (d *RetryDoer) delay(attempt int, override time.Duration) time.Duration {
|
||||||
|
if override > 0 {
|
||||||
|
return override
|
||||||
|
}
|
||||||
|
delay := float64(d.base) * math.Pow(d.factor, float64(attempt-1))
|
||||||
|
if d.jitter > 0 {
|
||||||
|
var b [8]byte
|
||||||
|
_, _ = crand.Read(b[:])
|
||||||
|
f := float64(binary.LittleEndian.Uint64(b[:])) / (1 << 64)
|
||||||
|
delay *= 1 + (f*2-1)*d.jitter
|
||||||
|
}
|
||||||
|
if delay > float64(d.max) {
|
||||||
|
delay = float64(d.max)
|
||||||
|
}
|
||||||
|
if delay < 0 {
|
||||||
|
delay = 0
|
||||||
|
}
|
||||||
|
return time.Duration(delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sleep waits for dur or ctx cancellation. Returns false if cancelled.
|
||||||
|
func (d *RetryDoer) sleep(ctx context.Context, dur time.Duration) bool {
|
||||||
|
if dur <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
t := time.NewTimer(dur)
|
||||||
|
defer t.Stop()
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
return true
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRetryablePayload reports whether body is a Telegram error response
|
||||||
|
// indicating a retryable failure (429 or 5xx error_code).
|
||||||
|
func isRetryablePayload(body []byte) bool {
|
||||||
|
var env struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &env); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if env.OK {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return env.ErrorCode == 429 || (env.ErrorCode >= 500 && env.ErrorCode < 600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryAfterFromPayload extracts the retry_after value from a Telegram
|
||||||
|
// error response body and returns it as a duration. Returns 0 if absent.
|
||||||
|
func retryAfterFromPayload(body []byte) time.Duration {
|
||||||
|
var env struct {
|
||||||
|
Parameters struct {
|
||||||
|
RetryAfter int `json:"retry_after"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &env); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return time.Duration(env.Parameters.RetryAfter) * time.Second
|
||||||
|
}
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type retryMockDoer struct{ mock.Mock }
|
||||||
|
|
||||||
|
func (m *retryMockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||||
|
args := m.Called(r)
|
||||||
|
if v := args.Get(0); v != nil {
|
||||||
|
return v.(*http.Response), args.Error(1)
|
||||||
|
}
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResp(body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_HappyPath(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":"hi"}`), nil).Once()
|
||||||
|
|
||||||
|
d := NewRetryDoer(m)
|
||||||
|
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
|
||||||
|
resp, err := d.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
m.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_RetriesOnNetworkError(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout")).Once()
|
||||||
|
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":"hi"}`), nil).Once()
|
||||||
|
|
||||||
|
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
|
||||||
|
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
|
||||||
|
resp, err := d.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
m.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_HonoursRetryAfter(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
okResp(`{"ok":false,"error_code":429,"description":"Too Many","parameters":{"retry_after":1}}`), nil).Once()
|
||||||
|
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":1}`), nil).Once()
|
||||||
|
|
||||||
|
// base is 10s — retry_after=1s should override it (much shorter wait).
|
||||||
|
d := NewRetryDoer(m, WithBaseBackoff(10*time.Second))
|
||||||
|
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := d.Do(req)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 900*time.Millisecond, "should honour retry_after=1s")
|
||||||
|
require.Less(t, elapsed, 3*time.Second, "should NOT use base backoff (10s)")
|
||||||
|
m.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_Retries5xx(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
okResp(`{"ok":false,"error_code":500,"description":"Internal Server Error"}`), nil).Once()
|
||||||
|
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":1}`), nil).Once()
|
||||||
|
|
||||||
|
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
|
||||||
|
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
|
||||||
|
resp, err := d.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 200, resp.StatusCode)
|
||||||
|
m.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_AllAttemptsFail(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
|
||||||
|
|
||||||
|
d := NewRetryDoer(m, WithMaxAttempts(3), WithBaseBackoff(time.Millisecond))
|
||||||
|
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
|
||||||
|
_, err := d.Do(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "dial timeout")
|
||||||
|
require.Equal(t, 3, len(m.Calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_ContextCancellationAborts(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
okResp(`{"ok":false,"error_code":500,"description":"server error"}`), nil).Maybe()
|
||||||
|
|
||||||
|
d := NewRetryDoer(m, WithBaseBackoff(100*time.Millisecond))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, "POST", "http://x", strings.NewReader(`{}`))
|
||||||
|
_, err := d.Do(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryDoer_ReplaysBody(t *testing.T) {
|
||||||
|
m := &retryMockDoer{}
|
||||||
|
var seen []string
|
||||||
|
|
||||||
|
// First call: capture body, return 500 to trigger retry.
|
||||||
|
m.On("Do", mock.Anything).Return(okResp(`{"ok":false,"error_code":500}`), nil).Once().Run(func(args mock.Arguments) {
|
||||||
|
r := args.Get(0).(*http.Request)
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
seen = append(seen, string(body))
|
||||||
|
})
|
||||||
|
// Second call: capture body, return success.
|
||||||
|
m.On("Do", mock.Anything).Return(okResp(`{"ok":true}`), nil).Once().Run(func(args mock.Arguments) {
|
||||||
|
r := args.Get(0).(*http.Request)
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
seen = append(seen, string(body))
|
||||||
|
})
|
||||||
|
|
||||||
|
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
|
||||||
|
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{"chat_id":42}`))
|
||||||
|
_, err := d.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, seen, 2)
|
||||||
|
require.Equal(t, seen[0], seen[1])
|
||||||
|
require.Equal(t, `{"chat_id":42}`, seen[0])
|
||||||
|
m.AssertExpectations(t)
|
||||||
|
}
|
||||||
@@ -0,0 +1,378 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// loadIR
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestLoadIR_ValidFile(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(api)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tmp := filepath.Join(t.TempDir(), "api.json")
|
||||||
|
require.NoError(t, os.WriteFile(tmp, data, 0o600))
|
||||||
|
|
||||||
|
loaded, err := loadIR(tmp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, loaded.Methods, 1)
|
||||||
|
require.Equal(t, "getMe", loaded.Methods[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadIR_MissingFile(t *testing.T) {
|
||||||
|
_, err := loadIR("/nonexistent/path/api.json")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "open IR")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadIR_InvalidJSON(t *testing.T) {
|
||||||
|
tmp := filepath.Join(t.TempDir(), "bad.json")
|
||||||
|
require.NoError(t, os.WriteFile(tmp, []byte("not json"), 0o600))
|
||||||
|
|
||||||
|
_, err := loadIR(tmp)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "decode IR")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// auditBool
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuditBool_LongDocTruncated(t *testing.T) {
|
||||||
|
longDoc := make([]byte, 200)
|
||||||
|
for i := range longDoc {
|
||||||
|
longDoc[i] = 'a'
|
||||||
|
}
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "myMethod", Doc: string(longDoc), Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
problems := auditBool(api, &spec.Overrides{})
|
||||||
|
require.Len(t, problems, 1)
|
||||||
|
require.Contains(t, problems[0], "…")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditBool_TrueIsReturnedVariant(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "doThing", Doc: "true is returned on success.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, auditBool(api, &spec.Overrides{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditBool_ReturnsBoolean(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "doThing", Doc: "Returns Boolean on success.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, auditBool(api, &spec.Overrides{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// formatTypeRef
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestFormatTypeRef_AllBranches(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
tr spec.TypeRef
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}, "bool"},
|
||||||
|
{spec.TypeRef{Kind: spec.KindNamed, Name: "User"}, "User"},
|
||||||
|
{spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}, "[]Update"},
|
||||||
|
{spec.TypeRef{Kind: spec.KindArray}, "[]any"},
|
||||||
|
{spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}, "(int64 | string)"},
|
||||||
|
{spec.TypeRef{Kind: spec.Kind(99)}, "?"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
got := formatTypeRef(c.tr)
|
||||||
|
require.Equal(t, c.want, got, "for kind=%v name=%v", c.tr.Kind, c.tr.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// auditDrift
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuditDrift_InvalidRefReturnsError(t *testing.T) {
|
||||||
|
cur := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := auditDrift("internal/spec/api.json", "THIS_REF_DOES_NOT_EXIST", cur)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditDrift_SameRefNoDrift(t *testing.T) {
|
||||||
|
irPath := "../../internal/spec/api.json"
|
||||||
|
cur, err := loadIR(irPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("api.json not available, skipping drift test")
|
||||||
|
}
|
||||||
|
changes, err := auditDrift(irPath, "HEAD", cur)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, changes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditDrift_InvalidJSONFromGit(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("shell script not supported on Windows")
|
||||||
|
}
|
||||||
|
tmp := t.TempDir()
|
||||||
|
fakeGit := filepath.Join(tmp, "git")
|
||||||
|
require.NoError(t, os.WriteFile(fakeGit, []byte("#!/bin/sh\necho 'not valid json'\n"), 0o600))
|
||||||
|
require.NoError(t, os.Chmod(fakeGit, 0o755))
|
||||||
|
|
||||||
|
origPATH := os.Getenv("PATH")
|
||||||
|
t.Cleanup(func() { _ = os.Setenv("PATH", origPATH) })
|
||||||
|
_ = os.Setenv("PATH", tmp+string(os.PathListSeparator)+origPATH)
|
||||||
|
|
||||||
|
_, err := auditDrift("internal/spec/api.json", "HEAD", &spec.API{})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "decode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// auditAny
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuditAny_FlagsUnknownMethodReturn(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{
|
||||||
|
Name: "weirdMethod",
|
||||||
|
Returns: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B", "C"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := auditAny(api)
|
||||||
|
require.Len(t, out, 1)
|
||||||
|
require.Contains(t, out[0], "any return: weirdMethod")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditAny_FlagsUnknownMethodParam(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{
|
||||||
|
Name: "weirdMethod",
|
||||||
|
Params: []spec.Field{
|
||||||
|
{Name: "Thing", JSONName: "thing", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"X", "Y", "Z"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := auditAny(api)
|
||||||
|
require.Len(t, out, 1)
|
||||||
|
require.Contains(t, out[0], "any param: weirdMethod.Thing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// diffSignatures
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestDiffSignatures_UnchangedNoDrift(t *testing.T) {
|
||||||
|
prev := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cur := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, diffSignatures(prev, cur))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// typeRefEqual
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestTypeRefEqual_ArrayNilElemDiffers(t *testing.T) {
|
||||||
|
a := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
b := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}
|
||||||
|
require.False(t, typeRefEqual(a, b))
|
||||||
|
require.False(t, typeRefEqual(b, a))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestHelperMain: subprocess helper for main() coverage.
|
||||||
|
//
|
||||||
|
// When AUDIT_HELPER_MAIN=1 is set, this function:
|
||||||
|
// 1. Resets flag.CommandLine so main()'s flag.Parse() gets a clean slate.
|
||||||
|
// 2. Sets os.Args to the args encoded in AUDIT_HELPER_ARGS.
|
||||||
|
// 3. Calls main() which calls os.Exit — the test process terminates with
|
||||||
|
// main's exit code, which the parent test captures.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHelperMain(t *testing.T) {
|
||||||
|
if os.Getenv("AUDIT_HELPER_MAIN") != "1" {
|
||||||
|
t.Skip("not running as subprocess helper")
|
||||||
|
}
|
||||||
|
// Reset flag.CommandLine so main()'s flag.Parse() gets a clean slate.
|
||||||
|
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
|
||||||
|
|
||||||
|
// Decode args from JSON-encoded env var.
|
||||||
|
encoded := os.Getenv("AUDIT_HELPER_ARGS")
|
||||||
|
if encoded != "" {
|
||||||
|
var args []string
|
||||||
|
if err := json.Unmarshal([]byte(encoded), &args); err == nil {
|
||||||
|
os.Args = append([]string{os.Args[0]}, args...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
os.Args = os.Args[:1]
|
||||||
|
}
|
||||||
|
|
||||||
|
main()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMain runs main() via the test binary subprocess so that coverage counters
|
||||||
|
// from main() are included in the profile. Args are JSON-encoded in an env var
|
||||||
|
// to avoid conflicts with the test binary's own flag parsing.
|
||||||
|
func runMain(t *testing.T, extraEnv []string, args ...string) (string, int) {
|
||||||
|
t.Helper()
|
||||||
|
argsJSON, _ := json.Marshal(args)
|
||||||
|
cmd := exec.Command(os.Args[0], "-test.run=TestHelperMain", "-test.v=false")
|
||||||
|
cmd.Env = append(os.Environ(), "AUDIT_HELPER_MAIN=1", "AUDIT_HELPER_ARGS="+string(argsJSON))
|
||||||
|
cmd.Env = append(cmd.Env, extraEnv...)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
code := 0
|
||||||
|
if err != nil {
|
||||||
|
if ee, ok := err.(*exec.ExitError); ok {
|
||||||
|
code = ee.ExitCode()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(out), code
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// main() integration tests — exercise main() code paths via subprocess
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMain_CleanExitsZero(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
ir := writeIR(t, tmp, &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Doc: "Returns True on success.", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ov := writeOverrides(t, tmp)
|
||||||
|
|
||||||
|
out, code := runMain(t, nil, "-ir", ir, "-overrides", ov)
|
||||||
|
require.Equal(t, exitClean, code, "expected exit 0 (clean)\nout: %s", out)
|
||||||
|
require.Contains(t, out, "clean")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain_FallbackExitsOne(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
ir := writeIR(t, tmp, &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "doSomething", Doc: "Does something.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ov := writeOverrides(t, tmp)
|
||||||
|
|
||||||
|
out, code := runMain(t, nil, "-ir", ir, "-overrides", ov)
|
||||||
|
require.Equal(t, exitFallback, code, "expected exit 1 (fallback)\nout: %s", out)
|
||||||
|
require.Contains(t, out, "bool fallback")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain_InvalidIRExitsThree(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
bad := filepath.Join(tmp, "bad.json")
|
||||||
|
require.NoError(t, os.WriteFile(bad, []byte("not json"), 0o600))
|
||||||
|
ov := writeOverrides(t, tmp)
|
||||||
|
|
||||||
|
out, code := runMain(t, nil, "-ir", bad, "-overrides", ov)
|
||||||
|
require.Equal(t, exitInvalid, code, "expected exit 3 (invalid IR)\nout: %s", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain_InvalidOverridesExitsThree(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
ir := writeIR(t, tmp, &spec.API{})
|
||||||
|
bad := filepath.Join(tmp, "bad_ov.json")
|
||||||
|
require.NoError(t, os.WriteFile(bad, []byte("not json"), 0o600))
|
||||||
|
|
||||||
|
out, code := runMain(t, nil, "-ir", ir, "-overrides", bad)
|
||||||
|
require.Equal(t, exitInvalid, code, "expected exit 3 (invalid overrides)\nout: %s", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMain_DriftDetectedExitsTwo(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("shell script not supported on Windows")
|
||||||
|
}
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
prevAPI := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
curAPI := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
curIR := writeIR(t, tmp, curAPI)
|
||||||
|
ov := writeOverrides(t, tmp)
|
||||||
|
|
||||||
|
prevData, err := json.Marshal(prevAPI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
prevFile := filepath.Join(tmp, "prev.json")
|
||||||
|
require.NoError(t, os.WriteFile(prevFile, prevData, 0o600))
|
||||||
|
|
||||||
|
fakeGit := filepath.Join(tmp, "git")
|
||||||
|
script := fmt.Sprintf("#!/bin/sh\ncat %s\n", prevFile)
|
||||||
|
require.NoError(t, os.WriteFile(fakeGit, []byte(script), 0o600))
|
||||||
|
require.NoError(t, os.Chmod(fakeGit, 0o755))
|
||||||
|
|
||||||
|
newPATH := tmp + string(os.PathListSeparator) + os.Getenv("PATH")
|
||||||
|
|
||||||
|
out, code := runMain(t,
|
||||||
|
[]string{"PATH=" + newPATH},
|
||||||
|
"-ir", curIR, "-overrides", ov, "-drift", "-against", "HEAD~1",
|
||||||
|
)
|
||||||
|
require.Equal(t, exitDrift, code, "expected exit 2 (drift)\nout: %s", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func writeIR(t *testing.T, dir string, api *spec.API) string {
|
||||||
|
t.Helper()
|
||||||
|
data, err := json.Marshal(api)
|
||||||
|
require.NoError(t, err)
|
||||||
|
p := filepath.Join(dir, "api.json")
|
||||||
|
require.NoError(t, os.WriteFile(p, data, 0o600))
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOverrides(t *testing.T, dir string) string {
|
||||||
|
t.Helper()
|
||||||
|
p := filepath.Join(dir, "overrides.json")
|
||||||
|
require.NoError(t, os.WriteFile(p, []byte("{}"), 0o600))
|
||||||
|
return p
|
||||||
|
}
|
||||||
@@ -0,0 +1,291 @@
|
|||||||
|
// Command audit reports IR-level codegen fallbacks and signature drift.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// audit -ir <path> (default internal/spec/api.json)
|
||||||
|
// audit -overrides <path> (default internal/spec/overrides.json)
|
||||||
|
// audit -drift (compare against -against ref's IR; off by default)
|
||||||
|
// audit -against <ref> (git ref to diff drift against; default HEAD~1)
|
||||||
|
//
|
||||||
|
// Exit codes:
|
||||||
|
//
|
||||||
|
// 0 — clean
|
||||||
|
// 1 — unaccounted bool fallbacks or any-typed fields
|
||||||
|
// 2 — drift detected (signature changed)
|
||||||
|
// 3 — invalid IR or overrides
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
exitClean = 0
|
||||||
|
exitFallback = 1
|
||||||
|
exitDrift = 2
|
||||||
|
exitInvalid = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
irPath := flag.String("ir", "internal/spec/api.json", "path to IR JSON")
|
||||||
|
ovPath := flag.String("overrides", "internal/spec/overrides.json", "path to overrides JSON")
|
||||||
|
checkDrift := flag.Bool("drift", false, "compare against -against ref's IR for signature changes")
|
||||||
|
againstRef := flag.String("against", "HEAD~1", "git ref to diff drift against (e.g. origin/main, HEAD~1)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
api, err := loadIR(*irPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "audit:", err)
|
||||||
|
os.Exit(exitInvalid)
|
||||||
|
}
|
||||||
|
overrides, err := spec.LoadOverrides(*ovPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "audit:", err)
|
||||||
|
os.Exit(exitInvalid)
|
||||||
|
}
|
||||||
|
|
||||||
|
var problems []string
|
||||||
|
|
||||||
|
problems = append(problems, auditBool(api, overrides)...)
|
||||||
|
problems = append(problems, auditAny(api)...)
|
||||||
|
|
||||||
|
driftFound := false
|
||||||
|
if *checkDrift {
|
||||||
|
if d, err := auditDrift(*irPath, *againstRef, api); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "audit: drift check skipped:", err)
|
||||||
|
} else if len(d) > 0 {
|
||||||
|
fmt.Println("Drift detected (signatures changed since HEAD):")
|
||||||
|
for _, p := range d {
|
||||||
|
fmt.Println(" ", p)
|
||||||
|
}
|
||||||
|
driftFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(problems) == 0 && !driftFound {
|
||||||
|
fmt.Println("audit: clean")
|
||||||
|
os.Exit(exitClean)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(problems) > 0 {
|
||||||
|
fmt.Println("Codegen fallbacks requiring action:")
|
||||||
|
for _, p := range problems {
|
||||||
|
fmt.Println(" ", p)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("To resolve: extend cmd/scrape/method.go regex patterns OR")
|
||||||
|
fmt.Println("add an entry to internal/spec/overrides.json.")
|
||||||
|
os.Exit(exitFallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// drift only, no fallbacks
|
||||||
|
os.Exit(exitDrift)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadIR(path string) (*spec.API, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open IR: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
var api spec.API
|
||||||
|
if err := json.NewDecoder(f).Decode(&api); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode IR: %w", err)
|
||||||
|
}
|
||||||
|
return &api, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// auditBool returns problems for methods returning bool whose docs don't
|
||||||
|
// actually say "Returns True" / etc. and which aren't in the approved list.
|
||||||
|
func auditBool(api *spec.API, ov *spec.Overrides) []string {
|
||||||
|
var out []string
|
||||||
|
for _, m := range api.Methods {
|
||||||
|
if m.Returns.Kind != spec.KindPrimitive || m.Returns.Name != "bool" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ov.IsBoolApproved(m.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if looksGenuinelyBool(m.Doc) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snippet := m.Doc
|
||||||
|
if len(snippet) > 120 {
|
||||||
|
snippet = snippet[:120] + "…"
|
||||||
|
}
|
||||||
|
out = append(out, fmt.Sprintf("bool fallback: %s — doc: %q", m.Name, snippet))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksGenuinelyBool(doc string) bool {
|
||||||
|
for _, p := range []string{
|
||||||
|
"Returns True", "Returns true",
|
||||||
|
"True is returned", "true is returned",
|
||||||
|
"Returns Boolean", "Returns Bool",
|
||||||
|
} {
|
||||||
|
if strings.Contains(doc, p) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// auditAny scans the IR for any KindOneOf TypeRef that would render as
|
||||||
|
// `any` in generated code (not matched by ChatID/InputFile-or-string/known
|
||||||
|
// union heuristics). Reports each occurrence with location.
|
||||||
|
func auditAny(api *spec.API) []string {
|
||||||
|
var out []string
|
||||||
|
isKnownUnion := func(variants []string) bool {
|
||||||
|
if hasVariants(variants, "int64", "string") {
|
||||||
|
return true // ChatID
|
||||||
|
}
|
||||||
|
if hasVariants(variants, "InputFile", "string") {
|
||||||
|
return true // *InputFile
|
||||||
|
}
|
||||||
|
// ReplyMarkup union: all four keyboard types — emitter renders as `any` intentionally
|
||||||
|
if hasVariants(variants, "InlineKeyboardMarkup", "ReplyKeyboardMarkup", "ReplyKeyboardRemove", "ForceReply") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, t := range api.Types {
|
||||||
|
if len(t.OneOf) > 0 && sameSet(variants, t.OneOf) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
isAny := func(tr spec.TypeRef) bool {
|
||||||
|
return tr.Kind == spec.KindOneOf && !isKnownUnion(tr.Variants)
|
||||||
|
}
|
||||||
|
for _, t := range api.Types {
|
||||||
|
for _, f := range t.Fields {
|
||||||
|
if isAny(f.Type) {
|
||||||
|
out = append(out, fmt.Sprintf("any field: %s.%s (variants=%v)", t.Name, f.Name, f.Type.Variants))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range api.Methods {
|
||||||
|
if isAny(m.Returns) {
|
||||||
|
out = append(out, fmt.Sprintf("any return: %s (variants=%v)", m.Name, m.Returns.Variants))
|
||||||
|
}
|
||||||
|
for _, p := range m.Params {
|
||||||
|
if isAny(p.Type) {
|
||||||
|
out = append(out, fmt.Sprintf("any param: %s.%s (variants=%v)", m.Name, p.Name, p.Type.Variants))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasVariants(got []string, want ...string) bool {
|
||||||
|
if len(got) != len(want) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
seen := map[string]int{}
|
||||||
|
for _, g := range got {
|
||||||
|
seen[g]++
|
||||||
|
}
|
||||||
|
for _, w := range want {
|
||||||
|
seen[w]--
|
||||||
|
}
|
||||||
|
for _, v := range seen {
|
||||||
|
if v != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameSet(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return hasVariants(a, b...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// auditDrift compares method/return signatures between the given git ref's
|
||||||
|
// version of irPath and the in-memory current IR. Returns a list of
|
||||||
|
// human-readable change descriptions.
|
||||||
|
func auditDrift(irPath, againstRef string, current *spec.API) ([]string, error) {
|
||||||
|
cmd := exec.Command("git", "show", againstRef+":"+irPath) // #nosec G204 - operator tool, ref controlled by caller
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("git show %s: %w", againstRef, err)
|
||||||
|
}
|
||||||
|
var prev spec.API
|
||||||
|
if err := json.Unmarshal(out, &prev); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode %s IR: %w", againstRef, err)
|
||||||
|
}
|
||||||
|
return diffSignatures(&prev, current), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func diffSignatures(prev, cur *spec.API) []string {
|
||||||
|
var changes []string
|
||||||
|
|
||||||
|
pmeth := indexByName(prev.Methods, func(m spec.MethodDecl) string { return m.Name })
|
||||||
|
cmeth := indexByName(cur.Methods, func(m spec.MethodDecl) string { return m.Name })
|
||||||
|
|
||||||
|
for name, p := range pmeth {
|
||||||
|
c, ok := cmeth[name]
|
||||||
|
if !ok {
|
||||||
|
changes = append(changes, fmt.Sprintf("removed method: %s", name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !typeRefEqual(p.Returns, c.Returns) {
|
||||||
|
changes = append(changes, fmt.Sprintf(
|
||||||
|
"method %s return changed: %s → %s",
|
||||||
|
name, formatTypeRef(p.Returns), formatTypeRef(c.Returns)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for name := range cmeth {
|
||||||
|
if _, ok := pmeth[name]; !ok {
|
||||||
|
changes = append(changes, fmt.Sprintf("added method: %s", name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changes
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexByName[T any](xs []T, f func(T) string) map[string]T {
|
||||||
|
out := map[string]T{}
|
||||||
|
for _, x := range xs {
|
||||||
|
out[f(x)] = x
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeRefEqual(a, b spec.TypeRef) bool {
|
||||||
|
if a.Kind != b.Kind || a.Name != b.Name {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if (a.ElemType == nil) != (b.ElemType == nil) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.ElemType != nil && !typeRefEqual(*a.ElemType, *b.ElemType) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sameSet(a.Variants, b.Variants)
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatTypeRef(t spec.TypeRef) string {
|
||||||
|
switch t.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
return t.Name
|
||||||
|
case spec.KindNamed:
|
||||||
|
return t.Name
|
||||||
|
case spec.KindArray:
|
||||||
|
if t.ElemType != nil {
|
||||||
|
return "[]" + formatTypeRef(*t.ElemType)
|
||||||
|
}
|
||||||
|
return "[]any"
|
||||||
|
case spec.KindOneOf:
|
||||||
|
return "(" + strings.Join(t.Variants, " | ") + ")"
|
||||||
|
}
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
@@ -0,0 +1,216 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- auditBool -----------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuditBool_FlagsUnapprovedBoolMethod(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Doc: "A simple method.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ov := &spec.Overrides{}
|
||||||
|
problems := auditBool(api, ov)
|
||||||
|
require.Len(t, problems, 1)
|
||||||
|
require.Contains(t, problems[0], "bool fallback: getMe")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditBool_SkipsApprovedBoolMethod(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "setWebhook", Doc: "Use this to set webhook.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ov := &spec.Overrides{ApprovedBoolMethods: []string{"setWebhook"}}
|
||||||
|
problems := auditBool(api, ov)
|
||||||
|
require.Empty(t, problems)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditBool_SkipsMethodWithReturnsTrueDoc(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "doThing", Doc: "Returns True on success.", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ov := &spec.Overrides{}
|
||||||
|
problems := auditBool(api, ov)
|
||||||
|
require.Empty(t, problems)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditBool_SkipsNonBoolMethods(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Doc: "Gets user.", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ov := &spec.Overrides{}
|
||||||
|
require.Empty(t, auditBool(api, ov))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- auditAny ------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuditAny_FlagsUnrecognisedOneOf(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Types: []spec.TypeDecl{
|
||||||
|
{
|
||||||
|
Name: "Foo",
|
||||||
|
Fields: []spec.Field{
|
||||||
|
{Name: "Bar", JSONName: "bar", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B", "C"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := auditAny(api)
|
||||||
|
require.Len(t, out, 1)
|
||||||
|
require.Contains(t, out[0], "any field: Foo.Bar")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditAny_SkipsChatIDShape(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Types: []spec.TypeDecl{
|
||||||
|
{
|
||||||
|
Name: "SendMessage",
|
||||||
|
Fields: []spec.Field{
|
||||||
|
{Name: "ChatID", JSONName: "chat_id", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, auditAny(api))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditAny_SkipsKnownUnion(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Types: []spec.TypeDecl{
|
||||||
|
{Name: "InputMedia", OneOf: []string{"InputMediaPhoto", "InputMediaVideo"}},
|
||||||
|
{
|
||||||
|
Name: "SomeMethod",
|
||||||
|
Fields: []spec.Field{
|
||||||
|
{Name: "Media", JSONName: "media", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputMediaPhoto", "InputMediaVideo"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, auditAny(api))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditAny_SkipsReplyMarkupShape(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{
|
||||||
|
Name: "sendMessage",
|
||||||
|
Params: []spec.Field{
|
||||||
|
{Name: "ReplyMarkup", JSONName: "reply_markup", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InlineKeyboardMarkup", "ReplyKeyboardMarkup", "ReplyKeyboardRemove", "ForceReply"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, auditAny(api))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditAny_SkipsInputFileShape(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{
|
||||||
|
Name: "sendPhoto",
|
||||||
|
Params: []spec.Field{
|
||||||
|
{Name: "Photo", JSONName: "photo", Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, auditAny(api))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- diffSignatures ------------------------------------------------------
|
||||||
|
|
||||||
|
func TestDiffSignatures_AddedMethod(t *testing.T) {
|
||||||
|
prev := &spec.API{}
|
||||||
|
cur := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "newMethod", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
changes := diffSignatures(prev, cur)
|
||||||
|
require.Len(t, changes, 1)
|
||||||
|
require.Contains(t, changes[0], "added method: newMethod")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiffSignatures_RemovedMethod(t *testing.T) {
|
||||||
|
prev := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "oldMethod", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cur := &spec.API{}
|
||||||
|
changes := diffSignatures(prev, cur)
|
||||||
|
require.Len(t, changes, 1)
|
||||||
|
require.Contains(t, changes[0], "removed method: oldMethod")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiffSignatures_ChangedReturn(t *testing.T) {
|
||||||
|
prev := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cur := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
changes := diffSignatures(prev, cur)
|
||||||
|
require.Len(t, changes, 1)
|
||||||
|
require.Contains(t, changes[0], "getMe")
|
||||||
|
require.Contains(t, changes[0], "bool")
|
||||||
|
require.Contains(t, changes[0], "User")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiffSignatures_Clean(t *testing.T) {
|
||||||
|
api := &spec.API{
|
||||||
|
Methods: []spec.MethodDecl{
|
||||||
|
{Name: "getMe", Returns: spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, diffSignatures(api, api))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- typeRefEqual --------------------------------------------------------
|
||||||
|
|
||||||
|
func TestTypeRefEqual_Primitive(t *testing.T) {
|
||||||
|
a := spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
|
||||||
|
require.True(t, typeRefEqual(a, a))
|
||||||
|
b := spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}
|
||||||
|
require.False(t, typeRefEqual(a, b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeRefEqual_Array(t *testing.T) {
|
||||||
|
elem := &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}
|
||||||
|
a := spec.TypeRef{Kind: spec.KindArray, ElemType: elem}
|
||||||
|
b := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}
|
||||||
|
require.True(t, typeRefEqual(a, b))
|
||||||
|
|
||||||
|
c := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}}
|
||||||
|
require.False(t, typeRefEqual(a, c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeRefEqual_OneOf(t *testing.T) {
|
||||||
|
a := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}
|
||||||
|
b := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"string", "int64"}}
|
||||||
|
require.True(t, typeRefEqual(a, b))
|
||||||
|
|
||||||
|
c := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64"}}
|
||||||
|
require.False(t, typeRefEqual(a, c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeRefEqual_NilVsNonNilElem(t *testing.T) {
|
||||||
|
a := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
b := spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}
|
||||||
|
require.False(t, typeRefEqual(a, b))
|
||||||
|
}
|
||||||
@@ -0,0 +1,749 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"go/format"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed types.tmpl
|
||||||
|
var typesTmpl string
|
||||||
|
|
||||||
|
//go:embed methods.tmpl
|
||||||
|
var methodsTmpl string
|
||||||
|
|
||||||
|
//go:embed enums.tmpl
|
||||||
|
var enumsTmpl string
|
||||||
|
|
||||||
|
//go:embed tests.tmpl
|
||||||
|
var testsTmpl string
|
||||||
|
|
||||||
|
// runtimeTypes lists types that are intentionally hand-coded and must not be
|
||||||
|
// emitted by the code generator. Skipping them prevents collisions between
|
||||||
|
// generated and hand-coded definitions.
|
||||||
|
var runtimeTypes = map[string]bool{
|
||||||
|
"InputFile": true,
|
||||||
|
"ResponseParameters": true,
|
||||||
|
"ChatID": true,
|
||||||
|
"MessageOrBool": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// discriminatorSpec describes how to decode a sealed-interface union by
|
||||||
|
// peeking at a single JSON field.
|
||||||
|
type discriminatorSpec struct {
|
||||||
|
Field string // JSON field name to peek at
|
||||||
|
Variants map[string]string // discriminator value → concrete Go type name
|
||||||
|
}
|
||||||
|
|
||||||
|
// knownDiscriminators maps parent union name → discriminator spec.
|
||||||
|
// Used by the template helpers hasDiscriminator / discriminatorField /
|
||||||
|
// discriminatorMap to emit UnmarshalXxx helpers.
|
||||||
|
var knownDiscriminators = map[string]discriminatorSpec{
|
||||||
|
"ChatMember": {
|
||||||
|
Field: "status",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"creator": "ChatMemberOwner",
|
||||||
|
"administrator": "ChatMemberAdministrator",
|
||||||
|
"member": "ChatMemberMember",
|
||||||
|
"restricted": "ChatMemberRestricted",
|
||||||
|
"left": "ChatMemberLeft",
|
||||||
|
"kicked": "ChatMemberBanned",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"MessageOrigin": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"user": "MessageOriginUser",
|
||||||
|
"hidden_user": "MessageOriginHiddenUser",
|
||||||
|
"chat": "MessageOriginChat",
|
||||||
|
"channel": "MessageOriginChannel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ReactionType": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"emoji": "ReactionTypeEmoji",
|
||||||
|
"custom_emoji": "ReactionTypeCustomEmoji",
|
||||||
|
"paid": "ReactionTypePaid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"PaidMedia": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"preview": "PaidMediaPreview",
|
||||||
|
"photo": "PaidMediaPhoto",
|
||||||
|
"video": "PaidMediaVideo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"BackgroundType": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"fill": "BackgroundTypeFill",
|
||||||
|
"wallpaper": "BackgroundTypeWallpaper",
|
||||||
|
"pattern": "BackgroundTypePattern",
|
||||||
|
"chat_theme": "BackgroundTypeChatTheme",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"BackgroundFill": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"solid": "BackgroundFillSolid",
|
||||||
|
"gradient": "BackgroundFillGradient",
|
||||||
|
"freeform_gradient": "BackgroundFillFreeformGradient",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ChatBoostSource": {
|
||||||
|
Field: "source",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"premium": "ChatBoostSourcePremium",
|
||||||
|
"gift_code": "ChatBoostSourceGiftCode",
|
||||||
|
"giveaway": "ChatBoostSourceGiveaway",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"RevenueWithdrawalState": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"pending": "RevenueWithdrawalStatePending",
|
||||||
|
"succeeded": "RevenueWithdrawalStateSucceeded",
|
||||||
|
"failed": "RevenueWithdrawalStateFailed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"TransactionPartner": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"fragment": "TransactionPartnerFragment",
|
||||||
|
"user": "TransactionPartnerUser",
|
||||||
|
"telegram_ads": "TransactionPartnerTelegramAds",
|
||||||
|
"telegram_api": "TransactionPartnerTelegramApi",
|
||||||
|
"other": "TransactionPartnerOther",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"MenuButton": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"commands": "MenuButtonCommands",
|
||||||
|
"web_app": "MenuButtonWebApp",
|
||||||
|
"default": "MenuButtonDefault",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"OwnedGift": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"regular": "OwnedGiftRegular",
|
||||||
|
"unique": "OwnedGiftUnique",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"StoryAreaType": {
|
||||||
|
Field: "type",
|
||||||
|
Variants: map[string]string{
|
||||||
|
"location": "StoryAreaTypeLocation",
|
||||||
|
"suggested_reaction": "StoryAreaTypeSuggestedReaction",
|
||||||
|
"link": "StoryAreaTypeLink",
|
||||||
|
"weather": "StoryAreaTypeWeather",
|
||||||
|
"unique_gift": "StoryAreaTypeUniqueGift",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// MaybeInaccessibleMessage uses an integer discriminator (date field).
|
||||||
|
// Variants is nil — the standard template block is skipped; a
|
||||||
|
// hand-coded UnmarshalMaybeInaccessibleMessage is emitted instead.
|
||||||
|
"MaybeInaccessibleMessage": {
|
||||||
|
Field: "",
|
||||||
|
Variants: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitter renders Go source from a spec.API IR.
|
||||||
|
type emitter struct {
|
||||||
|
api *spec.API
|
||||||
|
outDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmitter(api *spec.API, outDir string) *emitter {
|
||||||
|
return &emitter{api: api, outDir: outDir}
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitTypes renders types.gen.go.
|
||||||
|
func (e *emitter) emitTypes() error {
|
||||||
|
t, err := template.New("types").Funcs(funcs()).Parse(typesTmpl)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse types.tmpl: %w", err)
|
||||||
|
}
|
||||||
|
filtered := *e.api
|
||||||
|
filtered.Types = nil
|
||||||
|
for _, typ := range e.api.Types {
|
||||||
|
if !runtimeTypes[typ.Name] {
|
||||||
|
filtered.Types = append(filtered.Types, typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if execErr := t.Execute(&buf, &filtered); execErr != nil {
|
||||||
|
return fmt.Errorf("execute types.tmpl: %w", execErr)
|
||||||
|
}
|
||||||
|
src, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
// Surface the unformatted output so debugging is possible.
|
||||||
|
return fmt.Errorf("gofmt types.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join(e.outDir, "types.gen.go"), src, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAPI reads and decodes the IR JSON.
|
||||||
|
func loadAPI(path string) (*spec.API, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var api spec.API
|
||||||
|
if err := json.Unmarshal(data, &api); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &api, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// funcs is the FuncMap shared across templates.
|
||||||
|
func funcs() template.FuncMap {
|
||||||
|
return template.FuncMap{
|
||||||
|
"goType": goType,
|
||||||
|
"goField": goField,
|
||||||
|
"docComment": docComment,
|
||||||
|
"isOptional": func(f spec.Field) bool { return !f.Required },
|
||||||
|
"not": func(b bool) bool { return !b },
|
||||||
|
"title": title,
|
||||||
|
"isFileField": isFileField,
|
||||||
|
"fileCheck": fileCheck,
|
||||||
|
"multipartFieldEntry": multipartFieldEntry,
|
||||||
|
"multipartFileEntry": multipartFileEntry,
|
||||||
|
"returnGoType": returnGoType,
|
||||||
|
// discriminator helpers for types.tmpl
|
||||||
|
"hasDiscriminator": func(name string) bool { s, ok := knownDiscriminators[name]; return ok && len(s.Variants) > 0 },
|
||||||
|
"isSealedUnionReturn": func(tr spec.TypeRef) bool {
|
||||||
|
if tr.Kind != spec.KindNamed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := knownDiscriminators[tr.Name]
|
||||||
|
return ok && len(s.Variants) > 0
|
||||||
|
},
|
||||||
|
"isMaybeInaccessibleMessage": func(name string) bool { return name == "MaybeInaccessibleMessage" },
|
||||||
|
"discriminatorField": func(name string) string { return knownDiscriminators[name].Field },
|
||||||
|
"discriminatorMap": func(name string) map[string]string { return knownDiscriminators[name].Variants },
|
||||||
|
// union-field helpers for per-struct UnmarshalJSON emission
|
||||||
|
"unionFields": unionFieldsOf,
|
||||||
|
"isArrayUnion": func(tr spec.TypeRef) bool { return hasUnionElem(tr) },
|
||||||
|
"unionTypeName": func(tr spec.TypeRef) string { name, _ := unionTypeFor(tr); return name },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// title upper-cases the first byte of s (ASCII only — all Telegram method names are ASCII).
|
||||||
|
func title(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
r := s[0]
|
||||||
|
if r >= 'a' && r <= 'z' {
|
||||||
|
r = r - 'a' + 'A'
|
||||||
|
}
|
||||||
|
return string(r) + s[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// isFileField reports whether the field carries an InputFile.
|
||||||
|
func isFileField(f spec.Field) bool {
|
||||||
|
return mentionsInputFileTr(f.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mentionsInputFileTr(tr spec.TypeRef) bool {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindNamed:
|
||||||
|
return tr.Name == "InputFile"
|
||||||
|
case spec.KindArray:
|
||||||
|
if tr.ElemType != nil {
|
||||||
|
return mentionsInputFileTr(*tr.ElemType)
|
||||||
|
}
|
||||||
|
case spec.KindOneOf:
|
||||||
|
for _, v := range tr.Variants {
|
||||||
|
if v == "InputFile" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileCheck returns the HasFile guard line for a file-carrying field.
|
||||||
|
// Both named InputFile and InputFile-or-String oneOf fields are now *InputFile,
|
||||||
|
// so no type assertion is needed in either case.
|
||||||
|
func fileCheck(f spec.Field) string {
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil && p.%s.IsLocalUpload() { return true }\n", f.Name, f.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// multipartFileEntry returns the MultipartFiles append block for a file field.
|
||||||
|
// Both named InputFile and InputFile-or-String oneOf fields are now *InputFile,
|
||||||
|
// so the same code works for both cases.
|
||||||
|
func multipartFileEntry(f spec.Field) string {
|
||||||
|
jsonName := f.JSONName
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"\tif p.%s != nil && p.%s.IsLocalUpload() {\n\t\tname := p.%s.Filename\n\t\tif name == \"\" { name = %q }\n\t\tfiles = append(files, client.MultipartFile{FieldName: %q, Filename: name, Reader: p.%s.Reader})\n\t}\n",
|
||||||
|
f.Name, f.Name, f.Name, jsonName, jsonName, f.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// multipartFieldEntry generates the line that adds f to the multipart map.
|
||||||
|
// Required scalar fields go in unconditionally; optional ones go in only
|
||||||
|
// when non-zero/non-empty.
|
||||||
|
func multipartFieldEntry(f spec.Field) string {
|
||||||
|
switch f.Type.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
switch f.Type.Name {
|
||||||
|
case "int64":
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tout[%q] = strconv.FormatInt(p.%s, 10)\n", f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatInt(*p.%s, 10) }\n", f.Name, f.JSONName, f.Name)
|
||||||
|
case "string":
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tout[%q] = p.%s\n", f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif p.%s != \"\" { out[%q] = p.%s }\n", f.Name, f.JSONName, f.Name)
|
||||||
|
case "bool":
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tout[%q] = strconv.FormatBool(p.%s)\n", f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatBool(*p.%s) }\n", f.Name, f.JSONName, f.Name)
|
||||||
|
case "float64":
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tout[%q] = strconv.FormatFloat(p.%s, 'f', -1, 64)\n", f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil { out[%q] = strconv.FormatFloat(*p.%s, 'f', -1, 64) }\n", f.Name, f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
case spec.KindOneOf:
|
||||||
|
// Integer-or-String → ChatID: use .String() wire form.
|
||||||
|
if matchesVariants(f.Type.Variants, "int64", "string") {
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tout[%q] = p.%s.String()\n", f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif !p.%s.IsZero() { out[%q] = p.%s.String() }\n", f.Name, f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
// InputFile-or-String → *InputFile: non-upload branch sends PathOrID.
|
||||||
|
if matchesVariants(f.Type.Variants, "InputFile", "string") {
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil && !p.%s.IsLocalUpload() && p.%s.PathOrID != \"\" { out[%q] = p.%s.PathOrID }\n",
|
||||||
|
f.Name, f.Name, f.Name, f.JSONName, f.Name)
|
||||||
|
}
|
||||||
|
// Sealed-interface unions — JSON-marshal.
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tif b, _ := json.Marshal(p.%s); len(b) > 0 && string(b) != \"null\" { out[%q] = string(b) }\n", f.Name, f.JSONName)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil { if b, _ := json.Marshal(p.%s); len(b) > 0 && string(b) != \"null\" { out[%q] = string(b) } }\n", f.Name, f.Name, f.JSONName)
|
||||||
|
}
|
||||||
|
// Named or array: fall back to JSON-marshal to JSON string.
|
||||||
|
if f.Required {
|
||||||
|
return fmt.Sprintf("\tif b, _ := json.Marshal(p.%s); len(b) > 0 { out[%q] = string(b) }\n", f.Name, f.JSONName)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\tif p.%s != nil { if b, _ := json.Marshal(p.%s); len(b) > 0 { out[%q] = string(b) } }\n", f.Name, f.Name, f.JSONName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func returnGoType(tr spec.TypeRef) string {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
return tr.Name
|
||||||
|
case spec.KindNamed:
|
||||||
|
// Sealed-interface unions are returned by interface value, not pointer
|
||||||
|
// (you can't take a pointer to an interface in any useful way; the
|
||||||
|
// generated UnmarshalXxx returns the interface directly).
|
||||||
|
if _, ok := knownDiscriminators[tr.Name]; ok {
|
||||||
|
return tr.Name
|
||||||
|
}
|
||||||
|
// MessageOrBool is a hand-coded runtime wrapper — pointer return.
|
||||||
|
return "*" + tr.Name
|
||||||
|
case spec.KindArray:
|
||||||
|
if tr.ElemType == nil {
|
||||||
|
return "[]any"
|
||||||
|
}
|
||||||
|
return "[]" + returnGoElem(*tr.ElemType)
|
||||||
|
case spec.KindOneOf:
|
||||||
|
// Integer-or-String return (rare but possible).
|
||||||
|
if matchesVariants(tr.Variants, "int64", "string") {
|
||||||
|
return "ChatID"
|
||||||
|
}
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
|
||||||
|
func returnGoElem(tr spec.TypeRef) string {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
return tr.Name
|
||||||
|
case spec.KindNamed:
|
||||||
|
return tr.Name
|
||||||
|
case spec.KindArray:
|
||||||
|
if tr.ElemType == nil {
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
return "[]" + returnGoElem(*tr.ElemType)
|
||||||
|
}
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitMethods renders methods.gen.go.
|
||||||
|
func (e *emitter) emitMethods() error {
|
||||||
|
t, err := template.New("methods").Funcs(funcs()).Parse(methodsTmpl)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse methods.tmpl: %w", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if execErr := t.Execute(&buf, e.api); execErr != nil {
|
||||||
|
return fmt.Errorf("execute methods.tmpl: %w", execErr)
|
||||||
|
}
|
||||||
|
src, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gofmt methods.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join(e.outDir, "methods.gen.go"), src, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitEnums renders enums.gen.go.
|
||||||
|
func (e *emitter) emitEnums() error {
|
||||||
|
t, err := template.New("enums").Funcs(funcs()).Parse(enumsTmpl)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse enums.tmpl: %w", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if execErr := t.Execute(&buf, e.api); execErr != nil {
|
||||||
|
return fmt.Errorf("execute enums.tmpl: %w", execErr)
|
||||||
|
}
|
||||||
|
src, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gofmt enums.gen.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join(e.outDir, "enums.gen.go"), src, 0o600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// goType returns the Go type expression for a TypeRef.
|
||||||
|
// Optional fields use pointer types for primitives and named types,
|
||||||
|
// or rely on omitempty for slices and maps. parameter `optional` controls
|
||||||
|
// whether to wrap pointer-style.
|
||||||
|
func goType(tr spec.TypeRef, optional bool) string {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
if optional && (tr.Name == "bool" || tr.Name == "int64" || tr.Name == "float64") {
|
||||||
|
return "*" + tr.Name
|
||||||
|
}
|
||||||
|
return tr.Name
|
||||||
|
case spec.KindNamed:
|
||||||
|
// Named types are always pointer-optional when optional, except:
|
||||||
|
// 1. Union (interface) types — they are naturally nil-able; pointer-to-interface is invalid.
|
||||||
|
// 2. InputFile is always pointer-typed even when required: the
|
||||||
|
// multipart helpers (fileCheck, multipartFileEntry) call
|
||||||
|
// f.IsLocalUpload() and dereference Reader, both of which
|
||||||
|
// expect a pointer receiver.
|
||||||
|
if _, isUnion := knownDiscriminators[tr.Name]; isUnion {
|
||||||
|
// Interface type — never add *.
|
||||||
|
return tr.Name
|
||||||
|
}
|
||||||
|
if optional || tr.Name == "InputFile" {
|
||||||
|
return "*" + tr.Name
|
||||||
|
}
|
||||||
|
return tr.Name
|
||||||
|
case spec.KindArray:
|
||||||
|
if tr.ElemType == nil {
|
||||||
|
return "[]any"
|
||||||
|
}
|
||||||
|
// Inside slices, the element shape is its own thing — never wrap
|
||||||
|
// the element in a pointer just because the field is optional.
|
||||||
|
return "[]" + goType(*tr.ElemType, false)
|
||||||
|
case spec.KindOneOf:
|
||||||
|
// Integer-or-String: typed ChatID wrapper.
|
||||||
|
if matchesVariants(tr.Variants, "int64", "string") {
|
||||||
|
if optional {
|
||||||
|
return "*ChatID"
|
||||||
|
}
|
||||||
|
return "ChatID"
|
||||||
|
}
|
||||||
|
// InputFile-or-String: *InputFile runtime helper handles both.
|
||||||
|
if matchesVariants(tr.Variants, "InputFile", "string") {
|
||||||
|
return "*InputFile"
|
||||||
|
}
|
||||||
|
// All-named variants sealed interface: fall back to interface.
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
|
||||||
|
// unionField pairs a struct field with the name of its union type.
|
||||||
|
type unionField struct {
|
||||||
|
Field spec.Field
|
||||||
|
UnionName string // e.g. "ChatMember"
|
||||||
|
}
|
||||||
|
|
||||||
|
// unionFieldsOf returns the subset of t.Fields whose type is a known
|
||||||
|
// discriminated union (directly or as array element).
|
||||||
|
func unionFieldsOf(t spec.TypeDecl) []unionField {
|
||||||
|
var out []unionField
|
||||||
|
for _, f := range t.Fields {
|
||||||
|
if u, ok := unionTypeFor(f.Type); ok {
|
||||||
|
out = append(out, unionField{Field: f, UnionName: u})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// unionTypeFor inspects a TypeRef and reports whether it (or its array
|
||||||
|
// element) is a known discriminated union. Returns the union name and true.
|
||||||
|
func unionTypeFor(tr spec.TypeRef) (string, bool) {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindNamed:
|
||||||
|
if _, ok := knownDiscriminators[tr.Name]; ok {
|
||||||
|
return tr.Name, true
|
||||||
|
}
|
||||||
|
case spec.KindArray:
|
||||||
|
if tr.ElemType != nil {
|
||||||
|
return unionTypeFor(*tr.ElemType)
|
||||||
|
}
|
||||||
|
case spec.KindOneOf:
|
||||||
|
if u := unionNameByVariants(tr.Variants); u != "" {
|
||||||
|
return u, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// unionNameByVariants finds the parent union whose variant type names exactly
|
||||||
|
// match the given variant set (order-insensitive).
|
||||||
|
func unionNameByVariants(variants []string) string {
|
||||||
|
for parentName, ds := range knownDiscriminators {
|
||||||
|
wanted := make([]string, 0, len(ds.Variants))
|
||||||
|
for _, vt := range ds.Variants {
|
||||||
|
wanted = append(wanted, vt)
|
||||||
|
}
|
||||||
|
if matchesVariants(variants, wanted...) {
|
||||||
|
return parentName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasUnionElem reports whether tr is an array whose element type is a known union.
|
||||||
|
func hasUnionElem(tr spec.TypeRef) bool {
|
||||||
|
if tr.Kind != spec.KindArray || tr.ElemType == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := unionTypeFor(*tr.ElemType)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesVariants reports whether got equals want as a set (order-insensitive).
|
||||||
|
func matchesVariants(got []string, want ...string) bool {
|
||||||
|
if len(got) != len(want) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
seen := make(map[string]int, len(got))
|
||||||
|
for _, g := range got {
|
||||||
|
seen[g]++
|
||||||
|
}
|
||||||
|
for _, w := range want {
|
||||||
|
seen[w]--
|
||||||
|
}
|
||||||
|
for _, v := range seen {
|
||||||
|
if v != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// goField returns the Go struct-field declaration for a Field.
|
||||||
|
func goField(f spec.Field) string {
|
||||||
|
tag := fmt.Sprintf("`json:%q`", f.JSONName+omitempty(f))
|
||||||
|
return fmt.Sprintf("%s %s %s", f.Name, goType(f.Type, !f.Required), tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func omitempty(f spec.Field) string {
|
||||||
|
if f.Required {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return ",omitempty"
|
||||||
|
}
|
||||||
|
|
||||||
|
// docComment converts a doc string into a Go-style block comment with
|
||||||
|
// a leading "// " on each line.
|
||||||
|
func docComment(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for _, line := range splitLines(s) {
|
||||||
|
buf.WriteString("// ")
|
||||||
|
buf.WriteString(line)
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitLines(s string) []string {
|
||||||
|
var out []string
|
||||||
|
start := 0
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == '\n' {
|
||||||
|
out = append(out, s[start:i])
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start < len(s) {
|
||||||
|
out = append(out, s[start:])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasVariants reports whether the variant list contains all of the named strings (order-insensitive).
|
||||||
|
func hasVariants(variants []string, names ...string) bool {
|
||||||
|
return matchesVariants(variants, names...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildUnionTypeSet returns the set of all type names that generate interface types
|
||||||
|
// (i.e., types with one_of). This includes knownDiscriminators and marker-interface
|
||||||
|
// unions not covered by the discriminator map.
|
||||||
|
func buildUnionTypeSet(api *spec.API) map[string]bool {
|
||||||
|
s := make(map[string]bool, len(knownDiscriminators)+16)
|
||||||
|
for name := range knownDiscriminators {
|
||||||
|
s[name] = true
|
||||||
|
}
|
||||||
|
for _, t := range api.Types {
|
||||||
|
if len(t.OneOf) > 0 {
|
||||||
|
s[t.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeSentinelValue returns a sentinelValue func that uses the given union type set.
|
||||||
|
// It returns a minimal valid Go expression for a spec.Field's type,
|
||||||
|
// used in generated test param literals.
|
||||||
|
func makeSentinelValue(unionTypes map[string]bool) func(spec.Field) string {
|
||||||
|
return func(f spec.Field) string {
|
||||||
|
return sentinelForField(f, unionTypes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sentinelForField(f spec.Field, unionTypes map[string]bool) string {
|
||||||
|
tr := f.Type
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
switch tr.Name {
|
||||||
|
case "int64":
|
||||||
|
return "42"
|
||||||
|
case "string":
|
||||||
|
return `"test_value"`
|
||||||
|
case "bool":
|
||||||
|
return "true"
|
||||||
|
case "float64":
|
||||||
|
return "1.0"
|
||||||
|
}
|
||||||
|
case spec.KindNamed:
|
||||||
|
switch tr.Name {
|
||||||
|
case "ChatID":
|
||||||
|
return "ChatIDFromInt(123)"
|
||||||
|
case "InputFile":
|
||||||
|
return `&InputFile{PathOrID: "file_id_test"}`
|
||||||
|
}
|
||||||
|
// Interface (union) types are nil-able.
|
||||||
|
if unionTypes[tr.Name] {
|
||||||
|
return "nil"
|
||||||
|
}
|
||||||
|
// Required named struct types are value types in the generated struct.
|
||||||
|
if f.Required {
|
||||||
|
return tr.Name + "{}"
|
||||||
|
}
|
||||||
|
return "&" + tr.Name + "{}"
|
||||||
|
case spec.KindArray:
|
||||||
|
return "nil"
|
||||||
|
case spec.KindOneOf:
|
||||||
|
if hasVariants(tr.Variants, "int64", "string") {
|
||||||
|
return "ChatIDFromInt(123)"
|
||||||
|
}
|
||||||
|
if hasVariants(tr.Variants, "InputFile", "string") {
|
||||||
|
return `&InputFile{PathOrID: "file_id_test"}`
|
||||||
|
}
|
||||||
|
// Sealed named-union interface: use nil (any).
|
||||||
|
return "nil"
|
||||||
|
}
|
||||||
|
return "nil"
|
||||||
|
}
|
||||||
|
|
||||||
|
// successResp returns a backtick Go string literal containing a minimal
|
||||||
|
// {"ok":true,"result":...} JSON body for the method's return type.
|
||||||
|
func successResp(m spec.MethodDecl) string {
|
||||||
|
body := successBody(m.Returns)
|
||||||
|
return "`{\"ok\":true,\"result\":" + body + "}`"
|
||||||
|
}
|
||||||
|
|
||||||
|
func successBody(tr spec.TypeRef) string {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindPrimitive:
|
||||||
|
switch tr.Name {
|
||||||
|
case "bool":
|
||||||
|
return "true"
|
||||||
|
case "int64", "float64":
|
||||||
|
return "0"
|
||||||
|
case "string":
|
||||||
|
return `""`
|
||||||
|
}
|
||||||
|
case spec.KindNamed:
|
||||||
|
if tr.Name == "MessageOrBool" {
|
||||||
|
return "true"
|
||||||
|
}
|
||||||
|
// Sealed-interface unions need a discriminator field so UnmarshalXxx can dispatch.
|
||||||
|
// Pick the lexicographically first variant value for determinism (map
|
||||||
|
// iteration order in Go is randomized — using `range` directly produces
|
||||||
|
// non-deterministic regen output).
|
||||||
|
if disc, ok := knownDiscriminators[tr.Name]; ok && disc.Field != "" {
|
||||||
|
values := make([]string, 0, len(disc.Variants))
|
||||||
|
for v := range disc.Variants {
|
||||||
|
values = append(values, v)
|
||||||
|
}
|
||||||
|
sort.Strings(values)
|
||||||
|
if len(values) > 0 {
|
||||||
|
return fmt.Sprintf(`{"%s":"%s"}`, disc.Field, values[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// MaybeInaccessibleMessage uses date==0 → InaccessibleMessage variant.
|
||||||
|
if tr.Name == "MaybeInaccessibleMessage" {
|
||||||
|
return `{"date":0,"chat":{"id":1,"type":"private"},"message_id":1}`
|
||||||
|
}
|
||||||
|
return "{}"
|
||||||
|
case spec.KindArray:
|
||||||
|
return "[]"
|
||||||
|
case spec.KindOneOf:
|
||||||
|
return "null"
|
||||||
|
}
|
||||||
|
return "null"
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitTests renders methods_gen_test.go.
|
||||||
|
func (e *emitter) emitTests() error {
|
||||||
|
unionTypes := buildUnionTypeSet(e.api)
|
||||||
|
|
||||||
|
// Add test-specific helpers to the shared func map.
|
||||||
|
fm := funcs()
|
||||||
|
fm["sentinelValue"] = makeSentinelValue(unionTypes)
|
||||||
|
fm["successResp"] = successResp
|
||||||
|
|
||||||
|
t, err := template.New("tests").Funcs(fm).Parse(testsTmpl)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse tests.tmpl: %w", err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if execErr := t.Execute(&buf, e.api); execErr != nil {
|
||||||
|
return fmt.Errorf("execute tests.tmpl: %w", execErr)
|
||||||
|
}
|
||||||
|
src, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gofmt methods_gen_test.go: %w\n--- unformatted ---\n%s", err, buf.String())
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join(e.outDir, "methods_gen_test.go"), src, 0o600)
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var updateGolden = flag.Bool("update", false, "update golden files")
|
||||||
|
|
||||||
|
func TestEmit_Types_FixtureGolden(t *testing.T) {
|
||||||
|
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
e := newEmitter(api, tmp)
|
||||||
|
require.NoError(t, e.emitTypes())
|
||||||
|
|
||||||
|
got, err := os.ReadFile(filepath.Join(tmp, "types.gen.go"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
goldenPath := "../../testdata/golden/types.gen.go"
|
||||||
|
if *updateGolden {
|
||||||
|
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected, err := os.ReadFile(goldenPath)
|
||||||
|
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
|
||||||
|
require.Equal(t, string(expected), string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmit_Enums_FixtureGolden(t *testing.T) {
|
||||||
|
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
e := newEmitter(api, tmp)
|
||||||
|
require.NoError(t, e.emitEnums())
|
||||||
|
|
||||||
|
got, err := os.ReadFile(filepath.Join(tmp, "enums.gen.go"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
goldenPath := "../../testdata/golden/enums.gen.go"
|
||||||
|
if *updateGolden {
|
||||||
|
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected, err := os.ReadFile(goldenPath)
|
||||||
|
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
|
||||||
|
require.Equal(t, string(expected), string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmit_Methods_FixtureGolden(t *testing.T) {
|
||||||
|
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
e := newEmitter(api, tmp)
|
||||||
|
require.NoError(t, e.emitTypes()) // some methods reference types
|
||||||
|
require.NoError(t, e.emitMethods())
|
||||||
|
|
||||||
|
got, err := os.ReadFile(filepath.Join(tmp, "methods.gen.go"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
goldenPath := "../../testdata/golden/methods.gen.go"
|
||||||
|
if *updateGolden {
|
||||||
|
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected, err := os.ReadFile(goldenPath)
|
||||||
|
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
|
||||||
|
require.Equal(t, string(expected), string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmit_Tests_FixtureGolden(t *testing.T) {
|
||||||
|
api, err := loadAPI("../../testdata/golden/api_small_fixture.json")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
e := newEmitter(api, tmp)
|
||||||
|
require.NoError(t, e.emitTests())
|
||||||
|
|
||||||
|
got, err := os.ReadFile(filepath.Join(tmp, "methods_gen_test.go"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
goldenPath := "../../testdata/golden/methods_gen_test.go"
|
||||||
|
if *updateGolden {
|
||||||
|
require.NoError(t, os.WriteFile(goldenPath, got, 0o600))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected, err := os.ReadFile(goldenPath)
|
||||||
|
require.NoError(t, err, "missing golden; run `go test -update ./cmd/genapi/...`")
|
||||||
|
require.Equal(t, string(expected), string(got))
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
// Code generated by cmd/genapi. DO NOT EDIT.
|
||||||
|
|
||||||
|
//go:build !ignore_autogenerated
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
// ParseMode controls how Telegram interprets formatting in message text.
|
||||||
|
type ParseMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ParseModeMarkdown ParseMode = "Markdown" // legacy
|
||||||
|
ParseModeMarkdownV2 ParseMode = "MarkdownV2"
|
||||||
|
ParseModeHTML ParseMode = "HTML"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatType is the type of a Telegram chat.
|
||||||
|
type ChatType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatTypePrivate ChatType = "private"
|
||||||
|
ChatTypeGroup ChatType = "group"
|
||||||
|
ChatTypeSupergroup ChatType = "supergroup"
|
||||||
|
ChatTypeChannel ChatType = "channel"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateType identifies an Update payload variant. Used by allowed_updates
|
||||||
|
// in getUpdates / setWebhook.
|
||||||
|
type UpdateType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UpdateMessage UpdateType = "message"
|
||||||
|
UpdateEditedMessage UpdateType = "edited_message"
|
||||||
|
UpdateChannelPost UpdateType = "channel_post"
|
||||||
|
UpdateEditedChannelPost UpdateType = "edited_channel_post"
|
||||||
|
UpdateCallbackQuery UpdateType = "callback_query"
|
||||||
|
UpdateInlineQuery UpdateType = "inline_query"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageEntityType is the kind of an entity (mention, hashtag, command, ...).
|
||||||
|
type MessageEntityType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EntityMention MessageEntityType = "mention"
|
||||||
|
EntityHashtag MessageEntityType = "hashtag"
|
||||||
|
EntityCashtag MessageEntityType = "cashtag"
|
||||||
|
EntityBotCommand MessageEntityType = "bot_command"
|
||||||
|
EntityURL MessageEntityType = "url"
|
||||||
|
EntityEmail MessageEntityType = "email"
|
||||||
|
EntityPhoneNumber MessageEntityType = "phone_number"
|
||||||
|
EntityBold MessageEntityType = "bold"
|
||||||
|
EntityItalic MessageEntityType = "italic"
|
||||||
|
EntityUnderline MessageEntityType = "underline"
|
||||||
|
EntityStrike MessageEntityType = "strikethrough"
|
||||||
|
EntitySpoiler MessageEntityType = "spoiler"
|
||||||
|
EntityCode MessageEntityType = "code"
|
||||||
|
EntityPre MessageEntityType = "pre"
|
||||||
|
EntityTextLink MessageEntityType = "text_link"
|
||||||
|
EntityTextMention MessageEntityType = "text_mention"
|
||||||
|
EntityCustomEmoji MessageEntityType = "custom_emoji"
|
||||||
|
)
|
||||||
@@ -0,0 +1,645 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// goType — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGoType_Primitive(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
optional bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"bool", false, "bool"},
|
||||||
|
{"bool", true, "*bool"},
|
||||||
|
{"int64", false, "int64"},
|
||||||
|
{"int64", true, "*int64"},
|
||||||
|
{"float64", false, "float64"},
|
||||||
|
{"float64", true, "*float64"},
|
||||||
|
{"string", false, "string"},
|
||||||
|
{"string", true, "string"}, // string is not pointer-wrapped
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindPrimitive, Name: c.name}
|
||||||
|
got := goType(tr, c.optional)
|
||||||
|
require.Equal(t, c.want, got, "goType(%q, optional=%v)", c.name, c.optional)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_Named_Required(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
|
||||||
|
require.Equal(t, "Message", goType(tr, false))
|
||||||
|
require.Equal(t, "*Message", goType(tr, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_Named_InputFile(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}
|
||||||
|
// InputFile is always pointer even when required.
|
||||||
|
require.Equal(t, "*InputFile", goType(tr, false))
|
||||||
|
require.Equal(t, "*InputFile", goType(tr, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_Named_UnionInterface(t *testing.T) {
|
||||||
|
// ChatMember is a known discriminated union — no * even when optional.
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
|
||||||
|
require.Equal(t, "ChatMember", goType(tr, false))
|
||||||
|
require.Equal(t, "ChatMember", goType(tr, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_Array_NilElem(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
require.Equal(t, "[]any", goType(tr, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_Array_WithElem(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
require.Equal(t, "[]Update", goType(tr, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_OneOf_ChatID(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}
|
||||||
|
require.Equal(t, "ChatID", goType(tr, false))
|
||||||
|
require.Equal(t, "*ChatID", goType(tr, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_OneOf_InputFile(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}
|
||||||
|
require.Equal(t, "*InputFile", goType(tr, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_OneOf_SealedInterface(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}
|
||||||
|
require.Equal(t, "any", goType(tr, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoType_Unknown(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.Kind(99)}
|
||||||
|
require.Equal(t, "any", goType(tr, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// returnGoType — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestReturnGoType_Primitive(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
|
||||||
|
require.Equal(t, "bool", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoType_Named(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
|
||||||
|
require.Equal(t, "*Message", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoType_Array_NilElem(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
require.Equal(t, "[]any", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoType_Array_WithElem(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
require.Equal(t, "[]Update", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoType_OneOf_ChatID(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}
|
||||||
|
require.Equal(t, "ChatID", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoType_OneOf_Other(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}
|
||||||
|
require.Equal(t, "any", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoType_Unknown(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.Kind(99)}
|
||||||
|
require.Equal(t, "any", returnGoType(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// returnGoElem — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestReturnGoElem_Primitive(t *testing.T) {
|
||||||
|
require.Equal(t, "int64", returnGoElem(spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoElem_Named(t *testing.T) {
|
||||||
|
require.Equal(t, "Message", returnGoElem(spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoElem_Array_NilElem(t *testing.T) {
|
||||||
|
require.Equal(t, "any", returnGoElem(spec.TypeRef{Kind: spec.KindArray}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoElem_Array_WithElem(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "PhotoSize"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
require.Equal(t, "[]PhotoSize", returnGoElem(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReturnGoElem_Unknown(t *testing.T) {
|
||||||
|
require.Equal(t, "any", returnGoElem(spec.TypeRef{Kind: spec.Kind(99)}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// multipartFieldEntry — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func makeField(name, jname, typName string, kind spec.Kind, required bool) spec.Field {
|
||||||
|
return spec.Field{
|
||||||
|
Name: name,
|
||||||
|
JSONName: jname,
|
||||||
|
Type: spec.TypeRef{Kind: kind, Name: typName},
|
||||||
|
Required: required,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeFieldVariants(name, jname string, kind spec.Kind, variants []string, required bool) spec.Field {
|
||||||
|
return spec.Field{
|
||||||
|
Name: name,
|
||||||
|
JSONName: jname,
|
||||||
|
Type: spec.TypeRef{Kind: kind, Variants: variants},
|
||||||
|
Required: required,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_Int64Required(t *testing.T) {
|
||||||
|
f := makeField("ChatID", "chat_id", "int64", spec.KindPrimitive, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `FormatInt`)
|
||||||
|
require.NotContains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_Int64Optional(t *testing.T) {
|
||||||
|
f := makeField("MessageThreadID", "message_thread_id", "int64", spec.KindPrimitive, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `FormatInt`)
|
||||||
|
require.Contains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_StringRequired(t *testing.T) {
|
||||||
|
f := makeField("Text", "text", "string", spec.KindPrimitive, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `out["text"]`)
|
||||||
|
require.NotContains(t, got, "if p.Text")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_StringOptional(t *testing.T) {
|
||||||
|
f := makeField("ParseMode", "parse_mode", "string", spec.KindPrimitive, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `if p.ParseMode`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_BoolRequired(t *testing.T) {
|
||||||
|
f := makeField("DisableNotification", "disable_notification", "bool", spec.KindPrimitive, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `FormatBool`)
|
||||||
|
require.NotContains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_BoolOptional(t *testing.T) {
|
||||||
|
f := makeField("Protected", "protect_content", "bool", spec.KindPrimitive, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `FormatBool`)
|
||||||
|
require.Contains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_Float64Required(t *testing.T) {
|
||||||
|
f := makeField("Latitude", "latitude", "float64", spec.KindPrimitive, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `FormatFloat`)
|
||||||
|
require.NotContains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_Float64Optional(t *testing.T) {
|
||||||
|
f := makeField("Longitude", "longitude", "float64", spec.KindPrimitive, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `FormatFloat`)
|
||||||
|
require.Contains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_OneOf_ChatIDRequired(t *testing.T) {
|
||||||
|
f := makeFieldVariants("ChatID", "chat_id", spec.KindOneOf, []string{"int64", "string"}, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `.String()`)
|
||||||
|
require.NotContains(t, got, "IsZero")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_OneOf_ChatIDOptional(t *testing.T) {
|
||||||
|
f := makeFieldVariants("ChatID", "chat_id", spec.KindOneOf, []string{"int64", "string"}, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `IsZero`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_OneOf_InputFileOrString(t *testing.T) {
|
||||||
|
f := makeFieldVariants("Photo", "photo", spec.KindOneOf, []string{"InputFile", "string"}, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `PathOrID`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_OneOf_SealedRequired(t *testing.T) {
|
||||||
|
f := makeFieldVariants("Markup", "reply_markup", spec.KindOneOf, []string{"A", "B"}, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `json.Marshal`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_OneOf_SealedOptional(t *testing.T) {
|
||||||
|
f := makeFieldVariants("Markup", "reply_markup", spec.KindOneOf, []string{"A", "B"}, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `json.Marshal`)
|
||||||
|
require.Contains(t, got, "if p.Markup")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_Named_Required(t *testing.T) {
|
||||||
|
f := makeField("Entities", "entities", "MessageEntity", spec.KindNamed, true)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `json.Marshal`)
|
||||||
|
require.NotContains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipartFieldEntry_Named_Optional(t *testing.T) {
|
||||||
|
f := makeField("Entities", "entities", "MessageEntity", spec.KindNamed, false)
|
||||||
|
got := multipartFieldEntry(f)
|
||||||
|
require.Contains(t, got, `json.Marshal`)
|
||||||
|
require.Contains(t, got, "if p.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// unionTypeFor — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestUnionTypeFor_DirectNamed(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
|
||||||
|
name, ok := unionTypeFor(tr)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ChatMember", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnionTypeFor_Array(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
name, ok := unionTypeFor(tr)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ChatMember", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnionTypeFor_ArrayNilElem(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
_, ok := unionTypeFor(tr)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnionTypeFor_NotUnion(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
|
||||||
|
_, ok := unionTypeFor(tr)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnionTypeFor_Unknown(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.Kind(99)}
|
||||||
|
_, ok := unionTypeFor(tr)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// unionNameByVariants
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestUnionNameByVariants_ChatMember(t *testing.T) {
|
||||||
|
// Use the actual variants from knownDiscriminators["ChatMember"].
|
||||||
|
variants := []string{
|
||||||
|
"ChatMemberOwner", "ChatMemberAdministrator", "ChatMemberMember",
|
||||||
|
"ChatMemberRestricted", "ChatMemberLeft", "ChatMemberBanned",
|
||||||
|
}
|
||||||
|
name := unionNameByVariants(variants)
|
||||||
|
require.Equal(t, "ChatMember", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnionNameByVariants_Unknown(t *testing.T) {
|
||||||
|
name := unionNameByVariants([]string{"X", "Y", "Z"})
|
||||||
|
require.Equal(t, "", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// hasUnionElem
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHasUnionElem_NonArray(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
|
||||||
|
require.False(t, hasUnionElem(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasUnionElem_ArrayNilElem(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
require.False(t, hasUnionElem(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasUnionElem_ArrayUnionElem(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
require.True(t, hasUnionElem(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasUnionElem_ArrayNonUnionElem(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
require.False(t, hasUnionElem(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// unionFieldsOf
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestUnionFieldsOf_WithUnionField(t *testing.T) {
|
||||||
|
td := spec.TypeDecl{
|
||||||
|
Name: "ChatMemberUpdated",
|
||||||
|
Fields: []spec.Field{
|
||||||
|
{Name: "NewChatMember", JSONName: "new_chat_member", Type: spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}},
|
||||||
|
{Name: "OldChatMember", JSONName: "old_chat_member", Type: spec.TypeRef{Kind: spec.KindNamed, Name: "ChatMember"}},
|
||||||
|
{Name: "Date", JSONName: "date", Type: spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
uf := unionFieldsOf(td)
|
||||||
|
require.Len(t, uf, 2)
|
||||||
|
require.Equal(t, "ChatMember", uf[0].UnionName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// splitLines — edge cases
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSplitLines_Empty(t *testing.T) {
|
||||||
|
require.Empty(t, splitLines(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitLines_NoNewline(t *testing.T) {
|
||||||
|
got := splitLines("hello world")
|
||||||
|
require.Equal(t, []string{"hello world"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitLines_TrailingNewline(t *testing.T) {
|
||||||
|
got := splitLines("line1\nline2\n")
|
||||||
|
require.Equal(t, []string{"line1", "line2"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitLines_MultiLine(t *testing.T) {
|
||||||
|
got := splitLines("a\nb\nc")
|
||||||
|
require.Equal(t, []string{"a", "b", "c"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// docComment
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestDocComment_Empty(t *testing.T) {
|
||||||
|
require.Equal(t, "", docComment(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocComment_SingleLine(t *testing.T) {
|
||||||
|
got := docComment("Hello world.")
|
||||||
|
require.Equal(t, "// Hello world.\n", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDocComment_MultiLine(t *testing.T) {
|
||||||
|
got := docComment("Line 1\nLine 2")
|
||||||
|
require.Contains(t, got, "// Line 1\n")
|
||||||
|
require.Contains(t, got, "// Line 2\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// title
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestTitle_Empty(t *testing.T) {
|
||||||
|
require.Equal(t, "", title(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTitle_Lowercase(t *testing.T) {
|
||||||
|
require.Equal(t, "SendMessage", title("sendMessage"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTitle_AlreadyUpper(t *testing.T) {
|
||||||
|
require.Equal(t, "GetMe", title("GetMe"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// mentionsInputFileTr
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMentionsInputFileTr_Named(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}
|
||||||
|
require.True(t, mentionsInputFileTr(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMentionsInputFileTr_NotInputFile(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}
|
||||||
|
require.False(t, mentionsInputFileTr(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMentionsInputFileTr_Array(t *testing.T) {
|
||||||
|
elem := spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
require.True(t, mentionsInputFileTr(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMentionsInputFileTr_ArrayNilElem(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindArray}
|
||||||
|
require.False(t, mentionsInputFileTr(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMentionsInputFileTr_OneOf_WithInputFile(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}
|
||||||
|
require.True(t, mentionsInputFileTr(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMentionsInputFileTr_OneOf_WithoutInputFile(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}
|
||||||
|
require.False(t, mentionsInputFileTr(tr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// loadAPI — error paths
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestLoadAPI_MissingFile(t *testing.T) {
|
||||||
|
_, err := loadAPI("/nonexistent/path/api.json")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// runtimeTypes filter in emitTypes
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRuntimeTypes_NeverEmitted(t *testing.T) {
|
||||||
|
for name := range runtimeTypes {
|
||||||
|
require.True(t, runtimeTypes[name], "runtimeType %q should be true", name)
|
||||||
|
}
|
||||||
|
require.True(t, runtimeTypes["InputFile"])
|
||||||
|
require.True(t, runtimeTypes["ChatID"])
|
||||||
|
require.True(t, runtimeTypes["MessageOrBool"])
|
||||||
|
require.True(t, runtimeTypes["ResponseParameters"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// sentinelForField — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSentinelForField(t *testing.T) {
|
||||||
|
unionTypes := map[string]bool{"ChatMember": true}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
field spec.Field
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "int64 primitive",
|
||||||
|
field: makeField("Count", "count", "int64", spec.KindPrimitive, true),
|
||||||
|
contains: "42",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string primitive",
|
||||||
|
field: makeField("Text", "text", "string", spec.KindPrimitive, true),
|
||||||
|
contains: "test_value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bool primitive",
|
||||||
|
field: makeField("Flag", "flag", "bool", spec.KindPrimitive, true),
|
||||||
|
contains: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "float64 primitive",
|
||||||
|
field: makeField("Lat", "lat", "float64", spec.KindPrimitive, true),
|
||||||
|
contains: "1.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named ChatID",
|
||||||
|
field: makeField("ChatID", "chat_id", "ChatID", spec.KindNamed, true),
|
||||||
|
contains: "ChatIDFromInt",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named InputFile",
|
||||||
|
field: makeField("Photo", "photo", "InputFile", spec.KindNamed, true),
|
||||||
|
contains: "InputFile",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named union (nil-able)",
|
||||||
|
field: makeField("Member", "member", "ChatMember", spec.KindNamed, true),
|
||||||
|
contains: "nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named required struct",
|
||||||
|
field: makeField("Chat", "chat", "Chat", spec.KindNamed, true),
|
||||||
|
contains: "Chat{}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named optional struct",
|
||||||
|
field: makeField("Chat", "chat", "Chat", spec.KindNamed, false),
|
||||||
|
contains: "&Chat{}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array",
|
||||||
|
field: spec.Field{Name: "Items", JSONName: "items", Type: spec.TypeRef{Kind: spec.KindArray}},
|
||||||
|
contains: "nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oneOf ChatID variants",
|
||||||
|
field: makeFieldVariants("ChatID", "chat_id", spec.KindOneOf, []string{"int64", "string"}, true),
|
||||||
|
contains: "ChatIDFromInt",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oneOf InputFile variants",
|
||||||
|
field: makeFieldVariants("Photo", "photo", spec.KindOneOf, []string{"InputFile", "string"}, true),
|
||||||
|
contains: "InputFile",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oneOf sealed",
|
||||||
|
field: makeFieldVariants("Markup", "markup", spec.KindOneOf, []string{"A", "B"}, true),
|
||||||
|
contains: "nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown kind",
|
||||||
|
field: spec.Field{Name: "X", JSONName: "x", Type: spec.TypeRef{Kind: spec.Kind(99)}},
|
||||||
|
contains: "nil",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
got := sentinelForField(c.field, unionTypes)
|
||||||
|
require.Contains(t, got, c.contains, "sentinelForField for %q", c.name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// successBody — all branches
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSuccessBody(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
tr spec.TypeRef
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"bool", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}, "true"},
|
||||||
|
{"int64", spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}, "0"},
|
||||||
|
{"float64", spec.TypeRef{Kind: spec.KindPrimitive, Name: "float64"}, "0"},
|
||||||
|
{"string", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}, `""`},
|
||||||
|
{"MessageOrBool", spec.TypeRef{Kind: spec.KindNamed, Name: "MessageOrBool"}, "true"},
|
||||||
|
{"named", spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}, "{}"},
|
||||||
|
{"array", spec.TypeRef{Kind: spec.KindArray}, "[]"},
|
||||||
|
{"oneOf", spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"A", "B"}}, "null"},
|
||||||
|
{"unknown", spec.TypeRef{Kind: spec.Kind(99)}, "null"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
got := successBody(c.tr)
|
||||||
|
require.Equal(t, c.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// unionTypeFor — KindOneOf branch (variant-set match)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestUnionTypeFor_OneOfVariants(t *testing.T) {
|
||||||
|
// These variant *type names* match the ChatMember discriminator.
|
||||||
|
tr := spec.TypeRef{
|
||||||
|
Kind: spec.KindOneOf,
|
||||||
|
Variants: []string{
|
||||||
|
"ChatMemberOwner", "ChatMemberAdministrator", "ChatMemberMember",
|
||||||
|
"ChatMemberRestricted", "ChatMemberLeft", "ChatMemberBanned",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
name, ok := unionTypeFor(tr)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "ChatMember", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnionTypeFor_OneOfNoMatch(t *testing.T) {
|
||||||
|
tr := spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"Foo", "Bar"}}
|
||||||
|
_, ok := unionTypeFor(tr)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// funcs() returns a non-nil FuncMap with expected keys
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestFuncs_HasExpectedKeys(t *testing.T) {
|
||||||
|
fm := funcs()
|
||||||
|
require.NotNil(t, fm)
|
||||||
|
for _, key := range []string{"goType", "docComment", "returnGoType", "unionFields"} {
|
||||||
|
require.NotNil(t, fm[key], "funcs() missing key %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
// Command genapi reads internal/spec/api.json and emits api/*.gen.go.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// genapi -input <file> (default: internal/spec/api.json)
|
||||||
|
// genapi -outdir <dir> (default: api)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
input := flag.String("input", "internal/spec/api.json", "IR JSON path")
|
||||||
|
outdir := flag.String("outdir", "api", "output directory")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := run(*input, *outdir); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "genapi:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// run is filled in by P2.T8/T9/T10.
|
||||||
|
func run(input, outdir string) error {
|
||||||
|
api, err := loadAPI(input)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load api: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(outdir, 0o750); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
e := newEmitter(api, outdir)
|
||||||
|
if err := e.emitTypes(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := e.emitMethods(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := e.emitEnums(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := e.emitTests(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Code generated by cmd/genapi. DO NOT EDIT.
|
||||||
|
|
||||||
|
//go:build !ignore_autogenerated
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = strconv.Itoa // keep import for multipart helpers
|
||||||
|
var _ = json.Marshal // keep import for complex multipart fields
|
||||||
|
|
||||||
|
{{range .Methods}}
|
||||||
|
// {{title .Name}}Params is the parameter set for {{title .Name}}.
|
||||||
|
//
|
||||||
|
{{docComment .Doc -}}
|
||||||
|
type {{title .Name}}Params struct {
|
||||||
|
{{range .Params}}{{docComment .Doc}} {{goField .}}
|
||||||
|
{{end}}}
|
||||||
|
{{if .HasFiles}}
|
||||||
|
// HasFile reports whether a multipart upload is required.
|
||||||
|
func (p *{{title .Name}}Params) HasFile() bool {
|
||||||
|
{{range .Params}}{{if isFileField .}}{{fileCheck .}}{{end}}{{end}} return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultipartFields returns the non-file fields used in the multipart body.
|
||||||
|
func (p *{{title .Name}}Params) MultipartFields() map[string]string {
|
||||||
|
out := map[string]string{}
|
||||||
|
{{range .Params}}{{if not (isFileField .)}}{{multipartFieldEntry .}}{{end}}{{end}} return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultipartFiles returns the file parts.
|
||||||
|
func (p *{{title .Name}}Params) MultipartFiles() []client.MultipartFile {
|
||||||
|
var files []client.MultipartFile
|
||||||
|
{{range .Params}}{{if isFileField .}}{{multipartFileEntry .}}{{end}}{{end}} return files
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
// {{title .Name}} calls the {{.Name}} Telegram Bot API method.
|
||||||
|
//
|
||||||
|
{{docComment .Doc -}}
|
||||||
|
func {{title .Name}}(ctx context.Context, b *client.Bot, p *{{title .Name}}Params) ({{returnGoType .Returns}}, error) {
|
||||||
|
{{if isSealedUnionReturn .Returns -}}
|
||||||
|
raw, err := client.CallRaw[*{{title .Name}}Params](ctx, b, "{{.Name}}", p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return Unmarshal{{.Returns.Name}}(raw)
|
||||||
|
{{else -}}
|
||||||
|
return client.Call[*{{title .Name}}Params, {{returnGoType .Returns}}](ctx, b, "{{.Name}}", p)
|
||||||
|
{{end -}}
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
@@ -0,0 +1,220 @@
|
|||||||
|
// Code generated by cmd/genapi. DO NOT EDIT.
|
||||||
|
|
||||||
|
//go:build !ignore_autogenerated
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// genTestMockDoer is a testify-mock HTTPDoer used by generated tests only.
|
||||||
|
type genTestMockDoer struct{ mock.Mock }
|
||||||
|
|
||||||
|
func (m *genTestMockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||||
|
args := m.Called(r)
|
||||||
|
if v := args.Get(0); v != nil {
|
||||||
|
return v.(*http.Response), args.Error(1)
|
||||||
|
}
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genTestResp(status int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{{range .Methods}}{{$m := .}}{{$mName := title .Name}}{{$mWire := .Name}}
|
||||||
|
func Test_{{$mName}}_Success(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
|
||||||
|
return strings.HasSuffix(r.URL.Path, "/{{$mWire}}")
|
||||||
|
})).Return(genTestResp(200, {{successResp $m}}), nil)
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_{{$mName}}_APIError(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
genTestResp(200, `{"ok":false,"error_code":429,"description":"Too Many Requests","parameters":{"retry_after":1}}`), nil)
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.Error(t, err)
|
||||||
|
var ae *client.APIError
|
||||||
|
require.ErrorAs(t, err, &ae)
|
||||||
|
require.Equal(t, 429, ae.Code)
|
||||||
|
require.True(t, ae.IsRetryable())
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_{{$mName}}_NetworkError(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, errors.New("dial tcp: timeout"))
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.Error(t, err)
|
||||||
|
var ne *client.NetworkError
|
||||||
|
require.ErrorAs(t, err, &ne)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_{{$mName}}_ParseError(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(genTestResp(200, `not json`), nil)
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.Error(t, err)
|
||||||
|
var pe *client.ParseError
|
||||||
|
require.ErrorAs(t, err, &pe)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_{{$mName}}_ContextCanceled(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(nil, context.Canceled).Maybe()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(ctx, bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(ctx, bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test_{{$mName}}_MissingRequiredFields exercises Telegram's server-side
|
||||||
|
// validation: when a required field is omitted, Telegram returns 400 with
|
||||||
|
// a description like "Bad Request: <field> is empty". The library must
|
||||||
|
// surface this as *APIError with the ErrBadRequest sentinel.
|
||||||
|
func Test_{{$mName}}_MissingRequiredFields(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
genTestResp(200, `{"ok":false,"error_code":400,"description":"Bad Request: chat_id is empty"}`), nil)
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
// Send a Params with all required fields zeroed — simulates a caller
|
||||||
|
// that forgot to populate them. The bot library marshals as-is and
|
||||||
|
// surfaces Telegram's 400 reply.
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
require.Error(t, err)
|
||||||
|
var ae *client.APIError
|
||||||
|
require.ErrorAs(t, err, &ae)
|
||||||
|
require.Equal(t, 400, ae.Code)
|
||||||
|
require.True(t, errors.Is(err, client.ErrBadRequest))
|
||||||
|
require.False(t, ae.IsRetryable())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test_{{$mName}}_Forbidden exercises the 403 path (bot blocked by user,
|
||||||
|
// removed from chat, etc.). The library must surface the ErrForbidden
|
||||||
|
// sentinel.
|
||||||
|
func Test_{{$mName}}_Forbidden(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
genTestResp(200, `{"ok":false,"error_code":403,"description":"Forbidden: bot was blocked by the user"}`), nil)
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.Error(t, err)
|
||||||
|
var ae *client.APIError
|
||||||
|
require.ErrorAs(t, err, &ae)
|
||||||
|
require.Equal(t, 403, ae.Code)
|
||||||
|
require.True(t, errors.Is(err, client.ErrForbidden))
|
||||||
|
require.False(t, ae.IsRetryable())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test_{{$mName}}_ServerError exercises the 5xx path. The library must
|
||||||
|
// classify these as retryable so RetryDoer / user retry logic kicks in.
|
||||||
|
func Test_{{$mName}}_ServerError(t *testing.T) {
|
||||||
|
m := &genTestMockDoer{}
|
||||||
|
m.On("Do", mock.Anything).Return(
|
||||||
|
genTestResp(200, `{"ok":false,"error_code":500,"description":"Internal server error"}`), nil)
|
||||||
|
|
||||||
|
bot := client.New("test:token", client.WithHTTPClient(m))
|
||||||
|
{{- if .Params}}
|
||||||
|
params := &{{$mName}}Params{
|
||||||
|
{{- range .Params}}{{if .Required}}
|
||||||
|
{{.Name}}: {{sentinelValue .}},{{end}}
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, params)
|
||||||
|
{{- else}}
|
||||||
|
_, err := {{$mName}}(context.Background(), bot, &{{$mName}}Params{})
|
||||||
|
{{- end}}
|
||||||
|
require.Error(t, err)
|
||||||
|
var ae *client.APIError
|
||||||
|
require.ErrorAs(t, err, &ae)
|
||||||
|
require.Equal(t, 500, ae.Code)
|
||||||
|
require.True(t, ae.IsRetryable(), "5xx must be retryable")
|
||||||
|
}
|
||||||
|
|
||||||
|
{{end}}
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
// Code generated by cmd/genapi. DO NOT EDIT.
|
||||||
|
|
||||||
|
//go:build !ignore_autogenerated
|
||||||
|
|
||||||
|
// Package api contains the Telegram Bot API object types and method
|
||||||
|
// wrappers, generated from the live documentation by cmd/genapi.
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = io.Discard // keep import even if no fields use io
|
||||||
|
var _ = json.Marshal // keep import for UnmarshalXxx helpers
|
||||||
|
var _ = fmt.Errorf // keep import for UnmarshalXxx helpers
|
||||||
|
|
||||||
|
{{range .Types}}
|
||||||
|
{{- $td := . -}}
|
||||||
|
{{if .OneOf}}
|
||||||
|
// {{.Name}} is a union type. The following concrete variants implement
|
||||||
|
// it:
|
||||||
|
{{range .OneOf}}// - {{.}}
|
||||||
|
{{end}}//
|
||||||
|
{{docComment .Doc -}}
|
||||||
|
type {{.Name}} interface{ is{{.Name}}() }
|
||||||
|
|
||||||
|
{{range .OneOf}}
|
||||||
|
// is{{$td.Name}} is the marker method that makes {{.}} implement {{$td.Name}}.
|
||||||
|
func (*{{.}}) is{{$td.Name}}() {}
|
||||||
|
{{end}}
|
||||||
|
{{if hasDiscriminator .Name}}
|
||||||
|
// Unmarshal{{.Name}} decodes a {{.Name}} from JSON by inspecting the
|
||||||
|
// "{{discriminatorField .Name}}" field and dispatching to the correct concrete type.
|
||||||
|
func Unmarshal{{.Name}}(data []byte) ({{.Name}}, error) {
|
||||||
|
var probe struct {
|
||||||
|
V string `json:"{{discriminatorField .Name}}"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &probe); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var v {{.Name}}
|
||||||
|
switch probe.V {
|
||||||
|
{{range $val, $typ := discriminatorMap .Name}} case {{printf "%q" $val}}:
|
||||||
|
v = &{{$typ}}{}
|
||||||
|
{{end}} default:
|
||||||
|
return nil, fmt.Errorf("{{.Name}}: unknown {{discriminatorField .Name}} %q", probe.V)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, v); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
{{if isMaybeInaccessibleMessage .Name}}
|
||||||
|
// UnmarshalMaybeInaccessibleMessage decodes a JSON object into the correct
|
||||||
|
// MaybeInaccessibleMessage variant. Telegram uses the date field as a
|
||||||
|
// discriminator: date == 0 indicates InaccessibleMessage; any other value
|
||||||
|
// indicates a real Message.
|
||||||
|
func UnmarshalMaybeInaccessibleMessage(data []byte) (MaybeInaccessibleMessage, error) {
|
||||||
|
var probe struct {
|
||||||
|
Date int64 `json:"date"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &probe); err != nil {
|
||||||
|
return nil, fmt.Errorf("MaybeInaccessibleMessage: %w", err)
|
||||||
|
}
|
||||||
|
if probe.Date == 0 {
|
||||||
|
v := &InaccessibleMessage{}
|
||||||
|
if err := json.Unmarshal(data, v); err != nil {
|
||||||
|
return nil, fmt.Errorf("InaccessibleMessage: %w", err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
v := &Message{}
|
||||||
|
if err := json.Unmarshal(data, v); err != nil {
|
||||||
|
return nil, fmt.Errorf("Message: %w", err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
{{else}}
|
||||||
|
{{docComment .Doc -}}
|
||||||
|
type {{.Name}} struct {
|
||||||
|
{{range .Fields}}{{docComment .Doc}}{{goField .}}
|
||||||
|
{{end}}}
|
||||||
|
{{$unionFields := unionFields .}}{{if $unionFields}}
|
||||||
|
// UnmarshalJSON decodes {{.Name}} by dispatching union-typed fields
|
||||||
|
// ({{range $i, $u := $unionFields}}{{if $i}}, {{end}}{{$u.Field.Name}}{{end}}) through their concrete UnmarshalXxx helpers.
|
||||||
|
func (m *{{.Name}}) UnmarshalJSON(data []byte) error {
|
||||||
|
type Alias {{.Name}}
|
||||||
|
aux := &struct {
|
||||||
|
{{range $unionFields}}{{.Field.Name}} json.RawMessage `json:"{{.Field.JSONName}},omitempty"`
|
||||||
|
{{end}}*Alias
|
||||||
|
}{Alias: (*Alias)(m)}
|
||||||
|
if err := json.Unmarshal(data, aux); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
{{range $unionFields}}{{$f := .Field}}{{$u := .UnionName}}
|
||||||
|
if len(aux.{{$f.Name}}) > 0 && string(aux.{{$f.Name}}) != "null" {
|
||||||
|
{{if isArrayUnion $f.Type}}var raws []json.RawMessage
|
||||||
|
if err := json.Unmarshal(aux.{{$f.Name}}, &raws); err != nil {
|
||||||
|
return fmt.Errorf("decoding {{$f.JSONName}}: %w", err)
|
||||||
|
}
|
||||||
|
decoded := make([]{{$u}}, 0, len(raws))
|
||||||
|
for i, r := range raws {
|
||||||
|
v, err := Unmarshal{{$u}}(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decoding {{$f.JSONName}}[%d]: %w", i, err)
|
||||||
|
}
|
||||||
|
decoded = append(decoded, v)
|
||||||
|
}
|
||||||
|
m.{{$f.Name}} = decoded
|
||||||
|
{{else}}v, err := Unmarshal{{$u}}(aux.{{$f.Name}})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("decoding {{$f.JSONName}}: %w", err)
|
||||||
|
}
|
||||||
|
m.{{$f.Name}} = v
|
||||||
|
{{end}}
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// parseTypeRef — edge cases
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestParseTypeRef_Empty(t *testing.T) {
|
||||||
|
// Empty string → named with empty name (fallback).
|
||||||
|
got := parseTypeRef("")
|
||||||
|
require.Equal(t, spec.KindNamed, got.Kind)
|
||||||
|
require.Equal(t, "", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_Whitespace(t *testing.T) {
|
||||||
|
got := parseTypeRef(" Integer ")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "int64", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_True(t *testing.T) {
|
||||||
|
got := parseTypeRef("True")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "bool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_False(t *testing.T) {
|
||||||
|
got := parseTypeRef("False")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "bool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_FloatNumber(t *testing.T) {
|
||||||
|
got := parseTypeRef("Float number")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "float64", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_Int(t *testing.T) {
|
||||||
|
got := parseTypeRef("Int")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "int64", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_Bool(t *testing.T) {
|
||||||
|
got := parseTypeRef("Bool")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "bool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_CommaAndUnion(t *testing.T) {
|
||||||
|
// "Foo, Bar and Baz" → oneOf{Foo, Bar, Baz}
|
||||||
|
got := parseTypeRef("InputMediaPhoto, InputMediaVideo and InputMediaDocument")
|
||||||
|
require.Equal(t, spec.KindOneOf, got.Kind)
|
||||||
|
require.Len(t, got.Variants, 3)
|
||||||
|
require.Contains(t, got.Variants, "InputMediaPhoto")
|
||||||
|
require.Contains(t, got.Variants, "InputMediaVideo")
|
||||||
|
require.Contains(t, got.Variants, "InputMediaDocument")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef_ArrayOfNothing(t *testing.T) {
|
||||||
|
// "Array of " with trailing space — TrimSpace removes the trailing space
|
||||||
|
// leaving "Array of" which does NOT match the "Array of " prefix, so it
|
||||||
|
// falls through to primitiveOrNamed and returns KindNamed (not KindArray).
|
||||||
|
got := parseTypeRef("Array of ")
|
||||||
|
require.Equal(t, spec.KindNamed, got.Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// splitCommaAnd
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSplitCommaAnd_ThreeVariants(t *testing.T) {
|
||||||
|
got := splitCommaAnd("A, B and C")
|
||||||
|
require.Equal(t, []string{"A", "B", "C"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitCommaAnd_FourVariants(t *testing.T) {
|
||||||
|
got := splitCommaAnd("A, B, C and D")
|
||||||
|
require.Equal(t, []string{"A", "B", "C", "D"}, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitCommaAnd_ExtraSpaces(t *testing.T) {
|
||||||
|
got := splitCommaAnd(" Foo , Bar and Baz ")
|
||||||
|
require.Len(t, got, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// goName — edge cases
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGoName_Empty(t *testing.T) {
|
||||||
|
require.Equal(t, "", goName(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoName_SingleWord(t *testing.T) {
|
||||||
|
require.Equal(t, "Photo", goName("photo"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoName_JSON(t *testing.T) {
|
||||||
|
require.Equal(t, "JSON", goName("json"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoName_HTML(t *testing.T) {
|
||||||
|
require.Equal(t, "HTML", goName("html"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoName_HTTPS(t *testing.T) {
|
||||||
|
require.Equal(t, "HTTPS", goName("https"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoName_AlreadyUpperSegment(t *testing.T) {
|
||||||
|
// Segment that starts with uppercase letter should be passed through.
|
||||||
|
require.Equal(t, "MediaGroupID", goName("media_group_id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// extractReturn — additional patterns
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractReturn_ArrayPattern(t *testing.T) {
|
||||||
|
desc := "Returns an Array of Update objects."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
require.Equal(t, spec.KindArray, got.Kind)
|
||||||
|
require.Equal(t, "Update", got.ElemType.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_BoolPattern(t *testing.T) {
|
||||||
|
desc := "Returns True on success."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "bool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_OnSuccessTrueIsReturned(t *testing.T) {
|
||||||
|
desc := "On success, true is returned."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "bool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_NamedObject(t *testing.T) {
|
||||||
|
desc := "On success, returns a Message object."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
require.Equal(t, spec.KindNamed, got.Kind)
|
||||||
|
require.Equal(t, "Message", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_MessageOrBool(t *testing.T) {
|
||||||
|
desc := "On success, the edited Message is returned, otherwise True is returned."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
require.Equal(t, spec.KindNamed, got.Kind)
|
||||||
|
require.Equal(t, "MessageOrBool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_InFormOf(t *testing.T) {
|
||||||
|
desc := "The answer is provided in form of a ChatInviteLink object."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
require.Equal(t, spec.KindNamed, got.Kind)
|
||||||
|
require.Equal(t, "ChatInviteLink", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_Fallback(t *testing.T) {
|
||||||
|
// No recognized pattern → bool fallback.
|
||||||
|
got := extractReturn("This method does something interesting.")
|
||||||
|
require.Equal(t, spec.KindPrimitive, got.Kind)
|
||||||
|
require.Equal(t, "bool", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractReturn_MultipleReturnsFirstWins(t *testing.T) {
|
||||||
|
// Doc with multiple "Returns" phrases — first matching pattern should win.
|
||||||
|
// The indefinite-article pattern ("Returns a X object") appears earlier in
|
||||||
|
// the priority list than "Returns True", so it matches "Returns a Message"
|
||||||
|
// before the bool pattern can fire.
|
||||||
|
desc := "Returns True on success. You can also Returns a Message object later."
|
||||||
|
got := extractReturn(desc)
|
||||||
|
// The indefinite-article pattern fires first → returns Message (KindNamed).
|
||||||
|
require.Equal(t, spec.KindNamed, got.Kind)
|
||||||
|
require.Equal(t, "Message", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// extractVersion
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractVersion_InTitle(t *testing.T) {
|
||||||
|
sections := []section{
|
||||||
|
{Title: "Bot API 7.3", Description: ""},
|
||||||
|
}
|
||||||
|
require.Equal(t, "7.3", extractVersion(sections))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractVersion_InDescription(t *testing.T) {
|
||||||
|
sections := []section{
|
||||||
|
{Title: "April 2024", Description: "Released Bot API 7.2."},
|
||||||
|
}
|
||||||
|
require.Equal(t, "7.2", extractVersion(sections))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractVersion_NotFound(t *testing.T) {
|
||||||
|
sections := []section{
|
||||||
|
{Title: "Introduction", Description: "Welcome to the API."},
|
||||||
|
}
|
||||||
|
require.Equal(t, "", extractVersion(sections))
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
// Command scrape parses the Telegram Bot API HTML page into the IR
|
||||||
|
// (internal/spec.API) and writes it to internal/spec/api.json.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// scrape -input <file> (read HTML from local file)
|
||||||
|
// scrape -url <url> (fetch HTML from URL; default: live docs)
|
||||||
|
// scrape -output <file> (output path; default: internal/spec/api.json)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultURL = "https://core.telegram.org/bots/api"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
input := flag.String("input", "", "local HTML file (overrides -url)")
|
||||||
|
url := flag.String("url", defaultURL, "URL to fetch HTML from")
|
||||||
|
output := flag.String("output", "internal/spec/api.json", "output path")
|
||||||
|
overridesPath := flag.String("overrides", "internal/spec/overrides.json", "path to overrides JSON")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := run(*input, *url, *output, *overridesPath); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "scrape:", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(input, url, output, overridesPath string) error {
|
||||||
|
htmlBytes, err := readHTML(input, url)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read html: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
api, err := scrape(htmlBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("scrape: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
overrides, err := spec.LoadOverrides(overridesPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load overrides: %w", err)
|
||||||
|
}
|
||||||
|
overrides.Apply(api)
|
||||||
|
|
||||||
|
return writeJSON(output, api)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readHTML(input, url string) ([]byte, error) {
|
||||||
|
if input != "" {
|
||||||
|
return os.ReadFile(input)
|
||||||
|
}
|
||||||
|
c := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "go-telegram codegen scraper")
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, errors.New(resp.Status)
|
||||||
|
}
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractReturn pulls the return type from a method's description prose.
|
||||||
|
//
|
||||||
|
// Patterns we handle (in priority order):
|
||||||
|
//
|
||||||
|
// "Returns an Array of X" / "On success, an Array of X is returned" → array of named X
|
||||||
|
// "an array of X of the sent messages is returned" → array of named X
|
||||||
|
// "the edited X is returned, otherwise True is returned" → XOrBool
|
||||||
|
// "Returns ... as a X object" / "Returns ... as X object" → named X
|
||||||
|
// "Returns ... as String on success" → string
|
||||||
|
// "On success, returns a X object" / "Returns a X object" → named X (indefinite article)
|
||||||
|
// "On success, an? X is returned" / "On success, the X is returned" → named X
|
||||||
|
// "Returns True" / "On success, true is returned" → bool
|
||||||
|
// "Returns the verb-ed X" → named X
|
||||||
|
// "On success, X is returned" → named X
|
||||||
|
// "Returns X on success" (no article) → named X
|
||||||
|
// "in form of a X" → named X
|
||||||
|
// fallback: bool
|
||||||
|
func extractReturn(desc string) spec.TypeRef {
|
||||||
|
// Normalise; strip *bold* markers because Telegram uses italics.
|
||||||
|
d := strings.ReplaceAll(desc, "*", "")
|
||||||
|
|
||||||
|
patterns := []struct {
|
||||||
|
re *regexp.Regexp
|
||||||
|
fn func([]string) spec.TypeRef
|
||||||
|
}{
|
||||||
|
// Array patterns first — most specific.
|
||||||
|
{regexp.MustCompile(`Returns an? [Aa]rray of ([A-Z][A-Za-z0-9]+)`), func(m []string) spec.TypeRef {
|
||||||
|
elem := primitiveOrNamed(m[1])
|
||||||
|
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
}},
|
||||||
|
{regexp.MustCompile(`On success(?:,)?\s+(?:an?\s+)?[Aa]rray of ([A-Z][A-Za-z0-9]+)(?:\s+objects?)?\s+(?:is|are|that\s+\S+\s+\S+\s+)?(?:is |are )?returned`), func(m []string) spec.TypeRef {
|
||||||
|
elem := primitiveOrNamed(m[1])
|
||||||
|
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
}},
|
||||||
|
// "an array of X of the sent messages is returned" (ForwardMessages/CopyMessages shape).
|
||||||
|
{regexp.MustCompile(`(?:[Oo]n success[,.]?\s+)?an? array of ([A-Z][A-Za-z0-9]+)(?:\s+of [^.]+?)?\s+(?:objects\s+)?(?:is|are) returned`), func(m []string) spec.TypeRef {
|
||||||
|
elem := primitiveOrNamed(m[1])
|
||||||
|
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
}},
|
||||||
|
// "Message or True" conditional return → XOrBool sentinel.
|
||||||
|
{regexp.MustCompile(`the (?:edited|sent|stopped)?\s*([A-Z][A-Za-z0-9]+)\s+is returned, otherwise (?:True|true) is returned`), func(m []string) spec.TypeRef {
|
||||||
|
return spec.TypeRef{Kind: spec.KindNamed, Name: m[1] + "OrBool"}
|
||||||
|
}},
|
||||||
|
// "Returns ... as a X object" / "Returns ... as X object" (with or without article).
|
||||||
|
{regexp.MustCompile(`[Rr]eturns? (?:.+? )?as (?:an? )?([A-Z][A-Za-z0-9]+) object`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// "Returns ... as String on success" / "Returns ... as X on success" (named type after "as").
|
||||||
|
{regexp.MustCompile(`[Rr]eturns? (?:.+? )?as ([A-Z][A-Za-z0-9]+) on success`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// Indefinite article: "On success, returns a X object" / "Returns a X object".
|
||||||
|
{regexp.MustCompile(`(?:[Oo]n success[,.]?\s+)?[Rr]eturns? an? ([A-Z][A-Za-z0-9]+)(?:\s+object)?`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// "On success, an? X is returned" / "On success, the stopped X is returned".
|
||||||
|
{regexp.MustCompile(`On success,\s+(?:an?|the)?\s*(?:[a-z]+\s+)?([A-Z][A-Za-z0-9]+)(?:\s+object)?\s+is returned`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// Explicit True — must come before the broad "Returns X" pattern.
|
||||||
|
{regexp.MustCompile(`Returns True`), func(m []string) spec.TypeRef {
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
|
||||||
|
}},
|
||||||
|
{regexp.MustCompile(`(?i)on success, true is returned`), func(m []string) spec.TypeRef {
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
|
||||||
|
}},
|
||||||
|
// "Returns the verb-ed X" — accepts any verb prefix (uploaded, revoked, …).
|
||||||
|
{regexp.MustCompile(`Returns (?:the|an?)\s+(?:[a-z]+ )?([A-Z][A-Za-z0-9]+)`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// "On success, X is returned" (no article).
|
||||||
|
{regexp.MustCompile(`On success(?:,)?\s+(?:the\s+)?(?:newly\s+)?(?:edited\s+|sent\s+|created\s+|updated\s+)?([A-Z][A-Za-z0-9]+)\s+is returned`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// "Returns X on success" (no article, e.g. "Returns OwnedGifts on success").
|
||||||
|
{regexp.MustCompile(`[Rr]eturns ([A-Z][A-Za-z0-9]+) on success`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
// "in form of a X".
|
||||||
|
{regexp.MustCompile(`in (?:the )?form of (?:a )?([A-Z][A-Za-z0-9]+)`), func(m []string) spec.TypeRef {
|
||||||
|
return primitiveOrNamed(m[1])
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, p := range patterns {
|
||||||
|
if m := p.re.FindStringSubmatch(d); m != nil {
|
||||||
|
return p.fn(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: bool. Better than panic; method-by-method tests would
|
||||||
|
// catch any regression.
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasFilesParams returns true if any param mentions InputFile (the
|
||||||
|
// scraper convention triggering multipart/form-data).
|
||||||
|
func hasFilesParams(params []spec.Field) bool {
|
||||||
|
for _, p := range params {
|
||||||
|
if mentionsInputFile(p.Type) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mentionsInputFile(tr spec.TypeRef) bool {
|
||||||
|
switch tr.Kind {
|
||||||
|
case spec.KindNamed:
|
||||||
|
return tr.Name == "InputFile" || strings.HasPrefix(tr.Name, "InputMedia") || strings.HasPrefix(tr.Name, "InputPaidMedia")
|
||||||
|
case spec.KindArray:
|
||||||
|
if tr.ElemType != nil {
|
||||||
|
return mentionsInputFile(*tr.ElemType)
|
||||||
|
}
|
||||||
|
case spec.KindOneOf:
|
||||||
|
for _, v := range tr.Variants {
|
||||||
|
if v == "InputFile" || strings.HasPrefix(v, "InputMedia") || strings.HasPrefix(v, "InputPaidMedia") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractVersion finds the API version string in a "Bot API X.Y[.Z]" heading.
|
||||||
|
var versionRE = regexp.MustCompile(`Bot API (\d+\.\d+(?:\.\d+)?)`)
|
||||||
|
|
||||||
|
// extractVersion finds the API version string. The live docs page emits
|
||||||
|
// the version as "<strong>Bot API X.Y</strong>" inside a paragraph below
|
||||||
|
// a date heading; the small fixture uses an h4 "Bot API X.Y" instead.
|
||||||
|
// Both shapes are handled here by also scanning section descriptions.
|
||||||
|
func extractVersion(sections []section) string {
|
||||||
|
for _, s := range sections {
|
||||||
|
if m := versionRE.FindStringSubmatch(s.Title); m != nil {
|
||||||
|
return m[1]
|
||||||
|
}
|
||||||
|
if m := versionRE.FindStringSubmatch(s.Description); m != nil {
|
||||||
|
return m[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractReturn(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want spec.TypeRef
|
||||||
|
}{
|
||||||
|
{"Returns basic information about the bot in form of a User object.", spec.TypeRef{Kind: spec.KindNamed, Name: "User"}},
|
||||||
|
{"On success, the sent Message is returned.", spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}},
|
||||||
|
{"Returns an Array of Update objects.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}},
|
||||||
|
{"Returns True on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
{"On success, True is returned.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
// Issue 5: "Message or True" conditional return → MessageOrBool sentinel.
|
||||||
|
{"On success, if the edited message is not an inline message, the edited Message is returned, otherwise True is returned.", spec.TypeRef{Kind: spec.KindNamed, Name: "MessageOrBool"}},
|
||||||
|
// Issue 1: new phrasings.
|
||||||
|
{"On success, returns a WebhookInfo object.", spec.TypeRef{Kind: spec.KindNamed, Name: "WebhookInfo"}},
|
||||||
|
{"Returns a UserProfilePhotos object.", spec.TypeRef{Kind: spec.KindNamed, Name: "UserProfilePhotos"}},
|
||||||
|
{"Returns the uploaded File.", spec.TypeRef{Kind: spec.KindNamed, Name: "File"}},
|
||||||
|
{"On success, the stopped Poll is returned.", spec.TypeRef{Kind: spec.KindNamed, Name: "Poll"}},
|
||||||
|
{"On success, an Array of MessageId is returned.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "MessageId"}}},
|
||||||
|
{"On success, an array of Message objects that were sent is returned.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}}},
|
||||||
|
// ForwardMessages/CopyMessages shape: "an array of X of the sent messages is returned".
|
||||||
|
{"On success, an array of MessageId of the sent messages is returned.", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "MessageId"}}},
|
||||||
|
// "Returns X on success" (no article) — OwnedGifts, StarAmount, Story, MenuButton, etc.
|
||||||
|
{"Returns the gifts received and owned by a managed business account. Returns OwnedGifts on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "OwnedGifts"}},
|
||||||
|
{"Returns StarAmount on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "StarAmount"}},
|
||||||
|
{"Posts a story on behalf of a managed business account. Returns Story on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "Story"}},
|
||||||
|
{"Returns MenuButton on success.", spec.TypeRef{Kind: spec.KindNamed, Name: "MenuButton"}},
|
||||||
|
// "Returns ... as X object" (no article before type) — ChatInviteLink variants.
|
||||||
|
{"Returns the new invite link as ChatInviteLink object.", spec.TypeRef{Kind: spec.KindNamed, Name: "ChatInviteLink"}},
|
||||||
|
{"Returns the revoked invite link as ChatInviteLink object.", spec.TypeRef{Kind: spec.KindNamed, Name: "ChatInviteLink"}},
|
||||||
|
// "Returns ... as a X object" (with article) — createForumTopic.
|
||||||
|
{"Returns information about the created topic as a ForumTopic object.", spec.TypeRef{Kind: spec.KindNamed, Name: "ForumTopic"}},
|
||||||
|
// "Returns ... as String on success" — exportChatInviteLink / createInvoiceLink.
|
||||||
|
{"Returns the new invite link as String on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
|
||||||
|
{"Returns the created invoice link as String on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
|
||||||
|
// "Returns Int on success" — getChatMemberCount.
|
||||||
|
{"Returns Int on success.", spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
require.Equal(t, c.want, extractReturn(c.in), c.in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFilesParams(t *testing.T) {
|
||||||
|
require.True(t, hasFilesParams([]spec.Field{
|
||||||
|
{Type: spec.TypeRef{Kind: spec.KindNamed, Name: "InputFile"}},
|
||||||
|
}))
|
||||||
|
require.True(t, hasFilesParams([]spec.Field{
|
||||||
|
{Type: spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}},
|
||||||
|
}))
|
||||||
|
require.False(t, hasFilesParams([]spec.Field{
|
||||||
|
{Type: spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
|
||||||
|
}))
|
||||||
|
// Issue 2: Array of InputMedia* union triggers HasFiles.
|
||||||
|
require.True(t, hasFilesParams([]spec.Field{
|
||||||
|
{Type: spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputMediaPhoto", "InputMediaVideo"}}}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractVersion(t *testing.T) {
|
||||||
|
sections := []section{{Title: "Recent changes"}, {Title: "Bot API 7.10"}, {Title: "Available types"}}
|
||||||
|
require.Equal(t, "7.10", extractVersion(sections))
|
||||||
|
|
||||||
|
// Issue 4: 3-part version must not be truncated.
|
||||||
|
sections3 := []section{{Description: "Bot API 8.0.1"}}
|
||||||
|
require.Equal(t, "8.0.1", extractVersion(sections3))
|
||||||
|
}
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/net/html"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// scrape (the package-level implementation overriding the stub in main.go;
|
||||||
|
// remove the stub from main.go in this task) parses the docs HTML into IR.
|
||||||
|
func scrape(htmlBytes []byte) (*spec.API, error) {
|
||||||
|
doc, err := html.Parse(bytes.NewReader(htmlBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("html parse: %w", err)
|
||||||
|
}
|
||||||
|
sections := walk(doc)
|
||||||
|
api := &spec.API{Version: extractVersion(sections)}
|
||||||
|
for _, s := range sections {
|
||||||
|
switch {
|
||||||
|
case isMethodTitle(s.Title):
|
||||||
|
api.Methods = append(api.Methods, methodFromSection(s))
|
||||||
|
case isTypeTitle(s.Title):
|
||||||
|
api.Types = append(api.Types, typeFromSection(s))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return api, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeFromSection(s section) spec.TypeDecl {
|
||||||
|
td := spec.TypeDecl{Name: s.Title, Doc: s.Description}
|
||||||
|
if len(s.Tables) > 0 {
|
||||||
|
td.Fields = parseFieldsTable(s.Tables[0])
|
||||||
|
} else if len(s.Lists) > 0 {
|
||||||
|
// Union: extract variant names from <li><a>...</a></li>.
|
||||||
|
td.OneOf = extractListLinks(s.Lists[0])
|
||||||
|
}
|
||||||
|
return td
|
||||||
|
}
|
||||||
|
|
||||||
|
func methodFromSection(s section) spec.MethodDecl {
|
||||||
|
md := spec.MethodDecl{Name: s.Title, Doc: s.Description, Returns: extractReturn(s.Description)}
|
||||||
|
if len(s.Tables) > 0 {
|
||||||
|
md.Params = parseParamsTable(s.Tables[0])
|
||||||
|
}
|
||||||
|
md.HasFiles = hasFilesParams(md.Params)
|
||||||
|
return md
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractListLinks pulls anchor texts out of a <ul>: each <li><a>X</a></li>
|
||||||
|
// contributes "X" to the result. Used for union variant lists.
|
||||||
|
func extractListLinks(ul *html.Node) []string {
|
||||||
|
var names []string
|
||||||
|
var visit func(*html.Node)
|
||||||
|
visit = func(n *html.Node) {
|
||||||
|
if n.Type == html.ElementNode && n.Data == "a" {
|
||||||
|
names = append(names, textOf(n))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||||
|
visit(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
visit(ul)
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeJSON marshals the IR with stable, human-readable formatting and
|
||||||
|
// writes it to path. Marshalling is deterministic: types and methods
|
||||||
|
// preserve scrape order; struct fields use IR-defined order.
|
||||||
|
func writeJSON(path string, api *spec.API) error {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
if err := enc.Encode(api); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, buf.Bytes(), 0o644)
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var update = flag.Bool("update", false, "update golden files")
|
||||||
|
|
||||||
|
func TestScrape_Golden_SmallFixture(t *testing.T) {
|
||||||
|
htmlBytes, err := os.ReadFile("../../testdata/html/small_fixture.html")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
api, err := scrape(htmlBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
require.NoError(t, enc.Encode(api))
|
||||||
|
|
||||||
|
goldenPath := "../../testdata/golden/api_small_fixture.json"
|
||||||
|
if *update {
|
||||||
|
require.NoError(t, os.WriteFile(goldenPath, buf.Bytes(), 0o644))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected, err := os.ReadFile(goldenPath)
|
||||||
|
require.NoError(t, err, "missing golden; run `go test -update ./cmd/scrape/...` to create")
|
||||||
|
require.Equal(t, string(expected), buf.String())
|
||||||
|
}
|
||||||
@@ -0,0 +1,224 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/html"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseFieldsTable walks a <table> for an object-type definition.
|
||||||
|
// Columns: Field, Type, Description (optional column orders are not
|
||||||
|
// supported; Telegram's docs use a stable layout).
|
||||||
|
//
|
||||||
|
// Optional fields are detected via the "Optional." prefix in the
|
||||||
|
// description text, which is the documented convention.
|
||||||
|
func parseFieldsTable(t *html.Node) []spec.Field {
|
||||||
|
rows := tableRows(t)
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var fields []spec.Field
|
||||||
|
for _, row := range rows[1:] { // skip header
|
||||||
|
cells := rowCells(row)
|
||||||
|
if len(cells) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jname := strings.TrimSpace(textOf(cells[0]))
|
||||||
|
typeText := strings.TrimSpace(textOf(cells[1]))
|
||||||
|
desc := strings.TrimSpace(textOf(cells[2]))
|
||||||
|
|
||||||
|
required := !strings.HasPrefix(desc, "Optional.")
|
||||||
|
fields = append(fields, spec.Field{
|
||||||
|
Name: goName(jname),
|
||||||
|
JSONName: jname,
|
||||||
|
Type: parseTypeRef(typeText),
|
||||||
|
Required: required,
|
||||||
|
Doc: desc,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseParamsTable walks a <table> for a method definition.
|
||||||
|
// Columns: Parameter, Type, Required, Description.
|
||||||
|
func parseParamsTable(t *html.Node) []spec.Field {
|
||||||
|
rows := tableRows(t)
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var params []spec.Field
|
||||||
|
for _, row := range rows[1:] {
|
||||||
|
cells := rowCells(row)
|
||||||
|
if len(cells) < 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
jname := strings.TrimSpace(textOf(cells[0]))
|
||||||
|
typeText := strings.TrimSpace(textOf(cells[1]))
|
||||||
|
req := strings.EqualFold(strings.TrimSpace(textOf(cells[2])), "Yes")
|
||||||
|
desc := strings.TrimSpace(textOf(cells[3]))
|
||||||
|
|
||||||
|
params = append(params, spec.Field{
|
||||||
|
Name: goName(jname),
|
||||||
|
JSONName: jname,
|
||||||
|
Type: parseTypeRef(typeText),
|
||||||
|
Required: req,
|
||||||
|
Doc: desc,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// tableRows returns the <tr> children of a <table>, skipping over
|
||||||
|
// any <thead>/<tbody> wrappers.
|
||||||
|
func tableRows(t *html.Node) []*html.Node {
|
||||||
|
var rows []*html.Node
|
||||||
|
var visit func(*html.Node)
|
||||||
|
visit = func(n *html.Node) {
|
||||||
|
if n.Type == html.ElementNode && n.Data == "tr" {
|
||||||
|
rows = append(rows, n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||||
|
visit(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
visit(t)
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
|
||||||
|
// rowCells returns the <td> (or <th>) children of a <tr>.
|
||||||
|
func rowCells(tr *html.Node) []*html.Node {
|
||||||
|
var cells []*html.Node
|
||||||
|
for c := tr.FirstChild; c != nil; c = c.NextSibling {
|
||||||
|
if c.Type == html.ElementNode && (c.Data == "td" || c.Data == "th") {
|
||||||
|
cells = append(cells, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cells
|
||||||
|
}
|
||||||
|
|
||||||
|
// goName converts a snake_case JSON identifier to PascalCase.
|
||||||
|
// Special-cases common acronyms used in the Telegram docs.
|
||||||
|
func goName(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts := strings.Split(s, "_")
|
||||||
|
var b strings.Builder
|
||||||
|
for _, p := range parts {
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch p {
|
||||||
|
case "id":
|
||||||
|
b.WriteString("ID")
|
||||||
|
case "url":
|
||||||
|
b.WriteString("URL")
|
||||||
|
case "ip":
|
||||||
|
b.WriteString("IP")
|
||||||
|
case "https":
|
||||||
|
b.WriteString("HTTPS")
|
||||||
|
case "json":
|
||||||
|
b.WriteString("JSON")
|
||||||
|
case "html":
|
||||||
|
b.WriteString("HTML")
|
||||||
|
default:
|
||||||
|
if p[0] >= 'a' && p[0] <= 'z' {
|
||||||
|
b.WriteByte(p[0] - 'a' + 'A')
|
||||||
|
b.WriteString(p[1:])
|
||||||
|
} else {
|
||||||
|
b.WriteString(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTypeRef decodes the type-cell text into a spec.TypeRef.
|
||||||
|
//
|
||||||
|
// Recognised shapes:
|
||||||
|
//
|
||||||
|
// "Integer" → primitive int64
|
||||||
|
// "String" → primitive string
|
||||||
|
// "Boolean" / "True" → primitive bool
|
||||||
|
// "Float" / "Float number"→ primitive float64
|
||||||
|
// "Array of X" → array of (parseTypeRef of X)
|
||||||
|
// "Array of Array of X" → array of array of X
|
||||||
|
// "Foo" → named Foo
|
||||||
|
// "Foo or Bar" → oneOf {Foo, Bar}
|
||||||
|
// "InputFile or String" → oneOf (caller may translate to InputFile)
|
||||||
|
//
|
||||||
|
// parseTypeRef decodes the type-cell text into a spec.TypeRef.
|
||||||
|
//
|
||||||
|
// Recognised shapes:
|
||||||
|
//
|
||||||
|
// "Integer" → primitive int64
|
||||||
|
// "String" → primitive string
|
||||||
|
// "Boolean" / "True" → primitive bool
|
||||||
|
// "Float" / "Float number"→ primitive float64
|
||||||
|
// "Array of X" → array of (parseTypeRef of X)
|
||||||
|
// "Array of Array of X" → array of array of X
|
||||||
|
// "Foo" → named Foo
|
||||||
|
// "Foo or Bar" → oneOf {Foo, Bar}
|
||||||
|
// "Foo, Bar and Baz" → oneOf {Foo, Bar, Baz} (Telegram's comma+and union form)
|
||||||
|
// "InputFile or String" → oneOf (caller may translate to InputFile)
|
||||||
|
func parseTypeRef(s string) spec.TypeRef {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
// Array prefix.
|
||||||
|
if rest, ok := strings.CutPrefix(s, "Array of "); ok {
|
||||||
|
elem := parseTypeRef(rest)
|
||||||
|
return spec.TypeRef{Kind: spec.KindArray, ElemType: &elem}
|
||||||
|
}
|
||||||
|
// Comma-and union ("X, Y, Z and W") — used by Telegram for ≥3-variant unions.
|
||||||
|
if strings.Contains(s, ", ") && strings.Contains(s, " and ") {
|
||||||
|
parts := splitCommaAnd(s)
|
||||||
|
variants := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
variants = append(variants, primitiveOrNamed(strings.TrimSpace(p)).Name)
|
||||||
|
}
|
||||||
|
return spec.TypeRef{Kind: spec.KindOneOf, Variants: variants}
|
||||||
|
}
|
||||||
|
// "X or Y" union (the 2-variant form).
|
||||||
|
if strings.Contains(s, " or ") {
|
||||||
|
parts := strings.Split(s, " or ")
|
||||||
|
variants := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
variants = append(variants, primitiveOrNamed(strings.TrimSpace(p)).Name)
|
||||||
|
}
|
||||||
|
return spec.TypeRef{Kind: spec.KindOneOf, Variants: variants}
|
||||||
|
}
|
||||||
|
return primitiveOrNamed(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitCommaAnd splits "A, B, C and D" → ["A", "B", "C", "D"].
|
||||||
|
func splitCommaAnd(s string) []string {
|
||||||
|
// Replace " and " with ", " then split on ", ".
|
||||||
|
s = strings.ReplaceAll(s, " and ", ", ")
|
||||||
|
parts := strings.Split(s, ", ")
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
if p = strings.TrimSpace(p); p != "" {
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// primitiveOrNamed maps a single-word type cell to either a primitive
|
||||||
|
// or a named TypeRef. Unrecognised words are treated as named types.
|
||||||
|
func primitiveOrNamed(s string) spec.TypeRef {
|
||||||
|
switch s {
|
||||||
|
case "Integer", "Int":
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}
|
||||||
|
case "String":
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}
|
||||||
|
case "Boolean", "Bool", "True", "False":
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}
|
||||||
|
case "Float", "Float number":
|
||||||
|
return spec.TypeRef{Kind: spec.KindPrimitive, Name: "float64"}
|
||||||
|
default:
|
||||||
|
return spec.TypeRef{Kind: spec.KindNamed, Name: s}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/internal/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGoName(t *testing.T) {
|
||||||
|
cases := []struct{ in, want string }{
|
||||||
|
{"chat_id", "ChatID"},
|
||||||
|
{"first_name", "FirstName"},
|
||||||
|
{"is_bot", "IsBot"},
|
||||||
|
{"url", "URL"},
|
||||||
|
{"ip_address", "IPAddress"},
|
||||||
|
{"language_code", "LanguageCode"},
|
||||||
|
{"webhook_URL", "WebhookURL"}, // Issue 3: already-uppercase segment must not be corrupted.
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
require.Equal(t, c.want, goName(c.in), c.in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTypeRef(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want spec.TypeRef
|
||||||
|
}{
|
||||||
|
{"Integer", spec.TypeRef{Kind: spec.KindPrimitive, Name: "int64"}},
|
||||||
|
{"String", spec.TypeRef{Kind: spec.KindPrimitive, Name: "string"}},
|
||||||
|
{"Boolean", spec.TypeRef{Kind: spec.KindPrimitive, Name: "bool"}},
|
||||||
|
{"Float", spec.TypeRef{Kind: spec.KindPrimitive, Name: "float64"}},
|
||||||
|
{"Message", spec.TypeRef{Kind: spec.KindNamed, Name: "Message"}},
|
||||||
|
{"Array of Update", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "Update"}}},
|
||||||
|
{"Array of Array of PhotoSize", spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindArray, ElemType: &spec.TypeRef{Kind: spec.KindNamed, Name: "PhotoSize"}}}},
|
||||||
|
{"Integer or String", spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"int64", "string"}}},
|
||||||
|
{"InputFile or String", spec.TypeRef{Kind: spec.KindOneOf, Variants: []string{"InputFile", "string"}}},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
require.Equal(t, c.want, parseTypeRef(c.in), c.in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFieldsTable_FromFixture(t *testing.T) {
|
||||||
|
doc := parse(t, "../../testdata/html/small_fixture.html")
|
||||||
|
sections := walk(doc)
|
||||||
|
var user *section
|
||||||
|
for i := range sections {
|
||||||
|
if sections[i].Title == "User" {
|
||||||
|
user = §ions[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.Len(t, user.Tables, 1)
|
||||||
|
|
||||||
|
fields := parseFieldsTable(user.Tables[0])
|
||||||
|
require.Len(t, fields, 4)
|
||||||
|
require.Equal(t, "ID", fields[0].Name)
|
||||||
|
require.Equal(t, "id", fields[0].JSONName)
|
||||||
|
require.Equal(t, spec.KindPrimitive, fields[0].Type.Kind)
|
||||||
|
require.True(t, fields[0].Required)
|
||||||
|
|
||||||
|
require.Equal(t, "LastName", fields[3].Name)
|
||||||
|
require.False(t, fields[3].Required) // "Optional." prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseParamsTable_FromFixture(t *testing.T) {
|
||||||
|
doc := parse(t, "../../testdata/html/small_fixture.html")
|
||||||
|
sections := walk(doc)
|
||||||
|
var sm *section
|
||||||
|
for i := range sections {
|
||||||
|
if sections[i].Title == "sendMessage" {
|
||||||
|
sm = §ions[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, sm)
|
||||||
|
require.Len(t, sm.Tables, 1)
|
||||||
|
|
||||||
|
params := parseParamsTable(sm.Tables[0])
|
||||||
|
require.Len(t, params, 3)
|
||||||
|
require.Equal(t, "ChatID", params[0].Name)
|
||||||
|
require.True(t, params[0].Required)
|
||||||
|
require.Equal(t, spec.KindOneOf, params[0].Type.Kind)
|
||||||
|
require.Equal(t, []string{"int64", "string"}, params[0].Type.Variants)
|
||||||
|
|
||||||
|
require.Equal(t, "ParseMode", params[2].Name)
|
||||||
|
require.False(t, params[2].Required) // "Optional"
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/html"
|
||||||
|
)
|
||||||
|
|
||||||
|
// section is an h4-anchored block of the docs page. Title is the
|
||||||
|
// heading text (e.g. "User" or "sendMessage"). Description is the
|
||||||
|
// concatenation of immediately-following <p> paragraphs (until the
|
||||||
|
// next h4 / h3 / table / list). Tables and Lists hold raw nodes for
|
||||||
|
// later parsing by the table/oneof extractors.
|
||||||
|
type section struct {
|
||||||
|
Title string
|
||||||
|
Description string
|
||||||
|
Tables []*html.Node // <table> nodes
|
||||||
|
Lists []*html.Node // <ul> nodes (used for oneof variant lists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// walk parses the page and returns sections in document order.
|
||||||
|
// Sections whose title contains a space (e.g. "Bot API 7.10") are
|
||||||
|
// included; later passes ignore them or treat them specially.
|
||||||
|
func walk(doc *html.Node) []section {
|
||||||
|
var (
|
||||||
|
sections []section
|
||||||
|
current *section
|
||||||
|
)
|
||||||
|
|
||||||
|
var visit func(n *html.Node)
|
||||||
|
visit = func(n *html.Node) {
|
||||||
|
if n.Type == html.ElementNode {
|
||||||
|
switch n.Data {
|
||||||
|
case "h4":
|
||||||
|
if current != nil {
|
||||||
|
sections = append(sections, *current)
|
||||||
|
}
|
||||||
|
current = §ion{Title: textOf(n)}
|
||||||
|
// Don't recurse into the heading; we already have its text.
|
||||||
|
return
|
||||||
|
case "h3":
|
||||||
|
// h3 (e.g. "Available methods") delimits a section;
|
||||||
|
// flush the current h4 section but do not start a new one.
|
||||||
|
if current != nil {
|
||||||
|
sections = append(sections, *current)
|
||||||
|
current = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "p":
|
||||||
|
if current != nil {
|
||||||
|
if current.Description != "" {
|
||||||
|
current.Description += "\n"
|
||||||
|
}
|
||||||
|
current.Description += strings.TrimSpace(textOf(n))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "table":
|
||||||
|
if current != nil {
|
||||||
|
current.Tables = append(current.Tables, n)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case "ul":
|
||||||
|
if current != nil {
|
||||||
|
current.Lists = append(current.Lists, n)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||||
|
visit(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
visit(doc)
|
||||||
|
if current != nil {
|
||||||
|
sections = append(sections, *current)
|
||||||
|
}
|
||||||
|
return sections
|
||||||
|
}
|
||||||
|
|
||||||
|
// textOf returns the concatenated text content of n and descendants,
|
||||||
|
// with adjacent whitespace collapsed to single spaces.
|
||||||
|
func textOf(n *html.Node) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
var w func(*html.Node)
|
||||||
|
w = func(n *html.Node) {
|
||||||
|
if n.Type == html.TextNode {
|
||||||
|
sb.WriteString(n.Data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||||
|
w(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w(n)
|
||||||
|
return collapseWS(sb.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func collapseWS(s string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
prevSpace := false
|
||||||
|
for _, r := range s {
|
||||||
|
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
|
||||||
|
if !prevSpace {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
}
|
||||||
|
prevSpace = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prevSpace = false
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMethodTitle returns true for headings that look like method names
|
||||||
|
// (camelCase starting with a lowercase letter; e.g. "sendMessage").
|
||||||
|
func isMethodTitle(s string) bool {
|
||||||
|
if s == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
r := s[0]
|
||||||
|
return r >= 'a' && r <= 'z'
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTypeTitle returns true for headings that look like type names
|
||||||
|
// (PascalCase; e.g. "Message"). Allows a leading-uppercase only;
|
||||||
|
// excludes spaces (which would denote a header like "Bot API 7.10").
|
||||||
|
func isTypeTitle(s string) bool {
|
||||||
|
if s == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
r := s[0]
|
||||||
|
if r < 'A' || r > 'Z' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !strings.Contains(s, " ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/html"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parse(t *testing.T, path string) *html.Node {
|
||||||
|
t.Helper()
|
||||||
|
f, err := os.Open(path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer f.Close()
|
||||||
|
doc, err := html.Parse(f)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return doc
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalk_FixtureSections(t *testing.T) {
|
||||||
|
doc := parse(t, "../../testdata/html/small_fixture.html")
|
||||||
|
sections := walk(doc)
|
||||||
|
|
||||||
|
titles := make([]string, 0, len(sections))
|
||||||
|
for _, s := range sections {
|
||||||
|
titles = append(titles, s.Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, titles, "User")
|
||||||
|
require.Contains(t, titles, "ChatMember")
|
||||||
|
require.Contains(t, titles, "getMe")
|
||||||
|
require.Contains(t, titles, "sendMessage")
|
||||||
|
require.Contains(t, titles, "sendDocument")
|
||||||
|
require.Contains(t, titles, "getUpdates")
|
||||||
|
require.Contains(t, titles, "Bot API 7.10")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsMethodTitle(t *testing.T) {
|
||||||
|
require.True(t, isMethodTitle("sendMessage"))
|
||||||
|
require.True(t, isMethodTitle("getMe"))
|
||||||
|
require.False(t, isMethodTitle("Message"))
|
||||||
|
require.False(t, isMethodTitle(""))
|
||||||
|
require.False(t, isMethodTitle("Bot API 7.10"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTypeTitle(t *testing.T) {
|
||||||
|
require.True(t, isTypeTitle("Message"))
|
||||||
|
require.True(t, isTypeTitle("ChatMember"))
|
||||||
|
require.False(t, isTypeTitle("sendMessage"))
|
||||||
|
require.False(t, isTypeTitle("Bot API 7.10"))
|
||||||
|
require.False(t, isTypeTitle(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSection_DescriptionAndTables(t *testing.T) {
|
||||||
|
doc := parse(t, "../../testdata/html/small_fixture.html")
|
||||||
|
sections := walk(doc)
|
||||||
|
var sm *section
|
||||||
|
for i, s := range sections {
|
||||||
|
if s.Title == "sendMessage" {
|
||||||
|
sm = §ions[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, sm)
|
||||||
|
require.True(t, strings.Contains(sm.Description, "Use this method to send text messages"))
|
||||||
|
require.Len(t, sm.Tables, 1)
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
// Package dispatch provides a typed router for Telegram updates. It
|
||||||
|
// consumes any transport.Updater and dispatches updates to handlers
|
||||||
|
// registered by command, regex, or update-payload kind.
|
||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context bundles the per-update state every handler receives.
|
||||||
|
//
|
||||||
|
// Ctx is the request context propagated from Router.Run; cancelling the
|
||||||
|
// run cancels every handler.
|
||||||
|
//
|
||||||
|
// Bot is the API client. Handlers reply by calling api.SendMessage(c.Ctx,
|
||||||
|
// c.Bot, ...) etc.
|
||||||
|
//
|
||||||
|
// Update is the raw update; payload-typed handlers also receive a
|
||||||
|
// narrowed pointer to one of its sub-fields.
|
||||||
|
//
|
||||||
|
// Values is a per-update bag matchers populate. Conventional keys:
|
||||||
|
//
|
||||||
|
// "command": string, the matched bot command (e.g. "/start")
|
||||||
|
// "command_args": string, everything after the command
|
||||||
|
// "regex_match": []string, regex sub-matches when OnText matches
|
||||||
|
type Context struct {
|
||||||
|
Ctx context.Context
|
||||||
|
Bot *client.Bot
|
||||||
|
Update *api.Update
|
||||||
|
Values map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContext constructs a Context. Used by Router internally; exposed for
|
||||||
|
// custom test harnesses.
|
||||||
|
func NewContext(ctx context.Context, b *client.Bot, u *api.Update) *Context {
|
||||||
|
return &Context{Ctx: ctx, Bot: b, Update: u, Values: map[string]any{}}
|
||||||
|
}
|
||||||
@@ -0,0 +1,494 @@
|
|||||||
|
package conversation_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch/conversation"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- helpers ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func msgUpd(userID, chatID int64, text string) api.Update {
|
||||||
|
return api.Update{
|
||||||
|
UpdateID: 1,
|
||||||
|
Message: &api.Message{
|
||||||
|
MessageID: 1,
|
||||||
|
From: &api.User{ID: userID},
|
||||||
|
Chat: api.Chat{ID: chatID},
|
||||||
|
Text: text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCtx(u *api.Update) *dispatch.Context {
|
||||||
|
return dispatch.NewContext(context.Background(), client.New("t"), u)
|
||||||
|
}
|
||||||
|
|
||||||
|
// anyMsg matches any update that has a Message.
|
||||||
|
var anyMsg = func(u *api.Update) bool { return u.Message != nil }
|
||||||
|
|
||||||
|
// hasPrefix returns a filter matching updates whose Message.Text has prefix p.
|
||||||
|
func hasPrefix(p string) dispatch.Filter[*api.Update] {
|
||||||
|
return func(u *api.Update) bool {
|
||||||
|
return u.Message != nil && strings.HasPrefix(u.Message.Text, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeUpdater feeds a fixed set of updates then closes (mirrors router_test.go).
|
||||||
|
type fakeUpdater struct{ ch chan api.Update }
|
||||||
|
|
||||||
|
func newFake(ups ...api.Update) *fakeUpdater {
|
||||||
|
ch := make(chan api.Update, len(ups))
|
||||||
|
for _, u := range ups {
|
||||||
|
ch <- u
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return &fakeUpdater{ch: ch}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
|
||||||
|
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||||
|
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
|
||||||
|
|
||||||
|
// ---- Storage tests ---------------------------------------------------------
|
||||||
|
|
||||||
|
func TestStorage_ErrKeyNotFound(t *testing.T) {
|
||||||
|
s := conversation.NewMemoryStorage()
|
||||||
|
_, err := s.Get(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_SetAndGet(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := conversation.NewMemoryStorage()
|
||||||
|
require.NoError(t, s.Set(ctx, "k", "state_a"))
|
||||||
|
v, err := s.Get(ctx, "k")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, conversation.State("state_a"), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Delete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := conversation.NewMemoryStorage()
|
||||||
|
require.NoError(t, s.Set(ctx, "k", "state_a"))
|
||||||
|
require.NoError(t, s.Delete(ctx, "k"))
|
||||||
|
_, err := s.Get(ctx, "k")
|
||||||
|
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_DeleteNonExistentIsNoop(t *testing.T) {
|
||||||
|
require.NoError(t, conversation.NewMemoryStorage().Delete(context.Background(), "gone"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Key strategy tests ----------------------------------------------------
|
||||||
|
|
||||||
|
func TestKeyByUser_Variants(t *testing.T) {
|
||||||
|
t.Run("message", func(t *testing.T) {
|
||||||
|
u := msgUpd(42, 100, "hi")
|
||||||
|
require.Equal(t, "u:42", conversation.KeyByUser(&u))
|
||||||
|
})
|
||||||
|
t.Run("edited_message", func(t *testing.T) {
|
||||||
|
u := api.Update{EditedMessage: &api.Message{From: &api.User{ID: 7}, Chat: api.Chat{ID: 1}}}
|
||||||
|
require.Equal(t, "u:7", conversation.KeyByUser(&u))
|
||||||
|
})
|
||||||
|
t.Run("callback_query", func(t *testing.T) {
|
||||||
|
u := api.Update{CallbackQuery: &api.CallbackQuery{From: api.User{ID: 99}}}
|
||||||
|
require.Equal(t, "u:99", conversation.KeyByUser(&u))
|
||||||
|
})
|
||||||
|
t.Run("inline_query", func(t *testing.T) {
|
||||||
|
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||||
|
require.Equal(t, "u:5", conversation.KeyByUser(&u))
|
||||||
|
})
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
require.Equal(t, "", conversation.KeyByUser(&api.Update{}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyByChat_Variants(t *testing.T) {
|
||||||
|
t.Run("message", func(t *testing.T) {
|
||||||
|
u := msgUpd(1, 200, "")
|
||||||
|
require.Equal(t, "c:200", conversation.KeyByChat(&u))
|
||||||
|
})
|
||||||
|
t.Run("inline_has_no_chat", func(t *testing.T) {
|
||||||
|
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||||
|
require.Equal(t, "", conversation.KeyByChat(&u))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyByUserAndChat(t *testing.T) {
|
||||||
|
u := msgUpd(42, 100, "")
|
||||||
|
require.Equal(t, "uc:100:42", conversation.KeyByUserAndChat(&u))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Handler / state machine tests -----------------------------------------
|
||||||
|
|
||||||
|
func buildConv() *conversation.Conversation {
|
||||||
|
return &conversation.Conversation{
|
||||||
|
EntryPoints: []conversation.Step{{
|
||||||
|
Filter: hasPrefix("/start"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
return conversation.Next("await_name")
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
States: map[conversation.State][]conversation.Step{
|
||||||
|
"await_name": {{
|
||||||
|
Filter: anyMsg,
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("await_age") },
|
||||||
|
}},
|
||||||
|
"await_age": {{
|
||||||
|
Filter: anyMsg,
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
Exits: []conversation.Step{{
|
||||||
|
Filter: hasPrefix("/cancel"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_FullFlow(t *testing.T) {
|
||||||
|
conv := buildConv()
|
||||||
|
|
||||||
|
var downstream int
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||||
|
downstream++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
key := "uc:1:42"
|
||||||
|
|
||||||
|
// 1. /start → enters, state = await_name
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
v, err := conv.Storage.Get(context.Background(), key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, conversation.State("await_name"), v)
|
||||||
|
require.Equal(t, 0, downstream, "entry claimed update")
|
||||||
|
|
||||||
|
// 2. name → state = await_age
|
||||||
|
u2 := msgUpd(42, 1, "Alice")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
v, err = conv.Storage.Get(context.Background(), key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, conversation.State("await_age"), v)
|
||||||
|
|
||||||
|
// 3. age → End, key deleted
|
||||||
|
u3 := msgUpd(42, 1, "30")
|
||||||
|
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||||
|
_, err = conv.Storage.Get(context.Background(), key)
|
||||||
|
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_ExitsCancelMidFlow(t *testing.T) {
|
||||||
|
conv := buildConv()
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
// Start conversation.
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cancel mid-flow.
|
||||||
|
u2 := msgUpd(42, 1, "/cancel")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
_, err = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "exit should clear state")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_FallbackFiresWhenNoStateStepMatches(t *testing.T) {
|
||||||
|
fallbackHit := false
|
||||||
|
conv := &conversation.Conversation{
|
||||||
|
EntryPoints: []conversation.Step{{
|
||||||
|
Filter: hasPrefix("/start"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
|
||||||
|
}},
|
||||||
|
States: map[conversation.State][]conversation.Step{
|
||||||
|
// No steps for "waiting" that match a callback query.
|
||||||
|
"waiting": {},
|
||||||
|
},
|
||||||
|
Fallbacks: []conversation.Step{{
|
||||||
|
Filter: anyMsg,
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
fallbackHit = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
|
||||||
|
u2 := msgUpd(42, 1, "unexpected text")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
require.True(t, fallbackHit, "fallback should have fired")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_NoActiveConv_PassesToDownstream(t *testing.T) {
|
||||||
|
conv := buildConv()
|
||||||
|
downstreamHit := false
|
||||||
|
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||||
|
downstreamHit = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
mw := conv.Dispatch(downstream)
|
||||||
|
|
||||||
|
// Random message that doesn't match /start
|
||||||
|
u := msgUpd(42, 1, "hello")
|
||||||
|
require.NoError(t, mw(makeCtx(&u), &u))
|
||||||
|
require.True(t, downstreamHit, "unmatched update should reach downstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_EmptyKey_PassesThrough(t *testing.T) {
|
||||||
|
// InlineQuery has no chatID → KeyByUserAndChat returns "" → pass through.
|
||||||
|
conv := buildConv()
|
||||||
|
downstreamHit := false
|
||||||
|
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||||
|
downstreamHit = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
mw := conv.Dispatch(downstream)
|
||||||
|
|
||||||
|
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||||
|
require.NoError(t, mw(makeCtx(&u), &u))
|
||||||
|
require.True(t, downstreamHit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_AllowReEntry(t *testing.T) {
|
||||||
|
conv := buildConv()
|
||||||
|
conv.AllowReEntry = true
|
||||||
|
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
// Start.
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.Equal(t, conversation.State("await_name"), v)
|
||||||
|
|
||||||
|
// Advance once.
|
||||||
|
u2 := msgUpd(42, 1, "Alice")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.Equal(t, conversation.State("await_age"), v)
|
||||||
|
|
||||||
|
// Re-enter with /start — should restart to await_name even though mid-flow.
|
||||||
|
u3 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||||
|
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.Equal(t, conversation.State("await_name"), v, "AllowReEntry should restart")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_NoReEntry_EntryIgnoredWhenActive(t *testing.T) {
|
||||||
|
conv := buildConv()
|
||||||
|
conv.AllowReEntry = false
|
||||||
|
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
// Start.
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
|
||||||
|
// Advance to await_age.
|
||||||
|
u2 := msgUpd(42, 1, "Alice")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.Equal(t, conversation.State("await_age"), v)
|
||||||
|
|
||||||
|
// /start again — should NOT restart; state should stay await_age since
|
||||||
|
// /start matches the state step filter (anyMsg) and advances.
|
||||||
|
// Actually /start is handled by "await_age" anyMsg step → End().
|
||||||
|
u3 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||||
|
// State ended (End() called by await_age step).
|
||||||
|
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "state step should have consumed /start when AllowReEntry=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_StayInState_NilReturn(t *testing.T) {
|
||||||
|
// Handler returning nil keeps state unchanged.
|
||||||
|
stored := false
|
||||||
|
conv := &conversation.Conversation{
|
||||||
|
EntryPoints: []conversation.Step{{
|
||||||
|
Filter: hasPrefix("/start"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
return conversation.Next("waiting")
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
States: map[conversation.State][]conversation.Step{
|
||||||
|
"waiting": {{
|
||||||
|
Filter: anyMsg,
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
stored = true
|
||||||
|
return nil // stay in current state
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
|
||||||
|
u2 := msgUpd(42, 1, "something")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
require.True(t, stored)
|
||||||
|
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||||
|
require.Equal(t, conversation.State("waiting"), v, "nil return should leave state unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_ActiveNoMatch_Swallows(t *testing.T) {
|
||||||
|
// Active conversation with no matching state step and no fallback:
|
||||||
|
// update is swallowed (not passed downstream).
|
||||||
|
conv := &conversation.Conversation{
|
||||||
|
EntryPoints: []conversation.Step{{
|
||||||
|
Filter: hasPrefix("/start"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
|
||||||
|
}},
|
||||||
|
States: map[conversation.State][]conversation.Step{
|
||||||
|
"waiting": {{
|
||||||
|
// Only matches /done specifically.
|
||||||
|
Filter: hasPrefix("/done"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
downstreamHit := false
|
||||||
|
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||||
|
downstreamHit = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
mw := conv.Dispatch(downstream)
|
||||||
|
|
||||||
|
u1 := msgUpd(42, 1, "/start")
|
||||||
|
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||||
|
|
||||||
|
// Random text doesn't match /done and there's no fallback → swallowed.
|
||||||
|
u2 := msgUpd(42, 1, "random")
|
||||||
|
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||||
|
require.False(t, downstreamHit, "active conv with no matching step should swallow")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Via Router.Run --------------------------------------------------------
|
||||||
|
|
||||||
|
func TestConversation_ViaRouter(t *testing.T) {
|
||||||
|
var steps atomic.Int32
|
||||||
|
conv := &conversation.Conversation{
|
||||||
|
EntryPoints: []conversation.Step{{
|
||||||
|
Filter: hasPrefix("/start"),
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
steps.Add(1)
|
||||||
|
return conversation.Next("await_name")
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
States: map[conversation.State][]conversation.Step{
|
||||||
|
"await_name": {{
|
||||||
|
Filter: anyMsg,
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
steps.Add(1)
|
||||||
|
return conversation.Next("await_age")
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
"await_age": {{
|
||||||
|
Filter: anyMsg,
|
||||||
|
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||||
|
steps.Add(1)
|
||||||
|
return conversation.End()
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := dispatch.New(client.New("t"), dispatch.WithMaxConcurrency(0)) // serial
|
||||||
|
router.Use(conv.Dispatch)
|
||||||
|
|
||||||
|
ups := []api.Update{
|
||||||
|
msgUpd(42, 1, "/start"),
|
||||||
|
msgUpd(42, 1, "Alice"),
|
||||||
|
msgUpd(42, 1, "30"),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() { errCh <- router.Run(ctx, newFake(ups...)) }()
|
||||||
|
|
||||||
|
// Wait for updater channel to drain (Run returns when closed).
|
||||||
|
err := <-errCh
|
||||||
|
if err != nil && err != context.Canceled {
|
||||||
|
t.Fatalf("Run error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, int32(3), steps.Load(), "all three steps should have fired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Concurrent storage safety ---------------------------------------------
|
||||||
|
|
||||||
|
func TestConversation_ConcurrentStorageAccess(t *testing.T) {
|
||||||
|
// 15 goroutines each running a full /start → name → age flow against the
|
||||||
|
// same shared storage but DIFFERENT keys (one per goroutine). Validates
|
||||||
|
// no data races.
|
||||||
|
const numUsers = 15
|
||||||
|
|
||||||
|
conv := buildConv()
|
||||||
|
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||||
|
mw := conv.Dispatch(noop)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numUsers)
|
||||||
|
for i := 0; i < numUsers; i++ {
|
||||||
|
go func(uid int64) {
|
||||||
|
defer wg.Done()
|
||||||
|
u1 := msgUpd(uid, uid, "/start")
|
||||||
|
_ = mw(makeCtx(&u1), &u1)
|
||||||
|
u2 := msgUpd(uid, uid, "Alice")
|
||||||
|
_ = mw(makeCtx(&u2), &u2)
|
||||||
|
u3 := msgUpd(uid, uid, "30")
|
||||||
|
_ = mw(makeCtx(&u3), &u3)
|
||||||
|
}(int64(i + 1))
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
// Race detector catches bugs; no assertion needed beyond clean finish.
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConversation_ConcurrentSameKey(t *testing.T) {
|
||||||
|
// 12 goroutines hammer the same key concurrently. Storage must not panic
|
||||||
|
// or corrupt state. Race detector validates lock discipline.
|
||||||
|
const goroutines = 12
|
||||||
|
s := conversation.NewMemoryStorage()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = s.Set(ctx, "shared", conversation.State("step"))
|
||||||
|
_, _ = s.Get(ctx, "shared")
|
||||||
|
if i%4 == 0 {
|
||||||
|
_ = s.Delete(ctx, "shared")
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stateTransition is a sentinel error type carrying a state transition
|
||||||
|
// or end signal. Conversation handlers return one of these (via Next or
|
||||||
|
// End helpers below) to drive the state machine.
|
||||||
|
type stateTransition struct {
|
||||||
|
next State
|
||||||
|
end bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *stateTransition) Error() string {
|
||||||
|
if e.end {
|
||||||
|
return "conversation: end"
|
||||||
|
}
|
||||||
|
return "conversation: → " + string(e.next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next signals the conversation should advance to the given state.
|
||||||
|
// Conversation handlers return Next("state_name") to transition.
|
||||||
|
func Next(s State) error {
|
||||||
|
return &stateTransition{next: s}
|
||||||
|
}
|
||||||
|
|
||||||
|
// End signals the conversation has finished and state should be cleared.
|
||||||
|
// Conversation handlers return End() to terminate.
|
||||||
|
func End() error {
|
||||||
|
return &stateTransition{end: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler defines a step in the conversation. Receives the dispatch context
|
||||||
|
// and the raw update. Returns:
|
||||||
|
// - nil to stay in the current state
|
||||||
|
// - Next("state") to transition to a different state
|
||||||
|
// - End() to end the conversation
|
||||||
|
// - any other non-nil error to surface to the dispatcher (state unchanged)
|
||||||
|
type Handler func(ctx *dispatch.Context, u *api.Update) error
|
||||||
|
|
||||||
|
// Step pairs a filter with a handler for one conversation step.
|
||||||
|
type Step struct {
|
||||||
|
Filter dispatch.Filter[*api.Update]
|
||||||
|
Handler Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conversation is a stateful handler with entry, per-state, exit and
|
||||||
|
// fallback steps. A conversation is keyed by KeyStrategy (default
|
||||||
|
// KeyByUserAndChat) and persisted by Storage (default in-memory).
|
||||||
|
type Conversation struct {
|
||||||
|
// EntryPoints starts a new conversation when a matching filter fires
|
||||||
|
// and no conversation is already active for the key.
|
||||||
|
EntryPoints []Step
|
||||||
|
|
||||||
|
// States maps each state to the steps that handle it.
|
||||||
|
States map[State][]Step
|
||||||
|
|
||||||
|
// Exits, if any match, end the active conversation early. Useful for
|
||||||
|
// /cancel-style commands.
|
||||||
|
Exits []Step
|
||||||
|
|
||||||
|
// Fallbacks run when no state step matches the current update.
|
||||||
|
Fallbacks []Step
|
||||||
|
|
||||||
|
// Storage persists conversation state. Defaults to NewMemoryStorage.
|
||||||
|
Storage Storage
|
||||||
|
|
||||||
|
// KeyStrategy derives the persistence key. Defaults to KeyByUserAndChat.
|
||||||
|
KeyStrategy KeyStrategy
|
||||||
|
|
||||||
|
// AllowReEntry, when true, lets entry-point steps fire even while a
|
||||||
|
// conversation is already active for the key (effectively restarting it).
|
||||||
|
AllowReEntry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch is a global middleware-shaped Handler that consumes updates
|
||||||
|
// and routes them through the conversation graph. Register via
|
||||||
|
// router.Use(conv.Dispatch).
|
||||||
|
//
|
||||||
|
// If the conversation claims an update, downstream handlers are skipped.
|
||||||
|
// If the conversation does not claim it, downstream handlers run as normal.
|
||||||
|
func (c *Conversation) Dispatch(next dispatch.Handler[*api.Update]) dispatch.Handler[*api.Update] {
|
||||||
|
if c.Storage == nil {
|
||||||
|
c.Storage = NewMemoryStorage()
|
||||||
|
}
|
||||||
|
if c.KeyStrategy == nil {
|
||||||
|
c.KeyStrategy = KeyByUserAndChat
|
||||||
|
}
|
||||||
|
return func(dctx *dispatch.Context, u *api.Update) error {
|
||||||
|
key := c.KeyStrategy(u)
|
||||||
|
if key == "" {
|
||||||
|
return next(dctx, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := dctx.Ctx
|
||||||
|
current, err := c.Storage.Get(ctx, key)
|
||||||
|
if err != nil && !errors.Is(err, ErrKeyNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
active := !errors.Is(err, ErrKeyNotFound)
|
||||||
|
|
||||||
|
// Try exits first (always allowed if conversation is active).
|
||||||
|
if active {
|
||||||
|
for _, step := range c.Exits {
|
||||||
|
if step.Filter(u) {
|
||||||
|
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try entry points (only if no active conversation, or AllowReEntry).
|
||||||
|
if !active || c.AllowReEntry {
|
||||||
|
for _, step := range c.EntryPoints {
|
||||||
|
if step.Filter(u) {
|
||||||
|
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !active {
|
||||||
|
return next(dctx, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Active conversation: try state steps.
|
||||||
|
for _, step := range c.States[current] {
|
||||||
|
if step.Filter(u) {
|
||||||
|
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallbacks if no state step matched.
|
||||||
|
for _, step := range c.Fallbacks {
|
||||||
|
if step.Filter(u) {
|
||||||
|
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Active conversation but no step matched and no fallback: swallow the
|
||||||
|
// update (do NOT pass to downstream handlers, since the user is
|
||||||
|
// mid-conversation and an unrelated handler would surprise them).
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runStep invokes the handler and applies its return-value state transition.
|
||||||
|
func (c *Conversation) runStep(ctx context.Context, dctx *dispatch.Context, u *api.Update, key string, h Handler) error {
|
||||||
|
err := h(dctx, u)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var trans *stateTransition
|
||||||
|
if errors.As(err, &trans) {
|
||||||
|
if trans.end {
|
||||||
|
return c.Storage.Delete(ctx, key)
|
||||||
|
}
|
||||||
|
return c.Storage.Set(ctx, key, trans.next)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryStorage is the default in-process Storage. It is safe for
|
||||||
|
// concurrent use. Conversation state is lost on process restart; use
|
||||||
|
// a custom Storage backed by a database for persistent flows.
|
||||||
|
type MemoryStorage struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
state map[string]State
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryStorage constructs an empty in-memory storage.
|
||||||
|
func NewMemoryStorage() *MemoryStorage {
|
||||||
|
return &MemoryStorage{state: map[string]State{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStorage) Get(_ context.Context, key string) (State, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
v, ok := s.state[key]
|
||||||
|
if !ok {
|
||||||
|
return "", ErrKeyNotFound
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStorage) Set(_ context.Context, key string, state State) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.state[key] = state
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStorage) Delete(_ context.Context, key string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.state, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMemoryStorage_GetSetDelete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewMemoryStorage()
|
||||||
|
|
||||||
|
// Get on empty key returns ErrKeyNotFound.
|
||||||
|
_, err := s.Get(ctx, "k1")
|
||||||
|
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||||
|
|
||||||
|
// Set then Get returns the stored state.
|
||||||
|
require.NoError(t, s.Set(ctx, "k1", "step_a"))
|
||||||
|
v, err := s.Get(ctx, "k1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, State("step_a"), v)
|
||||||
|
|
||||||
|
// Overwrite works.
|
||||||
|
require.NoError(t, s.Set(ctx, "k1", "step_b"))
|
||||||
|
v, err = s.Get(ctx, "k1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, State("step_b"), v)
|
||||||
|
|
||||||
|
// Delete removes the key.
|
||||||
|
require.NoError(t, s.Delete(ctx, "k1"))
|
||||||
|
_, err = s.Get(ctx, "k1")
|
||||||
|
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||||
|
|
||||||
|
// Delete of non-existent key is a no-op (no error).
|
||||||
|
require.NoError(t, s.Delete(ctx, "nonexistent"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryStorage_MultipleKeys(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewMemoryStorage()
|
||||||
|
|
||||||
|
require.NoError(t, s.Set(ctx, "a", "stateA"))
|
||||||
|
require.NoError(t, s.Set(ctx, "b", "stateB"))
|
||||||
|
|
||||||
|
va, err := s.Get(ctx, "a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, State("stateA"), va)
|
||||||
|
|
||||||
|
vb, err := s.Get(ctx, "b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, State("stateB"), vb)
|
||||||
|
|
||||||
|
// Delete one key; the other remains.
|
||||||
|
require.NoError(t, s.Delete(ctx, "a"))
|
||||||
|
_, err = s.Get(ctx, "a")
|
||||||
|
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||||
|
|
||||||
|
vb, err = s.Get(ctx, "b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, State("stateB"), vb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryStorage_Concurrent(t *testing.T) {
|
||||||
|
// 20 goroutines hammering the same key concurrently — no data race.
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewMemoryStorage()
|
||||||
|
|
||||||
|
const goroutines = 20
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
key := "shared"
|
||||||
|
_ = s.Set(ctx, key, State("step"))
|
||||||
|
_, _ = s.Get(ctx, key)
|
||||||
|
if i%3 == 0 {
|
||||||
|
_ = s.Delete(ctx, key)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
// No assertion needed — race detector catches the bug if present.
|
||||||
|
}
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyStrategy derives a persistence key from an update. Strategies
|
||||||
|
// determine how conversation scope works — per-user, per-chat, or
|
||||||
|
// per-user-and-chat. Implementations must return a stable string for
|
||||||
|
// the same logical scope across updates.
|
||||||
|
//
|
||||||
|
// Returns the empty string if the update doesn't have enough context
|
||||||
|
// to derive a key (in which case the conversation handler skips it).
|
||||||
|
type KeyStrategy func(u *api.Update) string
|
||||||
|
|
||||||
|
// KeyByUser derives a key from the sending user's ID. Useful for DM
|
||||||
|
// conversations and any flow that should follow the user across chats.
|
||||||
|
var KeyByUser KeyStrategy = func(u *api.Update) string {
|
||||||
|
if uid := userID(u); uid != 0 {
|
||||||
|
return fmt.Sprintf("u:%d", uid)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyByChat derives a key from the chat ID. Useful for group flows where
|
||||||
|
// any user in the chat can drive the conversation.
|
||||||
|
var KeyByChat KeyStrategy = func(u *api.Update) string {
|
||||||
|
if cid := chatID(u); cid != 0 {
|
||||||
|
return fmt.Sprintf("c:%d", cid)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyByUserAndChat derives a key from both user and chat IDs. The most
|
||||||
|
// common strategy: each user has their own conversation per chat.
|
||||||
|
var KeyByUserAndChat KeyStrategy = func(u *api.Update) string {
|
||||||
|
uid := userID(u)
|
||||||
|
cid := chatID(u)
|
||||||
|
if uid == 0 || cid == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("uc:%d:%d", cid, uid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// userID extracts the sending user's ID from any update payload.
|
||||||
|
func userID(u *api.Update) int64 {
|
||||||
|
switch {
|
||||||
|
case u.Message != nil && u.Message.From != nil:
|
||||||
|
return u.Message.From.ID
|
||||||
|
case u.EditedMessage != nil && u.EditedMessage.From != nil:
|
||||||
|
return u.EditedMessage.From.ID
|
||||||
|
case u.CallbackQuery != nil:
|
||||||
|
return u.CallbackQuery.From.ID
|
||||||
|
case u.InlineQuery != nil:
|
||||||
|
return u.InlineQuery.From.ID
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatID extracts the relevant chat ID.
|
||||||
|
func chatID(u *api.Update) int64 {
|
||||||
|
switch {
|
||||||
|
case u.Message != nil:
|
||||||
|
return u.Message.Chat.ID
|
||||||
|
case u.EditedMessage != nil:
|
||||||
|
return u.EditedMessage.Chat.ID
|
||||||
|
case u.ChannelPost != nil:
|
||||||
|
return u.ChannelPost.Chat.ID
|
||||||
|
case u.EditedChannelPost != nil:
|
||||||
|
return u.EditedChannelPost.Chat.ID
|
||||||
|
case u.CallbackQuery != nil && u.CallbackQuery.Message != nil:
|
||||||
|
if msg, ok := u.CallbackQuery.Message.(*api.Message); ok {
|
||||||
|
return msg.Chat.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// helpers to build api.Update variants.
|
||||||
|
|
||||||
|
func msgUpdate(userID, chatID int64) *api.Update {
|
||||||
|
return &api.Update{
|
||||||
|
Message: &api.Message{
|
||||||
|
From: &api.User{ID: userID},
|
||||||
|
Chat: api.Chat{ID: chatID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func editedMsgUpdate(userID, chatID int64) *api.Update {
|
||||||
|
return &api.Update{
|
||||||
|
EditedMessage: &api.Message{
|
||||||
|
From: &api.User{ID: userID},
|
||||||
|
Chat: api.Chat{ID: chatID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackUpdate(userID, chatID int64) *api.Update {
|
||||||
|
return &api.Update{
|
||||||
|
CallbackQuery: &api.CallbackQuery{
|
||||||
|
From: api.User{ID: userID},
|
||||||
|
Message: &api.Message{Chat: api.Chat{ID: chatID}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func inlineUpdate(userID int64) *api.Update {
|
||||||
|
return &api.Update{
|
||||||
|
InlineQuery: &api.InlineQuery{
|
||||||
|
From: api.User{ID: userID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func emptyUpdate() *api.Update { return &api.Update{} }
|
||||||
|
|
||||||
|
func TestKeyByUser(t *testing.T) {
|
||||||
|
t.Run("message update", func(t *testing.T) {
|
||||||
|
require.Equal(t, "u:42", KeyByUser(msgUpdate(42, 100)))
|
||||||
|
})
|
||||||
|
t.Run("edited message", func(t *testing.T) {
|
||||||
|
require.Equal(t, "u:7", KeyByUser(editedMsgUpdate(7, 100)))
|
||||||
|
})
|
||||||
|
t.Run("callback query", func(t *testing.T) {
|
||||||
|
require.Equal(t, "u:99", KeyByUser(callbackUpdate(99, 100)))
|
||||||
|
})
|
||||||
|
t.Run("inline query", func(t *testing.T) {
|
||||||
|
require.Equal(t, "u:5", KeyByUser(inlineUpdate(5)))
|
||||||
|
})
|
||||||
|
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||||
|
require.Equal(t, "", KeyByUser(emptyUpdate()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyByChat(t *testing.T) {
|
||||||
|
t.Run("message update", func(t *testing.T) {
|
||||||
|
require.Equal(t, "c:100", KeyByChat(msgUpdate(42, 100)))
|
||||||
|
})
|
||||||
|
t.Run("edited message", func(t *testing.T) {
|
||||||
|
require.Equal(t, "c:200", KeyByChat(editedMsgUpdate(7, 200)))
|
||||||
|
})
|
||||||
|
t.Run("callback with accessible message", func(t *testing.T) {
|
||||||
|
require.Equal(t, "c:300", KeyByChat(callbackUpdate(99, 300)))
|
||||||
|
})
|
||||||
|
t.Run("inline query has no chat → empty", func(t *testing.T) {
|
||||||
|
require.Equal(t, "", KeyByChat(inlineUpdate(5)))
|
||||||
|
})
|
||||||
|
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||||
|
require.Equal(t, "", KeyByChat(emptyUpdate()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyByUserAndChat(t *testing.T) {
|
||||||
|
t.Run("message update", func(t *testing.T) {
|
||||||
|
require.Equal(t, "uc:100:42", KeyByUserAndChat(msgUpdate(42, 100)))
|
||||||
|
})
|
||||||
|
t.Run("edited message", func(t *testing.T) {
|
||||||
|
require.Equal(t, "uc:200:7", KeyByUserAndChat(editedMsgUpdate(7, 200)))
|
||||||
|
})
|
||||||
|
t.Run("callback query", func(t *testing.T) {
|
||||||
|
require.Equal(t, "uc:300:99", KeyByUserAndChat(callbackUpdate(99, 300)))
|
||||||
|
})
|
||||||
|
t.Run("inline query has no chat → empty", func(t *testing.T) {
|
||||||
|
require.Equal(t, "", KeyByUserAndChat(inlineUpdate(5)))
|
||||||
|
})
|
||||||
|
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||||
|
require.Equal(t, "", KeyByUserAndChat(emptyUpdate()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyByUserAndChat_CallbackInaccessibleMessage(t *testing.T) {
|
||||||
|
// CallbackQuery.Message is InaccessibleMessage (not *Message) — chatID returns 0.
|
||||||
|
u := &api.Update{
|
||||||
|
CallbackQuery: &api.CallbackQuery{
|
||||||
|
From: api.User{ID: 10},
|
||||||
|
Message: &api.InaccessibleMessage{}, // implements MaybeInaccessibleMessage, not *api.Message
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// userID picks up From.ID=10 but chatID fails type assertion → 0
|
||||||
|
require.Equal(t, "", KeyByUserAndChat(u), "no key when message inaccessible")
|
||||||
|
// KeyByUser still works since From is set.
|
||||||
|
require.Equal(t, "u:10", KeyByUser(u))
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
// Package conversation implements a stateful conversation handler for the
|
||||||
|
// go-telegram dispatch router. It provides a state-machine abstraction over
|
||||||
|
// multi-step Telegram bot interactions, with pluggable storage and flexible
|
||||||
|
// key strategies.
|
||||||
|
package conversation
|
||||||
|
|
||||||
|
// State is a label identifying a node in the conversation graph.
|
||||||
|
// The empty string is the implicit "no active conversation" state.
|
||||||
|
type State string
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrKeyNotFound is returned by Storage.Get when no conversation is active
|
||||||
|
// for the given key.
|
||||||
|
var ErrKeyNotFound = errors.New("conversation: key not found")
|
||||||
|
|
||||||
|
// Storage persists per-user (or per-chat, per-message — depending on the
|
||||||
|
// KeyStrategy in use) conversation state across update deliveries.
|
||||||
|
//
|
||||||
|
// Implementations must be safe for concurrent use.
|
||||||
|
type Storage interface {
|
||||||
|
Get(ctx context.Context, key string) (State, error)
|
||||||
|
Set(ctx context.Context, key string, state State) error
|
||||||
|
Delete(ctx context.Context, key string) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
// Filter is a predicate over a typed payload (e.g. *api.Message). Filters
|
||||||
|
// compose via And/Or/Not for multi-condition matching.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// f := message.HasPhoto().And(message.InChat(-100123456789))
|
||||||
|
type Filter[T any] func(payload T) bool
|
||||||
|
|
||||||
|
// And returns a Filter that matches iff f and every one of others matches.
|
||||||
|
func (f Filter[T]) And(others ...Filter[T]) Filter[T] {
|
||||||
|
return func(payload T) bool {
|
||||||
|
if !f(payload) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, o := range others {
|
||||||
|
if !o(payload) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or returns a Filter that matches iff f matches OR any of others matches.
|
||||||
|
func (f Filter[T]) Or(others ...Filter[T]) Filter[T] {
|
||||||
|
return func(payload T) bool {
|
||||||
|
if f(payload) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, o := range others {
|
||||||
|
if o(payload) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not returns a Filter that inverts f.
|
||||||
|
func (f Filter[T]) Not() Filter[T] {
|
||||||
|
return func(payload T) bool { return !f(payload) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// All combines filters with AND. Returns a Filter that matches when all match.
|
||||||
|
// Returns a filter that always matches when filters is empty.
|
||||||
|
func All[T any](filters ...Filter[T]) Filter[T] {
|
||||||
|
return func(payload T) bool {
|
||||||
|
for _, f := range filters {
|
||||||
|
if !f(payload) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any combines filters with OR. Returns a Filter that matches when at least
|
||||||
|
// one matches. Returns a filter that never matches when filters is empty.
|
||||||
|
func Any[T any](filters ...Filter[T]) Filter[T] {
|
||||||
|
return func(payload T) bool {
|
||||||
|
for _, f := range filters {
|
||||||
|
if f(payload) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func alwaysTrue[T any]() Filter[T] { return func(_ T) bool { return true } }
|
||||||
|
func alwaysFalse[T any]() Filter[T] { return func(_ T) bool { return false } }
|
||||||
|
|
||||||
|
func TestFilter_And(t *testing.T) {
|
||||||
|
t.Run("all true", func(t *testing.T) {
|
||||||
|
f := alwaysTrue[int]().And(alwaysTrue[int](), alwaysTrue[int]())
|
||||||
|
require.True(t, f(0))
|
||||||
|
})
|
||||||
|
t.Run("first false", func(t *testing.T) {
|
||||||
|
f := alwaysFalse[int]().And(alwaysTrue[int]())
|
||||||
|
require.False(t, f(0))
|
||||||
|
})
|
||||||
|
t.Run("other false", func(t *testing.T) {
|
||||||
|
f := alwaysTrue[int]().And(alwaysFalse[int]())
|
||||||
|
require.False(t, f(0))
|
||||||
|
})
|
||||||
|
t.Run("no others — acts as identity", func(t *testing.T) {
|
||||||
|
require.True(t, alwaysTrue[int]().And()(0))
|
||||||
|
require.False(t, alwaysFalse[int]().And()(0))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilter_Or(t *testing.T) {
|
||||||
|
t.Run("first true", func(t *testing.T) {
|
||||||
|
f := alwaysTrue[int]().Or(alwaysFalse[int]())
|
||||||
|
require.True(t, f(0))
|
||||||
|
})
|
||||||
|
t.Run("other true", func(t *testing.T) {
|
||||||
|
f := alwaysFalse[int]().Or(alwaysTrue[int]())
|
||||||
|
require.True(t, f(0))
|
||||||
|
})
|
||||||
|
t.Run("all false", func(t *testing.T) {
|
||||||
|
f := alwaysFalse[int]().Or(alwaysFalse[int]())
|
||||||
|
require.False(t, f(0))
|
||||||
|
})
|
||||||
|
t.Run("no others", func(t *testing.T) {
|
||||||
|
require.True(t, alwaysTrue[int]().Or()(0))
|
||||||
|
require.False(t, alwaysFalse[int]().Or()(0))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilter_Not(t *testing.T) {
|
||||||
|
require.False(t, alwaysTrue[int]().Not()(0))
|
||||||
|
require.True(t, alwaysFalse[int]().Not()(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAll(t *testing.T) {
|
||||||
|
t.Run("all true", func(t *testing.T) {
|
||||||
|
require.True(t, All(alwaysTrue[int](), alwaysTrue[int]())(0))
|
||||||
|
})
|
||||||
|
t.Run("one false", func(t *testing.T) {
|
||||||
|
require.False(t, All(alwaysTrue[int](), alwaysFalse[int]())(0))
|
||||||
|
})
|
||||||
|
t.Run("empty — always true", func(t *testing.T) {
|
||||||
|
require.True(t, All[int]()(0))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAny(t *testing.T) {
|
||||||
|
t.Run("one true", func(t *testing.T) {
|
||||||
|
require.True(t, Any(alwaysFalse[int](), alwaysTrue[int]())(0))
|
||||||
|
})
|
||||||
|
t.Run("all false", func(t *testing.T) {
|
||||||
|
require.False(t, Any(alwaysFalse[int](), alwaysFalse[int]())(0))
|
||||||
|
})
|
||||||
|
t.Run("empty — always false", func(t *testing.T) {
|
||||||
|
require.False(t, Any[int]()(0))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilter_Composition(t *testing.T) {
|
||||||
|
// (true AND false) OR true == true
|
||||||
|
f := alwaysTrue[int]().And(alwaysFalse[int]()).Or(alwaysTrue[int]())
|
||||||
|
require.True(t, f(0))
|
||||||
|
|
||||||
|
// NOT (true OR false) == false
|
||||||
|
g := alwaysTrue[int]().Or(alwaysFalse[int]()).Not()
|
||||||
|
require.False(t, g(0))
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
// Package callback provides Filter helpers for *api.CallbackQuery payloads.
|
||||||
|
package callback
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Data returns a Filter that matches callback queries whose Data matches
|
||||||
|
// pattern (regex). Panics at registration time on an invalid pattern.
|
||||||
|
func Data(pattern string) dispatch.Filter[*api.CallbackQuery] {
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
return func(q *api.CallbackQuery) bool {
|
||||||
|
return q != nil && re.MatchString(q.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DataEquals returns a Filter that matches callback queries whose Data equals
|
||||||
|
// s exactly.
|
||||||
|
func DataEquals(s string) dispatch.Filter[*api.CallbackQuery] {
|
||||||
|
return func(q *api.CallbackQuery) bool {
|
||||||
|
return q != nil && q.Data == s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DataPrefix returns a Filter that matches callback queries whose Data starts
|
||||||
|
// with prefix.
|
||||||
|
func DataPrefix(prefix string) dispatch.Filter[*api.CallbackQuery] {
|
||||||
|
return func(q *api.CallbackQuery) bool {
|
||||||
|
return q != nil && strings.HasPrefix(q.Data, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromUser returns a Filter that matches callback queries whose From.ID equals
|
||||||
|
// userID.
|
||||||
|
func FromUser(userID int64) dispatch.Filter[*api.CallbackQuery] {
|
||||||
|
return func(q *api.CallbackQuery) bool {
|
||||||
|
return q != nil && q.From.ID == userID
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package callback_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
cbfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/callback"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cq(data string, userID int64) *api.CallbackQuery {
|
||||||
|
return &api.CallbackQuery{
|
||||||
|
ID: "q",
|
||||||
|
From: api.User{ID: userID},
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestData(t *testing.T) {
|
||||||
|
f := cbfilter.Data(`^like:\d+$`)
|
||||||
|
require.True(t, f(cq("like:42", 1)))
|
||||||
|
require.False(t, f(cq("dislike:42", 1)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestData_PanicsOnBadPattern(t *testing.T) {
|
||||||
|
require.Panics(t, func() { cbfilter.Data(`[bad`) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataEquals(t *testing.T) {
|
||||||
|
f := cbfilter.DataEquals("yes")
|
||||||
|
require.True(t, f(cq("yes", 1)))
|
||||||
|
require.False(t, f(cq("yes please", 1)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataPrefix(t *testing.T) {
|
||||||
|
f := cbfilter.DataPrefix("vote:")
|
||||||
|
require.True(t, f(cq("vote:up", 1)))
|
||||||
|
require.False(t, f(cq("novote:up", 1)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromUser(t *testing.T) {
|
||||||
|
f := cbfilter.FromUser(7)
|
||||||
|
require.True(t, f(cq("data", 7)))
|
||||||
|
require.False(t, f(cq("data", 8)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposedCallbackFilters(t *testing.T) {
|
||||||
|
f := cbfilter.DataPrefix("vote:").And(cbfilter.FromUser(7))
|
||||||
|
require.True(t, f(cq("vote:up", 7)))
|
||||||
|
require.False(t, f(cq("vote:up", 8)))
|
||||||
|
require.False(t, f(cq("other", 7)))
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
// Package chatjoinrequest provides Filter helpers for *api.ChatJoinRequest payloads.
|
||||||
|
package chatjoinrequest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FromUser returns a Filter that matches join requests where the requesting
|
||||||
|
// user's ID equals uid.
|
||||||
|
func FromUser(uid int64) dispatch.Filter[*api.ChatJoinRequest] {
|
||||||
|
return func(r *api.ChatJoinRequest) bool {
|
||||||
|
return r != nil && r.From.ID == uid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InChat returns a Filter that matches join requests directed at the chat
|
||||||
|
// with the given chat ID.
|
||||||
|
func InChat(cid int64) dispatch.Filter[*api.ChatJoinRequest] {
|
||||||
|
return func(r *api.ChatJoinRequest) bool {
|
||||||
|
return r != nil && r.Chat.ID == cid
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package chatjoinrequest_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
cjrfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatjoinrequest"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func joinRequest(fromID, chatID int64) *api.ChatJoinRequest {
|
||||||
|
return &api.ChatJoinRequest{
|
||||||
|
Chat: api.Chat{ID: chatID},
|
||||||
|
From: api.User{ID: fromID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromUser_Matches(t *testing.T) {
|
||||||
|
f := cjrfilter.FromUser(10)
|
||||||
|
require.True(t, f(joinRequest(10, 100)))
|
||||||
|
require.False(t, f(joinRequest(99, 100)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInChat_Matches(t *testing.T) {
|
||||||
|
f := cjrfilter.InChat(100)
|
||||||
|
require.True(t, f(joinRequest(10, 100)))
|
||||||
|
require.False(t, f(joinRequest(10, 200)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposedFilters(t *testing.T) {
|
||||||
|
f := cjrfilter.FromUser(10).And(cjrfilter.InChat(100))
|
||||||
|
require.True(t, f(joinRequest(10, 100)))
|
||||||
|
require.False(t, f(joinRequest(10, 200)))
|
||||||
|
require.False(t, f(joinRequest(99, 100)))
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
// Package chatmember provides Filter helpers for *api.ChatMemberUpdated payloads.
|
||||||
|
package chatmember
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewStatus returns a Filter that matches updates where the new chat member
|
||||||
|
// status equals s (e.g. "member", "administrator", "kicked", "left").
|
||||||
|
func NewStatus(s string) dispatch.Filter[*api.ChatMemberUpdated] {
|
||||||
|
return func(u *api.ChatMemberUpdated) bool {
|
||||||
|
if u == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch m := u.NewChatMember.(type) {
|
||||||
|
case *api.ChatMemberOwner:
|
||||||
|
return m.Status == s
|
||||||
|
case *api.ChatMemberAdministrator:
|
||||||
|
return m.Status == s
|
||||||
|
case *api.ChatMemberMember:
|
||||||
|
return m.Status == s
|
||||||
|
case *api.ChatMemberRestricted:
|
||||||
|
return m.Status == s
|
||||||
|
case *api.ChatMemberLeft:
|
||||||
|
return m.Status == s
|
||||||
|
case *api.ChatMemberBanned:
|
||||||
|
return m.Status == s
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromUser returns a Filter that matches updates where the acting user
|
||||||
|
// (From.ID) equals uid.
|
||||||
|
func FromUser(uid int64) dispatch.Filter[*api.ChatMemberUpdated] {
|
||||||
|
return func(u *api.ChatMemberUpdated) bool {
|
||||||
|
return u != nil && u.From.ID == uid
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package chatmember_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
cmfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatmember"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func memberUpdate(status string, fromID int64) *api.ChatMemberUpdated {
|
||||||
|
var newMember api.ChatMember
|
||||||
|
switch status {
|
||||||
|
case "member":
|
||||||
|
newMember = &api.ChatMemberMember{Status: status}
|
||||||
|
case "administrator":
|
||||||
|
newMember = &api.ChatMemberAdministrator{Status: status}
|
||||||
|
case "kicked":
|
||||||
|
newMember = &api.ChatMemberBanned{Status: status}
|
||||||
|
case "left":
|
||||||
|
newMember = &api.ChatMemberLeft{Status: status}
|
||||||
|
default:
|
||||||
|
newMember = &api.ChatMemberMember{Status: status}
|
||||||
|
}
|
||||||
|
return &api.ChatMemberUpdated{
|
||||||
|
From: api.User{ID: fromID},
|
||||||
|
NewChatMember: newMember,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_Matches(t *testing.T) {
|
||||||
|
f := cmfilter.NewStatus("member")
|
||||||
|
require.True(t, f(memberUpdate("member", 1)))
|
||||||
|
require.False(t, f(memberUpdate("kicked", 1)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_Administrator(t *testing.T) {
|
||||||
|
f := cmfilter.NewStatus("administrator")
|
||||||
|
require.True(t, f(memberUpdate("administrator", 1)))
|
||||||
|
require.False(t, f(memberUpdate("member", 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_Kicked(t *testing.T) {
|
||||||
|
f := cmfilter.NewStatus("kicked")
|
||||||
|
require.True(t, f(memberUpdate("kicked", 1)))
|
||||||
|
require.False(t, f(memberUpdate("left", 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_Left(t *testing.T) {
|
||||||
|
f := cmfilter.NewStatus("left")
|
||||||
|
require.True(t, f(memberUpdate("left", 1)))
|
||||||
|
require.False(t, f(memberUpdate("member", 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromUser_Matches(t *testing.T) {
|
||||||
|
f := cmfilter.FromUser(42)
|
||||||
|
require.True(t, f(memberUpdate("member", 42)))
|
||||||
|
require.False(t, f(memberUpdate("member", 99)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposedFilters(t *testing.T) {
|
||||||
|
f := cmfilter.NewStatus("member").And(cmfilter.FromUser(7))
|
||||||
|
require.True(t, f(memberUpdate("member", 7)))
|
||||||
|
require.False(t, f(memberUpdate("member", 8)))
|
||||||
|
require.False(t, f(memberUpdate("kicked", 7)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_Owner(t *testing.T) {
|
||||||
|
u := &api.ChatMemberUpdated{
|
||||||
|
From: api.User{ID: 1},
|
||||||
|
NewChatMember: &api.ChatMemberOwner{Status: "creator"},
|
||||||
|
}
|
||||||
|
require.True(t, cmfilter.NewStatus("creator")(u))
|
||||||
|
require.False(t, cmfilter.NewStatus("member")(u))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_Restricted(t *testing.T) {
|
||||||
|
u := &api.ChatMemberUpdated{
|
||||||
|
From: api.User{ID: 1},
|
||||||
|
NewChatMember: &api.ChatMemberRestricted{Status: "restricted"},
|
||||||
|
}
|
||||||
|
require.True(t, cmfilter.NewStatus("restricted")(u))
|
||||||
|
require.False(t, cmfilter.NewStatus("member")(u))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStatus_UnknownType(t *testing.T) {
|
||||||
|
// nil NewChatMember → default branch → false
|
||||||
|
u := &api.ChatMemberUpdated{
|
||||||
|
From: api.User{ID: 1},
|
||||||
|
NewChatMember: nil,
|
||||||
|
}
|
||||||
|
require.False(t, cmfilter.NewStatus("member")(u))
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
// Package inline provides Filter helpers for *api.InlineQuery payloads.
|
||||||
|
package inline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Query returns a Filter that matches inline queries whose Query field matches
|
||||||
|
// pattern (regex). Panics at registration time on an invalid pattern.
|
||||||
|
func Query(pattern string) dispatch.Filter[*api.InlineQuery] {
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
return func(q *api.InlineQuery) bool {
|
||||||
|
return q != nil && re.MatchString(q.Query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryEquals returns a Filter that matches inline queries whose Query equals
|
||||||
|
// s exactly.
|
||||||
|
func QueryEquals(s string) dispatch.Filter[*api.InlineQuery] {
|
||||||
|
return func(q *api.InlineQuery) bool {
|
||||||
|
return q != nil && q.Query == s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryPrefix returns a Filter that matches inline queries whose Query starts
|
||||||
|
// with prefix.
|
||||||
|
func QueryPrefix(prefix string) dispatch.Filter[*api.InlineQuery] {
|
||||||
|
return func(q *api.InlineQuery) bool {
|
||||||
|
return q != nil && strings.HasPrefix(q.Query, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package inline_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
ilfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/inline"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func iq(query string) *api.InlineQuery {
|
||||||
|
return &api.InlineQuery{ID: "i", From: api.User{ID: 1}, Query: query}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery(t *testing.T) {
|
||||||
|
f := ilfilter.Query(`^find`)
|
||||||
|
require.True(t, f(iq("find me")))
|
||||||
|
require.False(t, f(iq("search me")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery_PanicsOnBadPattern(t *testing.T) {
|
||||||
|
require.Panics(t, func() { ilfilter.Query(`[bad`) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryEquals(t *testing.T) {
|
||||||
|
f := ilfilter.QueryEquals("exact")
|
||||||
|
require.True(t, f(iq("exact")))
|
||||||
|
require.False(t, f(iq("exact match")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryPrefix(t *testing.T) {
|
||||||
|
f := ilfilter.QueryPrefix("@user")
|
||||||
|
require.True(t, f(iq("@username")))
|
||||||
|
require.False(t, f(iq("no prefix")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposedInlineFilters(t *testing.T) {
|
||||||
|
f := ilfilter.QueryPrefix("find").Or(ilfilter.QueryEquals("help"))
|
||||||
|
require.True(t, f(iq("find me")))
|
||||||
|
require.True(t, f(iq("help")))
|
||||||
|
require.False(t, f(iq("other")))
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
// Package message provides Filter helpers for *api.Message payloads.
|
||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Text returns a Filter that matches messages whose Text matches pattern (regex).
|
||||||
|
// Panics at registration time on an invalid pattern.
|
||||||
|
func Text(pattern string) dispatch.Filter[*api.Message] {
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && re.MatchString(m.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextEquals returns a Filter that matches messages whose Text equals s exactly.
|
||||||
|
func TextEquals(s string) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.Text == s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextPrefix returns a Filter that matches messages whose Text starts with prefix.
|
||||||
|
func TextPrefix(prefix string) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && strings.HasPrefix(m.Text, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextContains returns a Filter that matches messages whose Text contains sub.
|
||||||
|
func TextContains(sub string) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && strings.Contains(m.Text, sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Command returns a Filter that matches messages whose first entity is a
|
||||||
|
// bot_command equal to "/<name>" (with or without "@BotName" suffix).
|
||||||
|
func Command(name string) dispatch.Filter[*api.Message] {
|
||||||
|
want := "/" + strings.TrimPrefix(name, "/")
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
if m == nil || len(m.Entities) == 0 || m.Text == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
first := m.Entities[0]
|
||||||
|
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
end := int(first.Length)
|
||||||
|
runes := []rune(m.Text)
|
||||||
|
if end > len(runes) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cmd := string(runes[:end])
|
||||||
|
if i := strings.Index(cmd, "@"); i >= 0 {
|
||||||
|
cmd = cmd[:i]
|
||||||
|
}
|
||||||
|
return cmd == want
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnyCommand returns a Filter that matches any message starting with a
|
||||||
|
// bot_command entity at offset 0.
|
||||||
|
func AnyCommand() dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
if m == nil || len(m.Entities) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
first := m.Entities[0]
|
||||||
|
return first.Type == string(api.EntityBotCommand) && first.Offset == 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReply returns a Filter that matches messages that have ReplyToMessage set.
|
||||||
|
func IsReply() dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.ReplyToMessage != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsForward returns a Filter that matches messages that have ForwardOrigin set.
|
||||||
|
func IsForward() dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.ForwardOrigin != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasPhoto returns a Filter that matches messages with a Photo attachment.
|
||||||
|
func HasPhoto() dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && len(m.Photo) > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasDocument returns a Filter that matches messages with a Document attachment.
|
||||||
|
func HasDocument() dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.Document != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasEntity returns a Filter that matches messages whose Entities contain at
|
||||||
|
// least one entity of type t (e.g. string(api.EntityBotCommand)).
|
||||||
|
func HasEntity(t string) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
if m == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, e := range m.Entities {
|
||||||
|
if e.Type == t {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatType returns a Filter that matches messages whose Chat.Type equals t.
|
||||||
|
func ChatType(t api.ChatType) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.Chat.Type == string(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromUser returns a Filter that matches messages whose From.ID equals userID.
|
||||||
|
func FromUser(userID int64) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.From != nil && m.From.ID == userID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InChat returns a Filter that matches messages whose Chat.ID equals chatID.
|
||||||
|
func InChat(chatID int64) dispatch.Filter[*api.Message] {
|
||||||
|
return func(m *api.Message) bool {
|
||||||
|
return m != nil && m.Chat.ID == chatID
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
package message_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
msgfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/message"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func msg(text string) *api.Message {
|
||||||
|
return &api.Message{
|
||||||
|
MessageID: 1,
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdMsg(cmd string) *api.Message {
|
||||||
|
text := cmd
|
||||||
|
return &api.Message{
|
||||||
|
MessageID: 1,
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: text,
|
||||||
|
Entities: []api.MessageEntity{
|
||||||
|
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len([]rune(text)))},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestText(t *testing.T) {
|
||||||
|
f := msgfilter.Text(`^hello`)
|
||||||
|
require.True(t, f(msg("hello world")))
|
||||||
|
require.False(t, f(msg("world hello")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestText_PanicsOnBadPattern(t *testing.T) {
|
||||||
|
require.Panics(t, func() { msgfilter.Text(`[invalid`) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextEquals(t *testing.T) {
|
||||||
|
f := msgfilter.TextEquals("hi")
|
||||||
|
require.True(t, f(msg("hi")))
|
||||||
|
require.False(t, f(msg("hi there")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextPrefix(t *testing.T) {
|
||||||
|
f := msgfilter.TextPrefix("/start")
|
||||||
|
require.True(t, f(msg("/start now")))
|
||||||
|
require.False(t, f(msg("no prefix")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextContains(t *testing.T) {
|
||||||
|
f := msgfilter.TextContains("bot")
|
||||||
|
require.True(t, f(msg("my bot is cool")))
|
||||||
|
require.False(t, f(msg("nothing here")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommand(t *testing.T) {
|
||||||
|
t.Run("matches exact command", func(t *testing.T) {
|
||||||
|
f := msgfilter.Command("/start")
|
||||||
|
require.True(t, f(cmdMsg("/start")))
|
||||||
|
})
|
||||||
|
t.Run("matches without leading slash", func(t *testing.T) {
|
||||||
|
f := msgfilter.Command("start")
|
||||||
|
require.True(t, f(cmdMsg("/start")))
|
||||||
|
})
|
||||||
|
t.Run("strips BotName suffix", func(t *testing.T) {
|
||||||
|
m := &api.Message{
|
||||||
|
Text: "/start@MyBot",
|
||||||
|
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 12}},
|
||||||
|
}
|
||||||
|
f := msgfilter.Command("/start")
|
||||||
|
require.True(t, f(m))
|
||||||
|
})
|
||||||
|
t.Run("no match different command", func(t *testing.T) {
|
||||||
|
f := msgfilter.Command("/stop")
|
||||||
|
require.False(t, f(cmdMsg("/start")))
|
||||||
|
})
|
||||||
|
t.Run("nil message", func(t *testing.T) {
|
||||||
|
require.False(t, msgfilter.Command("/start")(nil))
|
||||||
|
})
|
||||||
|
t.Run("no entities", func(t *testing.T) {
|
||||||
|
require.False(t, msgfilter.Command("/start")(msg("/start")))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnyCommand(t *testing.T) {
|
||||||
|
f := msgfilter.AnyCommand()
|
||||||
|
require.True(t, f(cmdMsg("/anything")))
|
||||||
|
require.False(t, f(msg("plain text")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsReply(t *testing.T) {
|
||||||
|
f := msgfilter.IsReply()
|
||||||
|
m := msg("reply")
|
||||||
|
m.ReplyToMessage = &api.Message{MessageID: 2}
|
||||||
|
require.True(t, f(m))
|
||||||
|
require.False(t, f(msg("no reply")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsForward(t *testing.T) {
|
||||||
|
// ForwardOrigin is a MessageOrigin interface; set via a concrete type.
|
||||||
|
f := msgfilter.IsForward()
|
||||||
|
m := msg("fwd")
|
||||||
|
m.ForwardOrigin = &api.MessageOriginUser{Type: "user"}
|
||||||
|
require.True(t, f(m))
|
||||||
|
require.False(t, f(msg("no fwd")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasPhoto(t *testing.T) {
|
||||||
|
f := msgfilter.HasPhoto()
|
||||||
|
m := msg("")
|
||||||
|
m.Photo = []api.PhotoSize{{FileID: "x", Width: 100, Height: 100}}
|
||||||
|
require.True(t, f(m))
|
||||||
|
require.False(t, f(msg("no photo")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasDocument(t *testing.T) {
|
||||||
|
f := msgfilter.HasDocument()
|
||||||
|
m := msg("")
|
||||||
|
m.Document = &api.Document{FileID: "doc1"}
|
||||||
|
require.True(t, f(m))
|
||||||
|
require.False(t, f(msg("no doc")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasEntity(t *testing.T) {
|
||||||
|
f := msgfilter.HasEntity(string(api.EntityURL))
|
||||||
|
m := msg("check https://example.com")
|
||||||
|
m.Entities = []api.MessageEntity{{Type: string(api.EntityURL), Offset: 6, Length: 19}}
|
||||||
|
require.True(t, f(m))
|
||||||
|
require.False(t, f(msg("plain")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatType(t *testing.T) {
|
||||||
|
f := msgfilter.ChatType(api.ChatTypePrivate)
|
||||||
|
private := msg("hi")
|
||||||
|
require.True(t, f(private))
|
||||||
|
|
||||||
|
group := msg("hi")
|
||||||
|
group.Chat.Type = string(api.ChatTypeGroup)
|
||||||
|
require.False(t, f(group))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromUser(t *testing.T) {
|
||||||
|
f := msgfilter.FromUser(42)
|
||||||
|
m := msg("hi")
|
||||||
|
m.From = &api.User{ID: 42}
|
||||||
|
require.True(t, f(m))
|
||||||
|
|
||||||
|
m2 := msg("hi")
|
||||||
|
m2.From = &api.User{ID: 99}
|
||||||
|
require.False(t, f(m2))
|
||||||
|
|
||||||
|
require.False(t, f(msg("no from")))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInChat(t *testing.T) {
|
||||||
|
f := msgfilter.InChat(1)
|
||||||
|
require.True(t, f(msg("hi")))
|
||||||
|
m2 := msg("hi")
|
||||||
|
m2.Chat.ID = 2
|
||||||
|
require.False(t, f(m2))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposedMessageFilters(t *testing.T) {
|
||||||
|
// private chat AND contains "hello"
|
||||||
|
f := msgfilter.ChatType(api.ChatTypePrivate).And(msgfilter.TextContains("hello"))
|
||||||
|
m := msg("say hello")
|
||||||
|
require.True(t, f(m))
|
||||||
|
|
||||||
|
m2 := msg("say hello")
|
||||||
|
m2.Chat.Type = string(api.ChatTypeGroup)
|
||||||
|
require.False(t, f(m2))
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
// Package precheckoutquery provides Filter helpers for *api.PreCheckoutQuery payloads.
|
||||||
|
package precheckoutquery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Currency returns a Filter that matches pre-checkout queries with the given
|
||||||
|
// ISO 4217 currency code (e.g. "USD", "EUR", "XTR").
|
||||||
|
func Currency(c string) dispatch.Filter[*api.PreCheckoutQuery] {
|
||||||
|
return func(q *api.PreCheckoutQuery) bool {
|
||||||
|
return q != nil && q.Currency == c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromUser returns a Filter that matches pre-checkout queries sent by the
|
||||||
|
// user with the given ID.
|
||||||
|
func FromUser(uid int64) dispatch.Filter[*api.PreCheckoutQuery] {
|
||||||
|
return func(q *api.PreCheckoutQuery) bool {
|
||||||
|
return q != nil && q.From.ID == uid
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package precheckoutquery_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
pcqfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/precheckoutquery"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pcq(currency string, fromID int64) *api.PreCheckoutQuery {
|
||||||
|
return &api.PreCheckoutQuery{
|
||||||
|
ID: "q",
|
||||||
|
Currency: currency,
|
||||||
|
From: api.User{ID: fromID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurrency_Matches(t *testing.T) {
|
||||||
|
f := pcqfilter.Currency("USD")
|
||||||
|
require.True(t, f(pcq("USD", 1)))
|
||||||
|
require.False(t, f(pcq("EUR", 1)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromUser_Matches(t *testing.T) {
|
||||||
|
f := pcqfilter.FromUser(5)
|
||||||
|
require.True(t, f(pcq("USD", 5)))
|
||||||
|
require.False(t, f(pcq("USD", 9)))
|
||||||
|
require.False(t, f(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposedFilters(t *testing.T) {
|
||||||
|
f := pcqfilter.Currency("XTR").And(pcqfilter.FromUser(42))
|
||||||
|
require.True(t, f(pcq("XTR", 42)))
|
||||||
|
require.False(t, f(pcq("XTR", 99)))
|
||||||
|
require.False(t, f(pcq("USD", 42)))
|
||||||
|
}
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrEndGroups stops dispatch from running any further handlers in any
|
||||||
|
// group for this update when returned by a handler. Use it to indicate
|
||||||
|
// the update has been definitively handled.
|
||||||
|
//
|
||||||
|
// errors.Is(err, ErrEndGroups) is the canonical check, though dispatch
|
||||||
|
// itself recognises it by exact identity.
|
||||||
|
var ErrEndGroups = errors.New("dispatch: end groups")
|
||||||
|
|
||||||
|
// ErrContinueGroups signals that this group's handler should be treated
|
||||||
|
// as not-matching when returned by a handler: dispatch moves on to the
|
||||||
|
// next handler in the same group, then to subsequent groups.
|
||||||
|
//
|
||||||
|
// Without ErrContinueGroups, a non-error return from a matched handler
|
||||||
|
// stops dispatch (default first-match-wins semantics).
|
||||||
|
var ErrContinueGroups = errors.New("dispatch: continue groups")
|
||||||
|
|
||||||
|
// RouterScope registers handlers into a specific priority group on its parent
|
||||||
|
// Router. Group 0 runs first, then group 1, etc. Within a group, handlers run
|
||||||
|
// in registration order; the first non-skipped match terminates dispatch
|
||||||
|
// unless the handler returns ErrContinueGroups.
|
||||||
|
type RouterScope struct {
|
||||||
|
router *Router
|
||||||
|
group int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group returns a RouterScope that registers handlers in the given group.
|
||||||
|
// Group 0 (the default) runs first, then group 1, etc. Within a group,
|
||||||
|
// handlers run in registration order; the first non-skipped match
|
||||||
|
// terminates dispatch unless the handler returns ErrContinueGroups.
|
||||||
|
func (r *Router) Group(group int) *RouterScope {
|
||||||
|
return &RouterScope{router: r, group: group}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCommand registers a command handler in this group.
|
||||||
|
func (s *RouterScope) OnCommand(cmd string, h Handler[*api.Message]) {
|
||||||
|
s.router.groupCommands = append(s.router.groupCommands, groupCommandRoute{
|
||||||
|
cmd: cmd, group: s.group, handler: h,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnText registers a regex text handler in this group.
|
||||||
|
// Panics at registration time if pattern is not a valid regular expression.
|
||||||
|
func (s *RouterScope) OnText(pattern string, h Handler[*api.Message]) {
|
||||||
|
s.router.groupTexts = append(s.router.groupTexts, groupTextRoute{
|
||||||
|
re: regexp.MustCompile(pattern), group: s.group, handler: h,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMessageFilter registers a filter-based message handler in this group.
|
||||||
|
func (s *RouterScope) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
|
||||||
|
s.router.groupMessageFilters = append(s.router.groupMessageFilters, groupMessageFilterRoute{
|
||||||
|
filter: f, group: s.group, handler: h,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// group-aware route types
|
||||||
|
|
||||||
|
type groupCommandRoute struct {
|
||||||
|
cmd string
|
||||||
|
group int
|
||||||
|
handler Handler[*api.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupTextRoute struct {
|
||||||
|
re *regexp.Regexp
|
||||||
|
group int
|
||||||
|
handler Handler[*api.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupMessageFilterRoute struct {
|
||||||
|
filter Filter[*api.Message]
|
||||||
|
group int
|
||||||
|
handler Handler[*api.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchGroups runs message handlers registered via RouterScope.Group().
|
||||||
|
// It collects all matching groups, sorts by group number, and applies
|
||||||
|
// first-match-wins semantics within each group. Handlers may return
|
||||||
|
// ErrContinueGroups (skip to next handler/group) or ErrEndGroups (stop all groups).
|
||||||
|
// A non-sentinel error stops dispatch and is returned to the caller.
|
||||||
|
func (r *Router) dispatchGroups(c *Context, m *api.Message) error {
|
||||||
|
// Collect group numbers present.
|
||||||
|
groupSet := map[int]struct{}{}
|
||||||
|
for _, gr := range r.groupCommands {
|
||||||
|
groupSet[gr.group] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, gr := range r.groupTexts {
|
||||||
|
groupSet[gr.group] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, gr := range r.groupMessageFilters {
|
||||||
|
groupSet[gr.group] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(groupSet) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := make([]int, 0, len(groupSet))
|
||||||
|
for g := range groupSet {
|
||||||
|
groups = append(groups, g)
|
||||||
|
}
|
||||||
|
sort.Ints(groups)
|
||||||
|
|
||||||
|
for _, g := range groups {
|
||||||
|
matched, err := r.runGroupHandlers(c, m, g)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrEndGroups) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if matched {
|
||||||
|
// First-match-wins: stop further groups.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// No match or ErrContinueGroups from all handlers: try next group.
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGroupHandlers runs all handlers in group g against m, in registration
|
||||||
|
// order. Returns (true, nil) when a handler matched (returned nil). Returns
|
||||||
|
// (false, nil) when all handlers returned ErrContinueGroups. Returns
|
||||||
|
// (false, err) for ErrEndGroups or any non-sentinel error.
|
||||||
|
func (r *Router) runGroupHandlers(c *Context, m *api.Message, g int) (matched bool, err error) {
|
||||||
|
// Commands.
|
||||||
|
if cmd, args, ok := extractCommand(m); ok {
|
||||||
|
for _, route := range r.groupCommands {
|
||||||
|
if route.group != g || route.cmd != cmd {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.Values["command"] = cmd
|
||||||
|
c.Values["command_args"] = args
|
||||||
|
if err := route.handler(c, m); err != nil {
|
||||||
|
if errors.Is(err, ErrContinueGroups) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Text regex.
|
||||||
|
if m.Text != "" {
|
||||||
|
for _, route := range r.groupTexts {
|
||||||
|
if route.group != g {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
subs := route.re.FindStringSubmatch(m.Text)
|
||||||
|
if subs == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.Values["regex_match"] = subs
|
||||||
|
if err := route.handler(c, m); err != nil {
|
||||||
|
if errors.Is(err, ErrContinueGroups) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Filter-based.
|
||||||
|
for _, route := range r.groupMessageFilters {
|
||||||
|
if route.group != g || !route.filter(m) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := route.handler(c, m); err != nil {
|
||||||
|
if errors.Is(err, ErrContinueGroups) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// msgUpdate builds a simple private message update.
|
||||||
|
func msgUpdate(id int64, text string) api.Update {
|
||||||
|
return api.Update{
|
||||||
|
UpdateID: id,
|
||||||
|
Message: &api.Message{
|
||||||
|
MessageID: id,
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmdUpdate builds a command message update.
|
||||||
|
func cmdUpdate(id int64, cmd string) api.Update {
|
||||||
|
return api.Update{
|
||||||
|
UpdateID: id,
|
||||||
|
Message: &api.Message{
|
||||||
|
MessageID: id,
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: cmd,
|
||||||
|
Entities: []api.MessageEntity{
|
||||||
|
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len(cmd))},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSingle fires one update through the router and waits for it to complete.
|
||||||
|
func runSingle(t *testing.T, r *Router, up api.Update) {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(up))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_Order verifies group 0 fires before group 1.
|
||||||
|
func TestGroup_Order(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var order []int
|
||||||
|
|
||||||
|
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
order = append(order, 0)
|
||||||
|
return ErrContinueGroups // let group 1 also run
|
||||||
|
})
|
||||||
|
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
order = append(order, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
runSingle(t, r, msgUpdate(1, "hello"))
|
||||||
|
require.Equal(t, []int{0, 1}, order)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_FirstMatchWins verifies group 0 match stops group 1 by default.
|
||||||
|
func TestGroup_FirstMatchWins(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var fired []int
|
||||||
|
|
||||||
|
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
fired = append(fired, 0)
|
||||||
|
return nil // matched — group 1 must NOT run
|
||||||
|
})
|
||||||
|
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
fired = append(fired, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
runSingle(t, r, msgUpdate(1, "hello"))
|
||||||
|
require.Equal(t, []int{0}, fired)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_ErrContinueGroups lets group 1 run when group 0 returns ErrContinueGroups.
|
||||||
|
func TestGroup_ErrContinueGroups(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
g1Hit := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
return ErrContinueGroups
|
||||||
|
})
|
||||||
|
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
g1Hit <- struct{}{}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ping"))) }()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-g1Hit:
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("group 1 handler never fired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_ErrEndGroups stops all further groups.
|
||||||
|
func TestGroup_ErrEndGroups(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var fired []int
|
||||||
|
|
||||||
|
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
fired = append(fired, 0)
|
||||||
|
return ErrEndGroups
|
||||||
|
})
|
||||||
|
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
fired = append(fired, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
runSingle(t, r, msgUpdate(1, "hello"))
|
||||||
|
require.Equal(t, []int{0}, fired)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_NonSentinelError propagates error and stops further groups.
|
||||||
|
func TestGroup_NonSentinelError(t *testing.T) {
|
||||||
|
r := New(client.New("t"), WithMaxConcurrency(0))
|
||||||
|
var fired []int
|
||||||
|
|
||||||
|
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
fired = append(fired, 0)
|
||||||
|
return context.DeadlineExceeded // non-sentinel real error
|
||||||
|
})
|
||||||
|
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
fired = append(fired, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
runSingle(t, r, msgUpdate(1, "hello"))
|
||||||
|
// group 1 must not fire
|
||||||
|
require.Equal(t, []int{0}, fired)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_Command verifies OnCommand in a group works.
|
||||||
|
func TestGroup_Command(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
|
||||||
|
r.Group(0).OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||||
|
hit <- "g0-start"
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
r.Group(1).OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||||
|
hit <- "g1-start"
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/start"))) }()
|
||||||
|
|
||||||
|
got := <-hit
|
||||||
|
require.Equal(t, "g0-start", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_MessageFilter verifies OnMessageFilter in a group works.
|
||||||
|
func TestGroup_MessageFilter(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan bool, 1)
|
||||||
|
|
||||||
|
r.Group(0).OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ok" }),
|
||||||
|
func(c *Context, m *api.Message) error {
|
||||||
|
hit <- true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ok"))) }()
|
||||||
|
|
||||||
|
require.True(t, <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGroup_ErrContinueGroups_WithCommand verifies ErrContinueGroups works for commands across groups.
|
||||||
|
func TestGroup_ErrContinueGroups_WithCommand(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var count atomic.Int32
|
||||||
|
|
||||||
|
r.Group(0).OnCommand("/ping", func(c *Context, m *api.Message) error {
|
||||||
|
count.Add(1)
|
||||||
|
return ErrContinueGroups
|
||||||
|
})
|
||||||
|
r.Group(1).OnCommand("/ping", func(c *Context, m *api.Message) error {
|
||||||
|
count.Add(10)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/ping"))) }()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.Equal(t, int32(11), count.Load())
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
// Handler is a generic handler over update payload type T. T is typically
|
||||||
|
// *api.Message, *api.CallbackQuery, *api.InlineQuery, or *api.Update for
|
||||||
|
// global middleware.
|
||||||
|
type Handler[T any] func(ctx *Context, payload T) error
|
||||||
|
|
||||||
|
// Middleware wraps a Handler[T] with cross-cutting behaviour (logging,
|
||||||
|
// recovery, auth). Middleware composition is left-to-right: Use(a,b,c)
|
||||||
|
// runs as a(b(c(handler))).
|
||||||
|
type Middleware[T any] func(Handler[T]) Handler[T]
|
||||||
|
|
||||||
|
// Chain composes a slice of middleware into a single Middleware[T].
|
||||||
|
func Chain[T any](mws ...Middleware[T]) Middleware[T] {
|
||||||
|
return func(h Handler[T]) Handler[T] {
|
||||||
|
for i := len(mws) - 1; i >= 0; i-- {
|
||||||
|
h = mws[i](h)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recovery returns middleware that recovers from panics in downstream
|
||||||
|
// handlers, converting them into a returned error and logging via the
|
||||||
|
// bot's configured logger. Registered automatically by NewRouter.
|
||||||
|
func Recovery() Middleware[*api.Update] {
|
||||||
|
return func(next Handler[*api.Update]) Handler[*api.Update] {
|
||||||
|
return func(c *Context, u *api.Update) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic in handler: %v\n%s", r, debug.Stack())
|
||||||
|
if c.Bot != nil {
|
||||||
|
c.Bot.Logger().Error("dispatch recovered panic", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return next(c, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NamedHandlers manages handlers by string name, allowing runtime
|
||||||
|
// registration, replacement, and removal. This complements the Router's
|
||||||
|
// registration methods: each registration via Named*() also gets a name
|
||||||
|
// for later lookup.
|
||||||
|
//
|
||||||
|
// Use case: a plugin system that loads/unloads command handlers without
|
||||||
|
// restarting the bot.
|
||||||
|
type NamedHandlers[T any] struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
handlers map[string]Handler[T]
|
||||||
|
order []string // preserves registration order
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNamedHandlers returns a new, empty NamedHandlers[T].
|
||||||
|
func NewNamedHandlers[T any]() *NamedHandlers[T] {
|
||||||
|
return &NamedHandlers[T]{handlers: map[string]Handler[T]{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set registers or replaces the handler under name. If name is new, it is
|
||||||
|
// appended to the end of the registration order.
|
||||||
|
func (n *NamedHandlers[T]) Set(name string, h Handler[T]) {
|
||||||
|
n.mu.Lock()
|
||||||
|
defer n.mu.Unlock()
|
||||||
|
if _, exists := n.handlers[name]; !exists {
|
||||||
|
n.order = append(n.order, name)
|
||||||
|
}
|
||||||
|
n.handlers[name] = h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove unregisters the handler under name. Returns true if it existed.
|
||||||
|
func (n *NamedHandlers[T]) Remove(name string) bool {
|
||||||
|
n.mu.Lock()
|
||||||
|
defer n.mu.Unlock()
|
||||||
|
if _, ok := n.handlers[name]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(n.handlers, name)
|
||||||
|
for i, k := range n.order {
|
||||||
|
if k == name {
|
||||||
|
n.order = append(n.order[:i], n.order[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has reports whether name is registered.
|
||||||
|
func (n *NamedHandlers[T]) Has(name string) bool {
|
||||||
|
n.mu.RLock()
|
||||||
|
defer n.mu.RUnlock()
|
||||||
|
_, ok := n.handlers[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Names returns the registered names in registration order.
|
||||||
|
func (n *NamedHandlers[T]) Names() []string {
|
||||||
|
n.mu.RLock()
|
||||||
|
defer n.mu.RUnlock()
|
||||||
|
out := make([]string, len(n.order))
|
||||||
|
copy(out, n.order)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns a single Handler[T] that runs each registered handler
|
||||||
|
// in registration order, first non-nil error stops the chain. Use this
|
||||||
|
// to wire NamedHandlers into a Router.OnXxx call:
|
||||||
|
//
|
||||||
|
// names := dispatch.NewNamedHandlers[*api.Message]()
|
||||||
|
// names.Set("logger", loggingHandler)
|
||||||
|
// names.Set("audit", auditHandler)
|
||||||
|
// router.OnCommand("/admin", names.Handler())
|
||||||
|
//
|
||||||
|
// Subsequent Set/Remove calls take effect on the next dispatch.
|
||||||
|
func (n *NamedHandlers[T]) Handler() Handler[T] {
|
||||||
|
return func(c *Context, payload T) error {
|
||||||
|
n.mu.RLock()
|
||||||
|
names := make([]string, len(n.order))
|
||||||
|
copy(names, n.order)
|
||||||
|
n.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, name := range names {
|
||||||
|
n.mu.RLock()
|
||||||
|
h, ok := n.handlers[name]
|
||||||
|
n.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := h(c, payload); err != nil {
|
||||||
|
return fmt.Errorf("named handler %q: %w", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeMsg returns a minimal *api.Message for use in handler tests.
|
||||||
|
func makeMsg() *api.Message {
|
||||||
|
return &api.Message{MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeCtx returns a minimal *Context (nil bot is fine for unit tests).
|
||||||
|
func makeCtx() *Context {
|
||||||
|
return NewContext(context.Background(), nil, &api.Update{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_SetAndHas(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
require.False(t, n.Has("a"))
|
||||||
|
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
require.True(t, n.Has("a"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_Names_RegistrationOrder(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
n.Set("first", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Set("second", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Set("third", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
require.Equal(t, []string{"first", "second", "third"}, n.Names())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_Remove(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Set("b", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
|
||||||
|
removed := n.Remove("a")
|
||||||
|
require.True(t, removed)
|
||||||
|
require.False(t, n.Has("a"))
|
||||||
|
require.Equal(t, []string{"b"}, n.Names())
|
||||||
|
|
||||||
|
// Remove non-existent returns false.
|
||||||
|
require.False(t, n.Remove("nonexistent"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_Replacement_SameOrderSlot(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Set("b", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
|
||||||
|
var called string
|
||||||
|
n.Set("a", func(c *Context, m *api.Message) error {
|
||||||
|
called = "replaced-a"
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Order must not change; "a" stays first.
|
||||||
|
require.Equal(t, []string{"a", "b"}, n.Names())
|
||||||
|
|
||||||
|
h := n.Handler()
|
||||||
|
_ = h(makeCtx(), makeMsg())
|
||||||
|
require.Equal(t, "replaced-a", called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_Handler_RunsInOrder(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
var calls []string
|
||||||
|
|
||||||
|
n.Set("first", func(c *Context, m *api.Message) error {
|
||||||
|
calls = append(calls, "first")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
n.Set("second", func(c *Context, m *api.Message) error {
|
||||||
|
calls = append(calls, "second")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
h := n.Handler()
|
||||||
|
require.NoError(t, h(makeCtx(), makeMsg()))
|
||||||
|
require.Equal(t, []string{"first", "second"}, calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_Handler_ErrorWrappedAndStops(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
sentinel := errors.New("boom")
|
||||||
|
|
||||||
|
n.Set("ok", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Set("fail", func(c *Context, m *api.Message) error { return sentinel })
|
||||||
|
n.Set("never", func(c *Context, m *api.Message) error {
|
||||||
|
t.Fatal("should not be called after an error")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
h := n.Handler()
|
||||||
|
err := h(makeCtx(), makeMsg())
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, sentinel))
|
||||||
|
require.Contains(t, err.Error(), `named handler "fail"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_Concurrent_SetRemove(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
|
||||||
|
// Pre-populate so Handler() has something to iterate.
|
||||||
|
for i := range 5 {
|
||||||
|
name := fmt.Sprintf("h%d", i)
|
||||||
|
n.Set(name, func(c *Context, m *api.Message) error { return nil })
|
||||||
|
}
|
||||||
|
|
||||||
|
h := n.Handler()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Concurrent readers (invoke handler).
|
||||||
|
for range 20 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = h(makeCtx(), makeMsg())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent writers.
|
||||||
|
for i := range 5 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
name := fmt.Sprintf("new%d", i)
|
||||||
|
n.Set(name, func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Remove(fmt.Sprintf("h%d", i))
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedHandlers_RemoveAndReinstate(t *testing.T) {
|
||||||
|
n := NewNamedHandlers[*api.Message]()
|
||||||
|
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Remove("a")
|
||||||
|
require.False(t, n.Has("a"))
|
||||||
|
|
||||||
|
// Re-register after removal; should be added at end.
|
||||||
|
n.Set("b", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||||
|
require.Equal(t, []string{"b", "a"}, n.Names())
|
||||||
|
}
|
||||||
@@ -0,0 +1,582 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/transport"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Router dispatches updates from any Updater to typed handlers.
|
||||||
|
//
|
||||||
|
// Matchers run in registration order; first match wins. A panic-recovery
|
||||||
|
// middleware is attached automatically and runs around every dispatch.
|
||||||
|
type Router struct {
|
||||||
|
bot *client.Bot
|
||||||
|
|
||||||
|
commands []commandRoute
|
||||||
|
texts []textRoute
|
||||||
|
callbacks []callbackRoute
|
||||||
|
inlines []Handler[*api.InlineQuery]
|
||||||
|
editedMsg []Handler[*api.Message]
|
||||||
|
channelPosts []Handler[*api.Message]
|
||||||
|
editedChannelPosts []Handler[*api.Message]
|
||||||
|
|
||||||
|
messageFilters []messageFilterRoute
|
||||||
|
callbackFilters []callbackFilterRoute
|
||||||
|
inlineFilters []inlineFilterRoute
|
||||||
|
|
||||||
|
// typed update handlers
|
||||||
|
myChatMember []Handler[*api.ChatMemberUpdated]
|
||||||
|
chatMember []Handler[*api.ChatMemberUpdated]
|
||||||
|
chatJoinRequest []Handler[*api.ChatJoinRequest]
|
||||||
|
preCheckoutQuery []Handler[*api.PreCheckoutQuery]
|
||||||
|
shippingQuery []Handler[*api.ShippingQuery]
|
||||||
|
poll []Handler[*api.Poll]
|
||||||
|
pollAnswer []Handler[*api.PollAnswer]
|
||||||
|
chosenInlineResult []Handler[*api.ChosenInlineResult]
|
||||||
|
messageReaction []Handler[*api.MessageReactionUpdated]
|
||||||
|
messageReactionCnt []Handler[*api.MessageReactionCountUpdated]
|
||||||
|
chatBoost []Handler[*api.ChatBoostUpdated]
|
||||||
|
removedChatBoost []Handler[*api.ChatBoostRemoved]
|
||||||
|
businessConn []Handler[*api.BusinessConnection]
|
||||||
|
purchasedPaidMedia []Handler[*api.PaidMediaPurchased]
|
||||||
|
|
||||||
|
myChatMemberFilters []chatMemberFilterRoute
|
||||||
|
chatMemberFilters []chatMemberFilterRoute
|
||||||
|
chatJoinRequestFilters []chatJoinRequestFilterRoute
|
||||||
|
preCheckoutFilters []preCheckoutFilterRoute
|
||||||
|
|
||||||
|
// group-priority routes (registered via Router.Group())
|
||||||
|
groupCommands []groupCommandRoute
|
||||||
|
groupTexts []groupTextRoute
|
||||||
|
groupMessageFilters []groupMessageFilterRoute
|
||||||
|
|
||||||
|
globalMW []Middleware[*api.Update]
|
||||||
|
|
||||||
|
maxConcurrency int // default 50; 0 = serial (legacy)
|
||||||
|
sem chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageFilterRoute struct {
|
||||||
|
filter Filter[*api.Message]
|
||||||
|
handler Handler[*api.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbackFilterRoute struct {
|
||||||
|
filter Filter[*api.CallbackQuery]
|
||||||
|
handler Handler[*api.CallbackQuery]
|
||||||
|
}
|
||||||
|
|
||||||
|
type inlineFilterRoute struct {
|
||||||
|
filter Filter[*api.InlineQuery]
|
||||||
|
handler Handler[*api.InlineQuery]
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatMemberFilterRoute struct {
|
||||||
|
filter Filter[*api.ChatMemberUpdated]
|
||||||
|
handler Handler[*api.ChatMemberUpdated]
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatJoinRequestFilterRoute struct {
|
||||||
|
filter Filter[*api.ChatJoinRequest]
|
||||||
|
handler Handler[*api.ChatJoinRequest]
|
||||||
|
}
|
||||||
|
|
||||||
|
type preCheckoutFilterRoute struct {
|
||||||
|
filter Filter[*api.PreCheckoutQuery]
|
||||||
|
handler Handler[*api.PreCheckoutQuery]
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOption configures a Router at construction time.
|
||||||
|
type RouterOption func(*Router)
|
||||||
|
|
||||||
|
// WithMaxConcurrency sets the maximum number of updates processed in parallel.
|
||||||
|
// Default is 50. Pass 0 to dispatch serially (one update at a time, in the
|
||||||
|
// calling goroutine — the legacy behaviour before v1.1.0).
|
||||||
|
//
|
||||||
|
// Note: concurrent dispatch means handlers for different updates may run
|
||||||
|
// simultaneously. Handlers that mutate shared state must be safe for concurrent
|
||||||
|
// access.
|
||||||
|
func WithMaxConcurrency(n int) RouterOption {
|
||||||
|
return func(r *Router) { r.maxConcurrency = n }
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandRoute struct {
|
||||||
|
cmd string
|
||||||
|
handler Handler[*api.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
type textRoute struct {
|
||||||
|
re *regexp.Regexp
|
||||||
|
handler Handler[*api.Message]
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbackRoute struct {
|
||||||
|
re *regexp.Regexp
|
||||||
|
handler Handler[*api.CallbackQuery]
|
||||||
|
}
|
||||||
|
|
||||||
|
// New constructs a Router. Recovery middleware is added by default; users
|
||||||
|
// can disable it by passing WithoutRecovery (not implemented here, but
|
||||||
|
// the hook is in place via Use).
|
||||||
|
func New(b *client.Bot, opts ...RouterOption) *Router {
|
||||||
|
r := &Router{bot: b, maxConcurrency: 50}
|
||||||
|
for _, o := range opts {
|
||||||
|
o(r)
|
||||||
|
}
|
||||||
|
if r.maxConcurrency > 0 {
|
||||||
|
r.sem = make(chan struct{}, r.maxConcurrency)
|
||||||
|
}
|
||||||
|
r.Use(Recovery())
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use registers a global middleware applied to every Update dispatch.
|
||||||
|
func (r *Router) Use(mw Middleware[*api.Update]) { r.globalMW = append(r.globalMW, mw) }
|
||||||
|
|
||||||
|
// OnCommand registers a handler for a slash command. The command string
|
||||||
|
// includes the leading slash (e.g. "/start"). Matching strips an optional
|
||||||
|
// "@BotName" suffix.
|
||||||
|
func (r *Router) OnCommand(cmd string, h Handler[*api.Message]) {
|
||||||
|
r.commands = append(r.commands, commandRoute{cmd: cmd, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnText registers a handler for messages whose Text matches the regex.
|
||||||
|
//
|
||||||
|
// Panics at registration time if pattern is not a valid regular expression.
|
||||||
|
func (r *Router) OnText(pattern string, h Handler[*api.Message]) {
|
||||||
|
r.texts = append(r.texts, textRoute{re: regexp.MustCompile(pattern), handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCallback registers a handler for callback queries whose Data matches
|
||||||
|
// the regex.
|
||||||
|
//
|
||||||
|
// Panics at registration time if pattern is not a valid regular expression.
|
||||||
|
func (r *Router) OnCallback(pattern string, h Handler[*api.CallbackQuery]) {
|
||||||
|
r.callbacks = append(r.callbacks, callbackRoute{re: regexp.MustCompile(pattern), handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnInlineQuery registers a handler for inline queries (one matcher only;
|
||||||
|
// inline queries are not partitioned by content here).
|
||||||
|
func (r *Router) OnInlineQuery(h Handler[*api.InlineQuery]) {
|
||||||
|
r.inlines = append(r.inlines, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnEditedMessage registers a handler for edited message updates.
|
||||||
|
func (r *Router) OnEditedMessage(h Handler[*api.Message]) {
|
||||||
|
r.editedMsg = append(r.editedMsg, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChannelPost registers a handler for channel post updates.
|
||||||
|
func (r *Router) OnChannelPost(h Handler[*api.Message]) {
|
||||||
|
r.channelPosts = append(r.channelPosts, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnEditedChannelPost registers a handler for edited channel post updates.
|
||||||
|
func (r *Router) OnEditedChannelPost(h Handler[*api.Message]) {
|
||||||
|
r.editedChannelPosts = append(r.editedChannelPosts, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMessageFilter registers a typed message handler gated by filter f.
|
||||||
|
// Filter routes are checked after command and text routes; first match wins.
|
||||||
|
func (r *Router) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
|
||||||
|
r.messageFilters = append(r.messageFilters, messageFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCallbackFilter registers a typed callback-query handler gated by filter f.
|
||||||
|
// Filter routes are checked after pattern-based OnCallback routes; first match wins.
|
||||||
|
func (r *Router) OnCallbackFilter(f Filter[*api.CallbackQuery], h Handler[*api.CallbackQuery]) {
|
||||||
|
r.callbackFilters = append(r.callbackFilters, callbackFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnInlineQueryFilter registers an inline-query handler gated by filter f.
|
||||||
|
// Filter routes are checked after bare OnInlineQuery handlers; first match wins.
|
||||||
|
func (r *Router) OnInlineQueryFilter(f Filter[*api.InlineQuery], h Handler[*api.InlineQuery]) {
|
||||||
|
r.inlineFilters = append(r.inlineFilters, inlineFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMyChatMember registers a handler for bot's own chat member status changes.
|
||||||
|
func (r *Router) OnMyChatMember(h Handler[*api.ChatMemberUpdated]) {
|
||||||
|
r.myChatMember = append(r.myChatMember, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMyChatMemberFilter registers a filtered handler for bot's own chat member status changes.
|
||||||
|
func (r *Router) OnMyChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
|
||||||
|
r.myChatMemberFilters = append(r.myChatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChatMember registers a handler for chat member status changes.
|
||||||
|
func (r *Router) OnChatMember(h Handler[*api.ChatMemberUpdated]) {
|
||||||
|
r.chatMember = append(r.chatMember, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChatMemberFilter registers a filtered handler for chat member status changes.
|
||||||
|
func (r *Router) OnChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
|
||||||
|
r.chatMemberFilters = append(r.chatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChatJoinRequest registers a handler for chat join requests.
|
||||||
|
func (r *Router) OnChatJoinRequest(h Handler[*api.ChatJoinRequest]) {
|
||||||
|
r.chatJoinRequest = append(r.chatJoinRequest, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChatJoinRequestFilter registers a filtered handler for chat join requests.
|
||||||
|
func (r *Router) OnChatJoinRequestFilter(f Filter[*api.ChatJoinRequest], h Handler[*api.ChatJoinRequest]) {
|
||||||
|
r.chatJoinRequestFilters = append(r.chatJoinRequestFilters, chatJoinRequestFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPreCheckoutQuery registers a handler for pre-checkout queries.
|
||||||
|
func (r *Router) OnPreCheckoutQuery(h Handler[*api.PreCheckoutQuery]) {
|
||||||
|
r.preCheckoutQuery = append(r.preCheckoutQuery, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPreCheckoutQueryFilter registers a filtered handler for pre-checkout queries.
|
||||||
|
func (r *Router) OnPreCheckoutQueryFilter(f Filter[*api.PreCheckoutQuery], h Handler[*api.PreCheckoutQuery]) {
|
||||||
|
r.preCheckoutFilters = append(r.preCheckoutFilters, preCheckoutFilterRoute{filter: f, handler: h})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnShippingQuery registers a handler for shipping queries.
|
||||||
|
func (r *Router) OnShippingQuery(h Handler[*api.ShippingQuery]) {
|
||||||
|
r.shippingQuery = append(r.shippingQuery, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPoll registers a handler for poll state updates.
|
||||||
|
func (r *Router) OnPoll(h Handler[*api.Poll]) {
|
||||||
|
r.poll = append(r.poll, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPollAnswer registers a handler for poll answer updates.
|
||||||
|
func (r *Router) OnPollAnswer(h Handler[*api.PollAnswer]) {
|
||||||
|
r.pollAnswer = append(r.pollAnswer, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChosenInlineResult registers a handler for chosen inline results.
|
||||||
|
func (r *Router) OnChosenInlineResult(h Handler[*api.ChosenInlineResult]) {
|
||||||
|
r.chosenInlineResult = append(r.chosenInlineResult, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMessageReaction registers a handler for message reaction updates.
|
||||||
|
func (r *Router) OnMessageReaction(h Handler[*api.MessageReactionUpdated]) {
|
||||||
|
r.messageReaction = append(r.messageReaction, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMessageReactionCount registers a handler for anonymous message reaction count updates.
|
||||||
|
func (r *Router) OnMessageReactionCount(h Handler[*api.MessageReactionCountUpdated]) {
|
||||||
|
r.messageReactionCnt = append(r.messageReactionCnt, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnChatBoost registers a handler for chat boost updates.
|
||||||
|
func (r *Router) OnChatBoost(h Handler[*api.ChatBoostUpdated]) {
|
||||||
|
r.chatBoost = append(r.chatBoost, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnRemovedChatBoost registers a handler for removed chat boost updates.
|
||||||
|
func (r *Router) OnRemovedChatBoost(h Handler[*api.ChatBoostRemoved]) {
|
||||||
|
r.removedChatBoost = append(r.removedChatBoost, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnBusinessConnection registers a handler for business connection updates.
|
||||||
|
func (r *Router) OnBusinessConnection(h Handler[*api.BusinessConnection]) {
|
||||||
|
r.businessConn = append(r.businessConn, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPurchasedPaidMedia registers a handler for purchased paid media updates.
|
||||||
|
func (r *Router) OnPurchasedPaidMedia(h Handler[*api.PaidMediaPurchased]) {
|
||||||
|
r.purchasedPaidMedia = append(r.purchasedPaidMedia, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run consumes the Updater and dispatches each update. It blocks until
|
||||||
|
// the Updater's channel is closed or ctx is cancelled.
|
||||||
|
//
|
||||||
|
// By default updates are processed concurrently (up to WithMaxConcurrency(50)
|
||||||
|
// goroutines). Handlers for different updates may therefore run simultaneously;
|
||||||
|
// shared state must be protected. Pass WithMaxConcurrency(0) to New to restore
|
||||||
|
// serial (legacy) behaviour.
|
||||||
|
//
|
||||||
|
// Run waits for all in-flight handlers to finish before returning.
|
||||||
|
func (r *Router) Run(ctx context.Context, u transport.Updater) error {
|
||||||
|
runErr := make(chan error, 1)
|
||||||
|
go func() { runErr <- u.Run(ctx) }()
|
||||||
|
|
||||||
|
root := r.dispatch
|
||||||
|
for i := len(r.globalMW) - 1; i >= 0; i-- {
|
||||||
|
root = r.globalMW[i](root)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
dispatch := func(up api.Update) {
|
||||||
|
c := NewContext(ctx, r.bot, &up)
|
||||||
|
if err := root(c, &up); err != nil {
|
||||||
|
if r.bot != nil {
|
||||||
|
r.bot.Logger().Error("dispatch handler error", "err", err, "update_id", up.UpdateID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-runErr:
|
||||||
|
return err
|
||||||
|
case up, ok := <-u.Updates():
|
||||||
|
if !ok {
|
||||||
|
// Channel closed; consume the run error if pending.
|
||||||
|
select {
|
||||||
|
case err := <-runErr:
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.sem == nil {
|
||||||
|
// Serial mode (legacy / WithMaxConcurrency(0)).
|
||||||
|
dispatch(up)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent mode: acquire semaphore slot then launch goroutine.
|
||||||
|
select {
|
||||||
|
case r.sem <- struct{}{}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(up api.Update) {
|
||||||
|
defer func() {
|
||||||
|
<-r.sem
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
dispatch(up)
|
||||||
|
}(up)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) dispatch(c *Context, u *api.Update) error {
|
||||||
|
switch {
|
||||||
|
case u.Message != nil:
|
||||||
|
return r.handleMessage(c, u.Message)
|
||||||
|
case u.EditedMessage != nil:
|
||||||
|
return runHandlers(r.editedMsg, c, u.EditedMessage)
|
||||||
|
case u.ChannelPost != nil:
|
||||||
|
return runHandlers(r.channelPosts, c, u.ChannelPost)
|
||||||
|
case u.EditedChannelPost != nil:
|
||||||
|
return runHandlers(r.editedChannelPosts, c, u.EditedChannelPost)
|
||||||
|
case u.CallbackQuery != nil:
|
||||||
|
return r.handleCallback(c, u.CallbackQuery)
|
||||||
|
case u.InlineQuery != nil:
|
||||||
|
if err := runHandlers(r.inlines, c, u.InlineQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, route := range r.inlineFilters {
|
||||||
|
if route.filter(u.InlineQuery) {
|
||||||
|
return route.handler(c, u.InlineQuery)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case u.MyChatMember != nil:
|
||||||
|
return r.handleChatMemberUpdate(c, u.MyChatMember, r.myChatMember, r.myChatMemberFilters)
|
||||||
|
case u.ChatMember != nil:
|
||||||
|
return r.handleChatMemberUpdate(c, u.ChatMember, r.chatMember, r.chatMemberFilters)
|
||||||
|
case u.ChatJoinRequest != nil:
|
||||||
|
return r.handleChatJoinRequest(c, u.ChatJoinRequest)
|
||||||
|
case u.PreCheckoutQuery != nil:
|
||||||
|
return r.handlePreCheckoutQuery(c, u.PreCheckoutQuery)
|
||||||
|
case u.ShippingQuery != nil:
|
||||||
|
return runHandlers(r.shippingQuery, c, u.ShippingQuery)
|
||||||
|
case u.Poll != nil:
|
||||||
|
return runHandlers(r.poll, c, u.Poll)
|
||||||
|
case u.PollAnswer != nil:
|
||||||
|
return runHandlers(r.pollAnswer, c, u.PollAnswer)
|
||||||
|
case u.ChosenInlineResult != nil:
|
||||||
|
return runHandlers(r.chosenInlineResult, c, u.ChosenInlineResult)
|
||||||
|
case u.MessageReaction != nil:
|
||||||
|
return runHandlers(r.messageReaction, c, u.MessageReaction)
|
||||||
|
case u.MessageReactionCount != nil:
|
||||||
|
return runHandlers(r.messageReactionCnt, c, u.MessageReactionCount)
|
||||||
|
case u.ChatBoost != nil:
|
||||||
|
return runHandlers(r.chatBoost, c, u.ChatBoost)
|
||||||
|
case u.RemovedChatBoost != nil:
|
||||||
|
return runHandlers(r.removedChatBoost, c, u.RemovedChatBoost)
|
||||||
|
case u.BusinessConnection != nil:
|
||||||
|
return runHandlers(r.businessConn, c, u.BusinessConnection)
|
||||||
|
case u.PurchasedPaidMedia != nil:
|
||||||
|
return runHandlers(r.purchasedPaidMedia, c, u.PurchasedPaidMedia)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) handleChatMemberUpdate(c *Context, payload *api.ChatMemberUpdated, handlers []Handler[*api.ChatMemberUpdated], filters []chatMemberFilterRoute) error {
|
||||||
|
if err := runHandlers(handlers, c, payload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, route := range filters {
|
||||||
|
if route.filter(payload) {
|
||||||
|
return route.handler(c, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) handleChatJoinRequest(c *Context, payload *api.ChatJoinRequest) error {
|
||||||
|
if err := runHandlers(r.chatJoinRequest, c, payload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, route := range r.chatJoinRequestFilters {
|
||||||
|
if route.filter(payload) {
|
||||||
|
return route.handler(c, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) handlePreCheckoutQuery(c *Context, payload *api.PreCheckoutQuery) error {
|
||||||
|
if err := runHandlers(r.preCheckoutQuery, c, payload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, route := range r.preCheckoutFilters {
|
||||||
|
if route.filter(payload) {
|
||||||
|
return route.handler(c, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runHandlers invokes each handler in order; returns the first non-nil error.
|
||||||
|
func runHandlers[T any](handlers []Handler[T], c *Context, payload T) error {
|
||||||
|
for _, h := range handlers {
|
||||||
|
if err := h(c, payload); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) handleMessage(c *Context, m *api.Message) error {
|
||||||
|
// Try command first (entity-aware).
|
||||||
|
if cmd, args, ok := extractCommand(m); ok {
|
||||||
|
for _, route := range r.commands {
|
||||||
|
if route.cmd == cmd {
|
||||||
|
c.Values["command"] = cmd
|
||||||
|
c.Values["command_args"] = args
|
||||||
|
return route.handler(c, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Then text regex matchers.
|
||||||
|
if m.Text != "" {
|
||||||
|
for _, route := range r.texts {
|
||||||
|
if subs := route.re.FindStringSubmatch(m.Text); subs != nil {
|
||||||
|
c.Values["regex_match"] = subs
|
||||||
|
return route.handler(c, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Filter-based routes.
|
||||||
|
for _, route := range r.messageFilters {
|
||||||
|
if route.filter(m) {
|
||||||
|
return route.handler(c, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Group-priority routes (registered via RouterScope.Group()).
|
||||||
|
return r.dispatchGroups(c, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) handleCallback(c *Context, q *api.CallbackQuery) error {
|
||||||
|
for _, route := range r.callbacks {
|
||||||
|
if subs := route.re.FindStringSubmatch(q.Data); subs != nil {
|
||||||
|
c.Values["regex_match"] = subs
|
||||||
|
return route.handler(c, q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Filter-based routes checked after pattern routes.
|
||||||
|
for _, route := range r.callbackFilters {
|
||||||
|
if route.filter(q) {
|
||||||
|
return route.handler(c, q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCommand returns the command (e.g. "/start") and the remaining
|
||||||
|
// argument string, when m carries a leading bot_command entity. It strips
|
||||||
|
// optional "@BotName" suffix on the command itself.
|
||||||
|
func extractCommand(m *api.Message) (cmd, args string, ok bool) {
|
||||||
|
if len(m.Entities) == 0 || m.Text == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
first := m.Entities[0]
|
||||||
|
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
cmd, sliceOk := utf16Slice(m.Text, int(first.Offset), int(first.Length))
|
||||||
|
if !sliceOk {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
if i := strings.Index(cmd, "@"); i >= 0 {
|
||||||
|
cmd = cmd[:i]
|
||||||
|
}
|
||||||
|
end := int(first.Offset) + int(first.Length)
|
||||||
|
rest, _ := utf16Slice(m.Text, end, utf16Len(m.Text)-end)
|
||||||
|
args = strings.TrimSpace(rest)
|
||||||
|
return cmd, args, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// utf16Slice returns the substring of s identified by a UTF-16 offset/length
|
||||||
|
// pair, as Telegram's MessageEntity uses. ok is false if the indices fall
|
||||||
|
// outside s's UTF-16 length.
|
||||||
|
func utf16Slice(s string, offset, length int) (string, bool) {
|
||||||
|
runes := []rune(s)
|
||||||
|
var startBytes, endBytes int
|
||||||
|
var u16 int
|
||||||
|
found := false
|
||||||
|
for i, r := range runes {
|
||||||
|
if u16 == offset {
|
||||||
|
startBytes = byteIndex(runes, i)
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
if u16 == offset+length {
|
||||||
|
endBytes = byteIndex(runes, i)
|
||||||
|
return s[startBytes:endBytes], true
|
||||||
|
}
|
||||||
|
if r > 0xFFFF {
|
||||||
|
u16 += 2
|
||||||
|
} else {
|
||||||
|
u16++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found && u16 == offset+length {
|
||||||
|
return s[startBytes:], true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func byteIndex(runes []rune, runeIdx int) int {
|
||||||
|
n := 0
|
||||||
|
for i := 0; i < runeIdx; i++ {
|
||||||
|
n += utf8.RuneLen(runes[i])
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func utf16Len(s string) int {
|
||||||
|
n := 0
|
||||||
|
for _, r := range s {
|
||||||
|
if r > 0xFFFF {
|
||||||
|
n += 2
|
||||||
|
} else {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
@@ -0,0 +1,940 @@
|
|||||||
|
package dispatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/go-telegram/api"
|
||||||
|
"github.com/lukaszraczylo/go-telegram/client"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeUpdater feeds a fixed slice of updates then closes.
|
||||||
|
type fakeUpdater struct{ ch chan api.Update }
|
||||||
|
|
||||||
|
func newFake(ups ...api.Update) *fakeUpdater {
|
||||||
|
ch := make(chan api.Update, len(ups))
|
||||||
|
for _, u := range ups {
|
||||||
|
ch <- u
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return &fakeUpdater{ch: ch}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
|
||||||
|
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||||
|
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
|
||||||
|
|
||||||
|
func cmdMessage(text string) api.Update {
|
||||||
|
return api.Update{
|
||||||
|
UpdateID: 1,
|
||||||
|
Message: &api.Message{
|
||||||
|
MessageID: 1, Date: 0, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: text,
|
||||||
|
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(indexEnd(text))}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexEnd(text string) int {
|
||||||
|
for i, r := range text {
|
||||||
|
if r == ' ' {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnCommandMatches(t *testing.T) {
|
||||||
|
b := client.New("t")
|
||||||
|
r := New(b)
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||||
|
hit <- c.Values["command"].(string)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start"))) }()
|
||||||
|
|
||||||
|
require.Equal(t, "/start", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnCommandStripsBotName(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||||
|
hit <- "matched"
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start@MyBot hello"))) }()
|
||||||
|
|
||||||
|
require.Equal(t, "matched", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnText(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan []string, 1)
|
||||||
|
r.OnText(`^hello (\w+)$`, func(c *Context, m *api.Message) error {
|
||||||
|
hit <- c.Values["regex_match"].([]string)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "hello world",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
subs := <-hit
|
||||||
|
require.Equal(t, "world", subs[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnCallback(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnCallback(`^like:(\d+)$`, func(c *Context, q *api.CallbackQuery) error {
|
||||||
|
hit <- q.Data
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
|
||||||
|
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "like:42",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
require.Equal(t, "like:42", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_NoMatch(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
called := false
|
||||||
|
r.OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
u := api.Update{UpdateID: 1, Message: &api.Message{Text: "no command"}}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(u))
|
||||||
|
require.False(t, called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_PanicRecovery(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
r.OnCommand("/boom", func(c *Context, m *api.Message) error {
|
||||||
|
panic("kaboom")
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
// Should not propagate panic to Run.
|
||||||
|
require.NotPanics(t, func() { _ = r.Run(ctx, newFake(cmdMessage("/boom"))) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouter_NonASCIICommand verifies that UTF-16 entity offsets are used
|
||||||
|
// correctly when the command contains non-ASCII runes. "/старт" is 6 runes,
|
||||||
|
// each a BMP code point, so UTF-16 length == 6.
|
||||||
|
func TestRouter_NonASCIICommand(t *testing.T) {
|
||||||
|
const text = "/старт аргумент"
|
||||||
|
// "/старт" = 1 + 5 runes, all BMP → UTF-16 length 6
|
||||||
|
const cmdU16Len = int64(6)
|
||||||
|
u := api.Update{
|
||||||
|
UpdateID: 1,
|
||||||
|
Message: &api.Message{
|
||||||
|
MessageID: 1,
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: text,
|
||||||
|
Entities: []api.MessageEntity{
|
||||||
|
{Type: string(api.EntityBotCommand), Offset: 0, Length: cmdU16Len},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan [2]string, 1)
|
||||||
|
r.OnCommand("/старт", func(c *Context, m *api.Message) error {
|
||||||
|
hit <- [2]string{
|
||||||
|
c.Values["command"].(string),
|
||||||
|
c.Values["command_args"].(string),
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
got := <-hit
|
||||||
|
require.Equal(t, "/старт", got[0])
|
||||||
|
require.Equal(t, "аргумент", got[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouter_CommandValuesNotLeakedOnNoMatch verifies that c.Values["command"]
|
||||||
|
// is not set when a command entity is present but no route matches, so a
|
||||||
|
// subsequent text handler doesn't see stale values.
|
||||||
|
func TestRouter_CommandValuesNotLeakedOnNoMatch(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
// Register a text handler that should fire as fallback.
|
||||||
|
leaked := make(chan bool, 1)
|
||||||
|
r.OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||||
|
_, hasCmd := c.Values["command"]
|
||||||
|
leaked <- hasCmd
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
// No OnCommand registered, so the command entity won't match any route.
|
||||||
|
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"},
|
||||||
|
Text: "/unknown",
|
||||||
|
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 8}},
|
||||||
|
}}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
require.False(t, <-leaked, "command value must not leak into text handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MiddlewareOrder(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var order []string
|
||||||
|
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
|
||||||
|
return func(c *Context, u *api.Update) error {
|
||||||
|
order = append(order, "before-1")
|
||||||
|
err := next(c, u)
|
||||||
|
order = append(order, "after-1")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
|
||||||
|
return func(c *Context, u *api.Update) error {
|
||||||
|
order = append(order, "before-2")
|
||||||
|
err := next(c, u)
|
||||||
|
order = append(order, "after-2")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
r.OnCommand("/x", func(c *Context, m *api.Message) error {
|
||||||
|
order = append(order, "handler")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(cmdMessage("/x")))
|
||||||
|
|
||||||
|
require.Equal(t,
|
||||||
|
[]string{"before-1", "before-2", "handler", "after-2", "after-1"},
|
||||||
|
order)
|
||||||
|
}
|
||||||
|
func TestRouter_OnChannelPost(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnChannelPost(func(c *Context, m *api.Message) error {
|
||||||
|
hit <- m.MessageID
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, ChannelPost: &api.Message{
|
||||||
|
MessageID: 99, Chat: api.Chat{ID: -100, Type: string(api.ChatTypeChannel)},
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
require.Equal(t, int64(99), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_RunsAllHandlersForEditedMessage(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var hits []string
|
||||||
|
r.OnEditedMessage(func(c *Context, m *api.Message) error {
|
||||||
|
hits = append(hits, "first")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
r.OnEditedMessage(func(c *Context, m *api.Message) error {
|
||||||
|
hits = append(hits, "second")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, EditedMessage: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "edited",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(u))
|
||||||
|
|
||||||
|
require.Equal(t, []string{"first", "second"}, hits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Filter-route tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRouter_OnMessageFilter_Matches(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ping" }),
|
||||||
|
func(c *Context, m *api.Message) error { hit <- m.Text; return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "ping",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
require.Equal(t, "ping", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnMessageFilter_NoMatch(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
called := false
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return false }),
|
||||||
|
func(c *Context, m *api.Message) error { called = true; return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "any",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(u))
|
||||||
|
require.False(t, called)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Command routes must take priority over filter routes.
|
||||||
|
func TestRouter_OnMessageFilter_CommandWins(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var winner string
|
||||||
|
r.OnCommand("/start", func(c *Context, m *api.Message) error { winner = "command"; return nil })
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||||
|
func(c *Context, m *api.Message) error { winner = "filter"; return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(cmdMessage("/start")))
|
||||||
|
|
||||||
|
require.Equal(t, "command", winner)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnCallbackFilter_Matches(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnCallbackFilter(
|
||||||
|
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return q != nil && q.Data == "yes" }),
|
||||||
|
func(c *Context, q *api.CallbackQuery) error { hit <- q.Data; return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
|
||||||
|
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
require.Equal(t, "yes", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern-based OnCallback wins over OnCallbackFilter when both match.
|
||||||
|
func TestRouter_OnCallbackFilter_PatternWins(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
var winner string
|
||||||
|
r.OnCallback(`^yes$`, func(c *Context, q *api.CallbackQuery) error { winner = "pattern"; return nil })
|
||||||
|
r.OnCallbackFilter(
|
||||||
|
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return true }),
|
||||||
|
func(c *Context, q *api.CallbackQuery) error { winner = "filter"; return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
|
||||||
|
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newFake(u))
|
||||||
|
|
||||||
|
require.Equal(t, "pattern", winner)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnInlineQueryFilter_Matches(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnInlineQueryFilter(
|
||||||
|
Filter[*api.InlineQuery](func(q *api.InlineQuery) bool { return q != nil && q.Query == "find" }),
|
||||||
|
func(c *Context, q *api.InlineQuery) error { hit <- q.Query; return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, InlineQuery: &api.InlineQuery{
|
||||||
|
ID: "i", From: api.User{ID: 1}, Query: "find",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||||
|
|
||||||
|
require.Equal(t, "find", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_FilterChain_Composition(t *testing.T) {
|
||||||
|
// Filter: private chat AND text contains "hello"
|
||||||
|
privateChat := Filter[*api.Message](func(m *api.Message) bool {
|
||||||
|
return m != nil && m.Chat.Type == string(api.ChatTypePrivate)
|
||||||
|
})
|
||||||
|
hasHello := Filter[*api.Message](func(m *api.Message) bool {
|
||||||
|
return m != nil && len(m.Text) > 0 && containsStr(m.Text, "hello")
|
||||||
|
})
|
||||||
|
combined := privateChat.And(hasHello)
|
||||||
|
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnMessageFilter(combined, func(c *Context, m *api.Message) error { hit <- m.Text; return nil })
|
||||||
|
|
||||||
|
match := api.Update{UpdateID: 1, Message: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "say hello",
|
||||||
|
}}
|
||||||
|
noMatch := api.Update{UpdateID: 2, Message: &api.Message{
|
||||||
|
MessageID: 2, Chat: api.Chat{ID: 2, Type: string(api.ChatTypeGroup)}, Text: "say hello",
|
||||||
|
}}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(match, noMatch)) }()
|
||||||
|
|
||||||
|
require.Equal(t, "say hello", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsStr is a helper to avoid importing strings in test file unnecessarily.
|
||||||
|
func containsStr(s, sub string) bool {
|
||||||
|
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsSubstr(s, sub string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(sub); i++ {
|
||||||
|
if s[i:i+len(sub)] == sub {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Concurrent dispatch tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// fakeSlowUpdater feeds n updates then blocks until ctx cancel.
|
||||||
|
type fakeSlowUpdater struct {
|
||||||
|
ch chan api.Update
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSlowFake(ups ...api.Update) *fakeSlowUpdater {
|
||||||
|
ch := make(chan api.Update, len(ups))
|
||||||
|
for _, u := range ups {
|
||||||
|
ch <- u
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return &fakeSlowUpdater{ch: ch}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSlowUpdater) Updates() <-chan api.Update { return f.ch }
|
||||||
|
func (f *fakeSlowUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||||
|
func (f *fakeSlowUpdater) Stop(ctx context.Context) error { return nil }
|
||||||
|
|
||||||
|
func TestRouter_ConcurrentDispatch_AllHandlersFire(t *testing.T) {
|
||||||
|
const n = 100
|
||||||
|
var fired atomic.Int64
|
||||||
|
|
||||||
|
ups := make([]api.Update, n)
|
||||||
|
for i := range ups {
|
||||||
|
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||||
|
MessageID: int64(i + 1),
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: "hi",
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := New(client.New("t"), WithMaxConcurrency(20))
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||||
|
func(c *Context, m *api.Message) error { fired.Add(1); return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newSlowFake(ups...))
|
||||||
|
|
||||||
|
require.Equal(t, int64(n), fired.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ConcurrentDispatch_SemaphoreBoundsConcurrency(t *testing.T) {
|
||||||
|
const limit = 5
|
||||||
|
const n = 30
|
||||||
|
|
||||||
|
var inFlight atomic.Int64
|
||||||
|
var maxSeen atomic.Int64
|
||||||
|
ready := make(chan struct{}) // signals handler to proceed
|
||||||
|
started := make(chan struct{}) // first handler signals it's running
|
||||||
|
|
||||||
|
ups := make([]api.Update, n)
|
||||||
|
for i := range ups {
|
||||||
|
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||||
|
MessageID: int64(i + 1),
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: "hi",
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
once := atomic.Bool{}
|
||||||
|
r := New(client.New("t"), WithMaxConcurrency(limit))
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||||
|
func(c *Context, m *api.Message) error {
|
||||||
|
cur := inFlight.Add(1)
|
||||||
|
for {
|
||||||
|
old := maxSeen.Load()
|
||||||
|
if cur <= old || maxSeen.CompareAndSwap(old, cur) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if once.CompareAndSwap(false, true) {
|
||||||
|
close(started)
|
||||||
|
}
|
||||||
|
<-ready
|
||||||
|
inFlight.Add(-1)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() { _ = r.Run(ctx, newSlowFake(ups...)) }()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("timed out waiting for first handler")
|
||||||
|
}
|
||||||
|
// Give the pool a moment to fill up.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
close(ready)
|
||||||
|
|
||||||
|
// Wait for Run to drain by cancelling context after a short wait.
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
require.LessOrEqual(t, maxSeen.Load(), int64(limit),
|
||||||
|
"in-flight goroutines exceeded semaphore limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ConcurrentDispatch_WaitsForInFlight(t *testing.T) {
|
||||||
|
unblock := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
r := New(client.New("t"), WithMaxConcurrency(10))
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||||
|
func(c *Context, m *api.Message) error {
|
||||||
|
<-unblock
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||||
|
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "hi",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = r.Run(ctx, newSlowFake(u))
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give Run time to pick up the update and launch the goroutine.
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel() // trigger Run to exit its loop
|
||||||
|
|
||||||
|
// Run should not return until handler unblocks.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Fatal("Run returned before in-flight handler finished")
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
close(unblock)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not return after handler finished")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_SerialMode_NoRace(t *testing.T) {
|
||||||
|
// WithMaxConcurrency(0) — serial; shared slice is safe without a mutex.
|
||||||
|
var order []int64
|
||||||
|
|
||||||
|
const n = 20
|
||||||
|
ups := make([]api.Update, n)
|
||||||
|
for i := range ups {
|
||||||
|
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||||
|
MessageID: int64(i + 1),
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: "hi",
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := New(client.New("t"), WithMaxConcurrency(0))
|
||||||
|
r.OnMessageFilter(
|
||||||
|
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||||
|
func(c *Context, m *api.Message) error {
|
||||||
|
order = append(order, m.MessageID)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = r.Run(ctx, newSlowFake(ups...))
|
||||||
|
|
||||||
|
require.Len(t, order, n)
|
||||||
|
for i, v := range order {
|
||||||
|
require.Equal(t, int64(i+1), v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// liveUpdater is an updater whose channel stays open until stopCh is closed.
|
||||||
|
type liveUpdater struct {
|
||||||
|
ch chan api.Update
|
||||||
|
stopCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLiveUpdater() *liveUpdater {
|
||||||
|
return &liveUpdater{ch: make(chan api.Update, 8), stopCh: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *liveUpdater) Send(u api.Update) { l.ch <- u }
|
||||||
|
func (l *liveUpdater) Close() { close(l.stopCh) }
|
||||||
|
func (l *liveUpdater) Updates() <-chan api.Update { return l.ch }
|
||||||
|
func (l *liveUpdater) Stop(ctx context.Context) error { return nil }
|
||||||
|
func (l *liveUpdater) Run(ctx context.Context) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-l.stopCh:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Typed handler tests (Feature 1)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRouter_OnMyChatMember(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnMyChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{
|
||||||
|
From: api.User{ID: 42},
|
||||||
|
Chat: api.Chat{ID: 1},
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(42), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnMyChatMemberFilter(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.From.ID == 99 })
|
||||||
|
r.OnMyChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
|
||||||
|
|
||||||
|
match := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 99}}}
|
||||||
|
noMatch := api.Update{UpdateID: 2, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 1}}}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(noMatch, match)) }()
|
||||||
|
require.Equal(t, int64(99), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnChatMember(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{
|
||||||
|
From: api.User{ID: 1},
|
||||||
|
Chat: api.Chat{ID: 77},
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(77), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnChatMemberFilter(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.Chat.ID == 55 })
|
||||||
|
r.OnChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{Chat: api.Chat{ID: 55}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(55), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnChatJoinRequest(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnChatJoinRequest(func(c *Context, req *api.ChatJoinRequest) error { hit <- req.From.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
|
||||||
|
From: api.User{ID: 11},
|
||||||
|
Chat: api.Chat{ID: 1},
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(11), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnChatJoinRequestFilter(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
f := Filter[*api.ChatJoinRequest](func(req *api.ChatJoinRequest) bool { return req.Chat.ID == 22 })
|
||||||
|
r.OnChatJoinRequestFilter(f, func(c *Context, req *api.ChatJoinRequest) error { hit <- req.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
|
||||||
|
From: api.User{ID: 1},
|
||||||
|
Chat: api.Chat{ID: 22},
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(22), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnPreCheckoutQuery(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnPreCheckoutQuery(func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
|
||||||
|
ID: "q1", From: api.User{ID: 1}, Currency: "USD",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "USD", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnPreCheckoutQueryFilter(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
f := Filter[*api.PreCheckoutQuery](func(q *api.PreCheckoutQuery) bool { return q.Currency == "EUR" })
|
||||||
|
r.OnPreCheckoutQueryFilter(f, func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
|
||||||
|
ID: "q1", From: api.User{ID: 1}, Currency: "EUR",
|
||||||
|
}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "EUR", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnShippingQuery(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnShippingQuery(func(c *Context, q *api.ShippingQuery) error { hit <- q.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ShippingQuery: &api.ShippingQuery{ID: "sq1", From: api.User{ID: 1}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "sq1", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnPoll(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnPoll(func(c *Context, p *api.Poll) error { hit <- p.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, Poll: &api.Poll{ID: "poll1"}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "poll1", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnPollAnswer(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnPollAnswer(func(c *Context, a *api.PollAnswer) error { hit <- a.PollID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, PollAnswer: &api.PollAnswer{PollID: "p1", OptionIds: []int64{0}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "p1", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnChosenInlineResult(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnChosenInlineResult(func(c *Context, res *api.ChosenInlineResult) error { hit <- res.ResultID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ChosenInlineResult: &api.ChosenInlineResult{ResultID: "r1", From: api.User{ID: 1}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "r1", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnMessageReaction(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnMessageReaction(func(c *Context, u *api.MessageReactionUpdated) error { hit <- u.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, MessageReaction: &api.MessageReactionUpdated{Chat: api.Chat{ID: 33}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(33), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnMessageReactionCount(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnMessageReactionCount(func(c *Context, u *api.MessageReactionCountUpdated) error { hit <- u.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, MessageReactionCount: &api.MessageReactionCountUpdated{Chat: api.Chat{ID: 44}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(44), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnChatBoost(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnChatBoost(func(c *Context, u *api.ChatBoostUpdated) error { hit <- u.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, ChatBoost: &api.ChatBoostUpdated{Chat: api.Chat{ID: 55}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(55), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnRemovedChatBoost(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan int64, 1)
|
||||||
|
r.OnRemovedChatBoost(func(c *Context, u *api.ChatBoostRemoved) error { hit <- u.Chat.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, RemovedChatBoost: &api.ChatBoostRemoved{Chat: api.Chat{ID: 66}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, int64(66), <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnBusinessConnection(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnBusinessConnection(func(c *Context, bc *api.BusinessConnection) error { hit <- bc.ID; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, BusinessConnection: &api.BusinessConnection{ID: "bc1", UserChatID: 1, User: api.User{ID: 1}}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "bc1", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_OnPurchasedPaidMedia(t *testing.T) {
|
||||||
|
r := New(client.New("t"))
|
||||||
|
hit := make(chan string, 1)
|
||||||
|
r.OnPurchasedPaidMedia(func(c *Context, p *api.PaidMediaPurchased) error { hit <- p.PaidMediaPayload; return nil })
|
||||||
|
|
||||||
|
upd := api.Update{UpdateID: 1, PurchasedPaidMedia: &api.PaidMediaPurchased{From: api.User{ID: 1}, PaidMediaPayload: "payload1"}}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||||
|
require.Equal(t, "payload1", <-hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ContextCancel_UnblocksWaitingAcquire(t *testing.T) {
|
||||||
|
// Fill the semaphore with slow handlers, send one more update, then cancel
|
||||||
|
// ctx. Run must unblock from the semaphore-acquire select and return.
|
||||||
|
const limit = 2
|
||||||
|
unblock := make(chan struct{})
|
||||||
|
|
||||||
|
slowHandler := func(c *Context, m *api.Message) error {
|
||||||
|
<-unblock
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lu := newLiveUpdater()
|
||||||
|
r := New(client.New("t"), WithMaxConcurrency(limit))
|
||||||
|
r.OnMessageFilter(Filter[*api.Message](func(m *api.Message) bool { return true }), slowHandler)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
runDone := make(chan error, 1)
|
||||||
|
go func() { runDone <- r.Run(ctx, lu) }()
|
||||||
|
|
||||||
|
// Send enough updates to fill semaphore.
|
||||||
|
for i := range limit {
|
||||||
|
lu.Send(api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||||
|
MessageID: int64(i + 1),
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: "hi",
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give goroutines time to acquire all semaphore slots.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send one more update — Run will block trying to acquire the full semaphore.
|
||||||
|
lu.Send(api.Update{UpdateID: int64(limit + 1), Message: &api.Message{
|
||||||
|
MessageID: int64(limit + 1),
|
||||||
|
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||||
|
Text: "extra",
|
||||||
|
}})
|
||||||
|
|
||||||
|
// Give Run a moment to reach the semaphore-acquire select.
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Unblock handlers so wg.Wait() inside Run can complete, allowing Run to
|
||||||
|
// return (and write to runDone).
|
||||||
|
close(unblock)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-runDone:
|
||||||
|
require.Error(t, err)
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not unblock after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// Package gotelegram is the module root.
|
||||||
|
//
|
||||||
|
// The public API lives in the api, client, transport, and dispatch packages.
|
||||||
|
// See https://github.com/lukaszraczylo/go-telegram for documentation.
|
||||||
|
package gotelegram
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
# Examples
|
||||||
|
|
||||||
|
Each subdirectory contains a self-contained sample bot demonstrating one feature area.
|
||||||
|
|
||||||
|
| Example | What it shows |
|
||||||
|
|---|---|
|
||||||
|
| [echo](./echo) | Long-poll bot that echoes text back to the sender |
|
||||||
|
| [webhook](./webhook) | Webhook delivery with secret-token verification |
|
||||||
|
| [callback](./callback) | Inline keyboard with callback queries and counter state |
|
||||||
|
| [conversation](./conversation) | Multi-step conversation flow with `dispatch/conversation` |
|
||||||
|
| [files](./files) | Upload and download files via `api.DownloadFile` |
|
||||||
|
| [inline](./inline) | Inline-mode bot returning search-style results |
|
||||||
|
| [middleware](./middleware) | Custom middleware chains via `Router.Use` |
|
||||||
|
| [stateful](./stateful) | Per-user state managed via closures |
|
||||||
|
| [welcome](./welcome) | Greet new chat members; detect and log departures |
|
||||||
|
| [moderation](./moderation) | `/kick`, `/ban`, `/mute`, `/warn` with admin permission checks |
|
||||||
|
| [polls](./polls) | Create polls and tally answers via `OnPollAnswer` |
|
||||||
|
| [payments](./payments) | Telegram Payments: sendInvoice → pre_checkout_query → successful_payment |
|
||||||
|
| [pagination](./pagination) | Multi-page inline keyboard with stateless prev/next navigation |
|
||||||
|
| [admin](./admin) | Auth middleware allowlisting specific user IDs via `Router.Use` |
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
All examples follow the same pattern:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export TELEGRAM_BOT_TOKEN=123456:ABC...
|
||||||
|
go run ./examples/<name>
|
||||||
|
```
|
||||||
|
|
||||||
|
Webhook examples need a public HTTPS endpoint (use Cloudflare Tunnel, ngrok, or similar).
|
||||||
|
|
||||||
|
## Common patterns
|
||||||
|
|
||||||
|
**Retry-safe HTTP** — every example wraps the HTTP client with `client.NewRetryDoer`, which automatically honours Telegram's `retry_after` field on 429 responses.
|
||||||
|
|
||||||
|
**Graceful shutdown** — all examples use `signal.NotifyContext` so the bot drains cleanly on `SIGINT`/`SIGTERM`.
|
||||||
|
|
||||||
|
**Structured logging** — for production, wire a logger via `client.WithLogger` and wrap the process in supervision (systemd unit, k8s liveness probe, etc.).
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
# admin
|
||||||
|
|
||||||
|
Authentication middleware that restricts the bot to an allowlist of Telegram user IDs.
|
||||||
|
|
||||||
|
## What it shows
|
||||||
|
|
||||||
|
- `router.Use(...)` to install a global `Middleware[*api.Update]`
|
||||||
|
- Parsing `ALLOWED_USERS` env var into a `map[int64]bool` lookup set
|
||||||
|
- Extracting sender ID from multiple update types in one helper
|
||||||
|
- Silent drop pattern for unauthorized updates (no error, no reply)
|
||||||
|
|
||||||
|
## Environment variables
|
||||||
|
|
||||||
|
| Variable | Required | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `TELEGRAM_BOT_TOKEN` | Yes | Bot token from @BotFather |
|
||||||
|
| `ALLOWED_USERS` | No | Comma-separated numeric user IDs, e.g. `123456,789012`. If unset, all users are permitted. |
|
||||||
|
|
||||||
|
## Finding your user ID
|
||||||
|
|
||||||
|
Send `/whoami` to the bot — it replies with your numeric Telegram user ID. Add that ID to `ALLOWED_USERS` to restrict the bot to you.
|
||||||
|
|
||||||
|
## Extending
|
||||||
|
|
||||||
|
Combine with `examples/moderation` to ensure only group admins can invoke moderation commands:
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.Use(allowlistMiddleware(adminIDs))
|
||||||
|
router.OnCommand("/ban", banHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
For group-context admin checks (verify the sender is an admin of *that specific group*), use `api.GetChatAdministrators` and check the result dynamically rather than a static ID list.
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export TELEGRAM_BOT_TOKEN=123456:ABC...
|
||||||
|
export ALLOWED_USERS=111111,222222
|
||||||
|
go run ./examples/admin
|
||||||
|
```
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user