mirror of
https://github.com/lukaszraczylo/go-telegram.git
synced 2026-06-10 23:09:04 +00:00
Initial release of go-telegram
A fully-generated, strongly-typed Go client for the Telegram Bot API. * 176 methods + 301 types generated from Bot API v10.0 * 1408 auto-generated tests (8 scenarios per method) * Typed unions throughout — no 'any' in the public surface * Pluggable HTTP transport and JSON codec (default goccy/go-json) * Built-in retry middleware honouring Telegram's retry_after * Generic dispatcher with filters and conversation handlers * Self-verifying codegen pipeline (regen → audit → emit → run tests) * 14 example bots covering common patterns
This commit is contained in:
@@ -0,0 +1,494 @@
|
||||
package conversation_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch/conversation"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---- helpers ---------------------------------------------------------------
|
||||
|
||||
func msgUpd(userID, chatID int64, text string) api.Update {
|
||||
return api.Update{
|
||||
UpdateID: 1,
|
||||
Message: &api.Message{
|
||||
MessageID: 1,
|
||||
From: &api.User{ID: userID},
|
||||
Chat: api.Chat{ID: chatID},
|
||||
Text: text,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeCtx(u *api.Update) *dispatch.Context {
|
||||
return dispatch.NewContext(context.Background(), client.New("t"), u)
|
||||
}
|
||||
|
||||
// anyMsg matches any update that has a Message.
|
||||
var anyMsg = func(u *api.Update) bool { return u.Message != nil }
|
||||
|
||||
// hasPrefix returns a filter matching updates whose Message.Text has prefix p.
|
||||
func hasPrefix(p string) dispatch.Filter[*api.Update] {
|
||||
return func(u *api.Update) bool {
|
||||
return u.Message != nil && strings.HasPrefix(u.Message.Text, p)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeUpdater feeds a fixed set of updates then closes (mirrors router_test.go).
|
||||
type fakeUpdater struct{ ch chan api.Update }
|
||||
|
||||
func newFake(ups ...api.Update) *fakeUpdater {
|
||||
ch := make(chan api.Update, len(ups))
|
||||
for _, u := range ups {
|
||||
ch <- u
|
||||
}
|
||||
close(ch)
|
||||
return &fakeUpdater{ch: ch}
|
||||
}
|
||||
|
||||
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
|
||||
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
|
||||
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
|
||||
|
||||
// ---- Storage tests ---------------------------------------------------------
|
||||
|
||||
func TestStorage_ErrKeyNotFound(t *testing.T) {
|
||||
s := conversation.NewMemoryStorage()
|
||||
_, err := s.Get(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestStorage_SetAndGet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := conversation.NewMemoryStorage()
|
||||
require.NoError(t, s.Set(ctx, "k", "state_a"))
|
||||
v, err := s.Get(ctx, "k")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conversation.State("state_a"), v)
|
||||
}
|
||||
|
||||
func TestStorage_Delete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := conversation.NewMemoryStorage()
|
||||
require.NoError(t, s.Set(ctx, "k", "state_a"))
|
||||
require.NoError(t, s.Delete(ctx, "k"))
|
||||
_, err := s.Get(ctx, "k")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestStorage_DeleteNonExistentIsNoop(t *testing.T) {
|
||||
require.NoError(t, conversation.NewMemoryStorage().Delete(context.Background(), "gone"))
|
||||
}
|
||||
|
||||
// ---- Key strategy tests ----------------------------------------------------
|
||||
|
||||
func TestKeyByUser_Variants(t *testing.T) {
|
||||
t.Run("message", func(t *testing.T) {
|
||||
u := msgUpd(42, 100, "hi")
|
||||
require.Equal(t, "u:42", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("edited_message", func(t *testing.T) {
|
||||
u := api.Update{EditedMessage: &api.Message{From: &api.User{ID: 7}, Chat: api.Chat{ID: 1}}}
|
||||
require.Equal(t, "u:7", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("callback_query", func(t *testing.T) {
|
||||
u := api.Update{CallbackQuery: &api.CallbackQuery{From: api.User{ID: 99}}}
|
||||
require.Equal(t, "u:99", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("inline_query", func(t *testing.T) {
|
||||
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||
require.Equal(t, "u:5", conversation.KeyByUser(&u))
|
||||
})
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
require.Equal(t, "", conversation.KeyByUser(&api.Update{}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByChat_Variants(t *testing.T) {
|
||||
t.Run("message", func(t *testing.T) {
|
||||
u := msgUpd(1, 200, "")
|
||||
require.Equal(t, "c:200", conversation.KeyByChat(&u))
|
||||
})
|
||||
t.Run("inline_has_no_chat", func(t *testing.T) {
|
||||
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||
require.Equal(t, "", conversation.KeyByChat(&u))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByUserAndChat(t *testing.T) {
|
||||
u := msgUpd(42, 100, "")
|
||||
require.Equal(t, "uc:100:42", conversation.KeyByUserAndChat(&u))
|
||||
}
|
||||
|
||||
// ---- Handler / state machine tests -----------------------------------------
|
||||
|
||||
func buildConv() *conversation.Conversation {
|
||||
return &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
return conversation.Next("await_name")
|
||||
},
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"await_name": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("await_age") },
|
||||
}},
|
||||
"await_age": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||
}},
|
||||
},
|
||||
Exits: []conversation.Step{{
|
||||
Filter: hasPrefix("/cancel"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestConversation_FullFlow(t *testing.T) {
|
||||
conv := buildConv()
|
||||
|
||||
var downstream int
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstream++
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
key := "uc:1:42"
|
||||
|
||||
// 1. /start → enters, state = await_name
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
v, err := conv.Storage.Get(context.Background(), key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conversation.State("await_name"), v)
|
||||
require.Equal(t, 0, downstream, "entry claimed update")
|
||||
|
||||
// 2. name → state = await_age
|
||||
u2 := msgUpd(42, 1, "Alice")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
v, err = conv.Storage.Get(context.Background(), key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, conversation.State("await_age"), v)
|
||||
|
||||
// 3. age → End, key deleted
|
||||
u3 := msgUpd(42, 1, "30")
|
||||
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||
_, err = conv.Storage.Get(context.Background(), key)
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
|
||||
}
|
||||
|
||||
func TestConversation_ExitsCancelMidFlow(t *testing.T) {
|
||||
conv := buildConv()
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
// Start conversation.
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cancel mid-flow.
|
||||
u2 := msgUpd(42, 1, "/cancel")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
_, err = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "exit should clear state")
|
||||
}
|
||||
|
||||
func TestConversation_FallbackFiresWhenNoStateStepMatches(t *testing.T) {
|
||||
fallbackHit := false
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
// No steps for "waiting" that match a callback query.
|
||||
"waiting": {},
|
||||
},
|
||||
Fallbacks: []conversation.Step{{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
fallbackHit = true
|
||||
return nil
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
u2 := msgUpd(42, 1, "unexpected text")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
require.True(t, fallbackHit, "fallback should have fired")
|
||||
}
|
||||
|
||||
func TestConversation_NoActiveConv_PassesToDownstream(t *testing.T) {
|
||||
conv := buildConv()
|
||||
downstreamHit := false
|
||||
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstreamHit = true
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(downstream)
|
||||
|
||||
// Random message that doesn't match /start
|
||||
u := msgUpd(42, 1, "hello")
|
||||
require.NoError(t, mw(makeCtx(&u), &u))
|
||||
require.True(t, downstreamHit, "unmatched update should reach downstream")
|
||||
}
|
||||
|
||||
func TestConversation_EmptyKey_PassesThrough(t *testing.T) {
|
||||
// InlineQuery has no chatID → KeyByUserAndChat returns "" → pass through.
|
||||
conv := buildConv()
|
||||
downstreamHit := false
|
||||
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstreamHit = true
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(downstream)
|
||||
|
||||
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
|
||||
require.NoError(t, mw(makeCtx(&u), &u))
|
||||
require.True(t, downstreamHit)
|
||||
}
|
||||
|
||||
func TestConversation_AllowReEntry(t *testing.T) {
|
||||
conv := buildConv()
|
||||
conv.AllowReEntry = true
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
// Start.
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_name"), v)
|
||||
|
||||
// Advance once.
|
||||
u2 := msgUpd(42, 1, "Alice")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_age"), v)
|
||||
|
||||
// Re-enter with /start — should restart to await_name even though mid-flow.
|
||||
u3 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_name"), v, "AllowReEntry should restart")
|
||||
}
|
||||
|
||||
func TestConversation_NoReEntry_EntryIgnoredWhenActive(t *testing.T) {
|
||||
conv := buildConv()
|
||||
conv.AllowReEntry = false
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
// Start.
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
// Advance to await_age.
|
||||
u2 := msgUpd(42, 1, "Alice")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("await_age"), v)
|
||||
|
||||
// /start again — should NOT restart; state should stay await_age since
|
||||
// /start matches the state step filter (anyMsg) and advances.
|
||||
// Actually /start is handled by "await_age" anyMsg step → End().
|
||||
u3 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u3), &u3))
|
||||
// State ended (End() called by await_age step).
|
||||
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "state step should have consumed /start when AllowReEntry=false")
|
||||
}
|
||||
|
||||
func TestConversation_StayInState_NilReturn(t *testing.T) {
|
||||
// Handler returning nil keeps state unchanged.
|
||||
stored := false
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
return conversation.Next("waiting")
|
||||
},
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"waiting": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
stored = true
|
||||
return nil // stay in current state
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
u2 := msgUpd(42, 1, "something")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
require.True(t, stored)
|
||||
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
|
||||
require.Equal(t, conversation.State("waiting"), v, "nil return should leave state unchanged")
|
||||
}
|
||||
|
||||
func TestConversation_ActiveNoMatch_Swallows(t *testing.T) {
|
||||
// Active conversation with no matching state step and no fallback:
|
||||
// update is swallowed (not passed downstream).
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"waiting": {{
|
||||
// Only matches /done specifically.
|
||||
Filter: hasPrefix("/done"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
downstreamHit := false
|
||||
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
|
||||
downstreamHit = true
|
||||
return nil
|
||||
})
|
||||
mw := conv.Dispatch(downstream)
|
||||
|
||||
u1 := msgUpd(42, 1, "/start")
|
||||
require.NoError(t, mw(makeCtx(&u1), &u1))
|
||||
|
||||
// Random text doesn't match /done and there's no fallback → swallowed.
|
||||
u2 := msgUpd(42, 1, "random")
|
||||
require.NoError(t, mw(makeCtx(&u2), &u2))
|
||||
require.False(t, downstreamHit, "active conv with no matching step should swallow")
|
||||
}
|
||||
|
||||
// ---- Via Router.Run --------------------------------------------------------
|
||||
|
||||
func TestConversation_ViaRouter(t *testing.T) {
|
||||
var steps atomic.Int32
|
||||
conv := &conversation.Conversation{
|
||||
EntryPoints: []conversation.Step{{
|
||||
Filter: hasPrefix("/start"),
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
steps.Add(1)
|
||||
return conversation.Next("await_name")
|
||||
},
|
||||
}},
|
||||
States: map[conversation.State][]conversation.Step{
|
||||
"await_name": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
steps.Add(1)
|
||||
return conversation.Next("await_age")
|
||||
},
|
||||
}},
|
||||
"await_age": {{
|
||||
Filter: anyMsg,
|
||||
Handler: func(c *dispatch.Context, u *api.Update) error {
|
||||
steps.Add(1)
|
||||
return conversation.End()
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
router := dispatch.New(client.New("t"), dispatch.WithMaxConcurrency(0)) // serial
|
||||
router.Use(conv.Dispatch)
|
||||
|
||||
ups := []api.Update{
|
||||
msgUpd(42, 1, "/start"),
|
||||
msgUpd(42, 1, "Alice"),
|
||||
msgUpd(42, 1, "30"),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- router.Run(ctx, newFake(ups...)) }()
|
||||
|
||||
// Wait for updater channel to drain (Run returns when closed).
|
||||
err := <-errCh
|
||||
if err != nil && err != context.Canceled {
|
||||
t.Fatalf("Run error: %v", err)
|
||||
}
|
||||
|
||||
require.Equal(t, int32(3), steps.Load(), "all three steps should have fired")
|
||||
}
|
||||
|
||||
// ---- Concurrent storage safety ---------------------------------------------
|
||||
|
||||
func TestConversation_ConcurrentStorageAccess(t *testing.T) {
|
||||
// 15 goroutines each running a full /start → name → age flow against the
|
||||
// same shared storage but DIFFERENT keys (one per goroutine). Validates
|
||||
// no data races.
|
||||
const numUsers = 15
|
||||
|
||||
conv := buildConv()
|
||||
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
|
||||
mw := conv.Dispatch(noop)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
go func(uid int64) {
|
||||
defer wg.Done()
|
||||
u1 := msgUpd(uid, uid, "/start")
|
||||
_ = mw(makeCtx(&u1), &u1)
|
||||
u2 := msgUpd(uid, uid, "Alice")
|
||||
_ = mw(makeCtx(&u2), &u2)
|
||||
u3 := msgUpd(uid, uid, "30")
|
||||
_ = mw(makeCtx(&u3), &u3)
|
||||
}(int64(i + 1))
|
||||
}
|
||||
wg.Wait()
|
||||
// Race detector catches bugs; no assertion needed beyond clean finish.
|
||||
}
|
||||
|
||||
func TestConversation_ConcurrentSameKey(t *testing.T) {
|
||||
// 12 goroutines hammer the same key concurrently. Storage must not panic
|
||||
// or corrupt state. Race detector validates lock discipline.
|
||||
const goroutines = 12
|
||||
s := conversation.NewMemoryStorage()
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_ = s.Set(ctx, "shared", conversation.State("step"))
|
||||
_, _ = s.Get(ctx, "shared")
|
||||
if i%4 == 0 {
|
||||
_ = s.Delete(ctx, "shared")
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/dispatch"
|
||||
)
|
||||
|
||||
// stateTransition is a sentinel error type carrying a state transition
|
||||
// or end signal. Conversation handlers return one of these (via Next or
|
||||
// End helpers below) to drive the state machine.
|
||||
type stateTransition struct {
|
||||
next State
|
||||
end bool
|
||||
}
|
||||
|
||||
func (e *stateTransition) Error() string {
|
||||
if e.end {
|
||||
return "conversation: end"
|
||||
}
|
||||
return "conversation: → " + string(e.next)
|
||||
}
|
||||
|
||||
// Next signals the conversation should advance to the given state.
|
||||
// Conversation handlers return Next("state_name") to transition.
|
||||
func Next(s State) error {
|
||||
return &stateTransition{next: s}
|
||||
}
|
||||
|
||||
// End signals the conversation has finished and state should be cleared.
|
||||
// Conversation handlers return End() to terminate.
|
||||
func End() error {
|
||||
return &stateTransition{end: true}
|
||||
}
|
||||
|
||||
// Handler defines a step in the conversation. Receives the dispatch context
|
||||
// and the raw update. Returns:
|
||||
// - nil to stay in the current state
|
||||
// - Next("state") to transition to a different state
|
||||
// - End() to end the conversation
|
||||
// - any other non-nil error to surface to the dispatcher (state unchanged)
|
||||
type Handler func(ctx *dispatch.Context, u *api.Update) error
|
||||
|
||||
// Step pairs a filter with a handler for one conversation step.
|
||||
type Step struct {
|
||||
Filter dispatch.Filter[*api.Update]
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
// Conversation is a stateful handler with entry, per-state, exit and
|
||||
// fallback steps. A conversation is keyed by KeyStrategy (default
|
||||
// KeyByUserAndChat) and persisted by Storage (default in-memory).
|
||||
type Conversation struct {
|
||||
// EntryPoints starts a new conversation when a matching filter fires
|
||||
// and no conversation is already active for the key.
|
||||
EntryPoints []Step
|
||||
|
||||
// States maps each state to the steps that handle it.
|
||||
States map[State][]Step
|
||||
|
||||
// Exits, if any match, end the active conversation early. Useful for
|
||||
// /cancel-style commands.
|
||||
Exits []Step
|
||||
|
||||
// Fallbacks run when no state step matches the current update.
|
||||
Fallbacks []Step
|
||||
|
||||
// Storage persists conversation state. Defaults to NewMemoryStorage.
|
||||
Storage Storage
|
||||
|
||||
// KeyStrategy derives the persistence key. Defaults to KeyByUserAndChat.
|
||||
KeyStrategy KeyStrategy
|
||||
|
||||
// AllowReEntry, when true, lets entry-point steps fire even while a
|
||||
// conversation is already active for the key (effectively restarting it).
|
||||
AllowReEntry bool
|
||||
}
|
||||
|
||||
// Dispatch is a global middleware-shaped Handler that consumes updates
|
||||
// and routes them through the conversation graph. Register via
|
||||
// router.Use(conv.Dispatch).
|
||||
//
|
||||
// If the conversation claims an update, downstream handlers are skipped.
|
||||
// If the conversation does not claim it, downstream handlers run as normal.
|
||||
func (c *Conversation) Dispatch(next dispatch.Handler[*api.Update]) dispatch.Handler[*api.Update] {
|
||||
if c.Storage == nil {
|
||||
c.Storage = NewMemoryStorage()
|
||||
}
|
||||
if c.KeyStrategy == nil {
|
||||
c.KeyStrategy = KeyByUserAndChat
|
||||
}
|
||||
return func(dctx *dispatch.Context, u *api.Update) error {
|
||||
key := c.KeyStrategy(u)
|
||||
if key == "" {
|
||||
return next(dctx, u)
|
||||
}
|
||||
|
||||
ctx := dctx.Ctx
|
||||
current, err := c.Storage.Get(ctx, key)
|
||||
if err != nil && !errors.Is(err, ErrKeyNotFound) {
|
||||
return err
|
||||
}
|
||||
active := !errors.Is(err, ErrKeyNotFound)
|
||||
|
||||
// Try exits first (always allowed if conversation is active).
|
||||
if active {
|
||||
for _, step := range c.Exits {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try entry points (only if no active conversation, or AllowReEntry).
|
||||
if !active || c.AllowReEntry {
|
||||
for _, step := range c.EntryPoints {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !active {
|
||||
return next(dctx, u)
|
||||
}
|
||||
|
||||
// Active conversation: try state steps.
|
||||
for _, step := range c.States[current] {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallbacks if no state step matched.
|
||||
for _, step := range c.Fallbacks {
|
||||
if step.Filter(u) {
|
||||
if err := c.runStep(ctx, dctx, u, key, step.Handler); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Active conversation but no step matched and no fallback: swallow the
|
||||
// update (do NOT pass to downstream handlers, since the user is
|
||||
// mid-conversation and an unrelated handler would surprise them).
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// runStep invokes the handler and applies its return-value state transition.
|
||||
func (c *Conversation) runStep(ctx context.Context, dctx *dispatch.Context, u *api.Update, key string, h Handler) error {
|
||||
err := h(dctx, u)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var trans *stateTransition
|
||||
if errors.As(err, &trans) {
|
||||
if trans.end {
|
||||
return c.Storage.Delete(ctx, key)
|
||||
}
|
||||
return c.Storage.Set(ctx, key, trans.next)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryStorage is the default in-process Storage. It is safe for
|
||||
// concurrent use. Conversation state is lost on process restart; use
|
||||
// a custom Storage backed by a database for persistent flows.
|
||||
type MemoryStorage struct {
|
||||
mu sync.RWMutex
|
||||
state map[string]State
|
||||
}
|
||||
|
||||
// NewMemoryStorage constructs an empty in-memory storage.
|
||||
func NewMemoryStorage() *MemoryStorage {
|
||||
return &MemoryStorage{state: map[string]State{}}
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) Get(_ context.Context, key string) (State, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.state[key]
|
||||
if !ok {
|
||||
return "", ErrKeyNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) Set(_ context.Context, key string, state State) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.state[key] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStorage) Delete(_ context.Context, key string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.state, key)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryStorage_GetSetDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := NewMemoryStorage()
|
||||
|
||||
// Get on empty key returns ErrKeyNotFound.
|
||||
_, err := s.Get(ctx, "k1")
|
||||
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||
|
||||
// Set then Get returns the stored state.
|
||||
require.NoError(t, s.Set(ctx, "k1", "step_a"))
|
||||
v, err := s.Get(ctx, "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("step_a"), v)
|
||||
|
||||
// Overwrite works.
|
||||
require.NoError(t, s.Set(ctx, "k1", "step_b"))
|
||||
v, err = s.Get(ctx, "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("step_b"), v)
|
||||
|
||||
// Delete removes the key.
|
||||
require.NoError(t, s.Delete(ctx, "k1"))
|
||||
_, err = s.Get(ctx, "k1")
|
||||
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||
|
||||
// Delete of non-existent key is a no-op (no error).
|
||||
require.NoError(t, s.Delete(ctx, "nonexistent"))
|
||||
}
|
||||
|
||||
func TestMemoryStorage_MultipleKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := NewMemoryStorage()
|
||||
|
||||
require.NoError(t, s.Set(ctx, "a", "stateA"))
|
||||
require.NoError(t, s.Set(ctx, "b", "stateB"))
|
||||
|
||||
va, err := s.Get(ctx, "a")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("stateA"), va)
|
||||
|
||||
vb, err := s.Get(ctx, "b")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("stateB"), vb)
|
||||
|
||||
// Delete one key; the other remains.
|
||||
require.NoError(t, s.Delete(ctx, "a"))
|
||||
_, err = s.Get(ctx, "a")
|
||||
require.ErrorIs(t, err, ErrKeyNotFound)
|
||||
|
||||
vb, err = s.Get(ctx, "b")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, State("stateB"), vb)
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Concurrent(t *testing.T) {
|
||||
// 20 goroutines hammering the same key concurrently — no data race.
|
||||
ctx := context.Background()
|
||||
s := NewMemoryStorage()
|
||||
|
||||
const goroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
key := "shared"
|
||||
_ = s.Set(ctx, key, State("step"))
|
||||
_, _ = s.Get(ctx, key)
|
||||
if i%3 == 0 {
|
||||
_ = s.Delete(ctx, key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
// No assertion needed — race detector catches the bug if present.
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
)
|
||||
|
||||
// KeyStrategy derives a persistence key from an update. Strategies
|
||||
// determine how conversation scope works — per-user, per-chat, or
|
||||
// per-user-and-chat. Implementations must return a stable string for
|
||||
// the same logical scope across updates.
|
||||
//
|
||||
// Returns the empty string if the update doesn't have enough context
|
||||
// to derive a key (in which case the conversation handler skips it).
|
||||
type KeyStrategy func(u *api.Update) string
|
||||
|
||||
// KeyByUser derives a key from the sending user's ID. Useful for DM
|
||||
// conversations and any flow that should follow the user across chats.
|
||||
var KeyByUser KeyStrategy = func(u *api.Update) string {
|
||||
if uid := userID(u); uid != 0 {
|
||||
return fmt.Sprintf("u:%d", uid)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// KeyByChat derives a key from the chat ID. Useful for group flows where
|
||||
// any user in the chat can drive the conversation.
|
||||
var KeyByChat KeyStrategy = func(u *api.Update) string {
|
||||
if cid := chatID(u); cid != 0 {
|
||||
return fmt.Sprintf("c:%d", cid)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// KeyByUserAndChat derives a key from both user and chat IDs. The most
|
||||
// common strategy: each user has their own conversation per chat.
|
||||
var KeyByUserAndChat KeyStrategy = func(u *api.Update) string {
|
||||
uid := userID(u)
|
||||
cid := chatID(u)
|
||||
if uid == 0 || cid == 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("uc:%d:%d", cid, uid)
|
||||
}
|
||||
|
||||
// userID extracts the sending user's ID from any update payload.
|
||||
func userID(u *api.Update) int64 {
|
||||
switch {
|
||||
case u.Message != nil && u.Message.From != nil:
|
||||
return u.Message.From.ID
|
||||
case u.EditedMessage != nil && u.EditedMessage.From != nil:
|
||||
return u.EditedMessage.From.ID
|
||||
case u.CallbackQuery != nil:
|
||||
return u.CallbackQuery.From.ID
|
||||
case u.InlineQuery != nil:
|
||||
return u.InlineQuery.From.ID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// chatID extracts the relevant chat ID.
|
||||
func chatID(u *api.Update) int64 {
|
||||
switch {
|
||||
case u.Message != nil:
|
||||
return u.Message.Chat.ID
|
||||
case u.EditedMessage != nil:
|
||||
return u.EditedMessage.Chat.ID
|
||||
case u.ChannelPost != nil:
|
||||
return u.ChannelPost.Chat.ID
|
||||
case u.EditedChannelPost != nil:
|
||||
return u.EditedChannelPost.Chat.ID
|
||||
case u.CallbackQuery != nil && u.CallbackQuery.Message != nil:
|
||||
if msg, ok := u.CallbackQuery.Message.(*api.Message); ok {
|
||||
return msg.Chat.ID
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// helpers to build api.Update variants.
|
||||
|
||||
func msgUpdate(userID, chatID int64) *api.Update {
|
||||
return &api.Update{
|
||||
Message: &api.Message{
|
||||
From: &api.User{ID: userID},
|
||||
Chat: api.Chat{ID: chatID},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func editedMsgUpdate(userID, chatID int64) *api.Update {
|
||||
return &api.Update{
|
||||
EditedMessage: &api.Message{
|
||||
From: &api.User{ID: userID},
|
||||
Chat: api.Chat{ID: chatID},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func callbackUpdate(userID, chatID int64) *api.Update {
|
||||
return &api.Update{
|
||||
CallbackQuery: &api.CallbackQuery{
|
||||
From: api.User{ID: userID},
|
||||
Message: &api.Message{Chat: api.Chat{ID: chatID}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func inlineUpdate(userID int64) *api.Update {
|
||||
return &api.Update{
|
||||
InlineQuery: &api.InlineQuery{
|
||||
From: api.User{ID: userID},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func emptyUpdate() *api.Update { return &api.Update{} }
|
||||
|
||||
func TestKeyByUser(t *testing.T) {
|
||||
t.Run("message update", func(t *testing.T) {
|
||||
require.Equal(t, "u:42", KeyByUser(msgUpdate(42, 100)))
|
||||
})
|
||||
t.Run("edited message", func(t *testing.T) {
|
||||
require.Equal(t, "u:7", KeyByUser(editedMsgUpdate(7, 100)))
|
||||
})
|
||||
t.Run("callback query", func(t *testing.T) {
|
||||
require.Equal(t, "u:99", KeyByUser(callbackUpdate(99, 100)))
|
||||
})
|
||||
t.Run("inline query", func(t *testing.T) {
|
||||
require.Equal(t, "u:5", KeyByUser(inlineUpdate(5)))
|
||||
})
|
||||
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByUser(emptyUpdate()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByChat(t *testing.T) {
|
||||
t.Run("message update", func(t *testing.T) {
|
||||
require.Equal(t, "c:100", KeyByChat(msgUpdate(42, 100)))
|
||||
})
|
||||
t.Run("edited message", func(t *testing.T) {
|
||||
require.Equal(t, "c:200", KeyByChat(editedMsgUpdate(7, 200)))
|
||||
})
|
||||
t.Run("callback with accessible message", func(t *testing.T) {
|
||||
require.Equal(t, "c:300", KeyByChat(callbackUpdate(99, 300)))
|
||||
})
|
||||
t.Run("inline query has no chat → empty", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByChat(inlineUpdate(5)))
|
||||
})
|
||||
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByChat(emptyUpdate()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByUserAndChat(t *testing.T) {
|
||||
t.Run("message update", func(t *testing.T) {
|
||||
require.Equal(t, "uc:100:42", KeyByUserAndChat(msgUpdate(42, 100)))
|
||||
})
|
||||
t.Run("edited message", func(t *testing.T) {
|
||||
require.Equal(t, "uc:200:7", KeyByUserAndChat(editedMsgUpdate(7, 200)))
|
||||
})
|
||||
t.Run("callback query", func(t *testing.T) {
|
||||
require.Equal(t, "uc:300:99", KeyByUserAndChat(callbackUpdate(99, 300)))
|
||||
})
|
||||
t.Run("inline query has no chat → empty", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByUserAndChat(inlineUpdate(5)))
|
||||
})
|
||||
t.Run("empty update returns empty string", func(t *testing.T) {
|
||||
require.Equal(t, "", KeyByUserAndChat(emptyUpdate()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyByUserAndChat_CallbackInaccessibleMessage(t *testing.T) {
|
||||
// CallbackQuery.Message is InaccessibleMessage (not *Message) — chatID returns 0.
|
||||
u := &api.Update{
|
||||
CallbackQuery: &api.CallbackQuery{
|
||||
From: api.User{ID: 10},
|
||||
Message: &api.InaccessibleMessage{}, // implements MaybeInaccessibleMessage, not *api.Message
|
||||
},
|
||||
}
|
||||
// userID picks up From.ID=10 but chatID fails type assertion → 0
|
||||
require.Equal(t, "", KeyByUserAndChat(u), "no key when message inaccessible")
|
||||
// KeyByUser still works since From is set.
|
||||
require.Equal(t, "u:10", KeyByUser(u))
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// Package conversation implements a stateful conversation handler for the
|
||||
// go-telegram dispatch router. It provides a state-machine abstraction over
|
||||
// multi-step Telegram bot interactions, with pluggable storage and flexible
|
||||
// key strategies.
|
||||
package conversation
|
||||
|
||||
// State is a label identifying a node in the conversation graph.
|
||||
// The empty string is the implicit "no active conversation" state.
|
||||
type State string
|
||||
@@ -0,0 +1,20 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrKeyNotFound is returned by Storage.Get when no conversation is active
|
||||
// for the given key.
|
||||
var ErrKeyNotFound = errors.New("conversation: key not found")
|
||||
|
||||
// Storage persists per-user (or per-chat, per-message — depending on the
|
||||
// KeyStrategy in use) conversation state across update deliveries.
|
||||
//
|
||||
// Implementations must be safe for concurrent use.
|
||||
type Storage interface {
|
||||
Get(ctx context.Context, key string) (State, error)
|
||||
Set(ctx context.Context, key string, state State) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
Reference in New Issue
Block a user