mirror of
https://github.com/lukaszraczylo/go-telegram.git
synced 2026-06-05 22:43:59 +00:00
Initial release of go-telegram
A fully-generated, strongly-typed Go client for the Telegram Bot API. * 176 methods + 301 types generated from Bot API v10.0 * 1408 auto-generated tests (8 scenarios per method) * Typed unions throughout — no 'any' in the public surface * Pluggable HTTP transport and JSON codec (default goccy/go-json) * Built-in retry middleware honouring Telegram's retry_after * Generic dispatcher with filters and conversation handlers * Self-verifying codegen pipeline (regen → audit → emit → run tests) * 14 example bots covering common patterns
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
// Package dispatch provides a typed router for Telegram updates. It
|
||||
// consumes any transport.Updater and dispatches updates to handlers
|
||||
// registered by command, regex, or update-payload kind.
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
)
|
||||
|
||||
// Context bundles the per-update state every handler receives.
|
||||
//
|
||||
// Ctx is the request context propagated from Router.Run; cancelling the
|
||||
// run cancels every handler.
|
||||
//
|
||||
// Bot is the API client. Handlers reply by calling api.SendMessage(c.Ctx,
|
||||
// c.Bot, ...) etc.
|
||||
//
|
||||
// Update is the raw update; payload-typed handlers also receive a
|
||||
// narrowed pointer to one of its sub-fields.
|
||||
//
|
||||
// Values is a per-update bag matchers populate. Conventional keys:
|
||||
//
|
||||
// "command": string, the matched bot command (e.g. "/start")
|
||||
// "command_args": string, everything after the command
|
||||
// "regex_match": []string, regex sub-matches when OnText matches
|
||||
type Context struct {
|
||||
Ctx context.Context
|
||||
Bot *client.Bot
|
||||
Update *api.Update
|
||||
Values map[string]any
|
||||
}
|
||||
|
||||
// NewContext constructs a Context. Used by Router internally; exposed for
|
||||
// custom test harnesses.
|
||||
func NewContext(ctx context.Context, b *client.Bot, u *api.Update) *Context {
|
||||
return &Context{Ctx: ctx, Bot: b, Update: u, Values: map[string]any{}}
|
||||
}
|
||||
@@ -0,0 +1,494 @@
|
||||
package conversation_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch/conversation"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---- helpers ---------------------------------------------------------------
|
||||
|
||||
func msgUpd(userID, chatID int64, text string) api.Update {
|
||||
return api.Update{
|
||||
UpdateID: 1,
|
||||
Message: &api.Message{
|
||||
MessageID: 1,
|
||||
From: &api.User{ID: userID},
|
||||
Chat: api.Chat{ID: chatID},
|
||||
Text: text,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeCtx(u *api.Update) *dispatch.Context {
|
||||
return dispatch.NewContext(context.Background(), client.New("t"), u)
|
||||
}
|
||||
|
||||
// anyMsg matches any update that has a Message.
|
||||
var anyMsg = func(u *api.Update) bool { return u.Message != nil }
|
||||
|
||||
// hasPrefix returns a filter matching updates whose Message.Text has prefix p.
|
||||
func hasPrefix(p string) dispatch.Filter[*api.Update] {
|
||||
return func(u *api.Update) bool {
|
||||
return u.Message != nil && strings.HasPrefix(u.Message.Text, p)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeUpdater feeds a fixed set of updates then closes (mirrors router_test.go).
|
||||
type fakeUpdater struct{ ch chan api.Update }
|
||||
|
||||
func newFake(ups ...api.Update) *fakeUpdater {
|
||||
ch := make(chan api.Update, len(ups))
|
||||
for _, u := range ups {
|
||||
ch <- u
|
||||
}
|
||||
close(ch)
|
||||
return &fakeUpdater{ch: ch}
|
||||
}
|
||||
|
||||
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
|
||||
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
|
||||
|
||||
// ---- Storage tests ---------------------------------------------------------
|
||||
|
||||
func TestStorage_ErrKeyNotFound(t *testing.T) {
|
||||
s := conversation.NewMemoryStorage()
|
||||
_, err := s.Get(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestStorage_SetAndGet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := conversation.NewMemoryStorage()
|
||||
require.NoError(t, s.Set(ctx, "k", "state_a"))
|
||||
v, err := s.Get(ctx, "k")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conversation.State("state_a"), v)
|
||||
}
|
||||
|
||||
func TestStorage_Delete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := conversation.NewMemoryStorage()
|
||||
require.NoError(t, s.Set(ctx, "k", "state_a"))
|
||||
require.NoError(t, s.Delete(ctx, "k"))
|
||||
_, err := s.Get(ctx, "k")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestStorage_DeleteNonExistentIsNoop(t *testing.T) {
|
||||
require.NoError(t, conversation.NewMemoryStorage().Delete(context.Background(), "gone"))
|
||||
}
|
||||
|
||||
// ---- Key strategy tests ----------------------------------------------------
|
||||
|
||||
func TestKeyByUser_Variants(t *testing.T) {
|
||||
t.Run("message", func(t *testing.T) {
|
||||
u := msgUpd(42, 100, "hi")
|
||||
require.Equal(t, "u:42", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("edited_message", func(t *testing.T) {
|
||||
u := api.Update{EditedMessage: &api.Message{From: &api.User{ID: 7}, Chat: api.Chat{ID: 1}}}
|
||||
require.Equal(t, "u:7", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("callback_query", func(t *testing.T) {
|
||||
u := api.Update{CallbackQuery: &api.CallbackQuery{From: api.User{ID: 99}}}
|
||||
require.Equal(t, "u:99", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("inline_query", func(t *testing.T) {
|
||||
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||
require.Equal(t, "u:5", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
require.Equal(t, "", conversation.KeyByUser(&api.Update{}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByChat_Variants(t *testing.T) {
|
||||
t.Run("message", func(t *testing.T) {
|
||||
u := msgUpd(1, 200, "")
|
||||
require.Equal(t, "c:200", conversation.KeyByChat(&u))
|
||||
})
|
||||
t.Run("inline_has_no_chat", func(t *testing.T) {
|
||||
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||
require.Equal(t, "", conversation.KeyByChat(&u))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByUserAndChat(t *testing.T) {
|
||||
u := msgUpd(42, 100, "")
|
||||
require.Equal(t, "uc:100:42", conversation.KeyByUserAndChat(&u))
|
||||
}
|
||||
|
||||
// ---- Handler / state machine tests -----------------------------------------
|
||||
|
||||
func buildConv() *conversation.Conversation {
|
||||
return &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
return conversation.Next("await_name")
|
||||
},
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"await_name": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("await_age") },
|
||||
}},
|
||||
"await_age": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||
}},
|
||||
},
|
||||
Exits: []conversation.Step{{
|
||||
Filter: hasPrefix("/cancel"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestConversation_FullFlow(t *testing.T) {
|
||||
conv := buildConv()
|
||||
|
||||
var downstream int
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstream++
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
key := "uc:1:42"
|
||||
|
||||
// 1. /start → enters, state = await_name
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
v, err := conv.Storage.Get(context.Background(), key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conversation.State("await_name"), v)
|
||||
require.Equal(t, 0, downstream, "entry claimed update")
|
||||
|
||||
// 2. name → state = await_age
|
||||
u2 := msgUpd(42, 1, "Alice")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
v, err = conv.Storage.Get(context.Background(), key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conversation.State("await_age"), v)
|
||||
|
||||
// 3. age → End, key deleted
|
||||
u3 := msgUpd(42, 1, "30")
|
||||
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||
_, err = conv.Storage.Get(context.Background(), key)
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestConversation_ExitsCancelMidFlow(t *testing.T) {
|
||||
conv := buildConv()
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
// Start conversation.
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cancel mid-flow.
|
||||
u2 := msgUpd(42, 1, "/cancel")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
_, err = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "exit should clear state")
|
||||
}
|
||||
|
||||
func TestConversation_FallbackFiresWhenNoStateStepMatches(t *testing.T) {
|
||||
fallbackHit := false
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
// No steps for "waiting" that match a callback query.
|
||||
"waiting": {},
|
||||
},
|
||||
Fallbacks: []conversation.Step{{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
fallbackHit = true
|
||||
return nil
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
u2 := msgUpd(42, 1, "unexpected text")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
require.True(t, fallbackHit, "fallback should have fired")
|
||||
}
|
||||
|
||||
func TestConversation_NoActiveConv_PassesToDownstream(t *testing.T) {
|
||||
conv := buildConv()
|
||||
downstreamHit := false
|
||||
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstreamHit = true
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(downstream)
|
||||
|
||||
// Random message that doesn't match /start
|
||||
u := msgUpd(42, 1, "hello")
|
||||
require.NoError(t, mw(makeCtx(&u), &u))
|
||||
require.True(t, downstreamHit, "unmatched update should reach downstream")
|
||||
}
|
||||
|
||||
func TestConversation_EmptyKey_PassesThrough(t *testing.T) {
|
||||
// InlineQuery has no chatID → KeyByUserAndChat returns "" → pass through.
|
||||
conv := buildConv()
|
||||
downstreamHit := false
|
||||
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstreamHit = true
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(downstream)
|
||||
|
||||
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||
require.NoError(t, mw(makeCtx(&u), &u))
|
||||
require.True(t, downstreamHit)
|
||||
}
|
||||
|
||||
func TestConversation_AllowReEntry(t *testing.T) {
|
||||
conv := buildConv()
|
||||
conv.AllowReEntry = true
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
// Start.
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_name"), v)
|
||||
|
||||
// Advance once.
|
||||
u2 := msgUpd(42, 1, "Alice")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_age"), v)
|
||||
|
||||
// Re-enter with /start — should restart to await_name even though mid-flow.
|
||||
u3 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_name"), v, "AllowReEntry should restart")
|
||||
}
|
||||
|
||||
func TestConversation_NoReEntry_EntryIgnoredWhenActive(t *testing.T) {
|
||||
conv := buildConv()
|
||||
conv.AllowReEntry = false
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
// Start.
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
// Advance to await_age.
|
||||
u2 := msgUpd(42, 1, "Alice")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_age"), v)
|
||||
|
||||
// /start again — should NOT restart; state should stay await_age since
|
||||
// /start matches the state step filter (anyMsg) and advances.
|
||||
// Actually /start is handled by "await_age" anyMsg step → End().
|
||||
u3 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||
// State ended (End() called by await_age step).
|
||||
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "state step should have consumed /start when AllowReEntry=false")
|
||||
}
|
||||
|
||||
func TestConversation_StayInState_NilReturn(t *testing.T) {
|
||||
// Handler returning nil keeps state unchanged.
|
||||
stored := false
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
return conversation.Next("waiting")
|
||||
},
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"waiting": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
stored = true
|
||||
return nil // stay in current state
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
u2 := msgUpd(42, 1, "something")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
require.True(t, stored)
|
||||
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("waiting"), v, "nil return should leave state unchanged")
|
||||
}
|
||||
|
||||
func TestConversation_ActiveNoMatch_Swallows(t *testing.T) {
|
||||
// Active conversation with no matching state step and no fallback:
|
||||
// update is swallowed (not passed downstream).
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"waiting": {{
|
||||
// Only matches /done specifically.
|
||||
Filter: hasPrefix("/done"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
downstreamHit := false
|
||||
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstreamHit = true
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(downstream)
|
||||
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
// Random text doesn't match /done and there's no fallback → swallowed.
|
||||
u2 := msgUpd(42, 1, "random")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
require.False(t, downstreamHit, "active conv with no matching step should swallow")
|
||||
}
|
||||
|
||||
// ---- Via Router.Run --------------------------------------------------------
|
||||
|
||||
func TestConversation_ViaRouter(t *testing.T) {
|
||||
var steps atomic.Int32
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
steps.Add(1)
|
||||
return conversation.Next("await_name")
|
||||
},
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"await_name": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
steps.Add(1)
|
||||
return conversation.Next("await_age")
|
||||
},
|
||||
}},
|
||||
"await_age": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
steps.Add(1)
|
||||
return conversation.End()
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
router := dispatch.New(client.New("t"), dispatch.WithMaxConcurrency(0)) // serial
|
||||
router.Use(conv.Dispatch)
|
||||
|
||||
ups := []api.Update{
|
||||
msgUpd(42, 1, "/start"),
|
||||
msgUpd(42, 1, "Alice"),
|
||||
msgUpd(42, 1, "30"),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- router.Run(ctx, newFake(ups...)) }()
|
||||
|
||||
// Wait for updater channel to drain (Run returns when closed).
|
||||
err := <-errCh
|
||||
if err != nil && err != context.Canceled {
|
||||
t.Fatalf("Run error: %v", err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(3), steps.Load(), "all three steps should have fired")
|
||||
}
|
||||
|
||||
// ---- Concurrent storage safety ---------------------------------------------
|
||||
|
||||
func TestConversation_ConcurrentStorageAccess(t *testing.T) {
|
||||
// 15 goroutines each running a full /start → name → age flow against the
|
||||
// same shared storage but DIFFERENT keys (one per goroutine). Validates
|
||||
// no data races.
|
||||
const numUsers = 15
|
||||
|
||||
conv := buildConv()
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
go func(uid int64) {
|
||||
defer wg.Done()
|
||||
u1 := msgUpd(uid, uid, "/start")
|
||||
_ = mw(makeCtx(&u1), &u1)
|
||||
u2 := msgUpd(uid, uid, "Alice")
|
||||
_ = mw(makeCtx(&u2), &u2)
|
||||
u3 := msgUpd(uid, uid, "30")
|
||||
_ = mw(makeCtx(&u3), &u3)
|
||||
}(int64(i + 1))
|
||||
}
|
||||
wg.Wait()
|
||||
// Race detector catches bugs; no assertion needed beyond clean finish.
|
||||
}
|
||||
|
||||
func TestConversation_ConcurrentSameKey(t *testing.T) {
|
||||
// 12 goroutines hammer the same key concurrently. Storage must not panic
|
||||
// or corrupt state. Race detector validates lock discipline.
|
||||
const goroutines = 12
|
||||
s := conversation.NewMemoryStorage()
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_ = s.Set(ctx, "shared", conversation.State("step"))
|
||||
_, _ = s.Get(ctx, "shared")
|
||||
if i%4 == 0 {
|
||||
_ = s.Delete(ctx, "shared")
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// stateTransition is a sentinel error type carrying a state transition
|
||||
// or end signal. Conversation handlers return one of these (via Next or
|
||||
// End helpers below) to drive the state machine.
|
||||
type stateTransition struct {
|
||||
next State
|
||||
end bool
|
||||
}
|
||||
|
||||
func (e *stateTransition) Error() string {
|
||||
if e.end {
|
||||
return "conversation: end"
|
||||
}
|
||||
return "conversation: → " + string(e.next)
|
||||
}
|
||||
|
||||
// Next signals the conversation should advance to the given state.
|
||||
// Conversation handlers return Next("state_name") to transition.
|
||||
func Next(s State) error {
|
||||
return &stateTransition{next: s}
|
||||
}
|
||||
|
||||
// End signals the conversation has finished and state should be cleared.
|
||||
// Conversation handlers return End() to terminate.
|
||||
func End() error {
|
||||
return &stateTransition{end: true}
|
||||
}
|
||||
|
||||
// Handler defines a step in the conversation. Receives the dispatch context
|
||||
// and the raw update. Returns:
|
||||
// - nil to stay in the current state
|
||||
// - Next("state") to transition to a different state
|
||||
// - End() to end the conversation
|
||||
// - any other non-nil error to surface to the dispatcher (state unchanged)
|
||||
type Handler func(ctx *dispatch.Context, u *api.Update) error
|
||||
|
||||
// Step pairs a filter with a handler for one conversation step.
|
||||
type Step struct {
|
||||
Filter dispatch.Filter[*api.Update]
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
// Conversation is a stateful handler with entry, per-state, exit and
|
||||
// fallback steps. A conversation is keyed by KeyStrategy (default
|
||||
// KeyByUserAndChat) and persisted by Storage (default in-memory).
|
||||
type Conversation struct {
|
||||
// EntryPoints starts a new conversation when a matching filter fires
|
||||
// and no conversation is already active for the key.
|
||||
EntryPoints []Step
|
||||
|
||||
// States maps each state to the steps that handle it.
|
||||
States map[State][]Step
|
||||
|
||||
// Exits, if any match, end the active conversation early. Useful for
|
||||
// /cancel-style commands.
|
||||
Exits []Step
|
||||
|
||||
// Fallbacks run when no state step matches the current update.
|
||||
Fallbacks []Step
|
||||
|
||||
// Storage persists conversation state. Defaults to NewMemoryStorage.
|
||||
Storage Storage
|
||||
|
||||
// KeyStrategy derives the persistence key. Defaults to KeyByUserAndChat.
|
||||
KeyStrategy KeyStrategy
|
||||
|
||||
// AllowReEntry, when true, lets entry-point steps fire even while a
|
||||
// conversation is already active for the key (effectively restarting it).
|
||||
AllowReEntry bool
|
||||
}
|
||||
|
||||
// Dispatch is a global middleware-shaped Handler that consumes updates
|
||||
// and routes them through the conversation graph. Register via
|
||||
// router.Use(conv.Dispatch).
|
||||
//
|
||||
// If the conversation claims an update, downstream handlers are skipped.
|
||||
// If the conversation does not claim it, downstream handlers run as normal.
|
||||
func (c *Conversation) Dispatch(next dispatch.Handler[*api.Update]) dispatch.Handler[*api.Update] {
|
||||
if c.Storage == nil {
|
||||
c.Storage = NewMemoryStorage()
|
||||
}
|
||||
if c.KeyStrategy == nil {
|
||||
c.KeyStrategy = KeyByUserAndChat
|
||||
}
|
||||
return func(dctx *dispatch.Context, u *api.Update) error {
|
||||
key := c.KeyStrategy(u)
|
||||
if key == "" {
|
||||
return next(dctx, u)
|
||||
}
|
||||
|
||||
ctx := dctx.Ctx
|
||||
current, err := c.Storage.Get(ctx, key)
|
||||
if err != nil && !errors.Is(err, ErrKeyNotFound) {
|
||||
return err
|
||||
}
|
||||
active := !errors.Is(err, ErrKeyNotFound)
|
||||
|
||||
// Try exits first (always allowed if conversation is active).
|
||||
if active {
|
||||
for _, step := range c.Exits {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try entry points (only if no active conversation, or AllowReEntry).
|
||||
if !active || c.AllowReEntry {
|
||||
for _, step := range c.EntryPoints {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !active {
|
||||
return next(dctx, u)
|
||||
}
|
||||
|
||||
// Active conversation: try state steps.
|
||||
for _, step := range c.States[current] {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallbacks if no state step matched.
|
||||
for _, step := range c.Fallbacks {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Active conversation but no step matched and no fallback: swallow the
|
||||
// update (do NOT pass to downstream handlers, since the user is
|
||||
// mid-conversation and an unrelated handler would surprise them).
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// runStep invokes the handler and applies its return-value state transition.
|
||||
func (c *Conversation) runStep(ctx context.Context, dctx *dispatch.Context, u *api.Update, key string, h Handler) error {
|
||||
err := h(dctx, u)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var trans *stateTransition
|
||||
if errors.As(err, &trans) {
|
||||
if trans.end {
|
||||
return c.Storage.Delete(ctx, key)
|
||||
}
|
||||
return c.Storage.Set(ctx, key, trans.next)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryStorage is the default in-process Storage. It is safe for
|
||||
// concurrent use. Conversation state is lost on process restart; use
|
||||
// a custom Storage backed by a database for persistent flows.
|
||||
type MemoryStorage struct {
|
||||
mu sync.RWMutex
|
||||
state map[string]State
|
||||
}
|
||||
|
||||
// NewMemoryStorage constructs an empty in-memory storage.
|
||||
func NewMemoryStorage() *MemoryStorage {
|
||||
return &MemoryStorage{state: map[string]State{}}
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) Get(_ context.Context, key string) (State, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.state[key]
|
||||
if !ok {
|
||||
return "", ErrKeyNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) Set(_ context.Context, key string, state State) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.state[key] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) Delete(_ context.Context, key string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.state, key)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryStorage_GetSetDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := NewMemoryStorage()
|
||||
|
||||
// Get on empty key returns ErrKeyNotFound.
|
||||
_, err := s.Get(ctx, "k1")
|
||||
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||
|
||||
// Set then Get returns the stored state.
|
||||
require.NoError(t, s.Set(ctx, "k1", "step_a"))
|
||||
v, err := s.Get(ctx, "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("step_a"), v)
|
||||
|
||||
// Overwrite works.
|
||||
require.NoError(t, s.Set(ctx, "k1", "step_b"))
|
||||
v, err = s.Get(ctx, "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("step_b"), v)
|
||||
|
||||
// Delete removes the key.
|
||||
require.NoError(t, s.Delete(ctx, "k1"))
|
||||
_, err = s.Get(ctx, "k1")
|
||||
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||
|
||||
// Delete of non-existent key is a no-op (no error).
|
||||
require.NoError(t, s.Delete(ctx, "nonexistent"))
|
||||
}
|
||||
|
||||
func TestMemoryStorage_MultipleKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := NewMemoryStorage()
|
||||
|
||||
require.NoError(t, s.Set(ctx, "a", "stateA"))
|
||||
require.NoError(t, s.Set(ctx, "b", "stateB"))
|
||||
|
||||
va, err := s.Get(ctx, "a")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("stateA"), va)
|
||||
|
||||
vb, err := s.Get(ctx, "b")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("stateB"), vb)
|
||||
|
||||
// Delete one key; the other remains.
|
||||
require.NoError(t, s.Delete(ctx, "a"))
|
||||
_, err = s.Get(ctx, "a")
|
||||
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||
|
||||
vb, err = s.Get(ctx, "b")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("stateB"), vb)
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Concurrent(t *testing.T) {
|
||||
// 20 goroutines hammering the same key concurrently — no data race.
|
||||
ctx := context.Background()
|
||||
s := NewMemoryStorage()
|
||||
|
||||
const goroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
key := "shared"
|
||||
_ = s.Set(ctx, key, State("step"))
|
||||
_, _ = s.Get(ctx, key)
|
||||
if i%3 == 0 {
|
||||
_ = s.Delete(ctx, key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
// No assertion needed — race detector catches the bug if present.
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
)
|
||||
|
||||
// KeyStrategy derives a persistence key from an update. Strategies
|
||||
// determine how conversation scope works — per-user, per-chat, or
|
||||
// per-user-and-chat. Implementations must return a stable string for
|
||||
// the same logical scope across updates.
|
||||
//
|
||||
// Returns the empty string if the update doesn't have enough context
|
||||
// to derive a key (in which case the conversation handler skips it).
|
||||
type KeyStrategy func(u *api.Update) string
|
||||
|
||||
// KeyByUser derives a key from the sending user's ID. Useful for DM
|
||||
// conversations and any flow that should follow the user across chats.
|
||||
var KeyByUser KeyStrategy = func(u *api.Update) string {
|
||||
if uid := userID(u); uid != 0 {
|
||||
return fmt.Sprintf("u:%d", uid)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// KeyByChat derives a key from the chat ID. Useful for group flows where
|
||||
// any user in the chat can drive the conversation.
|
||||
var KeyByChat KeyStrategy = func(u *api.Update) string {
|
||||
if cid := chatID(u); cid != 0 {
|
||||
return fmt.Sprintf("c:%d", cid)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// KeyByUserAndChat derives a key from both user and chat IDs. The most
|
||||
// common strategy: each user has their own conversation per chat.
|
||||
var KeyByUserAndChat KeyStrategy = func(u *api.Update) string {
|
||||
uid := userID(u)
|
||||
cid := chatID(u)
|
||||
if uid == 0 || cid == 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("uc:%d:%d", cid, uid)
|
||||
}
|
||||
|
||||
// userID extracts the sending user's ID from any update payload.
|
||||
func userID(u *api.Update) int64 {
|
||||
switch {
|
||||
case u.Message != nil && u.Message.From != nil:
|
||||
return u.Message.From.ID
|
||||
case u.EditedMessage != nil && u.EditedMessage.From != nil:
|
||||
return u.EditedMessage.From.ID
|
||||
case u.CallbackQuery != nil:
|
||||
return u.CallbackQuery.From.ID
|
||||
case u.InlineQuery != nil:
|
||||
return u.InlineQuery.From.ID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// chatID extracts the relevant chat ID.
|
||||
func chatID(u *api.Update) int64 {
|
||||
switch {
|
||||
case u.Message != nil:
|
||||
return u.Message.Chat.ID
|
||||
case u.EditedMessage != nil:
|
||||
return u.EditedMessage.Chat.ID
|
||||
case u.ChannelPost != nil:
|
||||
return u.ChannelPost.Chat.ID
|
||||
case u.EditedChannelPost != nil:
|
||||
return u.EditedChannelPost.Chat.ID
|
||||
case u.CallbackQuery != nil && u.CallbackQuery.Message != nil:
|
||||
if msg, ok := u.CallbackQuery.Message.(*api.Message); ok {
|
||||
return msg.Chat.ID
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// helpers to build api.Update variants.
|
||||
|
||||
func msgUpdate(userID, chatID int64) *api.Update {
|
||||
return &api.Update{
|
||||
Message: &api.Message{
|
||||
From: &api.User{ID: userID},
|
||||
Chat: api.Chat{ID: chatID},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func editedMsgUpdate(userID, chatID int64) *api.Update {
|
||||
return &api.Update{
|
||||
EditedMessage: &api.Message{
|
||||
From: &api.User{ID: userID},
|
||||
Chat: api.Chat{ID: chatID},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func callbackUpdate(userID, chatID int64) *api.Update {
|
||||
return &api.Update{
|
||||
CallbackQuery: &api.CallbackQuery{
|
||||
From: api.User{ID: userID},
|
||||
Message: &api.Message{Chat: api.Chat{ID: chatID}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func inlineUpdate(userID int64) *api.Update {
|
||||
return &api.Update{
|
||||
InlineQuery: &api.InlineQuery{
|
||||
From: api.User{ID: userID},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func emptyUpdate() *api.Update { return &api.Update{} }
|
||||
|
||||
func TestKeyByUser(t *testing.T) {
|
||||
t.Run("message update", func(t *testing.T) {
|
||||
require.Equal(t, "u:42", KeyByUser(msgUpdate(42, 100)))
|
||||
})
|
||||
t.Run("edited message", func(t *testing.T) {
|
||||
require.Equal(t, "u:7", KeyByUser(editedMsgUpdate(7, 100)))
|
||||
})
|
||||
t.Run("callback query", func(t *testing.T) {
|
||||
require.Equal(t, "u:99", KeyByUser(callbackUpdate(99, 100)))
|
||||
})
|
||||
t.Run("inline query", func(t *testing.T) {
|
||||
require.Equal(t, "u:5", KeyByUser(inlineUpdate(5)))
|
||||
})
|
||||
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByUser(emptyUpdate()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByChat(t *testing.T) {
|
||||
t.Run("message update", func(t *testing.T) {
|
||||
require.Equal(t, "c:100", KeyByChat(msgUpdate(42, 100)))
|
||||
})
|
||||
t.Run("edited message", func(t *testing.T) {
|
||||
require.Equal(t, "c:200", KeyByChat(editedMsgUpdate(7, 200)))
|
||||
})
|
||||
t.Run("callback with accessible message", func(t *testing.T) {
|
||||
require.Equal(t, "c:300", KeyByChat(callbackUpdate(99, 300)))
|
||||
})
|
||||
t.Run("inline query has no chat → empty", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByChat(inlineUpdate(5)))
|
||||
})
|
||||
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByChat(emptyUpdate()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByUserAndChat(t *testing.T) {
|
||||
t.Run("message update", func(t *testing.T) {
|
||||
require.Equal(t, "uc:100:42", KeyByUserAndChat(msgUpdate(42, 100)))
|
||||
})
|
||||
t.Run("edited message", func(t *testing.T) {
|
||||
require.Equal(t, "uc:200:7", KeyByUserAndChat(editedMsgUpdate(7, 200)))
|
||||
})
|
||||
t.Run("callback query", func(t *testing.T) {
|
||||
require.Equal(t, "uc:300:99", KeyByUserAndChat(callbackUpdate(99, 300)))
|
||||
})
|
||||
t.Run("inline query has no chat → empty", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByUserAndChat(inlineUpdate(5)))
|
||||
})
|
||||
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByUserAndChat(emptyUpdate()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByUserAndChat_CallbackInaccessibleMessage(t *testing.T) {
|
||||
// CallbackQuery.Message is InaccessibleMessage (not *Message) — chatID returns 0.
|
||||
u := &api.Update{
|
||||
CallbackQuery: &api.CallbackQuery{
|
||||
From: api.User{ID: 10},
|
||||
Message: &api.InaccessibleMessage{}, // implements MaybeInaccessibleMessage, not *api.Message
|
||||
},
|
||||
}
|
||||
// userID picks up From.ID=10 but chatID fails type assertion → 0
|
||||
require.Equal(t, "", KeyByUserAndChat(u), "no key when message inaccessible")
|
||||
// KeyByUser still works since From is set.
|
||||
require.Equal(t, "u:10", KeyByUser(u))
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// Package conversation implements a stateful conversation handler for the
|
||||
// go-telegram dispatch router. It provides a state-machine abstraction over
|
||||
// multi-step Telegram bot interactions, with pluggable storage and flexible
|
||||
// key strategies.
|
||||
package conversation
|
||||
|
||||
// State is a label identifying a node in the conversation graph.
|
||||
// The empty string is the implicit "no active conversation" state.
|
||||
type State string
|
||||
@@ -0,0 +1,20 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrKeyNotFound is returned by Storage.Get when no conversation is active
|
||||
// for the given key.
|
||||
var ErrKeyNotFound = errors.New("conversation: key not found")
|
||||
|
||||
// Storage persists per-user (or per-chat, per-message — depending on the
|
||||
// KeyStrategy in use) conversation state across update deliveries.
|
||||
//
|
||||
// Implementations must be safe for concurrent use.
|
||||
type Storage interface {
|
||||
Get(ctx context.Context, key string) (State, error)
|
||||
Set(ctx context.Context, key string, state State) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package dispatch
|
||||
|
||||
// Filter is a predicate over a typed payload (e.g. *api.Message). Filters
|
||||
// compose via And/Or/Not for multi-condition matching.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// f := message.HasPhoto().And(message.InChat(-100123456789))
|
||||
type Filter[T any] func(payload T) bool
|
||||
|
||||
// And returns a Filter that matches iff f and every one of others matches.
|
||||
func (f Filter[T]) And(others ...Filter[T]) Filter[T] {
|
||||
return func(payload T) bool {
|
||||
if !f(payload) {
|
||||
return false
|
||||
}
|
||||
for _, o := range others {
|
||||
if !o(payload) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Or returns a Filter that matches iff f matches OR any of others matches.
|
||||
func (f Filter[T]) Or(others ...Filter[T]) Filter[T] {
|
||||
return func(payload T) bool {
|
||||
if f(payload) {
|
||||
return true
|
||||
}
|
||||
for _, o := range others {
|
||||
if o(payload) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Not returns a Filter that inverts f.
|
||||
func (f Filter[T]) Not() Filter[T] {
|
||||
return func(payload T) bool { return !f(payload) }
|
||||
}
|
||||
|
||||
// All combines filters with AND. Returns a Filter that matches when all match.
|
||||
// Returns a filter that always matches when filters is empty.
|
||||
func All[T any](filters ...Filter[T]) Filter[T] {
|
||||
return func(payload T) bool {
|
||||
for _, f := range filters {
|
||||
if !f(payload) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Any combines filters with OR. Returns a Filter that matches when at least
|
||||
// one matches. Returns a filter that never matches when filters is empty.
|
||||
func Any[T any](filters ...Filter[T]) Filter[T] {
|
||||
return func(payload T) bool {
|
||||
for _, f := range filters {
|
||||
if f(payload) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func alwaysTrue[T any]() Filter[T] { return func(_ T) bool { return true } }
|
||||
func alwaysFalse[T any]() Filter[T] { return func(_ T) bool { return false } }
|
||||
|
||||
func TestFilter_And(t *testing.T) {
|
||||
t.Run("all true", func(t *testing.T) {
|
||||
f := alwaysTrue[int]().And(alwaysTrue[int](), alwaysTrue[int]())
|
||||
require.True(t, f(0))
|
||||
})
|
||||
t.Run("first false", func(t *testing.T) {
|
||||
f := alwaysFalse[int]().And(alwaysTrue[int]())
|
||||
require.False(t, f(0))
|
||||
})
|
||||
t.Run("other false", func(t *testing.T) {
|
||||
f := alwaysTrue[int]().And(alwaysFalse[int]())
|
||||
require.False(t, f(0))
|
||||
})
|
||||
t.Run("no others — acts as identity", func(t *testing.T) {
|
||||
require.True(t, alwaysTrue[int]().And()(0))
|
||||
require.False(t, alwaysFalse[int]().And()(0))
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilter_Or(t *testing.T) {
|
||||
t.Run("first true", func(t *testing.T) {
|
||||
f := alwaysTrue[int]().Or(alwaysFalse[int]())
|
||||
require.True(t, f(0))
|
||||
})
|
||||
t.Run("other true", func(t *testing.T) {
|
||||
f := alwaysFalse[int]().Or(alwaysTrue[int]())
|
||||
require.True(t, f(0))
|
||||
})
|
||||
t.Run("all false", func(t *testing.T) {
|
||||
f := alwaysFalse[int]().Or(alwaysFalse[int]())
|
||||
require.False(t, f(0))
|
||||
})
|
||||
t.Run("no others", func(t *testing.T) {
|
||||
require.True(t, alwaysTrue[int]().Or()(0))
|
||||
require.False(t, alwaysFalse[int]().Or()(0))
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilter_Not(t *testing.T) {
|
||||
require.False(t, alwaysTrue[int]().Not()(0))
|
||||
require.True(t, alwaysFalse[int]().Not()(0))
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
t.Run("all true", func(t *testing.T) {
|
||||
require.True(t, All(alwaysTrue[int](), alwaysTrue[int]())(0))
|
||||
})
|
||||
t.Run("one false", func(t *testing.T) {
|
||||
require.False(t, All(alwaysTrue[int](), alwaysFalse[int]())(0))
|
||||
})
|
||||
t.Run("empty — always true", func(t *testing.T) {
|
||||
require.True(t, All[int]()(0))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAny(t *testing.T) {
|
||||
t.Run("one true", func(t *testing.T) {
|
||||
require.True(t, Any(alwaysFalse[int](), alwaysTrue[int]())(0))
|
||||
})
|
||||
t.Run("all false", func(t *testing.T) {
|
||||
require.False(t, Any(alwaysFalse[int](), alwaysFalse[int]())(0))
|
||||
})
|
||||
t.Run("empty — always false", func(t *testing.T) {
|
||||
require.False(t, Any[int]()(0))
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilter_Composition(t *testing.T) {
|
||||
// (true AND false) OR true == true
|
||||
f := alwaysTrue[int]().And(alwaysFalse[int]()).Or(alwaysTrue[int]())
|
||||
require.True(t, f(0))
|
||||
|
||||
// NOT (true OR false) == false
|
||||
g := alwaysTrue[int]().Or(alwaysFalse[int]()).Not()
|
||||
require.False(t, g(0))
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Package callback provides Filter helpers for *api.CallbackQuery payloads.
|
||||
package callback
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// Data returns a Filter that matches callback queries whose Data matches
|
||||
// pattern (regex). Panics at registration time on an invalid pattern.
|
||||
func Data(pattern string) dispatch.Filter[*api.CallbackQuery] {
|
||||
re := regexp.MustCompile(pattern)
|
||||
return func(q *api.CallbackQuery) bool {
|
||||
return q != nil && re.MatchString(q.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// DataEquals returns a Filter that matches callback queries whose Data equals
|
||||
// s exactly.
|
||||
func DataEquals(s string) dispatch.Filter[*api.CallbackQuery] {
|
||||
return func(q *api.CallbackQuery) bool {
|
||||
return q != nil && q.Data == s
|
||||
}
|
||||
}
|
||||
|
||||
// DataPrefix returns a Filter that matches callback queries whose Data starts
|
||||
// with prefix.
|
||||
func DataPrefix(prefix string) dispatch.Filter[*api.CallbackQuery] {
|
||||
return func(q *api.CallbackQuery) bool {
|
||||
return q != nil && strings.HasPrefix(q.Data, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// FromUser returns a Filter that matches callback queries whose From.ID equals
|
||||
// userID.
|
||||
func FromUser(userID int64) dispatch.Filter[*api.CallbackQuery] {
|
||||
return func(q *api.CallbackQuery) bool {
|
||||
return q != nil && q.From.ID == userID
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package callback_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
cbfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/callback"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func cq(data string, userID int64) *api.CallbackQuery {
|
||||
return &api.CallbackQuery{
|
||||
ID: "q",
|
||||
From: api.User{ID: userID},
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
f := cbfilter.Data(`^like:\d+$`)
|
||||
require.True(t, f(cq("like:42", 1)))
|
||||
require.False(t, f(cq("dislike:42", 1)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestData_PanicsOnBadPattern(t *testing.T) {
|
||||
require.Panics(t, func() { cbfilter.Data(`[bad`) })
|
||||
}
|
||||
|
||||
func TestDataEquals(t *testing.T) {
|
||||
f := cbfilter.DataEquals("yes")
|
||||
require.True(t, f(cq("yes", 1)))
|
||||
require.False(t, f(cq("yes please", 1)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestDataPrefix(t *testing.T) {
|
||||
f := cbfilter.DataPrefix("vote:")
|
||||
require.True(t, f(cq("vote:up", 1)))
|
||||
require.False(t, f(cq("novote:up", 1)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestFromUser(t *testing.T) {
|
||||
f := cbfilter.FromUser(7)
|
||||
require.True(t, f(cq("data", 7)))
|
||||
require.False(t, f(cq("data", 8)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestComposedCallbackFilters(t *testing.T) {
|
||||
f := cbfilter.DataPrefix("vote:").And(cbfilter.FromUser(7))
|
||||
require.True(t, f(cq("vote:up", 7)))
|
||||
require.False(t, f(cq("vote:up", 8)))
|
||||
require.False(t, f(cq("other", 7)))
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Package chatjoinrequest provides Filter helpers for *api.ChatJoinRequest payloads.
|
||||
package chatjoinrequest
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// FromUser returns a Filter that matches join requests where the requesting
|
||||
// user's ID equals uid.
|
||||
func FromUser(uid int64) dispatch.Filter[*api.ChatJoinRequest] {
|
||||
return func(r *api.ChatJoinRequest) bool {
|
||||
return r != nil && r.From.ID == uid
|
||||
}
|
||||
}
|
||||
|
||||
// InChat returns a Filter that matches join requests directed at the chat
|
||||
// with the given chat ID.
|
||||
func InChat(cid int64) dispatch.Filter[*api.ChatJoinRequest] {
|
||||
return func(r *api.ChatJoinRequest) bool {
|
||||
return r != nil && r.Chat.ID == cid
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package chatjoinrequest_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
cjrfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatjoinrequest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func joinRequest(fromID, chatID int64) *api.ChatJoinRequest {
|
||||
return &api.ChatJoinRequest{
|
||||
Chat: api.Chat{ID: chatID},
|
||||
From: api.User{ID: fromID},
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromUser_Matches(t *testing.T) {
|
||||
f := cjrfilter.FromUser(10)
|
||||
require.True(t, f(joinRequest(10, 100)))
|
||||
require.False(t, f(joinRequest(99, 100)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestInChat_Matches(t *testing.T) {
|
||||
f := cjrfilter.InChat(100)
|
||||
require.True(t, f(joinRequest(10, 100)))
|
||||
require.False(t, f(joinRequest(10, 200)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestComposedFilters(t *testing.T) {
|
||||
f := cjrfilter.FromUser(10).And(cjrfilter.InChat(100))
|
||||
require.True(t, f(joinRequest(10, 100)))
|
||||
require.False(t, f(joinRequest(10, 200)))
|
||||
require.False(t, f(joinRequest(99, 100)))
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Package chatmember provides Filter helpers for *api.ChatMemberUpdated payloads.
|
||||
package chatmember
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// NewStatus returns a Filter that matches updates where the new chat member
|
||||
// status equals s (e.g. "member", "administrator", "kicked", "left").
|
||||
func NewStatus(s string) dispatch.Filter[*api.ChatMemberUpdated] {
|
||||
return func(u *api.ChatMemberUpdated) bool {
|
||||
if u == nil {
|
||||
return false
|
||||
}
|
||||
switch m := u.NewChatMember.(type) {
|
||||
case *api.ChatMemberOwner:
|
||||
return m.Status == s
|
||||
case *api.ChatMemberAdministrator:
|
||||
return m.Status == s
|
||||
case *api.ChatMemberMember:
|
||||
return m.Status == s
|
||||
case *api.ChatMemberRestricted:
|
||||
return m.Status == s
|
||||
case *api.ChatMemberLeft:
|
||||
return m.Status == s
|
||||
case *api.ChatMemberBanned:
|
||||
return m.Status == s
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FromUser returns a Filter that matches updates where the acting user
|
||||
// (From.ID) equals uid.
|
||||
func FromUser(uid int64) dispatch.Filter[*api.ChatMemberUpdated] {
|
||||
return func(u *api.ChatMemberUpdated) bool {
|
||||
return u != nil && u.From.ID == uid
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package chatmember_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
cmfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/chatmember"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func memberUpdate(status string, fromID int64) *api.ChatMemberUpdated {
|
||||
var newMember api.ChatMember
|
||||
switch status {
|
||||
case "member":
|
||||
newMember = &api.ChatMemberMember{Status: status}
|
||||
case "administrator":
|
||||
newMember = &api.ChatMemberAdministrator{Status: status}
|
||||
case "kicked":
|
||||
newMember = &api.ChatMemberBanned{Status: status}
|
||||
case "left":
|
||||
newMember = &api.ChatMemberLeft{Status: status}
|
||||
default:
|
||||
newMember = &api.ChatMemberMember{Status: status}
|
||||
}
|
||||
return &api.ChatMemberUpdated{
|
||||
From: api.User{ID: fromID},
|
||||
NewChatMember: newMember,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStatus_Matches(t *testing.T) {
|
||||
f := cmfilter.NewStatus("member")
|
||||
require.True(t, f(memberUpdate("member", 1)))
|
||||
require.False(t, f(memberUpdate("kicked", 1)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestNewStatus_Administrator(t *testing.T) {
|
||||
f := cmfilter.NewStatus("administrator")
|
||||
require.True(t, f(memberUpdate("administrator", 1)))
|
||||
require.False(t, f(memberUpdate("member", 1)))
|
||||
}
|
||||
|
||||
func TestNewStatus_Kicked(t *testing.T) {
|
||||
f := cmfilter.NewStatus("kicked")
|
||||
require.True(t, f(memberUpdate("kicked", 1)))
|
||||
require.False(t, f(memberUpdate("left", 1)))
|
||||
}
|
||||
|
||||
func TestNewStatus_Left(t *testing.T) {
|
||||
f := cmfilter.NewStatus("left")
|
||||
require.True(t, f(memberUpdate("left", 1)))
|
||||
require.False(t, f(memberUpdate("member", 1)))
|
||||
}
|
||||
|
||||
func TestFromUser_Matches(t *testing.T) {
|
||||
f := cmfilter.FromUser(42)
|
||||
require.True(t, f(memberUpdate("member", 42)))
|
||||
require.False(t, f(memberUpdate("member", 99)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestComposedFilters(t *testing.T) {
|
||||
f := cmfilter.NewStatus("member").And(cmfilter.FromUser(7))
|
||||
require.True(t, f(memberUpdate("member", 7)))
|
||||
require.False(t, f(memberUpdate("member", 8)))
|
||||
require.False(t, f(memberUpdate("kicked", 7)))
|
||||
}
|
||||
|
||||
func TestNewStatus_Owner(t *testing.T) {
|
||||
u := &api.ChatMemberUpdated{
|
||||
From: api.User{ID: 1},
|
||||
NewChatMember: &api.ChatMemberOwner{Status: "creator"},
|
||||
}
|
||||
require.True(t, cmfilter.NewStatus("creator")(u))
|
||||
require.False(t, cmfilter.NewStatus("member")(u))
|
||||
}
|
||||
|
||||
func TestNewStatus_Restricted(t *testing.T) {
|
||||
u := &api.ChatMemberUpdated{
|
||||
From: api.User{ID: 1},
|
||||
NewChatMember: &api.ChatMemberRestricted{Status: "restricted"},
|
||||
}
|
||||
require.True(t, cmfilter.NewStatus("restricted")(u))
|
||||
require.False(t, cmfilter.NewStatus("member")(u))
|
||||
}
|
||||
|
||||
func TestNewStatus_UnknownType(t *testing.T) {
|
||||
// nil NewChatMember → default branch → false
|
||||
u := &api.ChatMemberUpdated{
|
||||
From: api.User{ID: 1},
|
||||
NewChatMember: nil,
|
||||
}
|
||||
require.False(t, cmfilter.NewStatus("member")(u))
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Package inline provides Filter helpers for *api.InlineQuery payloads.
|
||||
package inline
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// Query returns a Filter that matches inline queries whose Query field matches
|
||||
// pattern (regex). Panics at registration time on an invalid pattern.
|
||||
func Query(pattern string) dispatch.Filter[*api.InlineQuery] {
|
||||
re := regexp.MustCompile(pattern)
|
||||
return func(q *api.InlineQuery) bool {
|
||||
return q != nil && re.MatchString(q.Query)
|
||||
}
|
||||
}
|
||||
|
||||
// QueryEquals returns a Filter that matches inline queries whose Query equals
|
||||
// s exactly.
|
||||
func QueryEquals(s string) dispatch.Filter[*api.InlineQuery] {
|
||||
return func(q *api.InlineQuery) bool {
|
||||
return q != nil && q.Query == s
|
||||
}
|
||||
}
|
||||
|
||||
// QueryPrefix returns a Filter that matches inline queries whose Query starts
|
||||
// with prefix.
|
||||
func QueryPrefix(prefix string) dispatch.Filter[*api.InlineQuery] {
|
||||
return func(q *api.InlineQuery) bool {
|
||||
return q != nil && strings.HasPrefix(q.Query, prefix)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package inline_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
ilfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/inline"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func iq(query string) *api.InlineQuery {
|
||||
return &api.InlineQuery{ID: "i", From: api.User{ID: 1}, Query: query}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
f := ilfilter.Query(`^find`)
|
||||
require.True(t, f(iq("find me")))
|
||||
require.False(t, f(iq("search me")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestQuery_PanicsOnBadPattern(t *testing.T) {
|
||||
require.Panics(t, func() { ilfilter.Query(`[bad`) })
|
||||
}
|
||||
|
||||
func TestQueryEquals(t *testing.T) {
|
||||
f := ilfilter.QueryEquals("exact")
|
||||
require.True(t, f(iq("exact")))
|
||||
require.False(t, f(iq("exact match")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestQueryPrefix(t *testing.T) {
|
||||
f := ilfilter.QueryPrefix("@user")
|
||||
require.True(t, f(iq("@username")))
|
||||
require.False(t, f(iq("no prefix")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestComposedInlineFilters(t *testing.T) {
|
||||
f := ilfilter.QueryPrefix("find").Or(ilfilter.QueryEquals("help"))
|
||||
require.True(t, f(iq("find me")))
|
||||
require.True(t, f(iq("help")))
|
||||
require.False(t, f(iq("other")))
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
// Package message provides Filter helpers for *api.Message payloads.
|
||||
package message
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// Text returns a Filter that matches messages whose Text matches pattern (regex).
|
||||
// Panics at registration time on an invalid pattern.
|
||||
func Text(pattern string) dispatch.Filter[*api.Message] {
|
||||
re := regexp.MustCompile(pattern)
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && re.MatchString(m.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TextEquals returns a Filter that matches messages whose Text equals s exactly.
|
||||
func TextEquals(s string) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.Text == s
|
||||
}
|
||||
}
|
||||
|
||||
// TextPrefix returns a Filter that matches messages whose Text starts with prefix.
|
||||
func TextPrefix(prefix string) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && strings.HasPrefix(m.Text, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// TextContains returns a Filter that matches messages whose Text contains sub.
|
||||
func TextContains(sub string) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && strings.Contains(m.Text, sub)
|
||||
}
|
||||
}
|
||||
|
||||
// Command returns a Filter that matches messages whose first entity is a
|
||||
// bot_command equal to "/<name>" (with or without "@BotName" suffix).
|
||||
func Command(name string) dispatch.Filter[*api.Message] {
|
||||
want := "/" + strings.TrimPrefix(name, "/")
|
||||
return func(m *api.Message) bool {
|
||||
if m == nil || len(m.Entities) == 0 || m.Text == "" {
|
||||
return false
|
||||
}
|
||||
first := m.Entities[0]
|
||||
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
|
||||
return false
|
||||
}
|
||||
end := int(first.Length)
|
||||
runes := []rune(m.Text)
|
||||
if end > len(runes) {
|
||||
return false
|
||||
}
|
||||
cmd := string(runes[:end])
|
||||
if i := strings.Index(cmd, "@"); i >= 0 {
|
||||
cmd = cmd[:i]
|
||||
}
|
||||
return cmd == want
|
||||
}
|
||||
}
|
||||
|
||||
// AnyCommand returns a Filter that matches any message starting with a
|
||||
// bot_command entity at offset 0.
|
||||
func AnyCommand() dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
if m == nil || len(m.Entities) == 0 {
|
||||
return false
|
||||
}
|
||||
first := m.Entities[0]
|
||||
return first.Type == string(api.EntityBotCommand) && first.Offset == 0
|
||||
}
|
||||
}
|
||||
|
||||
// IsReply returns a Filter that matches messages that have ReplyToMessage set.
|
||||
func IsReply() dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.ReplyToMessage != nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsForward returns a Filter that matches messages that have ForwardOrigin set.
|
||||
func IsForward() dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.ForwardOrigin != nil
|
||||
}
|
||||
}
|
||||
|
||||
// HasPhoto returns a Filter that matches messages with a Photo attachment.
|
||||
func HasPhoto() dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && len(m.Photo) > 0
|
||||
}
|
||||
}
|
||||
|
||||
// HasDocument returns a Filter that matches messages with a Document attachment.
|
||||
func HasDocument() dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.Document != nil
|
||||
}
|
||||
}
|
||||
|
||||
// HasEntity returns a Filter that matches messages whose Entities contain at
|
||||
// least one entity of type t (e.g. string(api.EntityBotCommand)).
|
||||
func HasEntity(t string) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range m.Entities {
|
||||
if e.Type == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ChatType returns a Filter that matches messages whose Chat.Type equals t.
|
||||
func ChatType(t api.ChatType) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.Chat.Type == string(t)
|
||||
}
|
||||
}
|
||||
|
||||
// FromUser returns a Filter that matches messages whose From.ID equals userID.
|
||||
func FromUser(userID int64) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.From != nil && m.From.ID == userID
|
||||
}
|
||||
}
|
||||
|
||||
// InChat returns a Filter that matches messages whose Chat.ID equals chatID.
|
||||
func InChat(chatID int64) dispatch.Filter[*api.Message] {
|
||||
return func(m *api.Message) bool {
|
||||
return m != nil && m.Chat.ID == chatID
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
msgfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/message"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func msg(text string) *api.Message {
|
||||
return &api.Message{
|
||||
MessageID: 1,
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: text,
|
||||
}
|
||||
}
|
||||
|
||||
func cmdMsg(cmd string) *api.Message {
|
||||
text := cmd
|
||||
return &api.Message{
|
||||
MessageID: 1,
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: text,
|
||||
Entities: []api.MessageEntity{
|
||||
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len([]rune(text)))},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestText(t *testing.T) {
|
||||
f := msgfilter.Text(`^hello`)
|
||||
require.True(t, f(msg("hello world")))
|
||||
require.False(t, f(msg("world hello")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestText_PanicsOnBadPattern(t *testing.T) {
|
||||
require.Panics(t, func() { msgfilter.Text(`[invalid`) })
|
||||
}
|
||||
|
||||
func TestTextEquals(t *testing.T) {
|
||||
f := msgfilter.TextEquals("hi")
|
||||
require.True(t, f(msg("hi")))
|
||||
require.False(t, f(msg("hi there")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestTextPrefix(t *testing.T) {
|
||||
f := msgfilter.TextPrefix("/start")
|
||||
require.True(t, f(msg("/start now")))
|
||||
require.False(t, f(msg("no prefix")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestTextContains(t *testing.T) {
|
||||
f := msgfilter.TextContains("bot")
|
||||
require.True(t, f(msg("my bot is cool")))
|
||||
require.False(t, f(msg("nothing here")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestCommand(t *testing.T) {
|
||||
t.Run("matches exact command", func(t *testing.T) {
|
||||
f := msgfilter.Command("/start")
|
||||
require.True(t, f(cmdMsg("/start")))
|
||||
})
|
||||
t.Run("matches without leading slash", func(t *testing.T) {
|
||||
f := msgfilter.Command("start")
|
||||
require.True(t, f(cmdMsg("/start")))
|
||||
})
|
||||
t.Run("strips BotName suffix", func(t *testing.T) {
|
||||
m := &api.Message{
|
||||
Text: "/start@MyBot",
|
||||
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 12}},
|
||||
}
|
||||
f := msgfilter.Command("/start")
|
||||
require.True(t, f(m))
|
||||
})
|
||||
t.Run("no match different command", func(t *testing.T) {
|
||||
f := msgfilter.Command("/stop")
|
||||
require.False(t, f(cmdMsg("/start")))
|
||||
})
|
||||
t.Run("nil message", func(t *testing.T) {
|
||||
require.False(t, msgfilter.Command("/start")(nil))
|
||||
})
|
||||
t.Run("no entities", func(t *testing.T) {
|
||||
require.False(t, msgfilter.Command("/start")(msg("/start")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnyCommand(t *testing.T) {
|
||||
f := msgfilter.AnyCommand()
|
||||
require.True(t, f(cmdMsg("/anything")))
|
||||
require.False(t, f(msg("plain text")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestIsReply(t *testing.T) {
|
||||
f := msgfilter.IsReply()
|
||||
m := msg("reply")
|
||||
m.ReplyToMessage = &api.Message{MessageID: 2}
|
||||
require.True(t, f(m))
|
||||
require.False(t, f(msg("no reply")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestIsForward(t *testing.T) {
|
||||
// ForwardOrigin is a MessageOrigin interface; set via a concrete type.
|
||||
f := msgfilter.IsForward()
|
||||
m := msg("fwd")
|
||||
m.ForwardOrigin = &api.MessageOriginUser{Type: "user"}
|
||||
require.True(t, f(m))
|
||||
require.False(t, f(msg("no fwd")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestHasPhoto(t *testing.T) {
|
||||
f := msgfilter.HasPhoto()
|
||||
m := msg("")
|
||||
m.Photo = []api.PhotoSize{{FileID: "x", Width: 100, Height: 100}}
|
||||
require.True(t, f(m))
|
||||
require.False(t, f(msg("no photo")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestHasDocument(t *testing.T) {
|
||||
f := msgfilter.HasDocument()
|
||||
m := msg("")
|
||||
m.Document = &api.Document{FileID: "doc1"}
|
||||
require.True(t, f(m))
|
||||
require.False(t, f(msg("no doc")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestHasEntity(t *testing.T) {
|
||||
f := msgfilter.HasEntity(string(api.EntityURL))
|
||||
m := msg("check https://example.com")
|
||||
m.Entities = []api.MessageEntity{{Type: string(api.EntityURL), Offset: 6, Length: 19}}
|
||||
require.True(t, f(m))
|
||||
require.False(t, f(msg("plain")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestChatType(t *testing.T) {
|
||||
f := msgfilter.ChatType(api.ChatTypePrivate)
|
||||
private := msg("hi")
|
||||
require.True(t, f(private))
|
||||
|
||||
group := msg("hi")
|
||||
group.Chat.Type = string(api.ChatTypeGroup)
|
||||
require.False(t, f(group))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestFromUser(t *testing.T) {
|
||||
f := msgfilter.FromUser(42)
|
||||
m := msg("hi")
|
||||
m.From = &api.User{ID: 42}
|
||||
require.True(t, f(m))
|
||||
|
||||
m2 := msg("hi")
|
||||
m2.From = &api.User{ID: 99}
|
||||
require.False(t, f(m2))
|
||||
|
||||
require.False(t, f(msg("no from")))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestInChat(t *testing.T) {
|
||||
f := msgfilter.InChat(1)
|
||||
require.True(t, f(msg("hi")))
|
||||
m2 := msg("hi")
|
||||
m2.Chat.ID = 2
|
||||
require.False(t, f(m2))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestComposedMessageFilters(t *testing.T) {
|
||||
// private chat AND contains "hello"
|
||||
f := msgfilter.ChatType(api.ChatTypePrivate).And(msgfilter.TextContains("hello"))
|
||||
m := msg("say hello")
|
||||
require.True(t, f(m))
|
||||
|
||||
m2 := msg("say hello")
|
||||
m2.Chat.Type = string(api.ChatTypeGroup)
|
||||
require.False(t, f(m2))
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Package precheckoutquery provides Filter helpers for *api.PreCheckoutQuery payloads.
|
||||
package precheckoutquery
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// Currency returns a Filter that matches pre-checkout queries with the given
|
||||
// ISO 4217 currency code (e.g. "USD", "EUR", "XTR").
|
||||
func Currency(c string) dispatch.Filter[*api.PreCheckoutQuery] {
|
||||
return func(q *api.PreCheckoutQuery) bool {
|
||||
return q != nil && q.Currency == c
|
||||
}
|
||||
}
|
||||
|
||||
// FromUser returns a Filter that matches pre-checkout queries sent by the
|
||||
// user with the given ID.
|
||||
func FromUser(uid int64) dispatch.Filter[*api.PreCheckoutQuery] {
|
||||
return func(q *api.PreCheckoutQuery) bool {
|
||||
return q != nil && q.From.ID == uid
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package precheckoutquery_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
pcqfilter "github.com/lukaszraczylo/go-telegram/dispatch/filters/precheckoutquery"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func pcq(currency string, fromID int64) *api.PreCheckoutQuery {
|
||||
return &api.PreCheckoutQuery{
|
||||
ID: "q",
|
||||
Currency: currency,
|
||||
From: api.User{ID: fromID},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCurrency_Matches(t *testing.T) {
|
||||
f := pcqfilter.Currency("USD")
|
||||
require.True(t, f(pcq("USD", 1)))
|
||||
require.False(t, f(pcq("EUR", 1)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestFromUser_Matches(t *testing.T) {
|
||||
f := pcqfilter.FromUser(5)
|
||||
require.True(t, f(pcq("USD", 5)))
|
||||
require.False(t, f(pcq("USD", 9)))
|
||||
require.False(t, f(nil))
|
||||
}
|
||||
|
||||
func TestComposedFilters(t *testing.T) {
|
||||
f := pcqfilter.Currency("XTR").And(pcqfilter.FromUser(42))
|
||||
require.True(t, f(pcq("XTR", 42)))
|
||||
require.False(t, f(pcq("XTR", 99)))
|
||||
require.False(t, f(pcq("USD", 42)))
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"sort"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
)
|
||||
|
||||
// ErrEndGroups stops dispatch from running any further handlers in any
|
||||
// group for this update when returned by a handler. Use it to indicate
|
||||
// the update has been definitively handled.
|
||||
//
|
||||
// errors.Is(err, ErrEndGroups) is the canonical check, though dispatch
|
||||
// itself recognises it by exact identity.
|
||||
var ErrEndGroups = errors.New("dispatch: end groups")
|
||||
|
||||
// ErrContinueGroups signals that this group's handler should be treated
|
||||
// as not-matching when returned by a handler: dispatch moves on to the
|
||||
// next handler in the same group, then to subsequent groups.
|
||||
//
|
||||
// Without ErrContinueGroups, a non-error return from a matched handler
|
||||
// stops dispatch (default first-match-wins semantics).
|
||||
var ErrContinueGroups = errors.New("dispatch: continue groups")
|
||||
|
||||
// RouterScope registers handlers into a specific priority group on its parent
|
||||
// Router. Group 0 runs first, then group 1, etc. Within a group, handlers run
|
||||
// in registration order; the first non-skipped match terminates dispatch
|
||||
// unless the handler returns ErrContinueGroups.
|
||||
type RouterScope struct {
|
||||
router *Router
|
||||
group int
|
||||
}
|
||||
|
||||
// Group returns a RouterScope that registers handlers in the given group.
|
||||
// Group 0 (the default) runs first, then group 1, etc. Within a group,
|
||||
// handlers run in registration order; the first non-skipped match
|
||||
// terminates dispatch unless the handler returns ErrContinueGroups.
|
||||
func (r *Router) Group(group int) *RouterScope {
|
||||
return &RouterScope{router: r, group: group}
|
||||
}
|
||||
|
||||
// OnCommand registers a command handler in this group.
|
||||
func (s *RouterScope) OnCommand(cmd string, h Handler[*api.Message]) {
|
||||
s.router.groupCommands = append(s.router.groupCommands, groupCommandRoute{
|
||||
cmd: cmd, group: s.group, handler: h,
|
||||
})
|
||||
}
|
||||
|
||||
// OnText registers a regex text handler in this group.
|
||||
// Panics at registration time if pattern is not a valid regular expression.
|
||||
func (s *RouterScope) OnText(pattern string, h Handler[*api.Message]) {
|
||||
s.router.groupTexts = append(s.router.groupTexts, groupTextRoute{
|
||||
re: regexp.MustCompile(pattern), group: s.group, handler: h,
|
||||
})
|
||||
}
|
||||
|
||||
// OnMessageFilter registers a filter-based message handler in this group.
|
||||
func (s *RouterScope) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
|
||||
s.router.groupMessageFilters = append(s.router.groupMessageFilters, groupMessageFilterRoute{
|
||||
filter: f, group: s.group, handler: h,
|
||||
})
|
||||
}
|
||||
|
||||
// group-aware route types
|
||||
|
||||
type groupCommandRoute struct {
|
||||
cmd string
|
||||
group int
|
||||
handler Handler[*api.Message]
|
||||
}
|
||||
|
||||
type groupTextRoute struct {
|
||||
re *regexp.Regexp
|
||||
group int
|
||||
handler Handler[*api.Message]
|
||||
}
|
||||
|
||||
type groupMessageFilterRoute struct {
|
||||
filter Filter[*api.Message]
|
||||
group int
|
||||
handler Handler[*api.Message]
|
||||
}
|
||||
|
||||
// dispatchGroups runs message handlers registered via RouterScope.Group().
|
||||
// It collects all matching groups, sorts by group number, and applies
|
||||
// first-match-wins semantics within each group. Handlers may return
|
||||
// ErrContinueGroups (skip to next handler/group) or ErrEndGroups (stop all groups).
|
||||
// A non-sentinel error stops dispatch and is returned to the caller.
|
||||
func (r *Router) dispatchGroups(c *Context, m *api.Message) error {
|
||||
// Collect group numbers present.
|
||||
groupSet := map[int]struct{}{}
|
||||
for _, gr := range r.groupCommands {
|
||||
groupSet[gr.group] = struct{}{}
|
||||
}
|
||||
for _, gr := range r.groupTexts {
|
||||
groupSet[gr.group] = struct{}{}
|
||||
}
|
||||
for _, gr := range r.groupMessageFilters {
|
||||
groupSet[gr.group] = struct{}{}
|
||||
}
|
||||
if len(groupSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groups := make([]int, 0, len(groupSet))
|
||||
for g := range groupSet {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
sort.Ints(groups)
|
||||
|
||||
for _, g := range groups {
|
||||
matched, err := r.runGroupHandlers(c, m, g)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrEndGroups) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if matched {
|
||||
// First-match-wins: stop further groups.
|
||||
return nil
|
||||
}
|
||||
// No match or ErrContinueGroups from all handlers: try next group.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runGroupHandlers runs all handlers in group g against m, in registration
|
||||
// order. Returns (true, nil) when a handler matched (returned nil). Returns
|
||||
// (false, nil) when all handlers returned ErrContinueGroups. Returns
|
||||
// (false, err) for ErrEndGroups or any non-sentinel error.
|
||||
func (r *Router) runGroupHandlers(c *Context, m *api.Message, g int) (matched bool, err error) {
|
||||
// Commands.
|
||||
if cmd, args, ok := extractCommand(m); ok {
|
||||
for _, route := range r.groupCommands {
|
||||
if route.group != g || route.cmd != cmd {
|
||||
continue
|
||||
}
|
||||
c.Values["command"] = cmd
|
||||
c.Values["command_args"] = args
|
||||
if err := route.handler(c, m); err != nil {
|
||||
if errors.Is(err, ErrContinueGroups) {
|
||||
continue
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
// Text regex.
|
||||
if m.Text != "" {
|
||||
for _, route := range r.groupTexts {
|
||||
if route.group != g {
|
||||
continue
|
||||
}
|
||||
subs := route.re.FindStringSubmatch(m.Text)
|
||||
if subs == nil {
|
||||
continue
|
||||
}
|
||||
c.Values["regex_match"] = subs
|
||||
if err := route.handler(c, m); err != nil {
|
||||
if errors.Is(err, ErrContinueGroups) {
|
||||
continue
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
// Filter-based.
|
||||
for _, route := range r.groupMessageFilters {
|
||||
if route.group != g || !route.filter(m) {
|
||||
continue
|
||||
}
|
||||
if err := route.handler(c, m); err != nil {
|
||||
if errors.Is(err, ErrContinueGroups) {
|
||||
continue
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// msgUpdate builds a simple private message update.
|
||||
func msgUpdate(id int64, text string) api.Update {
|
||||
return api.Update{
|
||||
UpdateID: id,
|
||||
Message: &api.Message{
|
||||
MessageID: id,
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: text,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// cmdUpdate builds a command message update.
|
||||
func cmdUpdate(id int64, cmd string) api.Update {
|
||||
return api.Update{
|
||||
UpdateID: id,
|
||||
Message: &api.Message{
|
||||
MessageID: id,
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: cmd,
|
||||
Entities: []api.MessageEntity{
|
||||
{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(len(cmd))},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// runSingle fires one update through the router and waits for it to complete.
|
||||
func runSingle(t *testing.T, r *Router, up api.Update) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(up))
|
||||
}
|
||||
|
||||
// TestGroup_Order verifies group 0 fires before group 1.
|
||||
func TestGroup_Order(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var order []int
|
||||
|
||||
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
order = append(order, 0)
|
||||
return ErrContinueGroups // let group 1 also run
|
||||
})
|
||||
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
order = append(order, 1)
|
||||
return nil
|
||||
})
|
||||
|
||||
runSingle(t, r, msgUpdate(1, "hello"))
|
||||
require.Equal(t, []int{0, 1}, order)
|
||||
}
|
||||
|
||||
// TestGroup_FirstMatchWins verifies group 0 match stops group 1 by default.
|
||||
func TestGroup_FirstMatchWins(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var fired []int
|
||||
|
||||
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
fired = append(fired, 0)
|
||||
return nil // matched — group 1 must NOT run
|
||||
})
|
||||
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
fired = append(fired, 1)
|
||||
return nil
|
||||
})
|
||||
|
||||
runSingle(t, r, msgUpdate(1, "hello"))
|
||||
require.Equal(t, []int{0}, fired)
|
||||
}
|
||||
|
||||
// TestGroup_ErrContinueGroups lets group 1 run when group 0 returns ErrContinueGroups.
|
||||
func TestGroup_ErrContinueGroups(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
g1Hit := make(chan struct{}, 1)
|
||||
|
||||
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
return ErrContinueGroups
|
||||
})
|
||||
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
g1Hit <- struct{}{}
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ping"))) }()
|
||||
|
||||
select {
|
||||
case <-g1Hit:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("group 1 handler never fired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroup_ErrEndGroups stops all further groups.
|
||||
func TestGroup_ErrEndGroups(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var fired []int
|
||||
|
||||
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
fired = append(fired, 0)
|
||||
return ErrEndGroups
|
||||
})
|
||||
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
fired = append(fired, 1)
|
||||
return nil
|
||||
})
|
||||
|
||||
runSingle(t, r, msgUpdate(1, "hello"))
|
||||
require.Equal(t, []int{0}, fired)
|
||||
}
|
||||
|
||||
// TestGroup_NonSentinelError propagates error and stops further groups.
|
||||
func TestGroup_NonSentinelError(t *testing.T) {
|
||||
r := New(client.New("t"), WithMaxConcurrency(0))
|
||||
var fired []int
|
||||
|
||||
r.Group(0).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
fired = append(fired, 0)
|
||||
return context.DeadlineExceeded // non-sentinel real error
|
||||
})
|
||||
r.Group(1).OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
fired = append(fired, 1)
|
||||
return nil
|
||||
})
|
||||
|
||||
runSingle(t, r, msgUpdate(1, "hello"))
|
||||
// group 1 must not fire
|
||||
require.Equal(t, []int{0}, fired)
|
||||
}
|
||||
|
||||
// TestGroup_Command verifies OnCommand in a group works.
|
||||
func TestGroup_Command(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
|
||||
r.Group(0).OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||
hit <- "g0-start"
|
||||
return nil
|
||||
})
|
||||
r.Group(1).OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||
hit <- "g1-start"
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/start"))) }()
|
||||
|
||||
got := <-hit
|
||||
require.Equal(t, "g0-start", got)
|
||||
}
|
||||
|
||||
// TestGroup_MessageFilter verifies OnMessageFilter in a group works.
|
||||
func TestGroup_MessageFilter(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan bool, 1)
|
||||
|
||||
r.Group(0).OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ok" }),
|
||||
func(c *Context, m *api.Message) error {
|
||||
hit <- true
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(msgUpdate(1, "ok"))) }()
|
||||
|
||||
require.True(t, <-hit)
|
||||
}
|
||||
|
||||
// TestGroup_ErrContinueGroups_WithCommand verifies ErrContinueGroups works for commands across groups.
|
||||
func TestGroup_ErrContinueGroups_WithCommand(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var count atomic.Int32
|
||||
|
||||
r.Group(0).OnCommand("/ping", func(c *Context, m *api.Message) error {
|
||||
count.Add(1)
|
||||
return ErrContinueGroups
|
||||
})
|
||||
r.Group(1).OnCommand("/ping", func(c *Context, m *api.Message) error {
|
||||
count.Add(10)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(cmdUpdate(1, "/ping"))) }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
require.Equal(t, int32(11), count.Load())
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package dispatch
|
||||
|
||||
// Handler is a generic handler over update payload type T. T is typically
|
||||
// *api.Message, *api.CallbackQuery, *api.InlineQuery, or *api.Update for
|
||||
// global middleware.
|
||||
type Handler[T any] func(ctx *Context, payload T) error
|
||||
|
||||
// Middleware wraps a Handler[T] with cross-cutting behaviour (logging,
|
||||
// recovery, auth). Middleware composition is left-to-right: Use(a,b,c)
|
||||
// runs as a(b(c(handler))).
|
||||
type Middleware[T any] func(Handler[T]) Handler[T]
|
||||
|
||||
// Chain composes a slice of middleware into a single Middleware[T].
|
||||
func Chain[T any](mws ...Middleware[T]) Middleware[T] {
|
||||
return func(h Handler[T]) Handler[T] {
|
||||
for i := len(mws) - 1; i >= 0; i-- {
|
||||
h = mws[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
)
|
||||
|
||||
// Recovery returns middleware that recovers from panics in downstream
|
||||
// handlers, converting them into a returned error and logging via the
|
||||
// bot's configured logger. Registered automatically by NewRouter.
|
||||
func Recovery() Middleware[*api.Update] {
|
||||
return func(next Handler[*api.Update]) Handler[*api.Update] {
|
||||
return func(c *Context, u *api.Update) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in handler: %v\n%s", r, debug.Stack())
|
||||
if c.Bot != nil {
|
||||
c.Bot.Logger().Error("dispatch recovered panic", "err", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return next(c, u)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NamedHandlers manages handlers by string name, allowing runtime
|
||||
// registration, replacement, and removal. This complements the Router's
|
||||
// registration methods: each registration via Named*() also gets a name
|
||||
// for later lookup.
|
||||
//
|
||||
// Use case: a plugin system that loads/unloads command handlers without
|
||||
// restarting the bot.
|
||||
type NamedHandlers[T any] struct {
|
||||
mu sync.RWMutex
|
||||
handlers map[string]Handler[T]
|
||||
order []string // preserves registration order
|
||||
}
|
||||
|
||||
// NewNamedHandlers returns a new, empty NamedHandlers[T].
|
||||
func NewNamedHandlers[T any]() *NamedHandlers[T] {
|
||||
return &NamedHandlers[T]{handlers: map[string]Handler[T]{}}
|
||||
}
|
||||
|
||||
// Set registers or replaces the handler under name. If name is new, it is
|
||||
// appended to the end of the registration order.
|
||||
func (n *NamedHandlers[T]) Set(name string, h Handler[T]) {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if _, exists := n.handlers[name]; !exists {
|
||||
n.order = append(n.order, name)
|
||||
}
|
||||
n.handlers[name] = h
|
||||
}
|
||||
|
||||
// Remove unregisters the handler under name. Returns true if it existed.
|
||||
func (n *NamedHandlers[T]) Remove(name string) bool {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if _, ok := n.handlers[name]; !ok {
|
||||
return false
|
||||
}
|
||||
delete(n.handlers, name)
|
||||
for i, k := range n.order {
|
||||
if k == name {
|
||||
n.order = append(n.order[:i], n.order[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Has reports whether name is registered.
|
||||
func (n *NamedHandlers[T]) Has(name string) bool {
|
||||
n.mu.RLock()
|
||||
defer n.mu.RUnlock()
|
||||
_, ok := n.handlers[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Names returns the registered names in registration order.
|
||||
func (n *NamedHandlers[T]) Names() []string {
|
||||
n.mu.RLock()
|
||||
defer n.mu.RUnlock()
|
||||
out := make([]string, len(n.order))
|
||||
copy(out, n.order)
|
||||
return out
|
||||
}
|
||||
|
||||
// Handler returns a single Handler[T] that runs each registered handler
|
||||
// in registration order, first non-nil error stops the chain. Use this
|
||||
// to wire NamedHandlers into a Router.OnXxx call:
|
||||
//
|
||||
// names := dispatch.NewNamedHandlers[*api.Message]()
|
||||
// names.Set("logger", loggingHandler)
|
||||
// names.Set("audit", auditHandler)
|
||||
// router.OnCommand("/admin", names.Handler())
|
||||
//
|
||||
// Subsequent Set/Remove calls take effect on the next dispatch.
|
||||
func (n *NamedHandlers[T]) Handler() Handler[T] {
|
||||
return func(c *Context, payload T) error {
|
||||
n.mu.RLock()
|
||||
names := make([]string, len(n.order))
|
||||
copy(names, n.order)
|
||||
n.mu.RUnlock()
|
||||
|
||||
for _, name := range names {
|
||||
n.mu.RLock()
|
||||
h, ok := n.handlers[name]
|
||||
n.mu.RUnlock()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := h(c, payload); err != nil {
|
||||
return fmt.Errorf("named handler %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// makeMsg returns a minimal *api.Message for use in handler tests.
|
||||
func makeMsg() *api.Message {
|
||||
return &api.Message{MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}}
|
||||
}
|
||||
|
||||
// makeCtx returns a minimal *Context (nil bot is fine for unit tests).
|
||||
func makeCtx() *Context {
|
||||
return NewContext(context.Background(), nil, &api.Update{})
|
||||
}
|
||||
|
||||
func TestNamedHandlers_SetAndHas(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
require.False(t, n.Has("a"))
|
||||
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||
require.True(t, n.Has("a"))
|
||||
}
|
||||
|
||||
func TestNamedHandlers_Names_RegistrationOrder(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
n.Set("first", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Set("second", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Set("third", func(c *Context, m *api.Message) error { return nil })
|
||||
require.Equal(t, []string{"first", "second", "third"}, n.Names())
|
||||
}
|
||||
|
||||
func TestNamedHandlers_Remove(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Set("b", func(c *Context, m *api.Message) error { return nil })
|
||||
|
||||
removed := n.Remove("a")
|
||||
require.True(t, removed)
|
||||
require.False(t, n.Has("a"))
|
||||
require.Equal(t, []string{"b"}, n.Names())
|
||||
|
||||
// Remove non-existent returns false.
|
||||
require.False(t, n.Remove("nonexistent"))
|
||||
}
|
||||
|
||||
func TestNamedHandlers_Replacement_SameOrderSlot(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Set("b", func(c *Context, m *api.Message) error { return nil })
|
||||
|
||||
var called string
|
||||
n.Set("a", func(c *Context, m *api.Message) error {
|
||||
called = "replaced-a"
|
||||
return nil
|
||||
})
|
||||
|
||||
// Order must not change; "a" stays first.
|
||||
require.Equal(t, []string{"a", "b"}, n.Names())
|
||||
|
||||
h := n.Handler()
|
||||
_ = h(makeCtx(), makeMsg())
|
||||
require.Equal(t, "replaced-a", called)
|
||||
}
|
||||
|
||||
func TestNamedHandlers_Handler_RunsInOrder(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
var calls []string
|
||||
|
||||
n.Set("first", func(c *Context, m *api.Message) error {
|
||||
calls = append(calls, "first")
|
||||
return nil
|
||||
})
|
||||
n.Set("second", func(c *Context, m *api.Message) error {
|
||||
calls = append(calls, "second")
|
||||
return nil
|
||||
})
|
||||
|
||||
h := n.Handler()
|
||||
require.NoError(t, h(makeCtx(), makeMsg()))
|
||||
require.Equal(t, []string{"first", "second"}, calls)
|
||||
}
|
||||
|
||||
func TestNamedHandlers_Handler_ErrorWrappedAndStops(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
sentinel := errors.New("boom")
|
||||
|
||||
n.Set("ok", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Set("fail", func(c *Context, m *api.Message) error { return sentinel })
|
||||
n.Set("never", func(c *Context, m *api.Message) error {
|
||||
t.Fatal("should not be called after an error")
|
||||
return nil
|
||||
})
|
||||
|
||||
h := n.Handler()
|
||||
err := h(makeCtx(), makeMsg())
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, sentinel))
|
||||
require.Contains(t, err.Error(), `named handler "fail"`)
|
||||
}
|
||||
|
||||
func TestNamedHandlers_Concurrent_SetRemove(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
|
||||
// Pre-populate so Handler() has something to iterate.
|
||||
for i := range 5 {
|
||||
name := fmt.Sprintf("h%d", i)
|
||||
n.Set(name, func(c *Context, m *api.Message) error { return nil })
|
||||
}
|
||||
|
||||
h := n.Handler()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent readers (invoke handler).
|
||||
for range 20 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = h(makeCtx(), makeMsg())
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent writers.
|
||||
for i := range 5 {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
name := fmt.Sprintf("new%d", i)
|
||||
n.Set(name, func(c *Context, m *api.Message) error { return nil })
|
||||
n.Remove(fmt.Sprintf("h%d", i))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestNamedHandlers_RemoveAndReinstate(t *testing.T) {
|
||||
n := NewNamedHandlers[*api.Message]()
|
||||
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Remove("a")
|
||||
require.False(t, n.Has("a"))
|
||||
|
||||
// Re-register after removal; should be added at end.
|
||||
n.Set("b", func(c *Context, m *api.Message) error { return nil })
|
||||
n.Set("a", func(c *Context, m *api.Message) error { return nil })
|
||||
require.Equal(t, []string{"b", "a"}, n.Names())
|
||||
}
|
||||
@@ -0,0 +1,582 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/lukaszraczylo/go-telegram/transport"
|
||||
)
|
||||
|
||||
// Router dispatches updates from any Updater to typed handlers.
|
||||
//
|
||||
// Matchers run in registration order; first match wins. A panic-recovery
|
||||
// middleware is attached automatically and runs around every dispatch.
|
||||
type Router struct {
|
||||
bot *client.Bot
|
||||
|
||||
commands []commandRoute
|
||||
texts []textRoute
|
||||
callbacks []callbackRoute
|
||||
inlines []Handler[*api.InlineQuery]
|
||||
editedMsg []Handler[*api.Message]
|
||||
channelPosts []Handler[*api.Message]
|
||||
editedChannelPosts []Handler[*api.Message]
|
||||
|
||||
messageFilters []messageFilterRoute
|
||||
callbackFilters []callbackFilterRoute
|
||||
inlineFilters []inlineFilterRoute
|
||||
|
||||
// typed update handlers
|
||||
myChatMember []Handler[*api.ChatMemberUpdated]
|
||||
chatMember []Handler[*api.ChatMemberUpdated]
|
||||
chatJoinRequest []Handler[*api.ChatJoinRequest]
|
||||
preCheckoutQuery []Handler[*api.PreCheckoutQuery]
|
||||
shippingQuery []Handler[*api.ShippingQuery]
|
||||
poll []Handler[*api.Poll]
|
||||
pollAnswer []Handler[*api.PollAnswer]
|
||||
chosenInlineResult []Handler[*api.ChosenInlineResult]
|
||||
messageReaction []Handler[*api.MessageReactionUpdated]
|
||||
messageReactionCnt []Handler[*api.MessageReactionCountUpdated]
|
||||
chatBoost []Handler[*api.ChatBoostUpdated]
|
||||
removedChatBoost []Handler[*api.ChatBoostRemoved]
|
||||
businessConn []Handler[*api.BusinessConnection]
|
||||
purchasedPaidMedia []Handler[*api.PaidMediaPurchased]
|
||||
|
||||
myChatMemberFilters []chatMemberFilterRoute
|
||||
chatMemberFilters []chatMemberFilterRoute
|
||||
chatJoinRequestFilters []chatJoinRequestFilterRoute
|
||||
preCheckoutFilters []preCheckoutFilterRoute
|
||||
|
||||
// group-priority routes (registered via Router.Group())
|
||||
groupCommands []groupCommandRoute
|
||||
groupTexts []groupTextRoute
|
||||
groupMessageFilters []groupMessageFilterRoute
|
||||
|
||||
globalMW []Middleware[*api.Update]
|
||||
|
||||
maxConcurrency int // default 50; 0 = serial (legacy)
|
||||
sem chan struct{}
|
||||
}
|
||||
|
||||
type messageFilterRoute struct {
|
||||
filter Filter[*api.Message]
|
||||
handler Handler[*api.Message]
|
||||
}
|
||||
|
||||
type callbackFilterRoute struct {
|
||||
filter Filter[*api.CallbackQuery]
|
||||
handler Handler[*api.CallbackQuery]
|
||||
}
|
||||
|
||||
type inlineFilterRoute struct {
|
||||
filter Filter[*api.InlineQuery]
|
||||
handler Handler[*api.InlineQuery]
|
||||
}
|
||||
|
||||
type chatMemberFilterRoute struct {
|
||||
filter Filter[*api.ChatMemberUpdated]
|
||||
handler Handler[*api.ChatMemberUpdated]
|
||||
}
|
||||
|
||||
type chatJoinRequestFilterRoute struct {
|
||||
filter Filter[*api.ChatJoinRequest]
|
||||
handler Handler[*api.ChatJoinRequest]
|
||||
}
|
||||
|
||||
type preCheckoutFilterRoute struct {
|
||||
filter Filter[*api.PreCheckoutQuery]
|
||||
handler Handler[*api.PreCheckoutQuery]
|
||||
}
|
||||
|
||||
// RouterOption configures a Router at construction time.
|
||||
type RouterOption func(*Router)
|
||||
|
||||
// WithMaxConcurrency sets the maximum number of updates processed in parallel.
|
||||
// Default is 50. Pass 0 to dispatch serially (one update at a time, in the
|
||||
// calling goroutine — the legacy behaviour before v1.1.0).
|
||||
//
|
||||
// Note: concurrent dispatch means handlers for different updates may run
|
||||
// simultaneously. Handlers that mutate shared state must be safe for concurrent
|
||||
// access.
|
||||
func WithMaxConcurrency(n int) RouterOption {
|
||||
return func(r *Router) { r.maxConcurrency = n }
|
||||
}
|
||||
|
||||
type commandRoute struct {
|
||||
cmd string
|
||||
handler Handler[*api.Message]
|
||||
}
|
||||
|
||||
type textRoute struct {
|
||||
re *regexp.Regexp
|
||||
handler Handler[*api.Message]
|
||||
}
|
||||
|
||||
type callbackRoute struct {
|
||||
re *regexp.Regexp
|
||||
handler Handler[*api.CallbackQuery]
|
||||
}
|
||||
|
||||
// New constructs a Router. Recovery middleware is added by default; users
|
||||
// can disable it by passing WithoutRecovery (not implemented here, but
|
||||
// the hook is in place via Use).
|
||||
func New(b *client.Bot, opts ...RouterOption) *Router {
|
||||
r := &Router{bot: b, maxConcurrency: 50}
|
||||
for _, o := range opts {
|
||||
o(r)
|
||||
}
|
||||
if r.maxConcurrency > 0 {
|
||||
r.sem = make(chan struct{}, r.maxConcurrency)
|
||||
}
|
||||
r.Use(Recovery())
|
||||
return r
|
||||
}
|
||||
|
||||
// Use registers a global middleware applied to every Update dispatch.
|
||||
func (r *Router) Use(mw Middleware[*api.Update]) { r.globalMW = append(r.globalMW, mw) }
|
||||
|
||||
// OnCommand registers a handler for a slash command. The command string
|
||||
// includes the leading slash (e.g. "/start"). Matching strips an optional
|
||||
// "@BotName" suffix.
|
||||
func (r *Router) OnCommand(cmd string, h Handler[*api.Message]) {
|
||||
r.commands = append(r.commands, commandRoute{cmd: cmd, handler: h})
|
||||
}
|
||||
|
||||
// OnText registers a handler for messages whose Text matches the regex.
|
||||
//
|
||||
// Panics at registration time if pattern is not a valid regular expression.
|
||||
func (r *Router) OnText(pattern string, h Handler[*api.Message]) {
|
||||
r.texts = append(r.texts, textRoute{re: regexp.MustCompile(pattern), handler: h})
|
||||
}
|
||||
|
||||
// OnCallback registers a handler for callback queries whose Data matches
|
||||
// the regex.
|
||||
//
|
||||
// Panics at registration time if pattern is not a valid regular expression.
|
||||
func (r *Router) OnCallback(pattern string, h Handler[*api.CallbackQuery]) {
|
||||
r.callbacks = append(r.callbacks, callbackRoute{re: regexp.MustCompile(pattern), handler: h})
|
||||
}
|
||||
|
||||
// OnInlineQuery registers a handler for inline queries (one matcher only;
|
||||
// inline queries are not partitioned by content here).
|
||||
func (r *Router) OnInlineQuery(h Handler[*api.InlineQuery]) {
|
||||
r.inlines = append(r.inlines, h)
|
||||
}
|
||||
|
||||
// OnEditedMessage registers a handler for edited message updates.
|
||||
func (r *Router) OnEditedMessage(h Handler[*api.Message]) {
|
||||
r.editedMsg = append(r.editedMsg, h)
|
||||
}
|
||||
|
||||
// OnChannelPost registers a handler for channel post updates.
|
||||
func (r *Router) OnChannelPost(h Handler[*api.Message]) {
|
||||
r.channelPosts = append(r.channelPosts, h)
|
||||
}
|
||||
|
||||
// OnEditedChannelPost registers a handler for edited channel post updates.
|
||||
func (r *Router) OnEditedChannelPost(h Handler[*api.Message]) {
|
||||
r.editedChannelPosts = append(r.editedChannelPosts, h)
|
||||
}
|
||||
|
||||
// OnMessageFilter registers a typed message handler gated by filter f.
|
||||
// Filter routes are checked after command and text routes; first match wins.
|
||||
func (r *Router) OnMessageFilter(f Filter[*api.Message], h Handler[*api.Message]) {
|
||||
r.messageFilters = append(r.messageFilters, messageFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnCallbackFilter registers a typed callback-query handler gated by filter f.
|
||||
// Filter routes are checked after pattern-based OnCallback routes; first match wins.
|
||||
func (r *Router) OnCallbackFilter(f Filter[*api.CallbackQuery], h Handler[*api.CallbackQuery]) {
|
||||
r.callbackFilters = append(r.callbackFilters, callbackFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnInlineQueryFilter registers an inline-query handler gated by filter f.
|
||||
// Filter routes are checked after bare OnInlineQuery handlers; first match wins.
|
||||
func (r *Router) OnInlineQueryFilter(f Filter[*api.InlineQuery], h Handler[*api.InlineQuery]) {
|
||||
r.inlineFilters = append(r.inlineFilters, inlineFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnMyChatMember registers a handler for bot's own chat member status changes.
|
||||
func (r *Router) OnMyChatMember(h Handler[*api.ChatMemberUpdated]) {
|
||||
r.myChatMember = append(r.myChatMember, h)
|
||||
}
|
||||
|
||||
// OnMyChatMemberFilter registers a filtered handler for bot's own chat member status changes.
|
||||
func (r *Router) OnMyChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
|
||||
r.myChatMemberFilters = append(r.myChatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnChatMember registers a handler for chat member status changes.
|
||||
func (r *Router) OnChatMember(h Handler[*api.ChatMemberUpdated]) {
|
||||
r.chatMember = append(r.chatMember, h)
|
||||
}
|
||||
|
||||
// OnChatMemberFilter registers a filtered handler for chat member status changes.
|
||||
func (r *Router) OnChatMemberFilter(f Filter[*api.ChatMemberUpdated], h Handler[*api.ChatMemberUpdated]) {
|
||||
r.chatMemberFilters = append(r.chatMemberFilters, chatMemberFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnChatJoinRequest registers a handler for chat join requests.
|
||||
func (r *Router) OnChatJoinRequest(h Handler[*api.ChatJoinRequest]) {
|
||||
r.chatJoinRequest = append(r.chatJoinRequest, h)
|
||||
}
|
||||
|
||||
// OnChatJoinRequestFilter registers a filtered handler for chat join requests.
|
||||
func (r *Router) OnChatJoinRequestFilter(f Filter[*api.ChatJoinRequest], h Handler[*api.ChatJoinRequest]) {
|
||||
r.chatJoinRequestFilters = append(r.chatJoinRequestFilters, chatJoinRequestFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnPreCheckoutQuery registers a handler for pre-checkout queries.
|
||||
func (r *Router) OnPreCheckoutQuery(h Handler[*api.PreCheckoutQuery]) {
|
||||
r.preCheckoutQuery = append(r.preCheckoutQuery, h)
|
||||
}
|
||||
|
||||
// OnPreCheckoutQueryFilter registers a filtered handler for pre-checkout queries.
|
||||
func (r *Router) OnPreCheckoutQueryFilter(f Filter[*api.PreCheckoutQuery], h Handler[*api.PreCheckoutQuery]) {
|
||||
r.preCheckoutFilters = append(r.preCheckoutFilters, preCheckoutFilterRoute{filter: f, handler: h})
|
||||
}
|
||||
|
||||
// OnShippingQuery registers a handler for shipping queries.
|
||||
func (r *Router) OnShippingQuery(h Handler[*api.ShippingQuery]) {
|
||||
r.shippingQuery = append(r.shippingQuery, h)
|
||||
}
|
||||
|
||||
// OnPoll registers a handler for poll state updates.
|
||||
func (r *Router) OnPoll(h Handler[*api.Poll]) {
|
||||
r.poll = append(r.poll, h)
|
||||
}
|
||||
|
||||
// OnPollAnswer registers a handler for poll answer updates.
|
||||
func (r *Router) OnPollAnswer(h Handler[*api.PollAnswer]) {
|
||||
r.pollAnswer = append(r.pollAnswer, h)
|
||||
}
|
||||
|
||||
// OnChosenInlineResult registers a handler for chosen inline results.
|
||||
func (r *Router) OnChosenInlineResult(h Handler[*api.ChosenInlineResult]) {
|
||||
r.chosenInlineResult = append(r.chosenInlineResult, h)
|
||||
}
|
||||
|
||||
// OnMessageReaction registers a handler for message reaction updates.
|
||||
func (r *Router) OnMessageReaction(h Handler[*api.MessageReactionUpdated]) {
|
||||
r.messageReaction = append(r.messageReaction, h)
|
||||
}
|
||||
|
||||
// OnMessageReactionCount registers a handler for anonymous message reaction count updates.
|
||||
func (r *Router) OnMessageReactionCount(h Handler[*api.MessageReactionCountUpdated]) {
|
||||
r.messageReactionCnt = append(r.messageReactionCnt, h)
|
||||
}
|
||||
|
||||
// OnChatBoost registers a handler for chat boost updates.
|
||||
func (r *Router) OnChatBoost(h Handler[*api.ChatBoostUpdated]) {
|
||||
r.chatBoost = append(r.chatBoost, h)
|
||||
}
|
||||
|
||||
// OnRemovedChatBoost registers a handler for removed chat boost updates.
|
||||
func (r *Router) OnRemovedChatBoost(h Handler[*api.ChatBoostRemoved]) {
|
||||
r.removedChatBoost = append(r.removedChatBoost, h)
|
||||
}
|
||||
|
||||
// OnBusinessConnection registers a handler for business connection updates.
|
||||
func (r *Router) OnBusinessConnection(h Handler[*api.BusinessConnection]) {
|
||||
r.businessConn = append(r.businessConn, h)
|
||||
}
|
||||
|
||||
// OnPurchasedPaidMedia registers a handler for purchased paid media updates.
|
||||
func (r *Router) OnPurchasedPaidMedia(h Handler[*api.PaidMediaPurchased]) {
|
||||
r.purchasedPaidMedia = append(r.purchasedPaidMedia, h)
|
||||
}
|
||||
|
||||
// Run consumes the Updater and dispatches each update. It blocks until
|
||||
// the Updater's channel is closed or ctx is cancelled.
|
||||
//
|
||||
// By default updates are processed concurrently (up to WithMaxConcurrency(50)
|
||||
// goroutines). Handlers for different updates may therefore run simultaneously;
|
||||
// shared state must be protected. Pass WithMaxConcurrency(0) to New to restore
|
||||
// serial (legacy) behaviour.
|
||||
//
|
||||
// Run waits for all in-flight handlers to finish before returning.
|
||||
func (r *Router) Run(ctx context.Context, u transport.Updater) error {
|
||||
runErr := make(chan error, 1)
|
||||
go func() { runErr <- u.Run(ctx) }()
|
||||
|
||||
root := r.dispatch
|
||||
for i := len(r.globalMW) - 1; i >= 0; i-- {
|
||||
root = r.globalMW[i](root)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
dispatch := func(up api.Update) {
|
||||
c := NewContext(ctx, r.bot, &up)
|
||||
if err := root(c, &up); err != nil {
|
||||
if r.bot != nil {
|
||||
r.bot.Logger().Error("dispatch handler error", "err", err, "update_id", up.UpdateID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-runErr:
|
||||
return err
|
||||
case up, ok := <-u.Updates():
|
||||
if !ok {
|
||||
// Channel closed; consume the run error if pending.
|
||||
select {
|
||||
case err := <-runErr:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.sem == nil {
|
||||
// Serial mode (legacy / WithMaxConcurrency(0)).
|
||||
dispatch(up)
|
||||
continue
|
||||
}
|
||||
|
||||
// Concurrent mode: acquire semaphore slot then launch goroutine.
|
||||
select {
|
||||
case r.sem <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(up api.Update) {
|
||||
defer func() {
|
||||
<-r.sem
|
||||
wg.Done()
|
||||
}()
|
||||
dispatch(up)
|
||||
}(up)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) dispatch(c *Context, u *api.Update) error {
|
||||
switch {
|
||||
case u.Message != nil:
|
||||
return r.handleMessage(c, u.Message)
|
||||
case u.EditedMessage != nil:
|
||||
return runHandlers(r.editedMsg, c, u.EditedMessage)
|
||||
case u.ChannelPost != nil:
|
||||
return runHandlers(r.channelPosts, c, u.ChannelPost)
|
||||
case u.EditedChannelPost != nil:
|
||||
return runHandlers(r.editedChannelPosts, c, u.EditedChannelPost)
|
||||
case u.CallbackQuery != nil:
|
||||
return r.handleCallback(c, u.CallbackQuery)
|
||||
case u.InlineQuery != nil:
|
||||
if err := runHandlers(r.inlines, c, u.InlineQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, route := range r.inlineFilters {
|
||||
if route.filter(u.InlineQuery) {
|
||||
return route.handler(c, u.InlineQuery)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case u.MyChatMember != nil:
|
||||
return r.handleChatMemberUpdate(c, u.MyChatMember, r.myChatMember, r.myChatMemberFilters)
|
||||
case u.ChatMember != nil:
|
||||
return r.handleChatMemberUpdate(c, u.ChatMember, r.chatMember, r.chatMemberFilters)
|
||||
case u.ChatJoinRequest != nil:
|
||||
return r.handleChatJoinRequest(c, u.ChatJoinRequest)
|
||||
case u.PreCheckoutQuery != nil:
|
||||
return r.handlePreCheckoutQuery(c, u.PreCheckoutQuery)
|
||||
case u.ShippingQuery != nil:
|
||||
return runHandlers(r.shippingQuery, c, u.ShippingQuery)
|
||||
case u.Poll != nil:
|
||||
return runHandlers(r.poll, c, u.Poll)
|
||||
case u.PollAnswer != nil:
|
||||
return runHandlers(r.pollAnswer, c, u.PollAnswer)
|
||||
case u.ChosenInlineResult != nil:
|
||||
return runHandlers(r.chosenInlineResult, c, u.ChosenInlineResult)
|
||||
case u.MessageReaction != nil:
|
||||
return runHandlers(r.messageReaction, c, u.MessageReaction)
|
||||
case u.MessageReactionCount != nil:
|
||||
return runHandlers(r.messageReactionCnt, c, u.MessageReactionCount)
|
||||
case u.ChatBoost != nil:
|
||||
return runHandlers(r.chatBoost, c, u.ChatBoost)
|
||||
case u.RemovedChatBoost != nil:
|
||||
return runHandlers(r.removedChatBoost, c, u.RemovedChatBoost)
|
||||
case u.BusinessConnection != nil:
|
||||
return runHandlers(r.businessConn, c, u.BusinessConnection)
|
||||
case u.PurchasedPaidMedia != nil:
|
||||
return runHandlers(r.purchasedPaidMedia, c, u.PurchasedPaidMedia)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) handleChatMemberUpdate(c *Context, payload *api.ChatMemberUpdated, handlers []Handler[*api.ChatMemberUpdated], filters []chatMemberFilterRoute) error {
|
||||
if err := runHandlers(handlers, c, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, route := range filters {
|
||||
if route.filter(payload) {
|
||||
return route.handler(c, payload)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) handleChatJoinRequest(c *Context, payload *api.ChatJoinRequest) error {
|
||||
if err := runHandlers(r.chatJoinRequest, c, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, route := range r.chatJoinRequestFilters {
|
||||
if route.filter(payload) {
|
||||
return route.handler(c, payload)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) handlePreCheckoutQuery(c *Context, payload *api.PreCheckoutQuery) error {
|
||||
if err := runHandlers(r.preCheckoutQuery, c, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, route := range r.preCheckoutFilters {
|
||||
if route.filter(payload) {
|
||||
return route.handler(c, payload)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runHandlers invokes each handler in order; returns the first non-nil error.
|
||||
func runHandlers[T any](handlers []Handler[T], c *Context, payload T) error {
|
||||
for _, h := range handlers {
|
||||
if err := h(c, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) handleMessage(c *Context, m *api.Message) error {
|
||||
// Try command first (entity-aware).
|
||||
if cmd, args, ok := extractCommand(m); ok {
|
||||
for _, route := range r.commands {
|
||||
if route.cmd == cmd {
|
||||
c.Values["command"] = cmd
|
||||
c.Values["command_args"] = args
|
||||
return route.handler(c, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Then text regex matchers.
|
||||
if m.Text != "" {
|
||||
for _, route := range r.texts {
|
||||
if subs := route.re.FindStringSubmatch(m.Text); subs != nil {
|
||||
c.Values["regex_match"] = subs
|
||||
return route.handler(c, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Filter-based routes.
|
||||
for _, route := range r.messageFilters {
|
||||
if route.filter(m) {
|
||||
return route.handler(c, m)
|
||||
}
|
||||
}
|
||||
// Group-priority routes (registered via RouterScope.Group()).
|
||||
return r.dispatchGroups(c, m)
|
||||
}
|
||||
|
||||
func (r *Router) handleCallback(c *Context, q *api.CallbackQuery) error {
|
||||
for _, route := range r.callbacks {
|
||||
if subs := route.re.FindStringSubmatch(q.Data); subs != nil {
|
||||
c.Values["regex_match"] = subs
|
||||
return route.handler(c, q)
|
||||
}
|
||||
}
|
||||
// Filter-based routes checked after pattern routes.
|
||||
for _, route := range r.callbackFilters {
|
||||
if route.filter(q) {
|
||||
return route.handler(c, q)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCommand returns the command (e.g. "/start") and the remaining
|
||||
// argument string, when m carries a leading bot_command entity. It strips
|
||||
// optional "@BotName" suffix on the command itself.
|
||||
func extractCommand(m *api.Message) (cmd, args string, ok bool) {
|
||||
if len(m.Entities) == 0 || m.Text == "" {
|
||||
return "", "", false
|
||||
}
|
||||
first := m.Entities[0]
|
||||
if first.Type != string(api.EntityBotCommand) || first.Offset != 0 {
|
||||
return "", "", false
|
||||
}
|
||||
cmd, sliceOk := utf16Slice(m.Text, int(first.Offset), int(first.Length))
|
||||
if !sliceOk {
|
||||
return "", "", false
|
||||
}
|
||||
if i := strings.Index(cmd, "@"); i >= 0 {
|
||||
cmd = cmd[:i]
|
||||
}
|
||||
end := int(first.Offset) + int(first.Length)
|
||||
rest, _ := utf16Slice(m.Text, end, utf16Len(m.Text)-end)
|
||||
args = strings.TrimSpace(rest)
|
||||
return cmd, args, true
|
||||
}
|
||||
|
||||
// utf16Slice returns the substring of s identified by a UTF-16 offset/length
|
||||
// pair, as Telegram's MessageEntity uses. ok is false if the indices fall
|
||||
// outside s's UTF-16 length.
|
||||
func utf16Slice(s string, offset, length int) (string, bool) {
|
||||
runes := []rune(s)
|
||||
var startBytes, endBytes int
|
||||
var u16 int
|
||||
found := false
|
||||
for i, r := range runes {
|
||||
if u16 == offset {
|
||||
startBytes = byteIndex(runes, i)
|
||||
found = true
|
||||
}
|
||||
if u16 == offset+length {
|
||||
endBytes = byteIndex(runes, i)
|
||||
return s[startBytes:endBytes], true
|
||||
}
|
||||
if r > 0xFFFF {
|
||||
u16 += 2
|
||||
} else {
|
||||
u16++
|
||||
}
|
||||
}
|
||||
if found && u16 == offset+length {
|
||||
return s[startBytes:], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func byteIndex(runes []rune, runeIdx int) int {
|
||||
n := 0
|
||||
for i := 0; i < runeIdx; i++ {
|
||||
n += utf8.RuneLen(runes[i])
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func utf16Len(s string) int {
|
||||
n := 0
|
||||
for _, r := range s {
|
||||
if r > 0xFFFF {
|
||||
n += 2
|
||||
} else {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,940 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeUpdater feeds a fixed slice of updates then closes.
|
||||
type fakeUpdater struct{ ch chan api.Update }
|
||||
|
||||
func newFake(ups ...api.Update) *fakeUpdater {
|
||||
ch := make(chan api.Update, len(ups))
|
||||
for _, u := range ups {
|
||||
ch <- u
|
||||
}
|
||||
close(ch)
|
||||
return &fakeUpdater{ch: ch}
|
||||
}
|
||||
|
||||
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
|
||||
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
|
||||
|
||||
func cmdMessage(text string) api.Update {
|
||||
return api.Update{
|
||||
UpdateID: 1,
|
||||
Message: &api.Message{
|
||||
MessageID: 1, Date: 0, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: text,
|
||||
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: int64(indexEnd(text))}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func indexEnd(text string) int {
|
||||
for i, r := range text {
|
||||
if r == ' ' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return len(text)
|
||||
}
|
||||
|
||||
func TestRouter_OnCommandMatches(t *testing.T) {
|
||||
b := client.New("t")
|
||||
r := New(b)
|
||||
hit := make(chan string, 1)
|
||||
r.OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||
hit <- c.Values["command"].(string)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start"))) }()
|
||||
|
||||
require.Equal(t, "/start", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnCommandStripsBotName(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||
hit <- "matched"
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(cmdMessage("/start@MyBot hello"))) }()
|
||||
|
||||
require.Equal(t, "matched", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnText(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan []string, 1)
|
||||
r.OnText(`^hello (\w+)$`, func(c *Context, m *api.Message) error {
|
||||
hit <- c.Values["regex_match"].([]string)
|
||||
return nil
|
||||
})
|
||||
|
||||
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "hello world",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
subs := <-hit
|
||||
require.Equal(t, "world", subs[1])
|
||||
}
|
||||
|
||||
func TestRouter_OnCallback(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnCallback(`^like:(\d+)$`, func(c *Context, q *api.CallbackQuery) error {
|
||||
hit <- q.Data
|
||||
return nil
|
||||
})
|
||||
|
||||
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
|
||||
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "like:42",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
require.Equal(t, "like:42", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_NoMatch(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
called := false
|
||||
r.OnCommand("/start", func(c *Context, m *api.Message) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
u := api.Update{UpdateID: 1, Message: &api.Message{Text: "no command"}}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(u))
|
||||
require.False(t, called)
|
||||
}
|
||||
|
||||
func TestRouter_PanicRecovery(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
r.OnCommand("/boom", func(c *Context, m *api.Message) error {
|
||||
panic("kaboom")
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
// Should not propagate panic to Run.
|
||||
require.NotPanics(t, func() { _ = r.Run(ctx, newFake(cmdMessage("/boom"))) })
|
||||
}
|
||||
|
||||
// TestRouter_NonASCIICommand verifies that UTF-16 entity offsets are used
|
||||
// correctly when the command contains non-ASCII runes. "/старт" is 6 runes,
|
||||
// each a BMP code point, so UTF-16 length == 6.
|
||||
func TestRouter_NonASCIICommand(t *testing.T) {
|
||||
const text = "/старт аргумент"
|
||||
// "/старт" = 1 + 5 runes, all BMP → UTF-16 length 6
|
||||
const cmdU16Len = int64(6)
|
||||
u := api.Update{
|
||||
UpdateID: 1,
|
||||
Message: &api.Message{
|
||||
MessageID: 1,
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: text,
|
||||
Entities: []api.MessageEntity{
|
||||
{Type: string(api.EntityBotCommand), Offset: 0, Length: cmdU16Len},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan [2]string, 1)
|
||||
r.OnCommand("/старт", func(c *Context, m *api.Message) error {
|
||||
hit <- [2]string{
|
||||
c.Values["command"].(string),
|
||||
c.Values["command_args"].(string),
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
got := <-hit
|
||||
require.Equal(t, "/старт", got[0])
|
||||
require.Equal(t, "аргумент", got[1])
|
||||
}
|
||||
|
||||
// TestRouter_CommandValuesNotLeakedOnNoMatch verifies that c.Values["command"]
|
||||
// is not set when a command entity is present but no route matches, so a
|
||||
// subsequent text handler doesn't see stale values.
|
||||
func TestRouter_CommandValuesNotLeakedOnNoMatch(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
// Register a text handler that should fire as fallback.
|
||||
leaked := make(chan bool, 1)
|
||||
r.OnText(`.*`, func(c *Context, m *api.Message) error {
|
||||
_, hasCmd := c.Values["command"]
|
||||
leaked <- hasCmd
|
||||
return nil
|
||||
})
|
||||
// No OnCommand registered, so the command entity won't match any route.
|
||||
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"},
|
||||
Text: "/unknown",
|
||||
Entities: []api.MessageEntity{{Type: string(api.EntityBotCommand), Offset: 0, Length: 8}},
|
||||
}}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
require.False(t, <-leaked, "command value must not leak into text handler")
|
||||
}
|
||||
|
||||
func TestRouter_MiddlewareOrder(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var order []string
|
||||
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
|
||||
return func(c *Context, u *api.Update) error {
|
||||
order = append(order, "before-1")
|
||||
err := next(c, u)
|
||||
order = append(order, "after-1")
|
||||
return err
|
||||
}
|
||||
})
|
||||
r.Use(func(next Handler[*api.Update]) Handler[*api.Update] {
|
||||
return func(c *Context, u *api.Update) error {
|
||||
order = append(order, "before-2")
|
||||
err := next(c, u)
|
||||
order = append(order, "after-2")
|
||||
return err
|
||||
}
|
||||
})
|
||||
r.OnCommand("/x", func(c *Context, m *api.Message) error {
|
||||
order = append(order, "handler")
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(cmdMessage("/x")))
|
||||
|
||||
require.Equal(t,
|
||||
[]string{"before-1", "before-2", "handler", "after-2", "after-1"},
|
||||
order)
|
||||
}
|
||||
func TestRouter_OnChannelPost(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnChannelPost(func(c *Context, m *api.Message) error {
|
||||
hit <- m.MessageID
|
||||
return nil
|
||||
})
|
||||
|
||||
u := api.Update{UpdateID: 1, ChannelPost: &api.Message{
|
||||
MessageID: 99, Chat: api.Chat{ID: -100, Type: string(api.ChatTypeChannel)},
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
require.Equal(t, int64(99), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_RunsAllHandlersForEditedMessage(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var hits []string
|
||||
r.OnEditedMessage(func(c *Context, m *api.Message) error {
|
||||
hits = append(hits, "first")
|
||||
return nil
|
||||
})
|
||||
r.OnEditedMessage(func(c *Context, m *api.Message) error {
|
||||
hits = append(hits, "second")
|
||||
return nil
|
||||
})
|
||||
|
||||
u := api.Update{UpdateID: 1, EditedMessage: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "edited",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(u))
|
||||
|
||||
require.Equal(t, []string{"first", "second"}, hits)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Filter-route tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRouter_OnMessageFilter_Matches(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return m != nil && m.Text == "ping" }),
|
||||
func(c *Context, m *api.Message) error { hit <- m.Text; return nil },
|
||||
)
|
||||
|
||||
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "ping",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
require.Equal(t, "ping", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnMessageFilter_NoMatch(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
called := false
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return false }),
|
||||
func(c *Context, m *api.Message) error { called = true; return nil },
|
||||
)
|
||||
|
||||
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: "private"}, Text: "any",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(u))
|
||||
require.False(t, called)
|
||||
}
|
||||
|
||||
// Command routes must take priority over filter routes.
|
||||
func TestRouter_OnMessageFilter_CommandWins(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var winner string
|
||||
r.OnCommand("/start", func(c *Context, m *api.Message) error { winner = "command"; return nil })
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||
func(c *Context, m *api.Message) error { winner = "filter"; return nil },
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(cmdMessage("/start")))
|
||||
|
||||
require.Equal(t, "command", winner)
|
||||
}
|
||||
|
||||
func TestRouter_OnCallbackFilter_Matches(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnCallbackFilter(
|
||||
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return q != nil && q.Data == "yes" }),
|
||||
func(c *Context, q *api.CallbackQuery) error { hit <- q.Data; return nil },
|
||||
)
|
||||
|
||||
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
|
||||
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
require.Equal(t, "yes", <-hit)
|
||||
}
|
||||
|
||||
// Pattern-based OnCallback wins over OnCallbackFilter when both match.
|
||||
func TestRouter_OnCallbackFilter_PatternWins(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
var winner string
|
||||
r.OnCallback(`^yes$`, func(c *Context, q *api.CallbackQuery) error { winner = "pattern"; return nil })
|
||||
r.OnCallbackFilter(
|
||||
Filter[*api.CallbackQuery](func(q *api.CallbackQuery) bool { return true }),
|
||||
func(c *Context, q *api.CallbackQuery) error { winner = "filter"; return nil },
|
||||
)
|
||||
|
||||
u := api.Update{UpdateID: 1, CallbackQuery: &api.CallbackQuery{
|
||||
ID: "x", From: api.User{ID: 1}, ChatInstance: "y", Data: "yes",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newFake(u))
|
||||
|
||||
require.Equal(t, "pattern", winner)
|
||||
}
|
||||
|
||||
func TestRouter_OnInlineQueryFilter_Matches(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnInlineQueryFilter(
|
||||
Filter[*api.InlineQuery](func(q *api.InlineQuery) bool { return q != nil && q.Query == "find" }),
|
||||
func(c *Context, q *api.InlineQuery) error { hit <- q.Query; return nil },
|
||||
)
|
||||
|
||||
u := api.Update{UpdateID: 1, InlineQuery: &api.InlineQuery{
|
||||
ID: "i", From: api.User{ID: 1}, Query: "find",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(u)) }()
|
||||
|
||||
require.Equal(t, "find", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_FilterChain_Composition(t *testing.T) {
|
||||
// Filter: private chat AND text contains "hello"
|
||||
privateChat := Filter[*api.Message](func(m *api.Message) bool {
|
||||
return m != nil && m.Chat.Type == string(api.ChatTypePrivate)
|
||||
})
|
||||
hasHello := Filter[*api.Message](func(m *api.Message) bool {
|
||||
return m != nil && len(m.Text) > 0 && containsStr(m.Text, "hello")
|
||||
})
|
||||
combined := privateChat.And(hasHello)
|
||||
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnMessageFilter(combined, func(c *Context, m *api.Message) error { hit <- m.Text; return nil })
|
||||
|
||||
match := api.Update{UpdateID: 1, Message: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "say hello",
|
||||
}}
|
||||
noMatch := api.Update{UpdateID: 2, Message: &api.Message{
|
||||
MessageID: 2, Chat: api.Chat{ID: 2, Type: string(api.ChatTypeGroup)}, Text: "say hello",
|
||||
}}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(match, noMatch)) }()
|
||||
|
||||
require.Equal(t, "say hello", <-hit)
|
||||
}
|
||||
|
||||
// containsStr is a helper to avoid importing strings in test file unnecessarily.
|
||||
func containsStr(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub))
|
||||
}
|
||||
|
||||
func containsSubstr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Concurrent dispatch tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// fakeSlowUpdater feeds n updates then blocks until ctx cancel.
|
||||
type fakeSlowUpdater struct {
|
||||
ch chan api.Update
|
||||
}
|
||||
|
||||
func newSlowFake(ups ...api.Update) *fakeSlowUpdater {
|
||||
ch := make(chan api.Update, len(ups))
|
||||
for _, u := range ups {
|
||||
ch <- u
|
||||
}
|
||||
close(ch)
|
||||
return &fakeSlowUpdater{ch: ch}
|
||||
}
|
||||
|
||||
func (f *fakeSlowUpdater) Updates() <-chan api.Update { return f.ch }
|
||||
func (f *fakeSlowUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||
func (f *fakeSlowUpdater) Stop(ctx context.Context) error { return nil }
|
||||
|
||||
func TestRouter_ConcurrentDispatch_AllHandlersFire(t *testing.T) {
|
||||
const n = 100
|
||||
var fired atomic.Int64
|
||||
|
||||
ups := make([]api.Update, n)
|
||||
for i := range ups {
|
||||
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||
MessageID: int64(i + 1),
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: "hi",
|
||||
}}
|
||||
}
|
||||
|
||||
r := New(client.New("t"), WithMaxConcurrency(20))
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||
func(c *Context, m *api.Message) error { fired.Add(1); return nil },
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newSlowFake(ups...))
|
||||
|
||||
require.Equal(t, int64(n), fired.Load())
|
||||
}
|
||||
|
||||
func TestRouter_ConcurrentDispatch_SemaphoreBoundsConcurrency(t *testing.T) {
|
||||
const limit = 5
|
||||
const n = 30
|
||||
|
||||
var inFlight atomic.Int64
|
||||
var maxSeen atomic.Int64
|
||||
ready := make(chan struct{}) // signals handler to proceed
|
||||
started := make(chan struct{}) // first handler signals it's running
|
||||
|
||||
ups := make([]api.Update, n)
|
||||
for i := range ups {
|
||||
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||
MessageID: int64(i + 1),
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: "hi",
|
||||
}}
|
||||
}
|
||||
|
||||
once := atomic.Bool{}
|
||||
r := New(client.New("t"), WithMaxConcurrency(limit))
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||
func(c *Context, m *api.Message) error {
|
||||
cur := inFlight.Add(1)
|
||||
for {
|
||||
old := maxSeen.Load()
|
||||
if cur <= old || maxSeen.CompareAndSwap(old, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if once.CompareAndSwap(false, true) {
|
||||
close(started)
|
||||
}
|
||||
<-ready
|
||||
inFlight.Add(-1)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() { _ = r.Run(ctx, newSlowFake(ups...)) }()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first handler")
|
||||
}
|
||||
// Give the pool a moment to fill up.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(ready)
|
||||
|
||||
// Wait for Run to drain by cancelling context after a short wait.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
require.LessOrEqual(t, maxSeen.Load(), int64(limit),
|
||||
"in-flight goroutines exceeded semaphore limit")
|
||||
}
|
||||
|
||||
func TestRouter_ConcurrentDispatch_WaitsForInFlight(t *testing.T) {
|
||||
unblock := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
|
||||
r := New(client.New("t"), WithMaxConcurrency(10))
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||
func(c *Context, m *api.Message) error {
|
||||
<-unblock
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
u := api.Update{UpdateID: 1, Message: &api.Message{
|
||||
MessageID: 1, Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)}, Text: "hi",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_ = r.Run(ctx, newSlowFake(u))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give Run time to pick up the update and launch the goroutine.
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel() // trigger Run to exit its loop
|
||||
|
||||
// Run should not return until handler unblocks.
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("Run returned before in-flight handler finished")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(unblock)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not return after handler finished")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SerialMode_NoRace(t *testing.T) {
|
||||
// WithMaxConcurrency(0) — serial; shared slice is safe without a mutex.
|
||||
var order []int64
|
||||
|
||||
const n = 20
|
||||
ups := make([]api.Update, n)
|
||||
for i := range ups {
|
||||
ups[i] = api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||
MessageID: int64(i + 1),
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: "hi",
|
||||
}}
|
||||
}
|
||||
|
||||
r := New(client.New("t"), WithMaxConcurrency(0))
|
||||
r.OnMessageFilter(
|
||||
Filter[*api.Message](func(m *api.Message) bool { return true }),
|
||||
func(c *Context, m *api.Message) error {
|
||||
order = append(order, m.MessageID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = r.Run(ctx, newSlowFake(ups...))
|
||||
|
||||
require.Len(t, order, n)
|
||||
for i, v := range order {
|
||||
require.Equal(t, int64(i+1), v)
|
||||
}
|
||||
}
|
||||
|
||||
// liveUpdater is an updater whose channel stays open until stopCh is closed.
|
||||
type liveUpdater struct {
|
||||
ch chan api.Update
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func newLiveUpdater() *liveUpdater {
|
||||
return &liveUpdater{ch: make(chan api.Update, 8), stopCh: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (l *liveUpdater) Send(u api.Update) { l.ch <- u }
|
||||
func (l *liveUpdater) Close() { close(l.stopCh) }
|
||||
func (l *liveUpdater) Updates() <-chan api.Update { return l.ch }
|
||||
func (l *liveUpdater) Stop(ctx context.Context) error { return nil }
|
||||
func (l *liveUpdater) Run(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-l.stopCh:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Typed handler tests (Feature 1)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRouter_OnMyChatMember(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnMyChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{
|
||||
From: api.User{ID: 42},
|
||||
Chat: api.Chat{ID: 1},
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(42), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnMyChatMemberFilter(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.From.ID == 99 })
|
||||
r.OnMyChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.From.ID; return nil })
|
||||
|
||||
match := api.Update{UpdateID: 1, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 99}}}
|
||||
noMatch := api.Update{UpdateID: 2, MyChatMember: &api.ChatMemberUpdated{From: api.User{ID: 1}}}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(noMatch, match)) }()
|
||||
require.Equal(t, int64(99), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnChatMember(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnChatMember(func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{
|
||||
From: api.User{ID: 1},
|
||||
Chat: api.Chat{ID: 77},
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(77), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnChatMemberFilter(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
f := Filter[*api.ChatMemberUpdated](func(u *api.ChatMemberUpdated) bool { return u.Chat.ID == 55 })
|
||||
r.OnChatMemberFilter(f, func(c *Context, u *api.ChatMemberUpdated) error { hit <- u.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ChatMember: &api.ChatMemberUpdated{Chat: api.Chat{ID: 55}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(55), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnChatJoinRequest(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnChatJoinRequest(func(c *Context, req *api.ChatJoinRequest) error { hit <- req.From.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
|
||||
From: api.User{ID: 11},
|
||||
Chat: api.Chat{ID: 1},
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(11), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnChatJoinRequestFilter(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
f := Filter[*api.ChatJoinRequest](func(req *api.ChatJoinRequest) bool { return req.Chat.ID == 22 })
|
||||
r.OnChatJoinRequestFilter(f, func(c *Context, req *api.ChatJoinRequest) error { hit <- req.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ChatJoinRequest: &api.ChatJoinRequest{
|
||||
From: api.User{ID: 1},
|
||||
Chat: api.Chat{ID: 22},
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(22), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnPreCheckoutQuery(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnPreCheckoutQuery(func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
|
||||
ID: "q1", From: api.User{ID: 1}, Currency: "USD",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "USD", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnPreCheckoutQueryFilter(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
f := Filter[*api.PreCheckoutQuery](func(q *api.PreCheckoutQuery) bool { return q.Currency == "EUR" })
|
||||
r.OnPreCheckoutQueryFilter(f, func(c *Context, q *api.PreCheckoutQuery) error { hit <- q.Currency; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, PreCheckoutQuery: &api.PreCheckoutQuery{
|
||||
ID: "q1", From: api.User{ID: 1}, Currency: "EUR",
|
||||
}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "EUR", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnShippingQuery(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnShippingQuery(func(c *Context, q *api.ShippingQuery) error { hit <- q.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ShippingQuery: &api.ShippingQuery{ID: "sq1", From: api.User{ID: 1}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "sq1", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnPoll(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnPoll(func(c *Context, p *api.Poll) error { hit <- p.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, Poll: &api.Poll{ID: "poll1"}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "poll1", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnPollAnswer(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnPollAnswer(func(c *Context, a *api.PollAnswer) error { hit <- a.PollID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, PollAnswer: &api.PollAnswer{PollID: "p1", OptionIds: []int64{0}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "p1", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnChosenInlineResult(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnChosenInlineResult(func(c *Context, res *api.ChosenInlineResult) error { hit <- res.ResultID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ChosenInlineResult: &api.ChosenInlineResult{ResultID: "r1", From: api.User{ID: 1}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "r1", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnMessageReaction(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnMessageReaction(func(c *Context, u *api.MessageReactionUpdated) error { hit <- u.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, MessageReaction: &api.MessageReactionUpdated{Chat: api.Chat{ID: 33}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(33), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnMessageReactionCount(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnMessageReactionCount(func(c *Context, u *api.MessageReactionCountUpdated) error { hit <- u.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, MessageReactionCount: &api.MessageReactionCountUpdated{Chat: api.Chat{ID: 44}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(44), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnChatBoost(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnChatBoost(func(c *Context, u *api.ChatBoostUpdated) error { hit <- u.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, ChatBoost: &api.ChatBoostUpdated{Chat: api.Chat{ID: 55}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(55), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnRemovedChatBoost(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan int64, 1)
|
||||
r.OnRemovedChatBoost(func(c *Context, u *api.ChatBoostRemoved) error { hit <- u.Chat.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, RemovedChatBoost: &api.ChatBoostRemoved{Chat: api.Chat{ID: 66}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, int64(66), <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnBusinessConnection(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnBusinessConnection(func(c *Context, bc *api.BusinessConnection) error { hit <- bc.ID; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, BusinessConnection: &api.BusinessConnection{ID: "bc1", UserChatID: 1, User: api.User{ID: 1}}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "bc1", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_OnPurchasedPaidMedia(t *testing.T) {
|
||||
r := New(client.New("t"))
|
||||
hit := make(chan string, 1)
|
||||
r.OnPurchasedPaidMedia(func(c *Context, p *api.PaidMediaPurchased) error { hit <- p.PaidMediaPayload; return nil })
|
||||
|
||||
upd := api.Update{UpdateID: 1, PurchasedPaidMedia: &api.PaidMediaPurchased{From: api.User{ID: 1}, PaidMediaPayload: "payload1"}}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
go func() { _ = r.Run(ctx, newFake(upd)) }()
|
||||
require.Equal(t, "payload1", <-hit)
|
||||
}
|
||||
|
||||
func TestRouter_ContextCancel_UnblocksWaitingAcquire(t *testing.T) {
|
||||
// Fill the semaphore with slow handlers, send one more update, then cancel
|
||||
// ctx. Run must unblock from the semaphore-acquire select and return.
|
||||
const limit = 2
|
||||
unblock := make(chan struct{})
|
||||
|
||||
slowHandler := func(c *Context, m *api.Message) error {
|
||||
<-unblock
|
||||
return nil
|
||||
}
|
||||
|
||||
lu := newLiveUpdater()
|
||||
r := New(client.New("t"), WithMaxConcurrency(limit))
|
||||
r.OnMessageFilter(Filter[*api.Message](func(m *api.Message) bool { return true }), slowHandler)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
runDone := make(chan error, 1)
|
||||
go func() { runDone <- r.Run(ctx, lu) }()
|
||||
|
||||
// Send enough updates to fill semaphore.
|
||||
for i := range limit {
|
||||
lu.Send(api.Update{UpdateID: int64(i + 1), Message: &api.Message{
|
||||
MessageID: int64(i + 1),
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: "hi",
|
||||
}})
|
||||
}
|
||||
|
||||
// Give goroutines time to acquire all semaphore slots.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Send one more update — Run will block trying to acquire the full semaphore.
|
||||
lu.Send(api.Update{UpdateID: int64(limit + 1), Message: &api.Message{
|
||||
MessageID: int64(limit + 1),
|
||||
Chat: api.Chat{ID: 1, Type: string(api.ChatTypePrivate)},
|
||||
Text: "extra",
|
||||
}})
|
||||
|
||||
// Give Run a moment to reach the semaphore-acquire select.
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// Unblock handlers so wg.Wait() inside Run can complete, allowing Run to
|
||||
// return (and write to runDone).
|
||||
close(unblock)
|
||||
|
||||
select {
|
||||
case err := <-runDone:
|
||||
require.Error(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not unblock after context cancel")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user