Initial release of go-telegram

A fully-generated, strongly-typed Go client for the Telegram Bot API.

* 176 methods + 301 types generated from Bot API v10.0
* 1408 auto-generated tests (8 scenarios per method)
* Typed unions throughout — no 'any' in the public surface
* Pluggable HTTP transport and JSON codec (default goccy/go-json)
* Built-in retry middleware honouring Telegram's retry_after
* Generic dispatcher with filters and conversation handlers
* Self-verifying codegen pipeline (regen → audit → emit → run tests)
* 14 example bots covering common patterns
This commit is contained in:
2026-05-09 13:09:27 +01:00
commit ac7cae8fa7
164 changed files with 100239 additions and 0 deletions
+40
View File
@@ -0,0 +1,40 @@
// Package dispatch provides a typed router for Telegram updates. It
// consumes any transport.Updater and dispatches updates to handlers
// registered by command, regex, or update-payload kind.
package dispatch
import (
"context"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
)
// Context bundles the per-update state every handler receives.
//
// Ctx is the request context propagated from Router.Run; cancelling the
// run cancels every handler.
//
// Bot is the API client. Handlers reply by calling api.SendMessage(c.Ctx,
// c.Bot, ...) etc.
//
// Update is the raw update; payload-typed handlers also receive a
// narrowed pointer to one of its sub-fields.
//
// Values is a per-update bag matchers populate. Conventional keys:
//
// "command": string, the matched bot command (e.g. "/start")
// "command_args": string, everything after the command
// "regex_match": []string, regex sub-matches when OnText matches
type Context struct {
Ctx context.Context
Bot *client.Bot
Update *api.Update
Values map[string]any
}
// NewContext constructs a Context. Used by Router internally; exposed for
// custom test harnesses.
func NewContext(ctx context.Context, b *client.Bot, u *api.Update) *Context {
return &Context{Ctx: ctx, Bot: b, Update: u, Values: map[string]any{}}
}
+494
View File
@@ -0,0 +1,494 @@
package conversation_test
import (
"context"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/lukaszraczylo/go-telegram/dispatch"
"github.com/lukaszraczylo/go-telegram/dispatch/conversation"
"github.com/stretchr/testify/require"
)
// ---- helpers ---------------------------------------------------------------
func msgUpd(userID, chatID int64, text string) api.Update {
return api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1,
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
Text: text,
},
}
}
func makeCtx(u *api.Update) *dispatch.Context {
return dispatch.NewContext(context.Background(), client.New("t"), u)
}
// anyMsg matches any update that has a Message.
var anyMsg = func(u *api.Update) bool { return u.Message != nil }
// hasPrefix returns a filter matching updates whose Message.Text has prefix p.
func hasPrefix(p string) dispatch.Filter[*api.Update] {
return func(u *api.Update) bool {
return u.Message != nil && strings.HasPrefix(u.Message.Text, p)
}
}
// fakeUpdater feeds a fixed set of updates then closes (mirrors router_test.go).
type fakeUpdater struct{ ch chan api.Update }
func newFake(ups ...api.Update) *fakeUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeUpdater{ch: ch}
}
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
// ---- Storage tests ---------------------------------------------------------
func TestStorage_ErrKeyNotFound(t *testing.T) {
s := conversation.NewMemoryStorage()
_, err := s.Get(context.Background(), "missing")
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestStorage_SetAndGet(t *testing.T) {
ctx := context.Background()
s := conversation.NewMemoryStorage()
require.NoError(t, s.Set(ctx, "k", "state_a"))
v, err := s.Get(ctx, "k")
require.NoError(t, err)
require.Equal(t, conversation.State("state_a"), v)
}
func TestStorage_Delete(t *testing.T) {
ctx := context.Background()
s := conversation.NewMemoryStorage()
require.NoError(t, s.Set(ctx, "k", "state_a"))
require.NoError(t, s.Delete(ctx, "k"))
_, err := s.Get(ctx, "k")
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestStorage_DeleteNonExistentIsNoop(t *testing.T) {
require.NoError(t, conversation.NewMemoryStorage().Delete(context.Background(), "gone"))
}
// ---- Key strategy tests ----------------------------------------------------
func TestKeyByUser_Variants(t *testing.T) {
t.Run("message", func(t *testing.T) {
u := msgUpd(42, 100, "hi")
require.Equal(t, "u:42", conversation.KeyByUser(&u))
})
t.Run("edited_message", func(t *testing.T) {
u := api.Update{EditedMessage: &api.Message{From: &api.User{ID: 7}, Chat: api.Chat{ID: 1}}}
require.Equal(t, "u:7", conversation.KeyByUser(&u))
})
t.Run("callback_query", func(t *testing.T) {
u := api.Update{CallbackQuery: &api.CallbackQuery{From: api.User{ID: 99}}}
require.Equal(t, "u:99", conversation.KeyByUser(&u))
})
t.Run("inline_query", func(t *testing.T) {
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.Equal(t, "u:5", conversation.KeyByUser(&u))
})
t.Run("empty", func(t *testing.T) {
require.Equal(t, "", conversation.KeyByUser(&api.Update{}))
})
}
func TestKeyByChat_Variants(t *testing.T) {
t.Run("message", func(t *testing.T) {
u := msgUpd(1, 200, "")
require.Equal(t, "c:200", conversation.KeyByChat(&u))
})
t.Run("inline_has_no_chat", func(t *testing.T) {
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.Equal(t, "", conversation.KeyByChat(&u))
})
}
func TestKeyByUserAndChat(t *testing.T) {
u := msgUpd(42, 100, "")
require.Equal(t, "uc:100:42", conversation.KeyByUserAndChat(&u))
}
// ---- Handler / state machine tests -----------------------------------------
func buildConv() *conversation.Conversation {
return &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("await_age") },
}},
"await_age": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
},
Exits: []conversation.Step{{
Filter: hasPrefix("/cancel"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
}
}
func TestConversation_FullFlow(t *testing.T) {
conv := buildConv()
var downstream int
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstream++
return nil
})
mw := conv.Dispatch(noop)
key := "uc:1:42"
// 1. /start → enters, state = await_name
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
v, err := conv.Storage.Get(context.Background(), key)
require.NoError(t, err)
require.Equal(t, conversation.State("await_name"), v)
require.Equal(t, 0, downstream, "entry claimed update")
// 2. name → state = await_age
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, err = conv.Storage.Get(context.Background(), key)
require.NoError(t, err)
require.Equal(t, conversation.State("await_age"), v)
// 3. age → End, key deleted
u3 := msgUpd(42, 1, "30")
require.NoError(t, mw(makeCtx(&u3), &u3))
_, err = conv.Storage.Get(context.Background(), key)
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestConversation_ExitsCancelMidFlow(t *testing.T) {
conv := buildConv()
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start conversation.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
require.NoError(t, err)
// Cancel mid-flow.
u2 := msgUpd(42, 1, "/cancel")
require.NoError(t, mw(makeCtx(&u2), &u2))
_, err = conv.Storage.Get(context.Background(), "uc:1:42")
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "exit should clear state")
}
func TestConversation_FallbackFiresWhenNoStateStepMatches(t *testing.T) {
fallbackHit := false
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
}},
States: map[conversation.State][]conversation.Step{
// No steps for "waiting" that match a callback query.
"waiting": {},
},
Fallbacks: []conversation.Step{{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
fallbackHit = true
return nil
},
}},
}
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
u2 := msgUpd(42, 1, "unexpected text")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.True(t, fallbackHit, "fallback should have fired")
}
func TestConversation_NoActiveConv_PassesToDownstream(t *testing.T) {
conv := buildConv()
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
// Random message that doesn't match /start
u := msgUpd(42, 1, "hello")
require.NoError(t, mw(makeCtx(&u), &u))
require.True(t, downstreamHit, "unmatched update should reach downstream")
}
func TestConversation_EmptyKey_PassesThrough(t *testing.T) {
// InlineQuery has no chatID → KeyByUserAndChat returns "" → pass through.
conv := buildConv()
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.NoError(t, mw(makeCtx(&u), &u))
require.True(t, downstreamHit)
}
func TestConversation_AllowReEntry(t *testing.T) {
conv := buildConv()
conv.AllowReEntry = true
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_name"), v)
// Advance once.
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_age"), v)
// Re-enter with /start — should restart to await_name even though mid-flow.
u3 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u3), &u3))
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_name"), v, "AllowReEntry should restart")
}
func TestConversation_NoReEntry_EntryIgnoredWhenActive(t *testing.T) {
conv := buildConv()
conv.AllowReEntry = false
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
// Advance to await_age.
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_age"), v)
// /start again — should NOT restart; state should stay await_age since
// /start matches the state step filter (anyMsg) and advances.
// Actually /start is handled by "await_age" anyMsg step → End().
u3 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u3), &u3))
// State ended (End() called by await_age step).
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "state step should have consumed /start when AllowReEntry=false")
}
func TestConversation_StayInState_NilReturn(t *testing.T) {
// Handler returning nil keeps state unchanged.
stored := false
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.Next("waiting")
},
}},
States: map[conversation.State][]conversation.Step{
"waiting": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
stored = true
return nil // stay in current state
},
}},
},
}
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
u2 := msgUpd(42, 1, "something")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.True(t, stored)
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("waiting"), v, "nil return should leave state unchanged")
}
func TestConversation_ActiveNoMatch_Swallows(t *testing.T) {
// Active conversation with no matching state step and no fallback:
// update is swallowed (not passed downstream).
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
}},
States: map[conversation.State][]conversation.Step{
"waiting": {{
// Only matches /done specifically.
Filter: hasPrefix("/done"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
},
}
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
// Random text doesn't match /done and there's no fallback → swallowed.
u2 := msgUpd(42, 1, "random")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.False(t, downstreamHit, "active conv with no matching step should swallow")
}
// ---- Via Router.Run --------------------------------------------------------
func TestConversation_ViaRouter(t *testing.T) {
var steps atomic.Int32
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.Next("await_age")
},
}},
"await_age": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.End()
},
}},
},
}
router := dispatch.New(client.New("t"), dispatch.WithMaxConcurrency(0)) // serial
router.Use(conv.Dispatch)
ups := []api.Update{
msgUpd(42, 1, "/start"),
msgUpd(42, 1, "Alice"),
msgUpd(42, 1, "30"),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- router.Run(ctx, newFake(ups...)) }()
// Wait for updater channel to drain (Run returns when closed).
err := <-errCh
if err != nil && err != context.Canceled {
t.Fatalf("Run error: %v", err)
}
require.Equal(t, int32(3), steps.Load(), "all three steps should have fired")
}
// ---- Concurrent storage safety ---------------------------------------------
func TestConversation_ConcurrentStorageAccess(t *testing.T) {
// 15 goroutines each running a full /start → name → age flow against the
// same shared storage but DIFFERENT keys (one per goroutine). Validates
// no data races.
const numUsers = 15
conv := buildConv()
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
var wg sync.WaitGroup
wg.Add(numUsers)
for i := 0; i < numUsers; i++ {
go func(uid int64) {
defer wg.Done()
u1 := msgUpd(uid, uid, "/start")
_ = mw(makeCtx(&u1), &u1)
u2 := msgUpd(uid, uid, "Alice")
_ = mw(makeCtx(&u2), &u2)
u3 := msgUpd(uid, uid, "30")
_ = mw(makeCtx(&u3), &u3)
}(int64(i + 1))
}
wg.Wait()
// Race detector catches bugs; no assertion needed beyond clean finish.
}
func TestConversation_ConcurrentSameKey(t *testing.T) {
// 12 goroutines hammer the same key concurrently. Storage must not panic
// or corrupt state. Race detector validates lock discipline.
const goroutines = 12
s := conversation.NewMemoryStorage()
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
_ = s.Set(ctx, "shared", conversation.State("step"))
_, _ = s.Get(ctx, "shared")
if i%4 == 0 {
_ = s.Delete(ctx, "shared")
}
}(i)
}
wg.Wait()
}
+176
View File
@@ -0,0 +1,176 @@
package conversation
import (
"context"
"errors"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// stateTransition is a sentinel error type carrying a state transition
// or end signal. Conversation handlers return one of these (via Next or
// End helpers below) to drive the state machine.
type stateTransition struct {
next State
end bool
}
func (e *stateTransition) Error() string {
if e.end {
return "conversation: end"
}
return "conversation: → " + string(e.next)
}
// Next signals the conversation should advance to the given state.
// Conversation handlers return Next("state_name") to transition.
func Next(s State) error {
return &stateTransition{next: s}
}
// End signals the conversation has finished and state should be cleared.
// Conversation handlers return End() to terminate.
func End() error {
return &stateTransition{end: true}
}
// Handler defines a step in the conversation. Receives the dispatch context
// and the raw update. Returns:
// - nil to stay in the current state
// - Next("state") to transition to a different state
// - End() to end the conversation
// - any other non-nil error to surface to the dispatcher (state unchanged)
type Handler func(ctx *dispatch.Context, u *api.Update) error
// Step pairs a filter with a handler for one conversation step.
type Step struct {
Filter dispatch.Filter[*api.Update]
Handler Handler
}
// Conversation is a stateful handler with entry, per-state, exit and
// fallback steps. A conversation is keyed by KeyStrategy (default
// KeyByUserAndChat) and persisted by Storage (default in-memory).
type Conversation struct {
// EntryPoints starts a new conversation when a matching filter fires
// and no conversation is already active for the key.
EntryPoints []Step
// States maps each state to the steps that handle it.
States map[State][]Step
// Exits, if any match, end the active conversation early. Useful for
// /cancel-style commands.
Exits []Step
// Fallbacks run when no state step matches the current update.
Fallbacks []Step
// Storage persists conversation state. Defaults to NewMemoryStorage.
Storage Storage
// KeyStrategy derives the persistence key. Defaults to KeyByUserAndChat.
KeyStrategy KeyStrategy
// AllowReEntry, when true, lets entry-point steps fire even while a
// conversation is already active for the key (effectively restarting it).
AllowReEntry bool
}
// Dispatch is a global middleware-shaped Handler that consumes updates
// and routes them through the conversation graph. Register via
// router.Use(conv.Dispatch).
//
// If the conversation claims an update, downstream handlers are skipped.
// If the conversation does not claim it, downstream handlers run as normal.
func (c *Conversation) Dispatch(next dispatch.Handler[*api.Update]) dispatch.Handler[*api.Update] {
if c.Storage == nil {
c.Storage = NewMemoryStorage()
}
if c.KeyStrategy == nil {
c.KeyStrategy = KeyByUserAndChat
}
return func(dctx *dispatch.Context, u *api.Update) error {
key := c.KeyStrategy(u)
if key == "" {
return next(dctx, u)
}
ctx := dctx.Ctx
current, err := c.Storage.Get(ctx, key)
if err != nil && !errors.Is(err, ErrKeyNotFound) {
return err
}
active := !errors.Is(err, ErrKeyNotFound)
// Try exits first (always allowed if conversation is active).
if active {
for _, step := range c.Exits {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
}
// Try entry points (only if no active conversation, or AllowReEntry).
if !active || c.AllowReEntry {
for _, step := range c.EntryPoints {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
}
if !active {
return next(dctx, u)
}
// Active conversation: try state steps.
for _, step := range c.States[current] {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
// Fallbacks if no state step matched.
for _, step := range c.Fallbacks {
if step.Filter(u) {
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
return err
}
return nil
}
}
// Active conversation but no step matched and no fallback: swallow the
// update (do NOT pass to downstream handlers, since the user is
// mid-conversation and an unrelated handler would surprise them).
return nil
}
}
// runStep invokes the handler and applies its return-value state transition.
func (c *Conversation) runStep(ctx context.Context, dctx *dispatch.Context, u *api.Update, key string, h Handler) error {
err := h(dctx, u)
if err == nil {
return nil
}
var trans *stateTransition
if errors.As(err, &trans) {
if trans.end {
return c.Storage.Delete(ctx, key)
}
return c.Storage.Set(ctx, key, trans.next)
}
return err
}
+43
View File
@@ -0,0 +1,43 @@
package conversation
import (
"context"
"sync"
)
// MemoryStorage is the default in-process Storage. It is safe for
// concurrent use. Conversation state is lost on process restart; use
// a custom Storage backed by a database for persistent flows.
type MemoryStorage struct {
mu sync.RWMutex
state map[string]State
}
// NewMemoryStorage constructs an empty in-memory storage.
func NewMemoryStorage() *MemoryStorage {
return &MemoryStorage{state: map[string]State{}}
}
func (s *MemoryStorage) Get(_ context.Context, key string) (State, error) {
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.state[key]
if !ok {
return "", ErrKeyNotFound
}
return v, nil
}
func (s *MemoryStorage) Set(_ context.Context, key string, state State) error {
s.mu.Lock()
defer s.mu.Unlock()
s.state[key] = state
return nil
}
func (s *MemoryStorage) Delete(_ context.Context, key string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.state, key)
return nil
}
+87
View File
@@ -0,0 +1,87 @@
package conversation
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
func TestMemoryStorage_GetSetDelete(t *testing.T) {
ctx := context.Background()
s := NewMemoryStorage()
// Get on empty key returns ErrKeyNotFound.
_, err := s.Get(ctx, "k1")
require.ErrorIs(t, err, ErrKeyNotFound)
// Set then Get returns the stored state.
require.NoError(t, s.Set(ctx, "k1", "step_a"))
v, err := s.Get(ctx, "k1")
require.NoError(t, err)
require.Equal(t, State("step_a"), v)
// Overwrite works.
require.NoError(t, s.Set(ctx, "k1", "step_b"))
v, err = s.Get(ctx, "k1")
require.NoError(t, err)
require.Equal(t, State("step_b"), v)
// Delete removes the key.
require.NoError(t, s.Delete(ctx, "k1"))
_, err = s.Get(ctx, "k1")
require.ErrorIs(t, err, ErrKeyNotFound)
// Delete of non-existent key is a no-op (no error).
require.NoError(t, s.Delete(ctx, "nonexistent"))
}
func TestMemoryStorage_MultipleKeys(t *testing.T) {
ctx := context.Background()
s := NewMemoryStorage()
require.NoError(t, s.Set(ctx, "a", "stateA"))
require.NoError(t, s.Set(ctx, "b", "stateB"))
va, err := s.Get(ctx, "a")
require.NoError(t, err)
require.Equal(t, State("stateA"), va)
vb, err := s.Get(ctx, "b")
require.NoError(t, err)
require.Equal(t, State("stateB"), vb)
// Delete one key; the other remains.
require.NoError(t, s.Delete(ctx, "a"))
_, err = s.Get(ctx, "a")
require.ErrorIs(t, err, ErrKeyNotFound)
vb, err = s.Get(ctx, "b")
require.NoError(t, err)
require.Equal(t, State("stateB"), vb)
}
func TestMemoryStorage_Concurrent(t *testing.T) {
// 20 goroutines hammering the same key concurrently — no data race.
ctx := context.Background()
s := NewMemoryStorage()
const goroutines = 20
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
key := "shared"
_ = s.Set(ctx, key, State("step"))
_, _ = s.Get(ctx, key)
if i%3 == 0 {
_ = s.Delete(ctx, key)
}
}(i)
}
wg.Wait()
// No assertion needed — race detector catches the bug if present.
}
+79
View File
@@ -0,0 +1,79 @@
package conversation
import (
"fmt"
"github.com/lukaszraczylo/go-telegram/api"
)
// KeyStrategy derives a persistence key from an update. Strategies
// determine how conversation scope works — per-user, per-chat, or
// per-user-and-chat. Implementations must return a stable string for
// the same logical scope across updates.
//
// Returns the empty string if the update doesn't have enough context
// to derive a key (in which case the conversation handler skips it).
type KeyStrategy func(u *api.Update) string
// KeyByUser derives a key from the sending user's ID. Useful for DM
// conversations and any flow that should follow the user across chats.
var KeyByUser KeyStrategy = func(u *api.Update) string {
if uid := userID(u); uid != 0 {
return fmt.Sprintf("u:%d", uid)
}
return ""
}
// KeyByChat derives a key from the chat ID. Useful for group flows where
// any user in the chat can drive the conversation.
var KeyByChat KeyStrategy = func(u *api.Update) string {
if cid := chatID(u); cid != 0 {
return fmt.Sprintf("c:%d", cid)
}
return ""
}
// KeyByUserAndChat derives a key from both user and chat IDs. The most
// common strategy: each user has their own conversation per chat.
var KeyByUserAndChat KeyStrategy = func(u *api.Update) string {
uid := userID(u)
cid := chatID(u)
if uid == 0 || cid == 0 {
return ""
}
return fmt.Sprintf("uc:%d:%d", cid, uid)
}
// userID extracts the sending user's ID from any update payload.
func userID(u *api.Update) int64 {
switch {
case u.Message != nil && u.Message.From != nil:
return u.Message.From.ID
case u.EditedMessage != nil && u.EditedMessage.From != nil:
return u.EditedMessage.From.ID
case u.CallbackQuery != nil:
return u.CallbackQuery.From.ID
case u.InlineQuery != nil:
return u.InlineQuery.From.ID
}
return 0
}
// chatID extracts the relevant chat ID.
func chatID(u *api.Update) int64 {
switch {
case u.Message != nil:
return u.Message.Chat.ID
case u.EditedMessage != nil:
return u.EditedMessage.Chat.ID
case u.ChannelPost != nil:
return u.ChannelPost.Chat.ID
case u.EditedChannelPost != nil:
return u.EditedChannelPost.Chat.ID
case u.CallbackQuery != nil && u.CallbackQuery.Message != nil:
if msg, ok := u.CallbackQuery.Message.(*api.Message); ok {
return msg.Chat.ID
}
}
return 0
}
+115
View File
@@ -0,0 +1,115 @@
package conversation
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/stretchr/testify/require"
)
// helpers to build api.Update variants.
func msgUpdate(userID, chatID int64) *api.Update {
return &api.Update{
Message: &api.Message{
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
},
}
}
func editedMsgUpdate(userID, chatID int64) *api.Update {
return &api.Update{
EditedMessage: &api.Message{
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
},
}
}
func callbackUpdate(userID, chatID int64) *api.Update {
return &api.Update{
CallbackQuery: &api.CallbackQuery{
From: api.User{ID: userID},
Message: &api.Message{Chat: api.Chat{ID: chatID}},
},
}
}
func inlineUpdate(userID int64) *api.Update {
return &api.Update{
InlineQuery: &api.InlineQuery{
From: api.User{ID: userID},
},
}
}
func emptyUpdate() *api.Update { return &api.Update{} }
func TestKeyByUser(t *testing.T) {
t.Run("message update", func(t *testing.T) {
require.Equal(t, "u:42", KeyByUser(msgUpdate(42, 100)))
})
t.Run("edited message", func(t *testing.T) {
require.Equal(t, "u:7", KeyByUser(editedMsgUpdate(7, 100)))
})
t.Run("callback query", func(t *testing.T) {
require.Equal(t, "u:99", KeyByUser(callbackUpdate(99, 100)))
})
t.Run("inline query", func(t *testing.T) {
require.Equal(t, "u:5", KeyByUser(inlineUpdate(5)))
})
t.Run("empty update returns empty string", func(t *testing.T) {
require.Equal(t, "", KeyByUser(emptyUpdate()))
})
}
func TestKeyByChat(t *testing.T) {
t.Run("message update", func(t *testing.T) {
require.Equal(t, "c:100", KeyByChat(msgUpdate(42, 100)))
})
t.Run("edited message", func(t *testing.T) {
require.Equal(t, "c:200", KeyByChat(editedMsgUpdate(7, 200)))
})
t.Run("callback with accessible message", func(t *testing.T) {
require.Equal(t, "c:300", KeyByChat(callbackUpdate(99, 300)))
})
t.Run("inline query has no chat → empty", func(t *testing.T) {
require.Equal(t, "", KeyByChat(inlineUpdate(5)))
})
t.Run("empty update returns empty string", func(t *testing.T) {
require.Equal(t, "", KeyByChat(emptyUpdate()))
})
}
func TestKeyByUserAndChat(t *testing.T) {
t.Run("message update", func(t *testing.T) {
require.Equal(t, "uc:100:42", KeyByUserAndChat(msgUpdate(42, 100)))
})
t.Run("edited message", func(t *testing.T) {
require.Equal(t, "uc:200:7", KeyByUserAndChat(editedMsgUpdate(7, 200)))
})
t.Run("callback query", func(t *testing.T) {
require.Equal(t, "uc:300:99", KeyByUserAndChat(callbackUpdate(99, 300)))
})
t.Run("inline query has no chat → empty", func(t *testing.T) {
require.Equal(t, "", KeyByUserAndChat(inlineUpdate(5)))
})
t.Run("empty update returns empty string", func(t *testing.T) {
require.Equal(t, "", KeyByUserAndChat(emptyUpdate()))
})
}
func TestKeyByUserAndChat_CallbackInaccessibleMessage(t *testing.T) {
// CallbackQuery.Message is InaccessibleMessage (not *Message) — chatID returns 0.
u := &api.Update{
CallbackQuery: &api.CallbackQuery{
From: api.User{ID: 10},
Message: &api.InaccessibleMessage{}, // implements MaybeInaccessibleMessage, not *api.Message
},
}
// userID picks up From.ID=10 but chatID fails type assertion → 0
require.Equal(t, "", KeyByUserAndChat(u), "no key when message inaccessible")
// KeyByUser still works since From is set.
require.Equal(t, "u:10", KeyByUser(u))
}
+9
View File
@@ -0,0 +1,9 @@
// Package conversation implements a stateful conversation handler for the
// go-telegram dispatch router. It provides a state-machine abstraction over
// multi-step Telegram bot interactions, with pluggable storage and flexible
// key strategies.
package conversation
// State is a label identifying a node in the conversation graph.
// The empty string is the implicit "no active conversation" state.
type State string
+20
View File
@@ -0,0 +1,20 @@
package conversation
import (
"context"
"errors"
)
// ErrKeyNotFound is returned by Storage.Get when no conversation is active
// for the given key.
var ErrKeyNotFound = errors.New("conversation: key not found")
// Storage persists per-user (or per-chat, per-message — depending on the
// KeyStrategy in use) conversation state across update deliveries.
//
// Implementations must be safe for concurrent use.
type Storage interface {
Get(ctx context.Context, key string) (State, error)
Set(ctx context.Context, key string, state State) error
Delete(ctx context.Context, key string) error
}
+70
View File
@@ -0,0 +1,70 @@
package dispatch
// Filter is a predicate over a typed payload (e.g. *api.Message). Filters
// compose via And/Or/Not for multi-condition matching.
//
// Example:
//
// f := message.HasPhoto().And(message.InChat(-100123456789))
type Filter[T any] func(payload T) bool
// And returns a Filter that matches iff f and every one of others matches.
func (f Filter[T]) And(others ...Filter[T]) Filter[T] {
return func(payload T) bool {
if !f(payload) {
return false
}
for _, o := range others {
if !o(payload) {
return false
}
}
return true
}
}
// Or returns a Filter that matches iff f matches OR any of others matches.
func (f Filter[T]) Or(others ...Filter[T]) Filter[T] {
return func(payload T) bool {
if f(payload) {
return true
}
for _, o := range others {
if o(payload) {
return true
}
}
return false
}
}
// Not returns a Filter that inverts f.
func (f Filter[T]) Not() Filter[T] {
return func(payload T) bool { return !f(payload) }
}
// All combines filters with AND. Returns a Filter that matches when all match.
// Returns a filter that always matches when filters is empty.
func All[T any](filters ...Filter[T]) Filter[T] {
return func(payload T) bool {
for _, f := range filters {
if !f(payload) {
return false
}
}
return true
}
}
// Any combines filters with OR. Returns a Filter that matches when at least
// one matches. Returns a filter that never matches when filters is empty.
func Any[T any](filters ...Filter[T]) Filter[T] {
return func(payload T) bool {
for _, f := range filters {
if f(payload) {
return true
}
}
return false
}
}
+87
View File
@@ -0,0 +1,87 @@
package dispatch
import (
"testing"
"github.com/stretchr/testify/require"
)
func alwaysTrue[T any]() Filter[T] { return func(_ T) bool { return true } }
func alwaysFalse[T any]() Filter[T] { return func(_ T) bool { return false } }
func TestFilter_And(t *testing.T) {
t.Run("all true", func(t *testing.T) {
f := alwaysTrue[int]().And(alwaysTrue[int](), alwaysTrue[int]())
require.True(t, f(0))
})
t.Run("first false", func(t *testing.T) {
f := alwaysFalse[int]().And(alwaysTrue[int]())
require.False(t, f(0))
})
t.Run("other false", func(t *testing.T) {
f := alwaysTrue[int]().And(alwaysFalse[int]())
require.False(t, f(0))
})
t.Run("no others — acts as identity", func(t *testing.T) {
require.True(t, alwaysTrue[int]().And()(0))
require.False(t, alwaysFalse[int]().And()(0))
})
}
func TestFilter_Or(t *testing.T) {
t.Run("first true", func(t *testing.T) {
f := alwaysTrue[int]().Or(alwaysFalse[int]())
require.True(t, f(0))
})
t.Run("other true", func(t *testing.T) {
f := alwaysFalse[int]().Or(alwaysTrue[int]())
require.True(t, f(0))
})
t.Run("all false", func(t *testing.T) {
f := alwaysFalse[int]().Or(alwaysFalse[int]())
require.False(t, f(0))
})
t.Run("no others", func(t *testing.T) {
require.True(t, alwaysTrue[int]().Or()(0))
require.False(t, alwaysFalse[int]().Or()(0))
})
}
func TestFilter_Not(t *testing.T) {
require.False(t, alwaysTrue[int]().Not()(0))
require.True(t, alwaysFalse[int]().Not()(0))
}
func TestAll(t *testing.T) {
t.Run("all true", func(t *testing.T) {
require.True(t, All(alwaysTrue[int](), alwaysTrue[int]())(0))
})
t.Run("one false", func(t *testing.T) {
require.False(t, All(alwaysTrue[int](), alwaysFalse[int]())(0))
})
t.Run("empty — always true", func(t *testing.T) {
require.True(t, All[int]()(0))
})
}
func TestAny(t *testing.T) {
t.Run("one true", func(t *testing.T) {
require.True(t, Any(alwaysFalse[int](), alwaysTrue[int]())(0))
})
t.Run("all false", func(t *testing.T) {
require.False(t, Any(alwaysFalse[int](), alwaysFalse[int]())(0))
})
t.Run("empty — always false", func(t *testing.T) {
require.False(t, Any[int]()(0))
})
}
func TestFilter_Composition(t *testing.T) {
// (true AND false) OR true == true
f := alwaysTrue[int]().And(alwaysFalse[int]()).Or(alwaysTrue[int]())
require.True(t, f(0))
// NOT (true OR false) == false
g := alwaysTrue[int]().Or(alwaysFalse[int]()).Not()
require.False(t, g(0))
}
+43
View File
@@ -0,0 +1,43 @@
// Package callback provides Filter helpers for *api.CallbackQuery payloads.
package callback
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Data returns a Filter that matches callback queries whose Data matches
// pattern (regex). Panics at registration time on an invalid pattern.
func Data(pattern string) dispatch.Filter[*api.CallbackQuery] {
re := regexp.MustCompile(pattern)
return func(q *api.CallbackQuery) bool {
return q != nil && re.MatchString(q.Data)
}
}
// DataEquals returns a Filter that matches callback queries whose Data equals
// s exactly.
func DataEquals(s string) dispatch.Filter[*api.CallbackQuery] {
return func(q *api.CallbackQuery) bool {
return q != nil && q.Data == s
}
}
// DataPrefix returns a Filter that matches callback queries whose Data starts
// with prefix.
func DataPrefix(prefix string) dispatch.Filter[*api.CallbackQuery] {
return func(q *api.CallbackQuery) bool {
return q != nil && strings.HasPrefix(q.Data, prefix)
}
}
// FromUser returns a Filter that matches callback queries whose From.ID equals
// userID.
func FromUser(userID int64) dispatch.Filter[*api.CallbackQuery] {
return func(q *api.CallbackQuery) bool {
return q != nil && q.From.ID == userID
}
}
@@ -0,0 +1,56 @@
package callback_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
cbfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/callback"
"github.com/stretchr/testify/require"
)
func cq(data string, userID int64) *api.CallbackQuery {
return &api.CallbackQuery{
ID: "q",
From: api.User{ID: userID},
Data: data,
}
}
func TestData(t *testing.T) {
f := cbfilter.Data(`^like:\d+$`)
require.True(t, f(cq("like:42", 1)))
require.False(t, f(cq("dislike:42", 1)))
require.False(t, f(nil))
}
func TestData_PanicsOnBadPattern(t *testing.T) {
require.Panics(t, func() { cbfilter.Data(`[bad`) })
}
func TestDataEquals(t *testing.T) {
f := cbfilter.DataEquals("yes")
require.True(t, f(cq("yes", 1)))
require.False(t, f(cq("yes please", 1)))
require.False(t, f(nil))
}
func TestDataPrefix(t *testing.T) {
f := cbfilter.DataPrefix("vote:")
require.True(t, f(cq("vote:up", 1)))
require.False(t, f(cq("novote:up", 1)))
require.False(t, f(nil))
}
func TestFromUser(t *testing.T) {
f := cbfilter.FromUser(7)
require.True(t, f(cq("data", 7)))
require.False(t, f(cq("data", 8)))
require.False(t, f(nil))
}
func TestComposedCallbackFilters(t *testing.T) {
f := cbfilter.DataPrefix("vote:").And(cbfilter.FromUser(7))
require.True(t, f(cq("vote:up", 7)))
require.False(t, f(cq("vote:up", 8)))
require.False(t, f(cq("other", 7)))
}
@@ -0,0 +1,23 @@
// Package chatjoinrequest provides Filter helpers for *api.ChatJoinRequest payloads.
package chatjoinrequest
import (
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// FromUser returns a Filter that matches join requests where the requesting
// user's ID equals uid.
func FromUser(uid int64) dispatch.Filter[*api.ChatJoinRequest] {
return func(r *api.ChatJoinRequest) bool {
return r != nil && r.From.ID == uid
}
}
// InChat returns a Filter that matches join requests directed at the chat
// with the given chat ID.
func InChat(cid int64) dispatch.Filter[*api.ChatJoinRequest] {
return func(r *api.ChatJoinRequest) bool {
return r != nil && r.Chat.ID == cid
}
}
@@ -0,0 +1,37 @@
package chatjoinrequest_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
cjrfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatjoinrequest"
"github.com/stretchr/testify/require"
)
func joinRequest(fromID, chatID int64) *api.ChatJoinRequest {
return &api.ChatJoinRequest{
Chat: api.Chat{ID: chatID},
From: api.User{ID: fromID},
}
}
func TestFromUser_Matches(t *testing.T) {
f := cjrfilter.FromUser(10)
require.True(t, f(joinRequest(10, 100)))
require.False(t, f(joinRequest(99, 100)))
require.False(t, f(nil))
}
func TestInChat_Matches(t *testing.T) {
f := cjrfilter.InChat(100)
require.True(t, f(joinRequest(10, 100)))
require.False(t, f(joinRequest(10, 200)))
require.False(t, f(nil))
}
func TestComposedFilters(t *testing.T) {
f := cjrfilter.FromUser(10).And(cjrfilter.InChat(100))
require.True(t, f(joinRequest(10, 100)))
require.False(t, f(joinRequest(10, 200)))
require.False(t, f(joinRequest(99, 100)))
}
+41
View File
@@ -0,0 +1,41 @@
// Package chatmember provides Filter helpers for *api.ChatMemberUpdated payloads.
package chatmember
import (
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// NewStatus returns a Filter that matches updates where the new chat member
// status equals s (e.g. "member", "administrator", "kicked", "left").
func NewStatus(s string) dispatch.Filter[*api.ChatMemberUpdated] {
return func(u *api.ChatMemberUpdated) bool {
if u == nil {
return false
}
switch m := u.NewChatMember.(type) {
case *api.ChatMemberOwner:
return m.Status == s
case *api.ChatMemberAdministrator:
return m.Status == s
case *api.ChatMemberMember:
return m.Status == s
case *api.ChatMemberRestricted:
return m.Status == s
case *api.ChatMemberLeft:
return m.Status == s
case *api.ChatMemberBanned:
return m.Status == s
default:
return false
}
}
}
// FromUser returns a Filter that matches updates where the acting user
// (From.ID) equals uid.
func FromUser(uid int64) dispatch.Filter[*api.ChatMemberUpdated] {
return func(u *api.ChatMemberUpdated) bool {
return u != nil && u.From.ID == uid
}
}
@@ -0,0 +1,95 @@
package chatmember_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
cmfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatmember"
"github.com/stretchr/testify/require"
)
func memberUpdate(status string, fromID int64) *api.ChatMemberUpdated {
var newMember api.ChatMember
switch status {
case "member":
newMember = &api.ChatMemberMember{Status: status}
case "administrator":
newMember = &api.ChatMemberAdministrator{Status: status}
case "kicked":
newMember = &api.ChatMemberBanned{Status: status}
case "left":
newMember = &api.ChatMemberLeft{Status: status}
default:
newMember = &api.ChatMemberMember{Status: status}
}
return &api.ChatMemberUpdated{
From: api.User{ID: fromID},
NewChatMember: newMember,
}
}
func TestNewStatus_Matches(t *testing.T) {
f := cmfilter.NewStatus("member")
require.True(t, f(memberUpdate("member", 1)))
require.False(t, f(memberUpdate("kicked", 1)))
require.False(t, f(nil))
}
func TestNewStatus_Administrator(t *testing.T) {
f := cmfilter.NewStatus("administrator")
require.True(t, f(memberUpdate("administrator", 1)))
require.False(t, f(memberUpdate("member", 1)))
}
func TestNewStatus_Kicked(t *testing.T) {
f := cmfilter.NewStatus("kicked")
require.True(t, f(memberUpdate("kicked", 1)))
require.False(t, f(memberUpdate("left", 1)))
}
func TestNewStatus_Left(t *testing.T) {
f := cmfilter.NewStatus("left")
require.True(t, f(memberUpdate("left", 1)))
require.False(t, f(memberUpdate("member", 1)))
}
func TestFromUser_Matches(t *testing.T) {
f := cmfilter.FromUser(42)
require.True(t, f(memberUpdate("member", 42)))
require.False(t, f(memberUpdate("member", 99)))
require.False(t, f(nil))
}
func TestComposedFilters(t *testing.T) {
f := cmfilter.NewStatus("member").And(cmfilter.FromUser(7))
require.True(t, f(memberUpdate("member", 7)))
require.False(t, f(memberUpdate("member", 8)))
require.False(t, f(memberUpdate("kicked", 7)))
}
func TestNewStatus_Owner(t *testing.T) {
u := &api.ChatMemberUpdated{
From: api.User{ID: 1},
NewChatMember: &api.ChatMemberOwner{Status: "creator"},
}
require.True(t, cmfilter.NewStatus("creator")(u))
require.False(t, cmfilter.NewStatus("member")(u))
}
func TestNewStatus_Restricted(t *testing.T) {
u := &api.ChatMemberUpdated{
From: api.User{ID: 1},
NewChatMember: &api.ChatMemberRestricted{Status: "restricted"},
}
require.True(t, cmfilter.NewStatus("restricted")(u))
require.False(t, cmfilter.NewStatus("member")(u))
}
func TestNewStatus_UnknownType(t *testing.T) {
// nil NewChatMember → default branch → false
u := &api.ChatMemberUpdated{
From: api.User{ID: 1},
NewChatMember: nil,
}
require.False(t, cmfilter.NewStatus("member")(u))
}
+35
View File
@@ -0,0 +1,35 @@
// Package inline provides Filter helpers for *api.InlineQuery payloads.
package inline
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Query returns a Filter that matches inline queries whose Query field matches
// pattern (regex). Panics at registration time on an invalid pattern.
func Query(pattern string) dispatch.Filter[*api.InlineQuery] {
re := regexp.MustCompile(pattern)
return func(q *api.InlineQuery) bool {
return q != nil && re.MatchString(q.Query)
}
}
// QueryEquals returns a Filter that matches inline queries whose Query equals
// s exactly.
func QueryEquals(s string) dispatch.Filter[*api.InlineQuery] {
return func(q *api.InlineQuery) bool {
return q != nil && q.Query == s
}
}
// QueryPrefix returns a Filter that matches inline queries whose Query starts
// with prefix.
func QueryPrefix(prefix string) dispatch.Filter[*api.InlineQuery] {
return func(q *api.InlineQuery) bool {
return q != nil && strings.HasPrefix(q.Query, prefix)
}
}
+45
View File
@@ -0,0 +1,45 @@
package inline_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
ilfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/inline"
"github.com/stretchr/testify/require"
)
func iq(query string) *api.InlineQuery {
return &api.InlineQuery{ID: "i", From: api.User{ID: 1}, Query: query}
}
func TestQuery(t *testing.T) {
f := ilfilter.Query(`^find`)
require.True(t, f(iq("find me")))
require.False(t, f(iq("search me")))
require.False(t, f(nil))
}
func TestQuery_PanicsOnBadPattern(t *testing.T) {
require.Panics(t, func() { ilfilter.Query(`[bad`) })
}
func TestQueryEquals(t *testing.T) {
f := ilfilter.QueryEquals("exact")
require.True(t, f(iq("exact")))
require.False(t, f(iq("exact match")))
require.False(t, f(nil))
}
func TestQueryPrefix(t *testing.T) {
f := ilfilter.QueryPrefix("@user")
require.True(t, f(iq("@username")))
require.False(t, f(iq("no prefix")))
require.False(t, f(nil))
}
func TestComposedInlineFilters(t *testing.T) {
f := ilfilter.QueryPrefix("find").Or(ilfilter.QueryEquals("help"))
require.True(t, f(iq("find me")))
require.True(t, f(iq("help")))
require.False(t, f(iq("other")))
}
+142
View File
@@ -0,0 +1,142 @@
// Package message provides Filter helpers for *api.Message payloads.
package message
import (
"regexp"
"strings"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Text returns a Filter that matches messages whose Text matches pattern (regex).
// Panics at registration time on an invalid pattern.
func Text(pattern string) dispatch.Filter[*api.Message] {
re := regexp.MustCompile(pattern)
return func(m *api.Message) bool {
return m != nil && re.MatchString(m.Text)
}
}
// TextEquals returns a Filter that matches messages whose Text equals s exactly.
func TextEquals(s string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Text == s
}
}
// TextPrefix returns a Filter that matches messages whose Text starts with prefix.
func TextPrefix(prefix string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && strings.HasPrefix(m.Text, prefix)
}
}
// TextContains returns a Filter that matches messages whose Text contains sub.
func TextContains(sub string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && strings.Contains(m.Text, sub)
}
}
// Command returns a Filter that matches messages whose first entity is a
// bot_command equal to "/<name>" (with or without "@BotName" suffix).
func Command(name string) dispatch.Filter[*api.Message] {
want := "/" + strings.TrimPrefix(name, "/")
return func(m *api.Message) bool {
if m == nil || len(m.Entities) == 0 || m.Text == "" {
return false
}
first := m.Entities[0]
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
return false
}
end := int(first.Length)
runes := []rune(m.Text)
if end > len(runes) {
return false
}
cmd := string(runes[:end])
if i := strings.Index(cmd, "@"); i >= 0 {
cmd = cmd[:i]
}
return cmd == want
}
}
// AnyCommand returns a Filter that matches any message starting with a
// bot_command entity at offset 0.
func AnyCommand() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
if m == nil || len(m.Entities) == 0 {
return false
}
first := m.Entities[0]
return first.Type == string(api.EntityBotCommand) && first.Offset == 0
}
}
// IsReply returns a Filter that matches messages that have ReplyToMessage set.
func IsReply() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.ReplyToMessage != nil
}
}
// IsForward returns a Filter that matches messages that have ForwardOrigin set.
func IsForward() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.ForwardOrigin != nil
}
}
// HasPhoto returns a Filter that matches messages with a Photo attachment.
func HasPhoto() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && len(m.Photo) > 0
}
}
// HasDocument returns a Filter that matches messages with a Document attachment.
func HasDocument() dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Document != nil
}
}
// HasEntity returns a Filter that matches messages whose Entities contain at
// least one entity of type t (e.g. string(api.EntityBotCommand)).
func HasEntity(t string) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
if m == nil {
return false
}
for _, e := range m.Entities {
if e.Type == t {
return true
}
}
return false
}
}
// ChatType returns a Filter that matches messages whose Chat.Type equals t.
func ChatType(t api.ChatType) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Chat.Type == string(t)
}
}
// FromUser returns a Filter that matches messages whose From.ID equals userID.
func FromUser(userID int64) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.From != nil && m.From.ID == userID
}
}
// InChat returns a Filter that matches messages whose Chat.ID equals chatID.
func InChat(chatID int64) dispatch.Filter[*api.Message] {
return func(m *api.Message) bool {
return m != nil && m.Chat.ID == chatID
}
}
+188
View File
@@ -0,0 +1,188 @@
package message_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
msgfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/message"
"github.com/stretchr/testify/require"
)
func msg(text string) *api.Message {
return &api.Message{
MessageID: 1,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
}
}
func cmdMsg(cmd string) *api.Message {
text := cmd
return &api.Message{
MessageID: 1,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
Entities: []api.MessageEntity{
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len([]rune(text)))},
},
}
}
func TestText(t *testing.T) {
f := msgfilter.Text(`^hello`)
require.True(t, f(msg("hello world")))
require.False(t, f(msg("world hello")))
require.False(t, f(nil))
}
func TestText_PanicsOnBadPattern(t *testing.T) {
require.Panics(t, func() { msgfilter.Text(`[invalid`) })
}
func TestTextEquals(t *testing.T) {
f := msgfilter.TextEquals("hi")
require.True(t, f(msg("hi")))
require.False(t, f(msg("hi there")))
require.False(t, f(nil))
}
func TestTextPrefix(t *testing.T) {
f := msgfilter.TextPrefix("/start")
require.True(t, f(msg("/start now")))
require.False(t, f(msg("no prefix")))
require.False(t, f(nil))
}
func TestTextContains(t *testing.T) {
f := msgfilter.TextContains("bot")
require.True(t, f(msg("my bot is cool")))
require.False(t, f(msg("nothing here")))
require.False(t, f(nil))
}
func TestCommand(t *testing.T) {
t.Run("matches exact command", func(t *testing.T) {
f := msgfilter.Command("/start")
require.True(t, f(cmdMsg("/start")))
})
t.Run("matches without leading slash", func(t *testing.T) {
f := msgfilter.Command("start")
require.True(t, f(cmdMsg("/start")))
})
t.Run("strips BotName suffix", func(t *testing.T) {
m := &api.Message{
Text: "/start@MyBot",
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 12}},
}
f := msgfilter.Command("/start")
require.True(t, f(m))
})
t.Run("no match different command", func(t *testing.T) {
f := msgfilter.Command("/stop")
require.False(t, f(cmdMsg("/start")))
})
t.Run("nil message", func(t *testing.T) {
require.False(t, msgfilter.Command("/start")(nil))
})
t.Run("no entities", func(t *testing.T) {
require.False(t, msgfilter.Command("/start")(msg("/start")))
})
}
func TestAnyCommand(t *testing.T) {
f := msgfilter.AnyCommand()
require.True(t, f(cmdMsg("/anything")))
require.False(t, f(msg("plain text")))
require.False(t, f(nil))
}
func TestIsReply(t *testing.T) {
f := msgfilter.IsReply()
m := msg("reply")
m.ReplyToMessage = &api.Message{MessageID: 2}
require.True(t, f(m))
require.False(t, f(msg("no reply")))
require.False(t, f(nil))
}
func TestIsForward(t *testing.T) {
// ForwardOrigin is a MessageOrigin interface; set via a concrete type.
f := msgfilter.IsForward()
m := msg("fwd")
m.ForwardOrigin = &api.MessageOriginUser{Type: "user"}
require.True(t, f(m))
require.False(t, f(msg("no fwd")))
require.False(t, f(nil))
}
func TestHasPhoto(t *testing.T) {
f := msgfilter.HasPhoto()
m := msg("")
m.Photo = []api.PhotoSize{{FileID: "x", Width: 100, Height: 100}}
require.True(t, f(m))
require.False(t, f(msg("no photo")))
require.False(t, f(nil))
}
func TestHasDocument(t *testing.T) {
f := msgfilter.HasDocument()
m := msg("")
m.Document = &api.Document{FileID: "doc1"}
require.True(t, f(m))
require.False(t, f(msg("no doc")))
require.False(t, f(nil))
}
func TestHasEntity(t *testing.T) {
f := msgfilter.HasEntity(string(api.EntityURL))
m := msg("check https://example.com")
m.Entities = []api.MessageEntity{{Type: string(api.EntityURL), Offset: 6, Length: 19}}
require.True(t, f(m))
require.False(t, f(msg("plain")))
require.False(t, f(nil))
}
func TestChatType(t *testing.T) {
f := msgfilter.ChatType(api.ChatTypePrivate)
private := msg("hi")
require.True(t, f(private))
group := msg("hi")
group.Chat.Type = string(api.ChatTypeGroup)
require.False(t, f(group))
require.False(t, f(nil))
}
func TestFromUser(t *testing.T) {
f := msgfilter.FromUser(42)
m := msg("hi")
m.From = &api.User{ID: 42}
require.True(t, f(m))
m2 := msg("hi")
m2.From = &api.User{ID: 99}
require.False(t, f(m2))
require.False(t, f(msg("no from")))
require.False(t, f(nil))
}
func TestInChat(t *testing.T) {
f := msgfilter.InChat(1)
require.True(t, f(msg("hi")))
m2 := msg("hi")
m2.Chat.ID = 2
require.False(t, f(m2))
require.False(t, f(nil))
}
func TestComposedMessageFilters(t *testing.T) {
// private chat AND contains "hello"
f := msgfilter.ChatType(api.ChatTypePrivate).And(msgfilter.TextContains("hello"))
m := msg("say hello")
require.True(t, f(m))
m2 := msg("say hello")
m2.Chat.Type = string(api.ChatTypeGroup)
require.False(t, f(m2))
}
@@ -0,0 +1,23 @@
// Package precheckoutquery provides Filter helpers for *api.PreCheckoutQuery payloads.
package precheckoutquery
import (
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/dispatch"
)
// Currency returns a Filter that matches pre-checkout queries with the given
// ISO 4217 currency code (e.g. "USD", "EUR", "XTR").
func Currency(c string) dispatch.Filter[*api.PreCheckoutQuery] {
return func(q *api.PreCheckoutQuery) bool {
return q != nil && q.Currency == c
}
}
// FromUser returns a Filter that matches pre-checkout queries sent by the
// user with the given ID.
func FromUser(uid int64) dispatch.Filter[*api.PreCheckoutQuery] {
return func(q *api.PreCheckoutQuery) bool {
return q != nil && q.From.ID == uid
}
}
@@ -0,0 +1,38 @@
package precheckoutquery_test
import (
"testing"
"github.com/lukaszraczylo/go-telegram/api"
pcqfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/precheckoutquery"
"github.com/stretchr/testify/require"
)
func pcq(currency string, fromID int64) *api.PreCheckoutQuery {
return &api.PreCheckoutQuery{
ID: "q",
Currency: currency,
From: api.User{ID: fromID},
}
}
func TestCurrency_Matches(t *testing.T) {
f := pcqfilter.Currency("USD")
require.True(t, f(pcq("USD", 1)))
require.False(t, f(pcq("EUR", 1)))
require.False(t, f(nil))
}
func TestFromUser_Matches(t *testing.T) {
f := pcqfilter.FromUser(5)
require.True(t, f(pcq("USD", 5)))
require.False(t, f(pcq("USD", 9)))
require.False(t, f(nil))
}
func TestComposedFilters(t *testing.T) {
f := pcqfilter.Currency("XTR").And(pcqfilter.FromUser(42))
require.True(t, f(pcq("XTR", 42)))
require.False(t, f(pcq("XTR", 99)))
require.False(t, f(pcq("USD", 42)))
}
+186
View File
@@ -0,0 +1,186 @@
package dispatch
import (
"errors"
"regexp"
"sort"
"github.com/lukaszraczylo/go-telegram/api"
)
// ErrEndGroups stops dispatch from running any further handlers in any
// group for this update when returned by a handler. Use it to indicate
// the update has been definitively handled.
//
// errors.Is(err, ErrEndGroups) is the canonical check, though dispatch
// itself recognises it by exact identity.
var ErrEndGroups = errors.New("dispatch: end groups")
// ErrContinueGroups signals that this group's handler should be treated
// as not-matching when returned by a handler: dispatch moves on to the
// next handler in the same group, then to subsequent groups.
//
// Without ErrContinueGroups, a non-error return from a matched handler
// stops dispatch (default first-match-wins semantics).
var ErrContinueGroups = errors.New("dispatch: continue groups")
// RouterScope registers handlers into a specific priority group on its parent
// Router. Group 0 runs first, then group 1, etc. Within a group, handlers run
// in registration order; the first non-skipped match terminates dispatch
// unless the handler returns ErrContinueGroups.
type RouterScope struct {
router *Router
group int
}
// Group returns a RouterScope that registers handlers in the given group.
// Group 0 (the default) runs first, then group 1, etc. Within a group,
// handlers run in registration order; the first non-skipped match
// terminates dispatch unless the handler returns ErrContinueGroups.
func (r *Router) Group(group int) *RouterScope {
return &RouterScope{router: r, group: group}
}
// OnCommand registers a command handler in this group.
func (s *RouterScope) OnCommand(cmd string, h Handler[*api.Message]) {
s.router.groupCommands = append(s.router.groupCommands, groupCommandRoute{
cmd: cmd, group: s.group, handler: h,
})
}
// OnText registers a regex text handler in this group.
// Panics at registration time if pattern is not a valid regular expression.
func (s *RouterScope) OnText(pattern string, h Handler[*api.Message]) {
s.router.groupTexts = append(s.router.groupTexts, groupTextRoute{
re: regexp.MustCompile(pattern), group: s.group, handler: h,
})
}
// OnMessageFilter registers a filter-based message handler in this group.
func (s *RouterScope) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
s.router.groupMessageFilters = append(s.router.groupMessageFilters, groupMessageFilterRoute{
filter: f, group: s.group, handler: h,
})
}
// group-aware route types
type groupCommandRoute struct {
cmd string
group int
handler Handler[*api.Message]
}
type groupTextRoute struct {
re *regexp.Regexp
group int
handler Handler[*api.Message]
}
type groupMessageFilterRoute struct {
filter Filter[*api.Message]
group int
handler Handler[*api.Message]
}
// dispatchGroups runs message handlers registered via RouterScope.Group().
// It collects all matching groups, sorts by group number, and applies
// first-match-wins semantics within each group. Handlers may return
// ErrContinueGroups (skip to next handler/group) or ErrEndGroups (stop all groups).
// A non-sentinel error stops dispatch and is returned to the caller.
func (r *Router) dispatchGroups(c *Context, m *api.Message) error {
// Collect group numbers present.
groupSet := map[int]struct{}{}
for _, gr := range r.groupCommands {
groupSet[gr.group] = struct{}{}
}
for _, gr := range r.groupTexts {
groupSet[gr.group] = struct{}{}
}
for _, gr := range r.groupMessageFilters {
groupSet[gr.group] = struct{}{}
}
if len(groupSet) == 0 {
return nil
}
groups := make([]int, 0, len(groupSet))
for g := range groupSet {
groups = append(groups, g)
}
sort.Ints(groups)
for _, g := range groups {
matched, err := r.runGroupHandlers(c, m, g)
if err != nil {
if errors.Is(err, ErrEndGroups) {
return nil
}
return err
}
if matched {
// First-match-wins: stop further groups.
return nil
}
// No match or ErrContinueGroups from all handlers: try next group.
}
return nil
}
// runGroupHandlers runs all handlers in group g against m, in registration
// order. Returns (true, nil) when a handler matched (returned nil). Returns
// (false, nil) when all handlers returned ErrContinueGroups. Returns
// (false, err) for ErrEndGroups or any non-sentinel error.
func (r *Router) runGroupHandlers(c *Context, m *api.Message, g int) (matched bool, err error) {
// Commands.
if cmd, args, ok := extractCommand(m); ok {
for _, route := range r.groupCommands {
if route.group != g || route.cmd != cmd {
continue
}
c.Values["command"] = cmd
c.Values["command_args"] = args
if err := route.handler(c, m); err != nil {
if errors.Is(err, ErrContinueGroups) {
continue
}
return false, err
}
return true, nil
}
}
// Text regex.
if m.Text != "" {
for _, route := range r.groupTexts {
if route.group != g {
continue
}
subs := route.re.FindStringSubmatch(m.Text)
if subs == nil {
continue
}
c.Values["regex_match"] = subs
if err := route.handler(c, m); err != nil {
if errors.Is(err, ErrContinueGroups) {
continue
}
return false, err
}
return true, nil
}
}
// Filter-based.
for _, route := range r.groupMessageFilters {
if route.group != g || !route.filter(m) {
continue
}
if err := route.handler(c, m); err != nil {
if errors.Is(err, ErrContinueGroups) {
continue
}
return false, err
}
return true, nil
}
return false, nil
}
+209
View File
@@ -0,0 +1,209 @@
package dispatch
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/require"
)
// msgUpdate builds a simple private message update.
func msgUpdate(id int64, text string) api.Update {
return api.Update{
UpdateID: id,
Message: &api.Message{
MessageID: id,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
},
}
}
// cmdUpdate builds a command message update.
func cmdUpdate(id int64, cmd string) api.Update {
return api.Update{
UpdateID: id,
Message: &api.Message{
MessageID: id,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: cmd,
Entities: []api.MessageEntity{
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len(cmd))},
},
},
}
}
// runSingle fires one update through the router and waits for it to complete.
func runSingle(t *testing.T, r *Router, up api.Update) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(up))
}
// TestGroup_Order verifies group 0 fires before group 1.
func TestGroup_Order(t *testing.T) {
r := New(client.New("t"))
var order []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
order = append(order, 0)
return ErrContinueGroups // let group 1 also run
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
order = append(order, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
require.Equal(t, []int{0, 1}, order)
}
// TestGroup_FirstMatchWins verifies group 0 match stops group 1 by default.
func TestGroup_FirstMatchWins(t *testing.T) {
r := New(client.New("t"))
var fired []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 0)
return nil // matched — group 1 must NOT run
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
require.Equal(t, []int{0}, fired)
}
// TestGroup_ErrContinueGroups lets group 1 run when group 0 returns ErrContinueGroups.
func TestGroup_ErrContinueGroups(t *testing.T) {
r := New(client.New("t"))
g1Hit := make(chan struct{}, 1)
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
return ErrContinueGroups
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
g1Hit <- struct{}{}
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ping"))) }()
select {
case <-g1Hit:
case <-ctx.Done():
t.Fatal("group 1 handler never fired")
}
}
// TestGroup_ErrEndGroups stops all further groups.
func TestGroup_ErrEndGroups(t *testing.T) {
r := New(client.New("t"))
var fired []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 0)
return ErrEndGroups
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
require.Equal(t, []int{0}, fired)
}
// TestGroup_NonSentinelError propagates error and stops further groups.
func TestGroup_NonSentinelError(t *testing.T) {
r := New(client.New("t"), WithMaxConcurrency(0))
var fired []int
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 0)
return context.DeadlineExceeded // non-sentinel real error
})
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
fired = append(fired, 1)
return nil
})
runSingle(t, r, msgUpdate(1, "hello"))
// group 1 must not fire
require.Equal(t, []int{0}, fired)
}
// TestGroup_Command verifies OnCommand in a group works.
func TestGroup_Command(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.Group(0).OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- "g0-start"
return nil
})
r.Group(1).OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- "g1-start"
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/start"))) }()
got := <-hit
require.Equal(t, "g0-start", got)
}
// TestGroup_MessageFilter verifies OnMessageFilter in a group works.
func TestGroup_MessageFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan bool, 1)
r.Group(0).OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ok" }),
func(c *Context, m *api.Message) error {
hit <- true
return nil
},
)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ok"))) }()
require.True(t, <-hit)
}
// TestGroup_ErrContinueGroups_WithCommand verifies ErrContinueGroups works for commands across groups.
func TestGroup_ErrContinueGroups_WithCommand(t *testing.T) {
r := New(client.New("t"))
var count atomic.Int32
r.Group(0).OnCommand("/ping", func(c *Context, m *api.Message) error {
count.Add(1)
return ErrContinueGroups
})
r.Group(1).OnCommand("/ping", func(c *Context, m *api.Message) error {
count.Add(10)
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/ping"))) }()
time.Sleep(100 * time.Millisecond)
cancel()
require.Equal(t, int32(11), count.Load())
}
+21
View File
@@ -0,0 +1,21 @@
package dispatch
// Handler is a generic handler over update payload type T. T is typically
// *api.Message, *api.CallbackQuery, *api.InlineQuery, or *api.Update for
// global middleware.
type Handler[T any] func(ctx *Context, payload T) error
// Middleware wraps a Handler[T] with cross-cutting behaviour (logging,
// recovery, auth). Middleware composition is left-to-right: Use(a,b,c)
// runs as a(b(c(handler))).
type Middleware[T any] func(Handler[T]) Handler[T]
// Chain composes a slice of middleware into a single Middleware[T].
func Chain[T any](mws ...Middleware[T]) Middleware[T] {
return func(h Handler[T]) Handler[T] {
for i := len(mws) - 1; i >= 0; i-- {
h = mws[i](h)
}
return h
}
}
+27
View File
@@ -0,0 +1,27 @@
package dispatch
import (
"fmt"
"runtime/debug"
"github.com/lukaszraczylo/go-telegram/api"
)
// Recovery returns middleware that recovers from panics in downstream
// handlers, converting them into a returned error and logging via the
// bot's configured logger. Registered automatically by NewRouter.
func Recovery() Middleware[*api.Update] {
return func(next Handler[*api.Update]) Handler[*api.Update] {
return func(c *Context, u *api.Update) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in handler: %v\n%s", r, debug.Stack())
if c.Bot != nil {
c.Bot.Logger().Error("dispatch recovered panic", "err", err)
}
}
}()
return next(c, u)
}
}
}
+101
View File
@@ -0,0 +1,101 @@
package dispatch
import (
"fmt"
"sync"
)
// NamedHandlers manages handlers by string name, allowing runtime
// registration, replacement, and removal. This complements the Router's
// registration methods: each registration via Named*() also gets a name
// for later lookup.
//
// Use case: a plugin system that loads/unloads command handlers without
// restarting the bot.
type NamedHandlers[T any] struct {
mu sync.RWMutex
handlers map[string]Handler[T]
order []string // preserves registration order
}
// NewNamedHandlers returns a new, empty NamedHandlers[T].
func NewNamedHandlers[T any]() *NamedHandlers[T] {
return &NamedHandlers[T]{handlers: map[string]Handler[T]{}}
}
// Set registers or replaces the handler under name. If name is new, it is
// appended to the end of the registration order.
func (n *NamedHandlers[T]) Set(name string, h Handler[T]) {
n.mu.Lock()
defer n.mu.Unlock()
if _, exists := n.handlers[name]; !exists {
n.order = append(n.order, name)
}
n.handlers[name] = h
}
// Remove unregisters the handler under name. Returns true if it existed.
func (n *NamedHandlers[T]) Remove(name string) bool {
n.mu.Lock()
defer n.mu.Unlock()
if _, ok := n.handlers[name]; !ok {
return false
}
delete(n.handlers, name)
for i, k := range n.order {
if k == name {
n.order = append(n.order[:i], n.order[i+1:]...)
break
}
}
return true
}
// Has reports whether name is registered.
func (n *NamedHandlers[T]) Has(name string) bool {
n.mu.RLock()
defer n.mu.RUnlock()
_, ok := n.handlers[name]
return ok
}
// Names returns the registered names in registration order.
func (n *NamedHandlers[T]) Names() []string {
n.mu.RLock()
defer n.mu.RUnlock()
out := make([]string, len(n.order))
copy(out, n.order)
return out
}
// Handler returns a single Handler[T] that runs each registered handler
// in registration order, first non-nil error stops the chain. Use this
// to wire NamedHandlers into a Router.OnXxx call:
//
// names := dispatch.NewNamedHandlers[*api.Message]()
// names.Set("logger", loggingHandler)
// names.Set("audit", auditHandler)
// router.OnCommand("/admin", names.Handler())
//
// Subsequent Set/Remove calls take effect on the next dispatch.
func (n *NamedHandlers[T]) Handler() Handler[T] {
return func(c *Context, payload T) error {
n.mu.RLock()
names := make([]string, len(n.order))
copy(names, n.order)
n.mu.RUnlock()
for _, name := range names {
n.mu.RLock()
h, ok := n.handlers[name]
n.mu.RUnlock()
if !ok {
continue
}
if err := h(c, payload); err != nil {
return fmt.Errorf("named handler %q: %w", name, err)
}
}
return nil
}
}
+153
View File
@@ -0,0 +1,153 @@
package dispatch
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/stretchr/testify/require"
)
// makeMsg returns a minimal *api.Message for use in handler tests.
func makeMsg() *api.Message {
return &api.Message{MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}}
}
// makeCtx returns a minimal *Context (nil bot is fine for unit tests).
func makeCtx() *Context {
return NewContext(context.Background(), nil, &api.Update{})
}
func TestNamedHandlers_SetAndHas(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
require.False(t, n.Has("a"))
n.Set("a", func(c *Context, m *api.Message) error { return nil })
require.True(t, n.Has("a"))
}
func TestNamedHandlers_Names_RegistrationOrder(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("first", func(c *Context, m *api.Message) error { return nil })
n.Set("second", func(c *Context, m *api.Message) error { return nil })
n.Set("third", func(c *Context, m *api.Message) error { return nil })
require.Equal(t, []string{"first", "second", "third"}, n.Names())
}
func TestNamedHandlers_Remove(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("a", func(c *Context, m *api.Message) error { return nil })
n.Set("b", func(c *Context, m *api.Message) error { return nil })
removed := n.Remove("a")
require.True(t, removed)
require.False(t, n.Has("a"))
require.Equal(t, []string{"b"}, n.Names())
// Remove non-existent returns false.
require.False(t, n.Remove("nonexistent"))
}
func TestNamedHandlers_Replacement_SameOrderSlot(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("a", func(c *Context, m *api.Message) error { return nil })
n.Set("b", func(c *Context, m *api.Message) error { return nil })
var called string
n.Set("a", func(c *Context, m *api.Message) error {
called = "replaced-a"
return nil
})
// Order must not change; "a" stays first.
require.Equal(t, []string{"a", "b"}, n.Names())
h := n.Handler()
_ = h(makeCtx(), makeMsg())
require.Equal(t, "replaced-a", called)
}
func TestNamedHandlers_Handler_RunsInOrder(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
var calls []string
n.Set("first", func(c *Context, m *api.Message) error {
calls = append(calls, "first")
return nil
})
n.Set("second", func(c *Context, m *api.Message) error {
calls = append(calls, "second")
return nil
})
h := n.Handler()
require.NoError(t, h(makeCtx(), makeMsg()))
require.Equal(t, []string{"first", "second"}, calls)
}
func TestNamedHandlers_Handler_ErrorWrappedAndStops(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
sentinel := errors.New("boom")
n.Set("ok", func(c *Context, m *api.Message) error { return nil })
n.Set("fail", func(c *Context, m *api.Message) error { return sentinel })
n.Set("never", func(c *Context, m *api.Message) error {
t.Fatal("should not be called after an error")
return nil
})
h := n.Handler()
err := h(makeCtx(), makeMsg())
require.Error(t, err)
require.True(t, errors.Is(err, sentinel))
require.Contains(t, err.Error(), `named handler "fail"`)
}
func TestNamedHandlers_Concurrent_SetRemove(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
// Pre-populate so Handler() has something to iterate.
for i := range 5 {
name := fmt.Sprintf("h%d", i)
n.Set(name, func(c *Context, m *api.Message) error { return nil })
}
h := n.Handler()
var wg sync.WaitGroup
// Concurrent readers (invoke handler).
for range 20 {
wg.Add(1)
go func() {
defer wg.Done()
_ = h(makeCtx(), makeMsg())
}()
}
// Concurrent writers.
for i := range 5 {
wg.Add(1)
go func(i int) {
defer wg.Done()
name := fmt.Sprintf("new%d", i)
n.Set(name, func(c *Context, m *api.Message) error { return nil })
n.Remove(fmt.Sprintf("h%d", i))
}(i)
}
wg.Wait()
}
func TestNamedHandlers_RemoveAndReinstate(t *testing.T) {
n := NewNamedHandlers[*api.Message]()
n.Set("a", func(c *Context, m *api.Message) error { return nil })
n.Remove("a")
require.False(t, n.Has("a"))
// Re-register after removal; should be added at end.
n.Set("b", func(c *Context, m *api.Message) error { return nil })
n.Set("a", func(c *Context, m *api.Message) error { return nil })
require.Equal(t, []string{"b", "a"}, n.Names())
}
+582
View File
@@ -0,0 +1,582 @@
package dispatch
import (
"context"
"regexp"
"strings"
"sync"
"unicode/utf8"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/lukaszraczylo/go-telegram/transport"
)
// Router dispatches updates from any Updater to typed handlers.
//
// Matchers run in registration order; first match wins. A panic-recovery
// middleware is attached automatically and runs around every dispatch.
type Router struct {
bot *client.Bot
commands []commandRoute
texts []textRoute
callbacks []callbackRoute
inlines []Handler[*api.InlineQuery]
editedMsg []Handler[*api.Message]
channelPosts []Handler[*api.Message]
editedChannelPosts []Handler[*api.Message]
messageFilters []messageFilterRoute
callbackFilters []callbackFilterRoute
inlineFilters []inlineFilterRoute
// typed update handlers
myChatMember []Handler[*api.ChatMemberUpdated]
chatMember []Handler[*api.ChatMemberUpdated]
chatJoinRequest []Handler[*api.ChatJoinRequest]
preCheckoutQuery []Handler[*api.PreCheckoutQuery]
shippingQuery []Handler[*api.ShippingQuery]
poll []Handler[*api.Poll]
pollAnswer []Handler[*api.PollAnswer]
chosenInlineResult []Handler[*api.ChosenInlineResult]
messageReaction []Handler[*api.MessageReactionUpdated]
messageReactionCnt []Handler[*api.MessageReactionCountUpdated]
chatBoost []Handler[*api.ChatBoostUpdated]
removedChatBoost []Handler[*api.ChatBoostRemoved]
businessConn []Handler[*api.BusinessConnection]
purchasedPaidMedia []Handler[*api.PaidMediaPurchased]
myChatMemberFilters []chatMemberFilterRoute
chatMemberFilters []chatMemberFilterRoute
chatJoinRequestFilters []chatJoinRequestFilterRoute
preCheckoutFilters []preCheckoutFilterRoute
// group-priority routes (registered via Router.Group())
groupCommands []groupCommandRoute
groupTexts []groupTextRoute
groupMessageFilters []groupMessageFilterRoute
globalMW []Middleware[*api.Update]
maxConcurrency int // default 50; 0 = serial (legacy)
sem chan struct{}
}
type messageFilterRoute struct {
filter Filter[*api.Message]
handler Handler[*api.Message]
}
type callbackFilterRoute struct {
filter Filter[*api.CallbackQuery]
handler Handler[*api.CallbackQuery]
}
type inlineFilterRoute struct {
filter Filter[*api.InlineQuery]
handler Handler[*api.InlineQuery]
}
type chatMemberFilterRoute struct {
filter Filter[*api.ChatMemberUpdated]
handler Handler[*api.ChatMemberUpdated]
}
type chatJoinRequestFilterRoute struct {
filter Filter[*api.ChatJoinRequest]
handler Handler[*api.ChatJoinRequest]
}
type preCheckoutFilterRoute struct {
filter Filter[*api.PreCheckoutQuery]
handler Handler[*api.PreCheckoutQuery]
}
// RouterOption configures a Router at construction time.
type RouterOption func(*Router)
// WithMaxConcurrency sets the maximum number of updates processed in parallel.
// Default is 50. Pass 0 to dispatch serially (one update at a time, in the
// calling goroutine — the legacy behaviour before v1.1.0).
//
// Note: concurrent dispatch means handlers for different updates may run
// simultaneously. Handlers that mutate shared state must be safe for concurrent
// access.
func WithMaxConcurrency(n int) RouterOption {
return func(r *Router) { r.maxConcurrency = n }
}
type commandRoute struct {
cmd string
handler Handler[*api.Message]
}
type textRoute struct {
re *regexp.Regexp
handler Handler[*api.Message]
}
type callbackRoute struct {
re *regexp.Regexp
handler Handler[*api.CallbackQuery]
}
// New constructs a Router. Recovery middleware is added by default; users
// can disable it by passing WithoutRecovery (not implemented here, but
// the hook is in place via Use).
func New(b *client.Bot, opts ...RouterOption) *Router {
r := &Router{bot: b, maxConcurrency: 50}
for _, o := range opts {
o(r)
}
if r.maxConcurrency > 0 {
r.sem = make(chan struct{}, r.maxConcurrency)
}
r.Use(Recovery())
return r
}
// Use registers a global middleware applied to every Update dispatch.
func (r *Router) Use(mw Middleware[*api.Update]) { r.globalMW = append(r.globalMW, mw) }
// OnCommand registers a handler for a slash command. The command string
// includes the leading slash (e.g. "/start"). Matching strips an optional
// "@BotName" suffix.
func (r *Router) OnCommand(cmd string, h Handler[*api.Message]) {
r.commands = append(r.commands, commandRoute{cmd: cmd, handler: h})
}
// OnText registers a handler for messages whose Text matches the regex.
//
// Panics at registration time if pattern is not a valid regular expression.
func (r *Router) OnText(pattern string, h Handler[*api.Message]) {
r.texts = append(r.texts, textRoute{re: regexp.MustCompile(pattern), handler: h})
}
// OnCallback registers a handler for callback queries whose Data matches
// the regex.
//
// Panics at registration time if pattern is not a valid regular expression.
func (r *Router) OnCallback(pattern string, h Handler[*api.CallbackQuery]) {
r.callbacks = append(r.callbacks, callbackRoute{re: regexp.MustCompile(pattern), handler: h})
}
// OnInlineQuery registers a handler for inline queries (one matcher only;
// inline queries are not partitioned by content here).
func (r *Router) OnInlineQuery(h Handler[*api.InlineQuery]) {
r.inlines = append(r.inlines, h)
}
// OnEditedMessage registers a handler for edited message updates.
func (r *Router) OnEditedMessage(h Handler[*api.Message]) {
r.editedMsg = append(r.editedMsg, h)
}
// OnChannelPost registers a handler for channel post updates.
func (r *Router) OnChannelPost(h Handler[*api.Message]) {
r.channelPosts = append(r.channelPosts, h)
}
// OnEditedChannelPost registers a handler for edited channel post updates.
func (r *Router) OnEditedChannelPost(h Handler[*api.Message]) {
r.editedChannelPosts = append(r.editedChannelPosts, h)
}
// OnMessageFilter registers a typed message handler gated by filter f.
// Filter routes are checked after command and text routes; first match wins.
func (r *Router) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
r.messageFilters = append(r.messageFilters, messageFilterRoute{filter: f, handler: h})
}
// OnCallbackFilter registers a typed callback-query handler gated by filter f.
// Filter routes are checked after pattern-based OnCallback routes; first match wins.
func (r *Router) OnCallbackFilter(f Filter[*api.CallbackQuery], h Handler[*api.CallbackQuery]) {
r.callbackFilters = append(r.callbackFilters, callbackFilterRoute{filter: f, handler: h})
}
// OnInlineQueryFilter registers an inline-query handler gated by filter f.
// Filter routes are checked after bare OnInlineQuery handlers; first match wins.
func (r *Router) OnInlineQueryFilter(f Filter[*api.InlineQuery], h Handler[*api.InlineQuery]) {
r.inlineFilters = append(r.inlineFilters, inlineFilterRoute{filter: f, handler: h})
}
// OnMyChatMember registers a handler for bot's own chat member status changes.
func (r *Router) OnMyChatMember(h Handler[*api.ChatMemberUpdated]) {
r.myChatMember = append(r.myChatMember, h)
}
// OnMyChatMemberFilter registers a filtered handler for bot's own chat member status changes.
func (r *Router) OnMyChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
r.myChatMemberFilters = append(r.myChatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
}
// OnChatMember registers a handler for chat member status changes.
func (r *Router) OnChatMember(h Handler[*api.ChatMemberUpdated]) {
r.chatMember = append(r.chatMember, h)
}
// OnChatMemberFilter registers a filtered handler for chat member status changes.
func (r *Router) OnChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
r.chatMemberFilters = append(r.chatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
}
// OnChatJoinRequest registers a handler for chat join requests.
func (r *Router) OnChatJoinRequest(h Handler[*api.ChatJoinRequest]) {
r.chatJoinRequest = append(r.chatJoinRequest, h)
}
// OnChatJoinRequestFilter registers a filtered handler for chat join requests.
func (r *Router) OnChatJoinRequestFilter(f Filter[*api.ChatJoinRequest], h Handler[*api.ChatJoinRequest]) {
r.chatJoinRequestFilters = append(r.chatJoinRequestFilters, chatJoinRequestFilterRoute{filter: f, handler: h})
}
// OnPreCheckoutQuery registers a handler for pre-checkout queries.
func (r *Router) OnPreCheckoutQuery(h Handler[*api.PreCheckoutQuery]) {
r.preCheckoutQuery = append(r.preCheckoutQuery, h)
}
// OnPreCheckoutQueryFilter registers a filtered handler for pre-checkout queries.
func (r *Router) OnPreCheckoutQueryFilter(f Filter[*api.PreCheckoutQuery], h Handler[*api.PreCheckoutQuery]) {
r.preCheckoutFilters = append(r.preCheckoutFilters, preCheckoutFilterRoute{filter: f, handler: h})
}
// OnShippingQuery registers a handler for shipping queries.
func (r *Router) OnShippingQuery(h Handler[*api.ShippingQuery]) {
r.shippingQuery = append(r.shippingQuery, h)
}
// OnPoll registers a handler for poll state updates.
func (r *Router) OnPoll(h Handler[*api.Poll]) {
r.poll = append(r.poll, h)
}
// OnPollAnswer registers a handler for poll answer updates.
func (r *Router) OnPollAnswer(h Handler[*api.PollAnswer]) {
r.pollAnswer = append(r.pollAnswer, h)
}
// OnChosenInlineResult registers a handler for chosen inline results.
func (r *Router) OnChosenInlineResult(h Handler[*api.ChosenInlineResult]) {
r.chosenInlineResult = append(r.chosenInlineResult, h)
}
// OnMessageReaction registers a handler for message reaction updates.
func (r *Router) OnMessageReaction(h Handler[*api.MessageReactionUpdated]) {
r.messageReaction = append(r.messageReaction, h)
}
// OnMessageReactionCount registers a handler for anonymous message reaction count updates.
func (r *Router) OnMessageReactionCount(h Handler[*api.MessageReactionCountUpdated]) {
r.messageReactionCnt = append(r.messageReactionCnt, h)
}
// OnChatBoost registers a handler for chat boost updates.
func (r *Router) OnChatBoost(h Handler[*api.ChatBoostUpdated]) {
r.chatBoost = append(r.chatBoost, h)
}
// OnRemovedChatBoost registers a handler for removed chat boost updates.
func (r *Router) OnRemovedChatBoost(h Handler[*api.ChatBoostRemoved]) {
r.removedChatBoost = append(r.removedChatBoost, h)
}
// OnBusinessConnection registers a handler for business connection updates.
func (r *Router) OnBusinessConnection(h Handler[*api.BusinessConnection]) {
r.businessConn = append(r.businessConn, h)
}
// OnPurchasedPaidMedia registers a handler for purchased paid media updates.
func (r *Router) OnPurchasedPaidMedia(h Handler[*api.PaidMediaPurchased]) {
r.purchasedPaidMedia = append(r.purchasedPaidMedia, h)
}
// Run consumes the Updater and dispatches each update. It blocks until
// the Updater's channel is closed or ctx is cancelled.
//
// By default updates are processed concurrently (up to WithMaxConcurrency(50)
// goroutines). Handlers for different updates may therefore run simultaneously;
// shared state must be protected. Pass WithMaxConcurrency(0) to New to restore
// serial (legacy) behaviour.
//
// Run waits for all in-flight handlers to finish before returning.
func (r *Router) Run(ctx context.Context, u transport.Updater) error {
runErr := make(chan error, 1)
go func() { runErr <- u.Run(ctx) }()
root := r.dispatch
for i := len(r.globalMW) - 1; i >= 0; i-- {
root = r.globalMW[i](root)
}
var wg sync.WaitGroup
defer wg.Wait()
dispatch := func(up api.Update) {
c := NewContext(ctx, r.bot, &up)
if err := root(c, &up); err != nil {
if r.bot != nil {
r.bot.Logger().Error("dispatch handler error", "err", err, "update_id", up.UpdateID)
}
}
}
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-runErr:
return err
case up, ok := <-u.Updates():
if !ok {
// Channel closed; consume the run error if pending.
select {
case err := <-runErr:
return err
default:
}
return nil
}
if r.sem == nil {
// Serial mode (legacy / WithMaxConcurrency(0)).
dispatch(up)
continue
}
// Concurrent mode: acquire semaphore slot then launch goroutine.
select {
case r.sem <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
wg.Add(1)
go func(up api.Update) {
defer func() {
<-r.sem
wg.Done()
}()
dispatch(up)
}(up)
}
}
}
func (r *Router) dispatch(c *Context, u *api.Update) error {
switch {
case u.Message != nil:
return r.handleMessage(c, u.Message)
case u.EditedMessage != nil:
return runHandlers(r.editedMsg, c, u.EditedMessage)
case u.ChannelPost != nil:
return runHandlers(r.channelPosts, c, u.ChannelPost)
case u.EditedChannelPost != nil:
return runHandlers(r.editedChannelPosts, c, u.EditedChannelPost)
case u.CallbackQuery != nil:
return r.handleCallback(c, u.CallbackQuery)
case u.InlineQuery != nil:
if err := runHandlers(r.inlines, c, u.InlineQuery); err != nil {
return err
}
for _, route := range r.inlineFilters {
if route.filter(u.InlineQuery) {
return route.handler(c, u.InlineQuery)
}
}
return nil
case u.MyChatMember != nil:
return r.handleChatMemberUpdate(c, u.MyChatMember, r.myChatMember, r.myChatMemberFilters)
case u.ChatMember != nil:
return r.handleChatMemberUpdate(c, u.ChatMember, r.chatMember, r.chatMemberFilters)
case u.ChatJoinRequest != nil:
return r.handleChatJoinRequest(c, u.ChatJoinRequest)
case u.PreCheckoutQuery != nil:
return r.handlePreCheckoutQuery(c, u.PreCheckoutQuery)
case u.ShippingQuery != nil:
return runHandlers(r.shippingQuery, c, u.ShippingQuery)
case u.Poll != nil:
return runHandlers(r.poll, c, u.Poll)
case u.PollAnswer != nil:
return runHandlers(r.pollAnswer, c, u.PollAnswer)
case u.ChosenInlineResult != nil:
return runHandlers(r.chosenInlineResult, c, u.ChosenInlineResult)
case u.MessageReaction != nil:
return runHandlers(r.messageReaction, c, u.MessageReaction)
case u.MessageReactionCount != nil:
return runHandlers(r.messageReactionCnt, c, u.MessageReactionCount)
case u.ChatBoost != nil:
return runHandlers(r.chatBoost, c, u.ChatBoost)
case u.RemovedChatBoost != nil:
return runHandlers(r.removedChatBoost, c, u.RemovedChatBoost)
case u.BusinessConnection != nil:
return runHandlers(r.businessConn, c, u.BusinessConnection)
case u.PurchasedPaidMedia != nil:
return runHandlers(r.purchasedPaidMedia, c, u.PurchasedPaidMedia)
}
return nil
}
func (r *Router) handleChatMemberUpdate(c *Context, payload *api.ChatMemberUpdated, handlers []Handler[*api.ChatMemberUpdated], filters []chatMemberFilterRoute) error {
if err := runHandlers(handlers, c, payload); err != nil {
return err
}
for _, route := range filters {
if route.filter(payload) {
return route.handler(c, payload)
}
}
return nil
}
func (r *Router) handleChatJoinRequest(c *Context, payload *api.ChatJoinRequest) error {
if err := runHandlers(r.chatJoinRequest, c, payload); err != nil {
return err
}
for _, route := range r.chatJoinRequestFilters {
if route.filter(payload) {
return route.handler(c, payload)
}
}
return nil
}
func (r *Router) handlePreCheckoutQuery(c *Context, payload *api.PreCheckoutQuery) error {
if err := runHandlers(r.preCheckoutQuery, c, payload); err != nil {
return err
}
for _, route := range r.preCheckoutFilters {
if route.filter(payload) {
return route.handler(c, payload)
}
}
return nil
}
// runHandlers invokes each handler in order; returns the first non-nil error.
func runHandlers[T any](handlers []Handler[T], c *Context, payload T) error {
for _, h := range handlers {
if err := h(c, payload); err != nil {
return err
}
}
return nil
}
func (r *Router) handleMessage(c *Context, m *api.Message) error {
// Try command first (entity-aware).
if cmd, args, ok := extractCommand(m); ok {
for _, route := range r.commands {
if route.cmd == cmd {
c.Values["command"] = cmd
c.Values["command_args"] = args
return route.handler(c, m)
}
}
}
// Then text regex matchers.
if m.Text != "" {
for _, route := range r.texts {
if subs := route.re.FindStringSubmatch(m.Text); subs != nil {
c.Values["regex_match"] = subs
return route.handler(c, m)
}
}
}
// Filter-based routes.
for _, route := range r.messageFilters {
if route.filter(m) {
return route.handler(c, m)
}
}
// Group-priority routes (registered via RouterScope.Group()).
return r.dispatchGroups(c, m)
}
func (r *Router) handleCallback(c *Context, q *api.CallbackQuery) error {
for _, route := range r.callbacks {
if subs := route.re.FindStringSubmatch(q.Data); subs != nil {
c.Values["regex_match"] = subs
return route.handler(c, q)
}
}
// Filter-based routes checked after pattern routes.
for _, route := range r.callbackFilters {
if route.filter(q) {
return route.handler(c, q)
}
}
return nil
}
// extractCommand returns the command (e.g. "/start") and the remaining
// argument string, when m carries a leading bot_command entity. It strips
// optional "@BotName" suffix on the command itself.
func extractCommand(m *api.Message) (cmd, args string, ok bool) {
if len(m.Entities) == 0 || m.Text == "" {
return "", "", false
}
first := m.Entities[0]
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
return "", "", false
}
cmd, sliceOk := utf16Slice(m.Text, int(first.Offset), int(first.Length))
if !sliceOk {
return "", "", false
}
if i := strings.Index(cmd, "@"); i >= 0 {
cmd = cmd[:i]
}
end := int(first.Offset) + int(first.Length)
rest, _ := utf16Slice(m.Text, end, utf16Len(m.Text)-end)
args = strings.TrimSpace(rest)
return cmd, args, true
}
// utf16Slice returns the substring of s identified by a UTF-16 offset/length
// pair, as Telegram's MessageEntity uses. ok is false if the indices fall
// outside s's UTF-16 length.
func utf16Slice(s string, offset, length int) (string, bool) {
runes := []rune(s)
var startBytes, endBytes int
var u16 int
found := false
for i, r := range runes {
if u16 == offset {
startBytes = byteIndex(runes, i)
found = true
}
if u16 == offset+length {
endBytes = byteIndex(runes, i)
return s[startBytes:endBytes], true
}
if r > 0xFFFF {
u16 += 2
} else {
u16++
}
}
if found && u16 == offset+length {
return s[startBytes:], true
}
return "", false
}
func byteIndex(runes []rune, runeIdx int) int {
n := 0
for i := 0; i < runeIdx; i++ {
n += utf8.RuneLen(runes[i])
}
return n
}
func utf16Len(s string) int {
n := 0
for _, r := range s {
if r > 0xFFFF {
n += 2
} else {
n++
}
}
return n
}
+940
View File
@@ -0,0 +1,940 @@
package dispatch
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/stretchr/testify/require"
)
// fakeUpdater feeds a fixed slice of updates then closes.
type fakeUpdater struct{ ch chan api.Update }
func newFake(ups ...api.Update) *fakeUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeUpdater{ch: ch}
}
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
func cmdMessage(text string) api.Update {
return api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1, Date: 0, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(indexEnd(text))}},
},
}
}
func indexEnd(text string) int {
for i, r := range text {
if r == ' ' {
return i
}
}
return len(text)
}
func TestRouter_OnCommandMatches(t *testing.T) {
b := client.New("t")
r := New(b)
hit := make(chan string, 1)
r.OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- c.Values["command"].(string)
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start"))) }()
require.Equal(t, "/start", <-hit)
}
func TestRouter_OnCommandStripsBotName(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnCommand("/start", func(c *Context, m *api.Message) error {
hit <- "matched"
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start@MyBot hello"))) }()
require.Equal(t, "matched", <-hit)
}
func TestRouter_OnText(t *testing.T) {
r := New(client.New("t"))
hit := make(chan []string, 1)
r.OnText(`^hello (\w+)$`, func(c *Context, m *api.Message) error {
hit <- c.Values["regex_match"].([]string)
return nil
})
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "hello world",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
subs := <-hit
require.Equal(t, "world", subs[1])
}
func TestRouter_OnCallback(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnCallback(`^like:(\d+)$`, func(c *Context, q *api.CallbackQuery) error {
hit <- q.Data
return nil
})
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "like:42",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "like:42", <-hit)
}
func TestRouter_NoMatch(t *testing.T) {
r := New(client.New("t"))
called := false
r.OnCommand("/start", func(c *Context, m *api.Message) error {
called = true
return nil
})
u := api.Update{UpdateID: 1, Message: &api.Message{Text: "no command"}}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.False(t, called)
}
func TestRouter_PanicRecovery(t *testing.T) {
r := New(client.New("t"))
r.OnCommand("/boom", func(c *Context, m *api.Message) error {
panic("kaboom")
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
// Should not propagate panic to Run.
require.NotPanics(t, func() { _ = r.Run(ctx, newFake(cmdMessage("/boom"))) })
}
// TestRouter_NonASCIICommand verifies that UTF-16 entity offsets are used
// correctly when the command contains non-ASCII runes. "/старт" is 6 runes,
// each a BMP code point, so UTF-16 length == 6.
func TestRouter_NonASCIICommand(t *testing.T) {
const text = "/старт аргумент"
// "/старт" = 1 + 5 runes, all BMP → UTF-16 length 6
const cmdU16Len = int64(6)
u := api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1,
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: text,
Entities: []api.MessageEntity{
{Type: string(api.EntityBotCommand), Offset: 0, Length: cmdU16Len},
},
},
}
r := New(client.New("t"))
hit := make(chan [2]string, 1)
r.OnCommand("/старт", func(c *Context, m *api.Message) error {
hit <- [2]string{
c.Values["command"].(string),
c.Values["command_args"].(string),
}
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
got := <-hit
require.Equal(t, "/старт", got[0])
require.Equal(t, "аргумент", got[1])
}
// TestRouter_CommandValuesNotLeakedOnNoMatch verifies that c.Values["command"]
// is not set when a command entity is present but no route matches, so a
// subsequent text handler doesn't see stale values.
func TestRouter_CommandValuesNotLeakedOnNoMatch(t *testing.T) {
r := New(client.New("t"))
// Register a text handler that should fire as fallback.
leaked := make(chan bool, 1)
r.OnText(`.*`, func(c *Context, m *api.Message) error {
_, hasCmd := c.Values["command"]
leaked <- hasCmd
return nil
})
// No OnCommand registered, so the command entity won't match any route.
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"},
Text: "/unknown",
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 8}},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.False(t, <-leaked, "command value must not leak into text handler")
}
func TestRouter_MiddlewareOrder(t *testing.T) {
r := New(client.New("t"))
var order []string
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
return func(c *Context, u *api.Update) error {
order = append(order, "before-1")
err := next(c, u)
order = append(order, "after-1")
return err
}
})
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
return func(c *Context, u *api.Update) error {
order = append(order, "before-2")
err := next(c, u)
order = append(order, "after-2")
return err
}
})
r.OnCommand("/x", func(c *Context, m *api.Message) error {
order = append(order, "handler")
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(cmdMessage("/x")))
require.Equal(t,
[]string{"before-1", "before-2", "handler", "after-2", "after-1"},
order)
}
func TestRouter_OnChannelPost(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChannelPost(func(c *Context, m *api.Message) error {
hit <- m.MessageID
return nil
})
u := api.Update{UpdateID: 1, ChannelPost: &api.Message{
MessageID: 99, Chat: api.Chat{ID: -100, Type: string(api.ChatTypeChannel)},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, int64(99), <-hit)
}
func TestRouter_RunsAllHandlersForEditedMessage(t *testing.T) {
r := New(client.New("t"))
var hits []string
r.OnEditedMessage(func(c *Context, m *api.Message) error {
hits = append(hits, "first")
return nil
})
r.OnEditedMessage(func(c *Context, m *api.Message) error {
hits = append(hits, "second")
return nil
})
u := api.Update{UpdateID: 1, EditedMessage: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "edited",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.Equal(t, []string{"first", "second"}, hits)
}
// ---------------------------------------------------------------------------
// Filter-route tests
// ---------------------------------------------------------------------------
func TestRouter_OnMessageFilter_Matches(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ping" }),
func(c *Context, m *api.Message) error { hit <- m.Text; return nil },
)
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "ping",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "ping", <-hit)
}
func TestRouter_OnMessageFilter_NoMatch(t *testing.T) {
r := New(client.New("t"))
called := false
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return false }),
func(c *Context, m *api.Message) error { called = true; return nil },
)
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "any",
}}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.False(t, called)
}
// Command routes must take priority over filter routes.
func TestRouter_OnMessageFilter_CommandWins(t *testing.T) {
r := New(client.New("t"))
var winner string
r.OnCommand("/start", func(c *Context, m *api.Message) error { winner = "command"; return nil })
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error { winner = "filter"; return nil },
)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(cmdMessage("/start")))
require.Equal(t, "command", winner)
}
func TestRouter_OnCallbackFilter_Matches(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnCallbackFilter(
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return q != nil && q.Data == "yes" }),
func(c *Context, q *api.CallbackQuery) error { hit <- q.Data; return nil },
)
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "yes", <-hit)
}
// Pattern-based OnCallback wins over OnCallbackFilter when both match.
func TestRouter_OnCallbackFilter_PatternWins(t *testing.T) {
r := New(client.New("t"))
var winner string
r.OnCallback(`^yes$`, func(c *Context, q *api.CallbackQuery) error { winner = "pattern"; return nil })
r.OnCallbackFilter(
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return true }),
func(c *Context, q *api.CallbackQuery) error { winner = "filter"; return nil },
)
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_ = r.Run(ctx, newFake(u))
require.Equal(t, "pattern", winner)
}
func TestRouter_OnInlineQueryFilter_Matches(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnInlineQueryFilter(
Filter[*api.InlineQuery](func(q *api.InlineQuery) bool { return q != nil && q.Query == "find" }),
func(c *Context, q *api.InlineQuery) error { hit <- q.Query; return nil },
)
u := api.Update{UpdateID: 1, InlineQuery: &api.InlineQuery{
ID: "i", From: api.User{ID: 1}, Query: "find",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(u)) }()
require.Equal(t, "find", <-hit)
}
func TestRouter_FilterChain_Composition(t *testing.T) {
// Filter: private chat AND text contains "hello"
privateChat := Filter[*api.Message](func(m *api.Message) bool {
return m != nil && m.Chat.Type == string(api.ChatTypePrivate)
})
hasHello := Filter[*api.Message](func(m *api.Message) bool {
return m != nil && len(m.Text) > 0 && containsStr(m.Text, "hello")
})
combined := privateChat.And(hasHello)
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnMessageFilter(combined, func(c *Context, m *api.Message) error { hit <- m.Text; return nil })
match := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "say hello",
}}
noMatch := api.Update{UpdateID: 2, Message: &api.Message{
MessageID: 2, Chat: api.Chat{ID: 2, Type: string(api.ChatTypeGroup)}, Text: "say hello",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(match, noMatch)) }()
require.Equal(t, "say hello", <-hit)
}
// containsStr is a helper to avoid importing strings in test file unnecessarily.
func containsStr(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub))
}
func containsSubstr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
// ---------------------------------------------------------------------------
// Concurrent dispatch tests
// ---------------------------------------------------------------------------
// fakeSlowUpdater feeds n updates then blocks until ctx cancel.
type fakeSlowUpdater struct {
ch chan api.Update
}
func newSlowFake(ups ...api.Update) *fakeSlowUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeSlowUpdater{ch: ch}
}
func (f *fakeSlowUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeSlowUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeSlowUpdater) Stop(ctx context.Context) error { return nil }
func TestRouter_ConcurrentDispatch_AllHandlersFire(t *testing.T) {
const n = 100
var fired atomic.Int64
ups := make([]api.Update, n)
for i := range ups {
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}}
}
r := New(client.New("t"), WithMaxConcurrency(20))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error { fired.Add(1); return nil },
)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = r.Run(ctx, newSlowFake(ups...))
require.Equal(t, int64(n), fired.Load())
}
func TestRouter_ConcurrentDispatch_SemaphoreBoundsConcurrency(t *testing.T) {
const limit = 5
const n = 30
var inFlight atomic.Int64
var maxSeen atomic.Int64
ready := make(chan struct{}) // signals handler to proceed
started := make(chan struct{}) // first handler signals it's running
ups := make([]api.Update, n)
for i := range ups {
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}}
}
once := atomic.Bool{}
r := New(client.New("t"), WithMaxConcurrency(limit))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error {
cur := inFlight.Add(1)
for {
old := maxSeen.Load()
if cur <= old || maxSeen.CompareAndSwap(old, cur) {
break
}
}
if once.CompareAndSwap(false, true) {
close(started)
}
<-ready
inFlight.Add(-1)
return nil
},
)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go func() { _ = r.Run(ctx, newSlowFake(ups...)) }()
select {
case <-started:
case <-ctx.Done():
t.Fatal("timed out waiting for first handler")
}
// Give the pool a moment to fill up.
time.Sleep(50 * time.Millisecond)
close(ready)
// Wait for Run to drain by cancelling context after a short wait.
time.Sleep(200 * time.Millisecond)
cancel()
require.LessOrEqual(t, maxSeen.Load(), int64(limit),
"in-flight goroutines exceeded semaphore limit")
}
func TestRouter_ConcurrentDispatch_WaitsForInFlight(t *testing.T) {
unblock := make(chan struct{})
done := make(chan struct{})
r := New(client.New("t"), WithMaxConcurrency(10))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error {
<-unblock
return nil
},
)
u := api.Update{UpdateID: 1, Message: &api.Message{
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "hi",
}}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go func() {
_ = r.Run(ctx, newSlowFake(u))
close(done)
}()
// Give Run time to pick up the update and launch the goroutine.
time.Sleep(30 * time.Millisecond)
cancel() // trigger Run to exit its loop
// Run should not return until handler unblocks.
select {
case <-done:
t.Fatal("Run returned before in-flight handler finished")
case <-time.After(50 * time.Millisecond):
}
close(unblock)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Run did not return after handler finished")
}
}
func TestRouter_SerialMode_NoRace(t *testing.T) {
// WithMaxConcurrency(0) — serial; shared slice is safe without a mutex.
var order []int64
const n = 20
ups := make([]api.Update, n)
for i := range ups {
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}}
}
r := New(client.New("t"), WithMaxConcurrency(0))
r.OnMessageFilter(
Filter[*api.Message](func(m *api.Message) bool { return true }),
func(c *Context, m *api.Message) error {
order = append(order, m.MessageID)
return nil
},
)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = r.Run(ctx, newSlowFake(ups...))
require.Len(t, order, n)
for i, v := range order {
require.Equal(t, int64(i+1), v)
}
}
// liveUpdater is an updater whose channel stays open until stopCh is closed.
type liveUpdater struct {
ch chan api.Update
stopCh chan struct{}
}
func newLiveUpdater() *liveUpdater {
return &liveUpdater{ch: make(chan api.Update, 8), stopCh: make(chan struct{})}
}
func (l *liveUpdater) Send(u api.Update) { l.ch <- u }
func (l *liveUpdater) Close() { close(l.stopCh) }
func (l *liveUpdater) Updates() <-chan api.Update { return l.ch }
func (l *liveUpdater) Stop(ctx context.Context) error { return nil }
func (l *liveUpdater) Run(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-l.stopCh:
return nil
}
}
// ---------------------------------------------------------------------------
// Typed handler tests (Feature 1)
// ---------------------------------------------------------------------------
func TestRouter_OnMyChatMember(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnMyChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
upd := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{
From: api.User{ID: 42},
Chat: api.Chat{ID: 1},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(42), <-hit)
}
func TestRouter_OnMyChatMemberFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.From.ID == 99 })
r.OnMyChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
match := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 99}}}
noMatch := api.Update{UpdateID: 2, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(noMatch, match)) }()
require.Equal(t, int64(99), <-hit)
}
func TestRouter_OnChatMember(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{
From: api.User{ID: 1},
Chat: api.Chat{ID: 77},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(77), <-hit)
}
func TestRouter_OnChatMemberFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.Chat.ID == 55 })
r.OnChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{Chat: api.Chat{ID: 55}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(55), <-hit)
}
func TestRouter_OnChatJoinRequest(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChatJoinRequest(func(c *Context, req *api.ChatJoinRequest) error { hit <- req.From.ID; return nil })
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
From: api.User{ID: 11},
Chat: api.Chat{ID: 1},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(11), <-hit)
}
func TestRouter_OnChatJoinRequestFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
f := Filter[*api.ChatJoinRequest](func(req *api.ChatJoinRequest) bool { return req.Chat.ID == 22 })
r.OnChatJoinRequestFilter(f, func(c *Context, req *api.ChatJoinRequest) error { hit <- req.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
From: api.User{ID: 1},
Chat: api.Chat{ID: 22},
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(22), <-hit)
}
func TestRouter_OnPreCheckoutQuery(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPreCheckoutQuery(func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
ID: "q1", From: api.User{ID: 1}, Currency: "USD",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "USD", <-hit)
}
func TestRouter_OnPreCheckoutQueryFilter(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
f := Filter[*api.PreCheckoutQuery](func(q *api.PreCheckoutQuery) bool { return q.Currency == "EUR" })
r.OnPreCheckoutQueryFilter(f, func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
ID: "q1", From: api.User{ID: 1}, Currency: "EUR",
}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "EUR", <-hit)
}
func TestRouter_OnShippingQuery(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnShippingQuery(func(c *Context, q *api.ShippingQuery) error { hit <- q.ID; return nil })
upd := api.Update{UpdateID: 1, ShippingQuery: &api.ShippingQuery{ID: "sq1", From: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "sq1", <-hit)
}
func TestRouter_OnPoll(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPoll(func(c *Context, p *api.Poll) error { hit <- p.ID; return nil })
upd := api.Update{UpdateID: 1, Poll: &api.Poll{ID: "poll1"}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "poll1", <-hit)
}
func TestRouter_OnPollAnswer(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPollAnswer(func(c *Context, a *api.PollAnswer) error { hit <- a.PollID; return nil })
upd := api.Update{UpdateID: 1, PollAnswer: &api.PollAnswer{PollID: "p1", OptionIds: []int64{0}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "p1", <-hit)
}
func TestRouter_OnChosenInlineResult(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnChosenInlineResult(func(c *Context, res *api.ChosenInlineResult) error { hit <- res.ResultID; return nil })
upd := api.Update{UpdateID: 1, ChosenInlineResult: &api.ChosenInlineResult{ResultID: "r1", From: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "r1", <-hit)
}
func TestRouter_OnMessageReaction(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnMessageReaction(func(c *Context, u *api.MessageReactionUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, MessageReaction: &api.MessageReactionUpdated{Chat: api.Chat{ID: 33}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(33), <-hit)
}
func TestRouter_OnMessageReactionCount(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnMessageReactionCount(func(c *Context, u *api.MessageReactionCountUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, MessageReactionCount: &api.MessageReactionCountUpdated{Chat: api.Chat{ID: 44}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(44), <-hit)
}
func TestRouter_OnChatBoost(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnChatBoost(func(c *Context, u *api.ChatBoostUpdated) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, ChatBoost: &api.ChatBoostUpdated{Chat: api.Chat{ID: 55}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(55), <-hit)
}
func TestRouter_OnRemovedChatBoost(t *testing.T) {
r := New(client.New("t"))
hit := make(chan int64, 1)
r.OnRemovedChatBoost(func(c *Context, u *api.ChatBoostRemoved) error { hit <- u.Chat.ID; return nil })
upd := api.Update{UpdateID: 1, RemovedChatBoost: &api.ChatBoostRemoved{Chat: api.Chat{ID: 66}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, int64(66), <-hit)
}
func TestRouter_OnBusinessConnection(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnBusinessConnection(func(c *Context, bc *api.BusinessConnection) error { hit <- bc.ID; return nil })
upd := api.Update{UpdateID: 1, BusinessConnection: &api.BusinessConnection{ID: "bc1", UserChatID: 1, User: api.User{ID: 1}}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "bc1", <-hit)
}
func TestRouter_OnPurchasedPaidMedia(t *testing.T) {
r := New(client.New("t"))
hit := make(chan string, 1)
r.OnPurchasedPaidMedia(func(c *Context, p *api.PaidMediaPurchased) error { hit <- p.PaidMediaPayload; return nil })
upd := api.Update{UpdateID: 1, PurchasedPaidMedia: &api.PaidMediaPurchased{From: api.User{ID: 1}, PaidMediaPayload: "payload1"}}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() { _ = r.Run(ctx, newFake(upd)) }()
require.Equal(t, "payload1", <-hit)
}
func TestRouter_ContextCancel_UnblocksWaitingAcquire(t *testing.T) {
// Fill the semaphore with slow handlers, send one more update, then cancel
// ctx. Run must unblock from the semaphore-acquire select and return.
const limit = 2
unblock := make(chan struct{})
slowHandler := func(c *Context, m *api.Message) error {
<-unblock
return nil
}
lu := newLiveUpdater()
r := New(client.New("t"), WithMaxConcurrency(limit))
r.OnMessageFilter(Filter[*api.Message](func(m *api.Message) bool { return true }), slowHandler)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runDone := make(chan error, 1)
go func() { runDone <- r.Run(ctx, lu) }()
// Send enough updates to fill semaphore.
for i := range limit {
lu.Send(api.Update{UpdateID: int64(i + 1), Message: &api.Message{
MessageID: int64(i + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "hi",
}})
}
// Give goroutines time to acquire all semaphore slots.
time.Sleep(50 * time.Millisecond)
// Send one more update — Run will block trying to acquire the full semaphore.
lu.Send(api.Update{UpdateID: int64(limit + 1), Message: &api.Message{
MessageID: int64(limit + 1),
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
Text: "extra",
}})
// Give Run a moment to reach the semaphore-acquire select.
time.Sleep(30 * time.Millisecond)
cancel()
// Unblock handlers so wg.Wait() inside Run can complete, allowing Run to
// return (and write to runDone).
close(unblock)
select {
case err := <-runDone:
require.Error(t, err)
case <-time.After(2 * time.Second):
t.Fatal("Run did not unblock after context cancel")
}
}