Files
go-telegram/dispatch/conversation/conversation_test.go
T
lukaszraczylo ac7cae8fa7 Initial release of go-telegram
A fully-generated, strongly-typed Go client for the Telegram Bot API.

* 176 methods + 301 types generated from Bot API v10.0
* 1408 auto-generated tests (8 scenarios per method)
* Typed unions throughout — no 'any' in the public surface
* Pluggable HTTP transport and JSON codec (default goccy/go-json)
* Built-in retry middleware honouring Telegram's retry_after
* Generic dispatcher with filters and conversation handlers
* Self-verifying codegen pipeline (regen → audit → emit → run tests)
* 14 example bots covering common patterns
2026-05-09 13:09:27 +01:00

495 lines
15 KiB
Go

package conversation_test
import (
"context"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/lukaszraczylo/go-telegram/api"
"github.com/lukaszraczylo/go-telegram/client"
"github.com/lukaszraczylo/go-telegram/dispatch"
"github.com/lukaszraczylo/go-telegram/dispatch/conversation"
"github.com/stretchr/testify/require"
)
// ---- helpers ---------------------------------------------------------------
func msgUpd(userID, chatID int64, text string) api.Update {
return api.Update{
UpdateID: 1,
Message: &api.Message{
MessageID: 1,
From: &api.User{ID: userID},
Chat: api.Chat{ID: chatID},
Text: text,
},
}
}
func makeCtx(u *api.Update) *dispatch.Context {
return dispatch.NewContext(context.Background(), client.New("t"), u)
}
// anyMsg matches any update that has a Message.
var anyMsg = func(u *api.Update) bool { return u.Message != nil }
// hasPrefix returns a filter matching updates whose Message.Text has prefix p.
func hasPrefix(p string) dispatch.Filter[*api.Update] {
return func(u *api.Update) bool {
return u.Message != nil && strings.HasPrefix(u.Message.Text, p)
}
}
// fakeUpdater feeds a fixed set of updates then closes (mirrors router_test.go).
type fakeUpdater struct{ ch chan api.Update }
func newFake(ups ...api.Update) *fakeUpdater {
ch := make(chan api.Update, len(ups))
for _, u := range ups {
ch <- u
}
close(ch)
return &fakeUpdater{ch: ch}
}
func (f *fakeUpdater) Updates() <-chan api.Update { return f.ch }
func (f *fakeUpdater) Run(ctx context.Context) error { <-ctx.Done(); return ctx.Err() }
func (f *fakeUpdater) Stop(ctx context.Context) error { return nil }
// ---- Storage tests ---------------------------------------------------------
func TestStorage_ErrKeyNotFound(t *testing.T) {
s := conversation.NewMemoryStorage()
_, err := s.Get(context.Background(), "missing")
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestStorage_SetAndGet(t *testing.T) {
ctx := context.Background()
s := conversation.NewMemoryStorage()
require.NoError(t, s.Set(ctx, "k", "state_a"))
v, err := s.Get(ctx, "k")
require.NoError(t, err)
require.Equal(t, conversation.State("state_a"), v)
}
func TestStorage_Delete(t *testing.T) {
ctx := context.Background()
s := conversation.NewMemoryStorage()
require.NoError(t, s.Set(ctx, "k", "state_a"))
require.NoError(t, s.Delete(ctx, "k"))
_, err := s.Get(ctx, "k")
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestStorage_DeleteNonExistentIsNoop(t *testing.T) {
require.NoError(t, conversation.NewMemoryStorage().Delete(context.Background(), "gone"))
}
// ---- Key strategy tests ----------------------------------------------------
func TestKeyByUser_Variants(t *testing.T) {
t.Run("message", func(t *testing.T) {
u := msgUpd(42, 100, "hi")
require.Equal(t, "u:42", conversation.KeyByUser(&u))
})
t.Run("edited_message", func(t *testing.T) {
u := api.Update{EditedMessage: &api.Message{From: &api.User{ID: 7}, Chat: api.Chat{ID: 1}}}
require.Equal(t, "u:7", conversation.KeyByUser(&u))
})
t.Run("callback_query", func(t *testing.T) {
u := api.Update{CallbackQuery: &api.CallbackQuery{From: api.User{ID: 99}}}
require.Equal(t, "u:99", conversation.KeyByUser(&u))
})
t.Run("inline_query", func(t *testing.T) {
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.Equal(t, "u:5", conversation.KeyByUser(&u))
})
t.Run("empty", func(t *testing.T) {
require.Equal(t, "", conversation.KeyByUser(&api.Update{}))
})
}
func TestKeyByChat_Variants(t *testing.T) {
t.Run("message", func(t *testing.T) {
u := msgUpd(1, 200, "")
require.Equal(t, "c:200", conversation.KeyByChat(&u))
})
t.Run("inline_has_no_chat", func(t *testing.T) {
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.Equal(t, "", conversation.KeyByChat(&u))
})
}
func TestKeyByUserAndChat(t *testing.T) {
u := msgUpd(42, 100, "")
require.Equal(t, "uc:100:42", conversation.KeyByUserAndChat(&u))
}
// ---- Handler / state machine tests -----------------------------------------
func buildConv() *conversation.Conversation {
return &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("await_age") },
}},
"await_age": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
},
Exits: []conversation.Step{{
Filter: hasPrefix("/cancel"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
}
}
func TestConversation_FullFlow(t *testing.T) {
conv := buildConv()
var downstream int
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstream++
return nil
})
mw := conv.Dispatch(noop)
key := "uc:1:42"
// 1. /start → enters, state = await_name
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
v, err := conv.Storage.Get(context.Background(), key)
require.NoError(t, err)
require.Equal(t, conversation.State("await_name"), v)
require.Equal(t, 0, downstream, "entry claimed update")
// 2. name → state = await_age
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, err = conv.Storage.Get(context.Background(), key)
require.NoError(t, err)
require.Equal(t, conversation.State("await_age"), v)
// 3. age → End, key deleted
u3 := msgUpd(42, 1, "30")
require.NoError(t, mw(makeCtx(&u3), &u3))
_, err = conv.Storage.Get(context.Background(), key)
require.ErrorIs(t, err, conversation.ErrKeyNotFound)
}
func TestConversation_ExitsCancelMidFlow(t *testing.T) {
conv := buildConv()
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start conversation.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
require.NoError(t, err)
// Cancel mid-flow.
u2 := msgUpd(42, 1, "/cancel")
require.NoError(t, mw(makeCtx(&u2), &u2))
_, err = conv.Storage.Get(context.Background(), "uc:1:42")
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "exit should clear state")
}
func TestConversation_FallbackFiresWhenNoStateStepMatches(t *testing.T) {
fallbackHit := false
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
}},
States: map[conversation.State][]conversation.Step{
// No steps for "waiting" that match a callback query.
"waiting": {},
},
Fallbacks: []conversation.Step{{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
fallbackHit = true
return nil
},
}},
}
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
u2 := msgUpd(42, 1, "unexpected text")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.True(t, fallbackHit, "fallback should have fired")
}
func TestConversation_NoActiveConv_PassesToDownstream(t *testing.T) {
conv := buildConv()
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
// Random message that doesn't match /start
u := msgUpd(42, 1, "hello")
require.NoError(t, mw(makeCtx(&u), &u))
require.True(t, downstreamHit, "unmatched update should reach downstream")
}
func TestConversation_EmptyKey_PassesThrough(t *testing.T) {
// InlineQuery has no chatID → KeyByUserAndChat returns "" → pass through.
conv := buildConv()
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
u := api.Update{InlineQuery: &api.InlineQuery{From: api.User{ID: 5}}}
require.NoError(t, mw(makeCtx(&u), &u))
require.True(t, downstreamHit)
}
func TestConversation_AllowReEntry(t *testing.T) {
conv := buildConv()
conv.AllowReEntry = true
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_name"), v)
// Advance once.
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_age"), v)
// Re-enter with /start — should restart to await_name even though mid-flow.
u3 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u3), &u3))
v, _ = conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_name"), v, "AllowReEntry should restart")
}
func TestConversation_NoReEntry_EntryIgnoredWhenActive(t *testing.T) {
conv := buildConv()
conv.AllowReEntry = false
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
// Start.
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
// Advance to await_age.
u2 := msgUpd(42, 1, "Alice")
require.NoError(t, mw(makeCtx(&u2), &u2))
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("await_age"), v)
// /start again — should NOT restart; state should stay await_age since
// /start matches the state step filter (anyMsg) and advances.
// Actually /start is handled by "await_age" anyMsg step → End().
u3 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u3), &u3))
// State ended (End() called by await_age step).
_, err := conv.Storage.Get(context.Background(), "uc:1:42")
require.ErrorIs(t, err, conversation.ErrKeyNotFound, "state step should have consumed /start when AllowReEntry=false")
}
func TestConversation_StayInState_NilReturn(t *testing.T) {
// Handler returning nil keeps state unchanged.
stored := false
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
return conversation.Next("waiting")
},
}},
States: map[conversation.State][]conversation.Step{
"waiting": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
stored = true
return nil // stay in current state
},
}},
},
}
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
u2 := msgUpd(42, 1, "something")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.True(t, stored)
v, _ := conv.Storage.Get(context.Background(), "uc:1:42")
require.Equal(t, conversation.State("waiting"), v, "nil return should leave state unchanged")
}
func TestConversation_ActiveNoMatch_Swallows(t *testing.T) {
// Active conversation with no matching state step and no fallback:
// update is swallowed (not passed downstream).
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.Next("waiting") },
}},
States: map[conversation.State][]conversation.Step{
"waiting": {{
// Only matches /done specifically.
Filter: hasPrefix("/done"),
Handler: func(c *dispatch.Context, u *api.Update) error { return conversation.End() },
}},
},
}
downstreamHit := false
downstream := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error {
downstreamHit = true
return nil
})
mw := conv.Dispatch(downstream)
u1 := msgUpd(42, 1, "/start")
require.NoError(t, mw(makeCtx(&u1), &u1))
// Random text doesn't match /done and there's no fallback → swallowed.
u2 := msgUpd(42, 1, "random")
require.NoError(t, mw(makeCtx(&u2), &u2))
require.False(t, downstreamHit, "active conv with no matching step should swallow")
}
// ---- Via Router.Run --------------------------------------------------------
func TestConversation_ViaRouter(t *testing.T) {
var steps atomic.Int32
conv := &conversation.Conversation{
EntryPoints: []conversation.Step{{
Filter: hasPrefix("/start"),
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.Next("await_name")
},
}},
States: map[conversation.State][]conversation.Step{
"await_name": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.Next("await_age")
},
}},
"await_age": {{
Filter: anyMsg,
Handler: func(c *dispatch.Context, u *api.Update) error {
steps.Add(1)
return conversation.End()
},
}},
},
}
router := dispatch.New(client.New("t"), dispatch.WithMaxConcurrency(0)) // serial
router.Use(conv.Dispatch)
ups := []api.Update{
msgUpd(42, 1, "/start"),
msgUpd(42, 1, "Alice"),
msgUpd(42, 1, "30"),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() { errCh <- router.Run(ctx, newFake(ups...)) }()
// Wait for updater channel to drain (Run returns when closed).
err := <-errCh
if err != nil && err != context.Canceled {
t.Fatalf("Run error: %v", err)
}
require.Equal(t, int32(3), steps.Load(), "all three steps should have fired")
}
// ---- Concurrent storage safety ---------------------------------------------
func TestConversation_ConcurrentStorageAccess(t *testing.T) {
// 15 goroutines each running a full /start → name → age flow against the
// same shared storage but DIFFERENT keys (one per goroutine). Validates
// no data races.
const numUsers = 15
conv := buildConv()
noop := dispatch.Handler[*api.Update](func(_ *dispatch.Context, _ *api.Update) error { return nil })
mw := conv.Dispatch(noop)
var wg sync.WaitGroup
wg.Add(numUsers)
for i := 0; i < numUsers; i++ {
go func(uid int64) {
defer wg.Done()
u1 := msgUpd(uid, uid, "/start")
_ = mw(makeCtx(&u1), &u1)
u2 := msgUpd(uid, uid, "Alice")
_ = mw(makeCtx(&u2), &u2)
u3 := msgUpd(uid, uid, "30")
_ = mw(makeCtx(&u3), &u3)
}(int64(i + 1))
}
wg.Wait()
// Race detector catches bugs; no assertion needed beyond clean finish.
}
func TestConversation_ConcurrentSameKey(t *testing.T) {
// 12 goroutines hammer the same key concurrently. Storage must not panic
// or corrupt state. Race detector validates lock discipline.
const goroutines = 12
s := conversation.NewMemoryStorage()
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
_ = s.Set(ctx, "shared", conversation.State("step"))
_, _ = s.Get(ctx, "shared")
if i%4 == 0 {
_ = s.Delete(ctx, "shared")
}
}(i)
}
wg.Wait()
}