mirror of
https://github.com/lukaszraczylo/go-telegram.git
synced 2026-06-05 22:43:59 +00:00
Initial release of go-telegram
A fully-generated, strongly-typed Go client for the Telegram Bot API. * 176 methods + 301 types generated from Bot API v10.0 * 1408 auto-generated tests (8 scenarios per method) * Typed unions throughout — no 'any' in the public surface * Pluggable HTTP transport and JSON codec (default goccy/go-json) * Built-in retry middleware honouring Telegram's retry_after * Generic dispatcher with filters and conversation handlers * Self-verifying codegen pipeline (regen → audit → emit → run tests) * 14 example bots covering common patterns
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BackoffStrategy returns the duration to wait before the next attempt
|
||||
// after `attempt` consecutive failures (1-based). Implementations must
|
||||
// be safe to call from a single goroutine.
|
||||
type BackoffStrategy interface {
|
||||
NextDelay(attempt int) time.Duration
|
||||
}
|
||||
|
||||
// ExponentialBackoff implements capped exponential back-off with jitter.
|
||||
// Defaults: Base=500ms, Max=30s, Factor=2.0, Jitter=0.2.
|
||||
type ExponentialBackoff struct {
|
||||
Base time.Duration
|
||||
Max time.Duration
|
||||
Factor float64
|
||||
Jitter float64 // 0..1; fraction of computed delay added/subtracted at random
|
||||
}
|
||||
|
||||
// DefaultBackoff returns an ExponentialBackoff with library defaults.
|
||||
func DefaultBackoff() *ExponentialBackoff {
|
||||
return &ExponentialBackoff{
|
||||
Base: 500 * time.Millisecond,
|
||||
Max: 30 * time.Second,
|
||||
Factor: 2.0,
|
||||
Jitter: 0.2,
|
||||
}
|
||||
}
|
||||
|
||||
// NextDelay implements BackoffStrategy.
|
||||
func (b *ExponentialBackoff) NextDelay(attempt int) time.Duration {
|
||||
if attempt < 1 {
|
||||
attempt = 1
|
||||
}
|
||||
d := float64(b.Base) * math.Pow(b.Factor, float64(attempt-1))
|
||||
if b.Jitter > 0 {
|
||||
d *= 1 + (rand.Float64()*2-1)*b.Jitter
|
||||
}
|
||||
if d > float64(b.Max) {
|
||||
d = float64(b.Max)
|
||||
}
|
||||
if d < 0 {
|
||||
d = 0
|
||||
}
|
||||
return time.Duration(d)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestExponentialBackoff_MaxCapAfterJitter verifies that the Max cap is applied
|
||||
// after jitter so no delay can exceed Max regardless of jitter magnitude.
|
||||
func TestExponentialBackoff_MaxCapAfterJitter(t *testing.T) {
|
||||
b := &ExponentialBackoff{
|
||||
Base: 10 * time.Second,
|
||||
Max: 20 * time.Second,
|
||||
Factor: 2.0,
|
||||
Jitter: 0.5,
|
||||
}
|
||||
|
||||
// Run many times to account for randomness.
|
||||
for i := 0; i < 10_000; i++ {
|
||||
d := b.NextDelay(10)
|
||||
if d > b.Max {
|
||||
t.Fatalf("attempt 10: got %v, want ≤ %v (jitter exceeded Max cap)", d, b.Max)
|
||||
}
|
||||
if d < 0 {
|
||||
t.Fatalf("attempt 10: got negative delay %v", d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestExponentialBackoff_ZeroAttemptClamped ensures attempt < 1 is treated as 1.
|
||||
func TestExponentialBackoff_ZeroAttemptClamped(t *testing.T) {
|
||||
b := DefaultBackoff()
|
||||
d0 := b.NextDelay(0)
|
||||
d1 := b.NextDelay(1)
|
||||
// Both should be in the same ballpark (Base ± Jitter*Base).
|
||||
maxBase := float64(b.Base) * (1 + b.Jitter)
|
||||
if float64(d0) > maxBase || float64(d1) > maxBase {
|
||||
t.Fatalf("unexpected delay: d0=%v d1=%v maxBase=%v", d0, d1, time.Duration(maxBase))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LongPoller — unauthorized error causes immediate return
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLongPoller_UnauthorizedExits(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
m.On("Do", mock.Anything).Return(
|
||||
resp(`{"ok":false,"error_code":401,"description":"Unauthorized"}`), nil,
|
||||
).Once()
|
||||
|
||||
b := client.New("bad-token", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := p.Run(ctx)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, client.ErrUnauthorized), "expected unauthorized: %v", err)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LongPoller — ctx cancelled while waiting for retry backoff
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLongPoller_CtxCancelledDuringBackoff(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
var callCount int
|
||||
m.On("Do", mock.Anything).Run(func(args mock.Arguments) {
|
||||
callCount++
|
||||
}).Return(nil, errors.New("network error")).Maybe()
|
||||
|
||||
b := client.New("t", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
// Long backoff ensures ctx cancels before retry fires.
|
||||
p.Backoff = &ExponentialBackoff{Base: 5 * time.Second, Max: 5 * time.Second, Factor: 1}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := p.Run(ctx)
|
||||
require.Error(t, err)
|
||||
// Should fail fast, not wait the full 5s backoff.
|
||||
require.LessOrEqual(t, callCount, 3)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LongPoller — AllowedTypes field
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLongPoller_AllowedTypes(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
var seenBody string
|
||||
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
seenBody = string(b)
|
||||
return true
|
||||
})).Return(resp(`{"ok":true,"result":[]}`), nil).Once()
|
||||
m.On("Do", mock.Anything).Return(resp(`{"ok":true,"result":[]}`), nil).Maybe()
|
||||
|
||||
b := client.New("t", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
p.AllowedTypes = []api.UpdateType{"message", "callback_query"}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = p.Run(ctx)
|
||||
|
||||
require.Contains(t, seenBody, "allowed_updates")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebhookServer — ListenAndServe error (bind on in-use port)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestWebhookServer_ListenAndServeError(t *testing.T) {
|
||||
// Bind a port to block the webhook server.
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = l.Close() })
|
||||
addr := l.Addr().String()
|
||||
|
||||
b := client.New("t")
|
||||
w := NewWebhookServer(b)
|
||||
|
||||
ctx := context.Background()
|
||||
err = w.ListenAndServe(ctx, addr)
|
||||
require.Error(t, err, "should error when port is in use")
|
||||
require.False(t, errors.Is(err, http.ErrServerClosed))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebhookServer — body too large (> 1 MiB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestWebhookServer_BodyTooLarge(t *testing.T) {
|
||||
b := client.New("t")
|
||||
w := NewWebhookServer(b)
|
||||
|
||||
// Construct a body slightly over 1 MiB.
|
||||
bigBody := bytes.Repeat([]byte("x"), 1<<20+1)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewReader(bigBody))
|
||||
rw := newTestResponseWriter()
|
||||
w.ServeHTTP(rw, req)
|
||||
require.Equal(t, http.StatusRequestEntityTooLarge, rw.code)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebhookServer — Stop when srv is nil (no ListenAndServe called)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestWebhookServer_StopNoServer(t *testing.T) {
|
||||
b := client.New("t")
|
||||
w := NewWebhookServer(b)
|
||||
require.NoError(t, w.Stop(context.Background()))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebhookServer — no secret token, any POST accepted
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestWebhookServer_NoSecretAllowsAnyPost(t *testing.T) {
|
||||
b := client.New("t")
|
||||
w := NewWebhookServer(b)
|
||||
|
||||
body := `{"update_id":99}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
// No secret header set.
|
||||
rw := newTestResponseWriter()
|
||||
// ServeHTTP would block on w.out <- u unless we drain it.
|
||||
go func() {
|
||||
for range w.Updates() {
|
||||
}
|
||||
}()
|
||||
w.ServeHTTP(rw, req)
|
||||
// Bad JSON would return 400; update_id-only body is valid enough for Update.
|
||||
require.NotEqual(t, http.StatusUnauthorized, rw.code)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExponentialBackoff — negative attempt clamped to 1
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExponentialBackoff_NegativeAttempt(t *testing.T) {
|
||||
b := DefaultBackoff()
|
||||
d := b.NextDelay(-5)
|
||||
require.GreaterOrEqual(t, d, time.Duration(0))
|
||||
require.LessOrEqual(t, d, b.Max)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type testResponseWriter struct {
|
||||
code int
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func newTestResponseWriter() *testResponseWriter {
|
||||
return &testResponseWriter{code: http.StatusOK, header: http.Header{}}
|
||||
}
|
||||
|
||||
func (r *testResponseWriter) Header() http.Header { return r.header }
|
||||
func (r *testResponseWriter) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (r *testResponseWriter) WriteHeader(statusCode int) { r.code = statusCode }
|
||||
@@ -0,0 +1,129 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
)
|
||||
|
||||
// LongPoller pulls updates via Bot.GetUpdates in a loop, advancing the
|
||||
// offset cursor after each batch. It applies BackoffStrategy on transient
|
||||
// errors (network failures, 5xx, 429).
|
||||
//
|
||||
// At-least-once semantics on shutdown: when ctx is cancelled or Stop is
|
||||
// called mid-batch, any updates already fetched but not yet dispatched are
|
||||
// dropped without advancing the offset. On the next restart those updates
|
||||
// will be re-delivered by Telegram.
|
||||
type LongPoller struct {
|
||||
Bot *client.Bot
|
||||
Timeout int // seconds, default 30
|
||||
Limit int // 1..100, default 100
|
||||
AllowedTypes []api.UpdateType
|
||||
Backoff BackoffStrategy
|
||||
|
||||
out chan api.Update
|
||||
once sync.Once
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
// NewLongPoller constructs a LongPoller with sensible defaults.
|
||||
func NewLongPoller(b *client.Bot) *LongPoller {
|
||||
return &LongPoller{
|
||||
Bot: b,
|
||||
Timeout: 30,
|
||||
Limit: 100,
|
||||
Backoff: DefaultBackoff(),
|
||||
out: make(chan api.Update, 64),
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Updates implements Updater.
|
||||
func (p *LongPoller) Updates() <-chan api.Update { return p.out }
|
||||
|
||||
// Run implements Updater. It blocks until ctx is cancelled, Stop is
|
||||
// called, or a fatal error occurs (e.g. unauthorized). See LongPoller
|
||||
// for at-least-once delivery semantics on shutdown.
|
||||
func (p *LongPoller) Run(ctx context.Context) error {
|
||||
defer close(p.out)
|
||||
|
||||
var offset int64
|
||||
failures := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.stop:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
params := &api.GetUpdatesParams{Offset: &offset}
|
||||
if p.Limit > 0 {
|
||||
lim := int64(p.Limit)
|
||||
params.Limit = &lim
|
||||
}
|
||||
if p.Timeout > 0 {
|
||||
to := int64(p.Timeout)
|
||||
params.Timeout = &to
|
||||
}
|
||||
if len(p.AllowedTypes) > 0 {
|
||||
allowed := make([]string, len(p.AllowedTypes))
|
||||
for i, t := range p.AllowedTypes {
|
||||
allowed[i] = string(t)
|
||||
}
|
||||
params.AllowedUpdates = allowed
|
||||
}
|
||||
ups, err := api.GetUpdates(ctx, p.Bot, params)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
// Fatal: unauthorized -> bail.
|
||||
if errors.Is(err, client.ErrUnauthorized) {
|
||||
return err
|
||||
}
|
||||
var ae *client.APIError
|
||||
var delay time.Duration
|
||||
if errors.As(err, &ae) && ae.RetryAfter() > 0 {
|
||||
delay = ae.RetryAfter()
|
||||
// Don't escalate failures count — Telegram is dictating the wait.
|
||||
} else {
|
||||
failures++
|
||||
delay = p.Backoff.NextDelay(failures)
|
||||
}
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.stop:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
failures = 0
|
||||
|
||||
for _, u := range ups {
|
||||
select {
|
||||
case p.out <- u:
|
||||
if u.UpdateID >= offset {
|
||||
offset = u.UpdateID + 1
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-p.stop:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop implements Updater.
|
||||
func (p *LongPoller) Stop(ctx context.Context) error {
|
||||
p.once.Do(func() { close(p.stop) })
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockDoer struct{ mock.Mock }
|
||||
|
||||
func (m *mockDoer) Do(r *http.Request) (*http.Response, error) {
|
||||
args := m.Called(r)
|
||||
if v := args.Get(0); v != nil {
|
||||
return v.(*http.Response), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
func resp(body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestLongPoller_DeliversUpdatesAndAdvancesOffset(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
m.On("Do", mock.Anything).Return(
|
||||
resp(`{"ok":true,"result":[{"update_id":10,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hi"}}]}`),
|
||||
nil,
|
||||
).Once()
|
||||
m.On("Do", mock.Anything).Return(
|
||||
resp(`{"ok":true,"result":[{"update_id":11,"message":{"message_id":2,"date":0,"chat":{"id":1,"type":"private"},"text":"there"}}]}`),
|
||||
nil,
|
||||
).Once()
|
||||
m.On("Do", mock.Anything).Return(
|
||||
resp(`{"ok":true,"result":[]}`),
|
||||
nil,
|
||||
).Maybe()
|
||||
|
||||
b := client.New("t", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() { _ = p.Run(ctx) }()
|
||||
|
||||
u1 := <-p.Updates()
|
||||
require.Equal(t, int64(10), u1.UpdateID)
|
||||
u2 := <-p.Updates()
|
||||
require.Equal(t, int64(11), u2.UpdateID)
|
||||
}
|
||||
|
||||
func TestLongPoller_BackoffOnNetworkError(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
var attempts atomic.Int32
|
||||
m.On("Do", mock.Anything).Run(func(args mock.Arguments) {
|
||||
attempts.Add(1)
|
||||
}).Return(nil, io.ErrUnexpectedEOF).Maybe()
|
||||
|
||||
b := client.New("t", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
p.Backoff = &ExponentialBackoff{Base: 5 * time.Millisecond, Max: 5 * time.Millisecond}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_ = p.Run(ctx)
|
||||
require.GreaterOrEqual(t, attempts.Load(), int32(2), "should retry at least once")
|
||||
}
|
||||
|
||||
func TestLongPoller_StopCloses(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
m.On("Do", mock.Anything).Return(resp(`{"ok":true,"result":[]}`), nil).Maybe()
|
||||
|
||||
b := client.New("t", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
|
||||
ctx := context.Background()
|
||||
done := make(chan struct{})
|
||||
go func() { _ = p.Run(ctx); close(done) }()
|
||||
|
||||
require.NoError(t, p.Stop(ctx))
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Run did not exit after Stop")
|
||||
}
|
||||
|
||||
// Channel must be closed.
|
||||
_, ok := <-p.Updates()
|
||||
require.False(t, ok, "expected closed channel after Stop")
|
||||
}
|
||||
func TestLongPoller_HonoursRetryAfterOn429(t *testing.T) {
|
||||
m := &mockDoer{}
|
||||
var requestTimes []time.Time
|
||||
var mu sync.Mutex
|
||||
|
||||
record := func(args mock.Arguments) {
|
||||
mu.Lock()
|
||||
requestTimes = append(requestTimes, time.Now())
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// First call: 429 with retry_after=1.
|
||||
m.On("Do", mock.Anything).
|
||||
Run(record).
|
||||
Return(resp(`{"ok":false,"error_code":429,"description":"Too Many Requests","parameters":{"retry_after":1}}`), nil).
|
||||
Once()
|
||||
// Subsequent calls: empty success.
|
||||
m.On("Do", mock.Anything).
|
||||
Run(record).
|
||||
Return(resp(`{"ok":true,"result":[]}`), nil).
|
||||
Maybe()
|
||||
|
||||
b := client.New("t", client.WithHTTPClient(m))
|
||||
p := NewLongPoller(b)
|
||||
p.Timeout = 0
|
||||
// Backoff base is huge so if it were used we'd see >>1s delay.
|
||||
p.Backoff = &ExponentialBackoff{Base: 10 * time.Second, Max: 30 * time.Second}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2500*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = p.Run(ctx)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.GreaterOrEqual(t, len(requestTimes), 2, "expected at least 2 requests")
|
||||
gap := requestTimes[1].Sub(requestTimes[0])
|
||||
require.GreaterOrEqual(t, gap, 900*time.Millisecond, "should have waited ~1s per retry_after, got %v", gap)
|
||||
require.Less(t, gap, 3*time.Second, "should not have waited backoff base (10s), got %v", gap)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
)
|
||||
|
||||
// Updater is the abstraction over update sources. Implementations must:
|
||||
// - return a channel from Updates() that receives every Update they read.
|
||||
// - close the channel after Run returns.
|
||||
// - honour ctx cancellation in Run.
|
||||
type Updater interface {
|
||||
// Updates returns the channel updates flow into. Multiple readers
|
||||
// is implementation-defined; users should treat it as single-reader.
|
||||
Updates() <-chan api.Update
|
||||
// Run blocks until ctx is cancelled or a fatal error occurs. It is
|
||||
// the user's responsibility to call Run in a goroutine if needed.
|
||||
Run(ctx context.Context) error
|
||||
// Stop signals Run to exit and waits for the channel to drain.
|
||||
// Implementations must be idempotent.
|
||||
Stop(ctx context.Context) error
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
// Package transport provides update delivery mechanisms (long-poll and
|
||||
// webhook) that feed updates into the dispatch package's Router.
|
||||
//
|
||||
// All implementations satisfy the Updater interface so user code can
|
||||
// swap one for the other without touching handler logic.
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/api"
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
)
|
||||
|
||||
// WebhookServer implements Updater by exposing an http.Handler that
|
||||
// receives updates from Telegram. It can be mounted on the user's own
|
||||
// HTTP server (via ServeHTTP) or run standalone (via ListenAndServe).
|
||||
type WebhookServer struct {
|
||||
Bot *client.Bot
|
||||
SecretToken string // verify X-Telegram-Bot-Api-Secret-Token; empty disables
|
||||
|
||||
out chan api.Update
|
||||
once sync.Once
|
||||
stop chan struct{}
|
||||
mu sync.Mutex
|
||||
handlers sync.WaitGroup
|
||||
|
||||
srv *http.Server
|
||||
}
|
||||
|
||||
// WebhookOption configures a WebhookServer at construction time.
|
||||
type WebhookOption func(*webhookOptions)
|
||||
|
||||
type webhookOptions struct {
|
||||
bufferSize int
|
||||
}
|
||||
|
||||
// WithBufferSize sets the size of the updates channel buffer.
|
||||
// Default is 64.
|
||||
func WithBufferSize(n int) WebhookOption {
|
||||
return func(o *webhookOptions) { o.bufferSize = n }
|
||||
}
|
||||
|
||||
// NewWebhookServer constructs a WebhookServer with default buffer size (64).
|
||||
// Use WithBufferSize to override.
|
||||
func NewWebhookServer(b *client.Bot, opts ...WebhookOption) *WebhookServer {
|
||||
cfg := webhookOptions{bufferSize: 64}
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return &WebhookServer{
|
||||
Bot: b,
|
||||
out: make(chan api.Update, cfg.bufferSize),
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Updates implements Updater.
|
||||
func (w *WebhookServer) Updates() <-chan api.Update { return w.out }
|
||||
|
||||
// Run implements Updater. It blocks until Stop is called or ctx is
|
||||
// cancelled. If the server has not been started via ListenAndServe, Run
|
||||
// only watches for shutdown — the user is expected to mount ServeHTTP
|
||||
// on their own router.
|
||||
func (w *WebhookServer) Run(ctx context.Context) error {
|
||||
defer close(w.out)
|
||||
defer w.handlers.Wait() // drain in-flight ServeHTTP calls before closing out (LIFO: runs first)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-w.stop:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stop implements Updater.
|
||||
func (w *WebhookServer) Stop(ctx context.Context) error {
|
||||
w.once.Do(func() { close(w.stop) })
|
||||
w.mu.Lock()
|
||||
srv := w.srv
|
||||
w.mu.Unlock()
|
||||
if srv != nil {
|
||||
return srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler. Telegram POSTs each update as JSON
|
||||
// to this endpoint. Non-POST requests get 405; bad bodies get 400; secret
|
||||
// token mismatches get 401.
|
||||
func (w *WebhookServer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
rw.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if w.SecretToken != "" {
|
||||
got := r.Header.Get("X-Telegram-Bot-Api-Secret-Token")
|
||||
if subtle.ConstantTimeCompare([]byte(got), []byte(w.SecretToken)) != 1 {
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
w.handlers.Add(1)
|
||||
defer w.handlers.Done()
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
|
||||
const max = 1 << 20 // 1 MiB cap on body
|
||||
buf := make([]byte, 0, 1024)
|
||||
tmp := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Body.Read(tmp)
|
||||
if n > 0 {
|
||||
buf = append(buf, tmp[:n]...)
|
||||
if len(buf) > max {
|
||||
rw.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
}
|
||||
if errors.Is(err, http.ErrBodyReadAfterClose) || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var u api.Update
|
||||
codec := w.Bot.Codec()
|
||||
if err := codec.Unmarshal(buf, &u); err != nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case w.out <- u:
|
||||
case <-w.stop:
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// ListenAndServe starts an HTTP server on addr and blocks until Stop is
|
||||
// called (which triggers Shutdown with the caller's context) or the server
|
||||
// returns an error other than http.ErrServerClosed. Callers must invoke
|
||||
// Stop(ctx) to cleanly shut down the server; the ctx passed here is only
|
||||
// used as the server's base context for incoming requests.
|
||||
func (w *WebhookServer) ListenAndServe(ctx context.Context, addr string) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", w)
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
BaseContext: func(net.Listener) context.Context { return ctx },
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
w.mu.Lock()
|
||||
w.srv = srv
|
||||
w.mu.Unlock()
|
||||
err := srv.ListenAndServe()
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/go-telegram/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWebhook_DeliversUpdate(t *testing.T) {
|
||||
b := client.New("t")
|
||||
w := NewWebhookServer(b)
|
||||
w.SecretToken = "secret"
|
||||
|
||||
srv := httptest.NewServer(w)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
body := `{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hi"}}`
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL, strings.NewReader(body))
|
||||
req.Header.Set("X-Telegram-Bot-Api-Secret-Token", "secret")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
_ = resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
select {
|
||||
case u := <-w.Updates():
|
||||
require.Equal(t, int64(1), u.UpdateID)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("update not delivered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhook_RejectsBadSecret(t *testing.T) {
|
||||
b := client.New("t")
|
||||
w := NewWebhookServer(b)
|
||||
w.SecretToken = "secret"
|
||||
|
||||
srv := httptest.NewServer(w)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL, strings.NewReader(`{}`))
|
||||
req.Header.Set("X-Telegram-Bot-Api-Secret-Token", "wrong")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
_ = resp.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWebhook_RejectsNonPOST(t *testing.T) {
|
||||
w := NewWebhookServer(client.New("t"))
|
||||
srv := httptest.NewServer(w)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
resp, err := http.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
_ = resp.Body.Close()
|
||||
require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWebhook_RejectsBadJSON(t *testing.T) {
|
||||
w := NewWebhookServer(client.New("t"))
|
||||
srv := httptest.NewServer(w)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
resp, err := http.Post(srv.URL, "application/json", bytes.NewBufferString("not json"))
|
||||
require.NoError(t, err)
|
||||
_ = resp.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestWebhook_StopExitsRun(t *testing.T) {
|
||||
w := NewWebhookServer(client.New("t"))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { _ = w.Run(context.Background()); close(done) }()
|
||||
|
||||
require.NoError(t, w.Stop(context.Background()))
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Run did not exit after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebhook_ConcurrentStopNoPanic fires many concurrent requests while
|
||||
// simultaneously calling Stop, and asserts no panic (send on closed channel).
|
||||
// Run under -race to verify mutex and WaitGroup correctness.
|
||||
func TestWebhook_ConcurrentStopNoPanic(t *testing.T) {
|
||||
body := `{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hi"}}`
|
||||
|
||||
for range 20 {
|
||||
w := NewWebhookServer(client.New("t"), WithBufferSize(256))
|
||||
srv := httptest.NewServer(w)
|
||||
|
||||
// Drain updates so ServeHTTP doesn't block on a full channel.
|
||||
go func() {
|
||||
for range w.Updates() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Run in background.
|
||||
go func() { _ = w.Run(context.Background()) }()
|
||||
|
||||
// Fire concurrent requests.
|
||||
const goroutines = 20
|
||||
ready := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-ready
|
||||
for j := 0; j < 5; j++ {
|
||||
req, _ := http.NewRequest(http.MethodPost, srv.URL, strings.NewReader(body))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err == nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(ready)
|
||||
time.Sleep(5 * time.Millisecond) // let some requests land before Stop
|
||||
srv.Close()
|
||||
_ = w.Stop(context.Background())
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user