Initial release of go-telegram

A fully-generated, strongly-typed Go client for the Telegram Bot API.

* 176 methods + 301 types generated from Bot API v10.0
* 1408 auto-generated tests (8 scenarios per method)
* Typed unions throughout — no 'any' in the public surface
* Pluggable HTTP transport and JSON codec (default goccy/go-json)
* Built-in retry middleware honouring Telegram's retry_after
* Generic dispatcher with filters and conversation handlers
* Self-verifying codegen pipeline (regen → audit → emit → run tests)
* 14 example bots covering common patterns
This commit is contained in:
2026-05-09 13:09:27 +01:00
commit ac7cae8fa7
164 changed files with 100239 additions and 0 deletions
+171
View File
@@ -0,0 +1,171 @@
package client
import (
"bytes"
"context"
"errors"
"github.com/goccy/go-json"
"io"
"net/http"
"reflect"
)
// Call is the single point through which every Telegram Bot API method
// invocation flows. It marshals the request, signs the URL with the bot
// token, dispatches via HTTPDoer, decodes the Result envelope, and
// translates non-OK responses into typed errors.
//
// It is generic over both request and response types. Methods with no
// parameters may pass a nil Req; the helper sends "{}" in that case so
// Telegram receives a syntactically valid empty object.
//
// Call is exported because the api package (which lives outside this one)
// invokes it from generated method wrappers. User code should not normally
// call it directly — use the typed wrappers in package api instead.
func Call[Req any, Resp any](ctx context.Context, b *Bot, method string, req Req) (Resp, error) {
var zero Resp
if mp, ok := any(req).(multipartRequest); ok {
if mp == nil {
return zero, &ParseError{Err: errors.New("client: nil multipart request")}
}
if mp.HasFile() {
return callMultipart[Resp](ctx, b, method, mp)
}
}
body, err := encodeJSONBody(b.codec, req)
if err != nil {
return zero, err
}
url := b.base + "/bot" + b.token + "/" + method
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return zero, &NetworkError{Err: err}
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
resp, err := b.http.Do(httpReq)
if err != nil {
// Surface ctx errors faithfully so callers can errors.Is(err, ctx.Err()).
if ctxErr := ctx.Err(); ctxErr != nil {
return zero, ctxErr
}
return zero, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return zero, &NetworkError{Err: err}
}
return decodeResult[Resp](b.codec, raw)
}
// CallRaw is like Call but returns the raw JSON of the result field
// instead of decoding it into a typed value. Generated method wrappers
// for sealed-interface return types (ChatMember, MenuButton, etc.) use
// this helper, then dispatch through the union's UnmarshalXxx function.
//
// CallRaw still translates non-OK responses into *APIError just like Call.
func CallRaw[Req any](ctx context.Context, b *Bot, method string, req Req) (json.RawMessage, error) {
if mp, ok := any(req).(multipartRequest); ok {
if mp == nil {
return nil, &ParseError{Err: errors.New("client: nil multipart request")}
}
if mp.HasFile() {
return callMultipartRaw(ctx, b, method, mp)
}
}
body, err := encodeJSONBody(b.codec, req)
if err != nil {
return nil, err
}
url := b.base + "/bot" + b.token + "/" + method
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return nil, &NetworkError{Err: err}
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "application/json")
resp, err := b.http.Do(httpReq)
if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &NetworkError{Err: err}
}
return decodeResultRaw(b.codec, raw)
}
// decodeResultRaw is decodeResult's sibling that returns the raw result
// field instead of typing it.
func decodeResultRaw(codec Codec, raw []byte) (json.RawMessage, error) {
var env Result[json.RawMessage]
if err := codec.Unmarshal(raw, &env); err != nil {
return nil, &ParseError{Err: err, Body: copyBody(raw)}
}
if !env.OK {
return nil, mapAPIError(env.ErrorCode, env.Description, env.Parameters)
}
return env.Result, nil
}
// encodeJSONBody marshals req to a JSON body. A nil interface or nil
// pointer req yields "{}" so Telegram receives a valid empty object.
func encodeJSONBody(codec Codec, req any) (io.Reader, error) {
if req == nil || isNilPointer(req) {
return bytes.NewBufferString("{}"), nil
}
data, err := codec.Marshal(req)
if err != nil {
return nil, &ParseError{Err: err}
}
return bytes.NewReader(data), nil
}
// decodeResult unmarshals raw into Result[Resp] and translates non-OK
// responses into *APIError.
func decodeResult[Resp any](codec Codec, raw []byte) (Resp, error) {
var zero Resp
var env Result[Resp]
if err := codec.Unmarshal(raw, &env); err != nil {
return zero, &ParseError{Err: err, Body: copyBody(raw)}
}
if !env.OK {
return zero, mapAPIError(env.ErrorCode, env.Description, env.Parameters)
}
return env.Result, nil
}
// isNilPointer returns true when v is a typed nil pointer (the interface
// itself is non-nil because it carries a type, but the underlying value
// is nil). One reflect call per request; not on a hot path that demands
// allocation-freedom.
func isNilPointer(v any) bool {
rv := reflect.ValueOf(v)
return rv.Kind() == reflect.Ptr && rv.IsNil()
}
func copyBody(b []byte) []byte {
const max = 4096
if len(b) > max {
b = b[:max]
}
out := make([]byte, len(b))
copy(out, b)
return out
}
+121
View File
@@ -0,0 +1,121 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockDoer struct{ mock.Mock }
func (m *mockDoer) Do(r *http.Request) (*http.Response, error) {
args := m.Called(r)
if v := args.Get(0); v != nil {
return v.(*http.Response), args.Error(1)
}
return nil, args.Error(1)
}
func newResp(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
type echoReq struct {
ChatID int64 `json:"chat_id"`
Text string `json:"text"`
}
type echoResp struct {
MessageID int64 `json:"message_id"`
}
func TestCall_Success(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
if !strings.HasSuffix(r.URL.Path, "/bot123:abc/sendEcho") {
return false
}
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(r.Body)
return strings.Contains(buf.String(), `"chat_id":42`)
})).Return(newResp(200, `{"ok":true,"result":{"message_id":7}}`), nil)
b := New("123:abc", WithHTTPClient(m))
out, err := Call[*echoReq, *echoResp](context.Background(), b, "sendEcho", &echoReq{ChatID: 42, Text: "hi"})
require.NoError(t, err)
require.Equal(t, int64(7), out.MessageID)
m.AssertExpectations(t)
}
func TestCall_APIError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(
newResp(200, `{"ok":false,"error_code":429,"description":"Too Many Requests: retry after 3","parameters":{"retry_after":3}}`), nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var ae *APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 429, ae.Code)
require.True(t, ae.IsRetryable())
require.True(t, errors.Is(err, ErrTooManyRequests))
}
func TestCall_NetworkError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var ne *NetworkError
require.ErrorAs(t, err, &ne)
}
func TestCall_ParseError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(newResp(200, `not json`), nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var pe *ParseError
require.ErrorAs(t, err, &pe)
}
func TestCall_ContextCanceled(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(nil, context.Canceled).Maybe()
ctx, cancel := context.WithCancel(context.Background())
cancel()
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](ctx, b, "x", &echoReq{})
require.ErrorIs(t, err, context.Canceled)
}
func TestCall_NilRequest(t *testing.T) {
// Methods with no params (e.g. getMe) may pass a nil Req value.
m := &mockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(r.Body)
return buf.String() == "{}"
})).Return(newResp(200, `{"ok":true,"result":{"message_id":0}}`), nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", nil)
require.NoError(t, err)
}
+48
View File
@@ -0,0 +1,48 @@
package client
const defaultBaseURL = "https://api.telegram.org"
// Bot is the Telegram Bot API client. Construct via New. All API methods
// (declared in package api) hang off *Bot via thin wrappers around call.
type Bot struct {
token string
base string
http HTTPDoer
codec Codec
logger Logger
}
// Token returns the bot token. Exposed for advanced use cases (custom
// transports, manual URL building); ordinary code does not need it.
func (b *Bot) Token() string { return b.token }
// BaseURL returns the configured Telegram API base URL.
func (b *Bot) BaseURL() string { return b.base }
// HTTP returns the underlying HTTPDoer. Exposed for adapters that need
// to share connection pools or for diagnostic checks.
func (b *Bot) HTTP() HTTPDoer { return b.http }
// Codec returns the configured Codec.
func (b *Bot) Codec() Codec { return b.codec }
// Logger returns the configured Logger.
func (b *Bot) Logger() Logger { return b.logger }
// New constructs a Bot with the given token and optional configuration.
// The default HTTP client is tuned for long-poll workloads (see
// NewDefaultHTTPDoer); the default codec wraps encoding/json; the default
// logger discards records.
func New(token string, opts ...Option) *Bot {
b := &Bot{
token: token,
base: defaultBaseURL,
http: NewDefaultHTTPDoer(),
codec: DefaultCodec{},
logger: NoopLogger{},
}
for _, o := range opts {
o(b)
}
return b
}
+42
View File
@@ -0,0 +1,42 @@
package client
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestNew_Defaults(t *testing.T) {
b := New("123:abc")
require.Equal(t, "123:abc", b.token)
require.Equal(t, defaultBaseURL, b.base)
require.NotNil(t, b.http)
require.NotNil(t, b.codec)
require.NotNil(t, b.logger)
}
func TestNew_OptionsApplied(t *testing.T) {
custom := &http.Client{}
type fakeCodec struct{ DefaultCodec }
c := fakeCodec{}
b := New("t",
WithHTTPClient(custom),
WithCodec(c),
WithBaseURL("https://example.test"),
WithLogger(NoopLogger{}),
)
require.Same(t, custom, b.http)
require.Equal(t, c, b.codec)
require.Equal(t, "https://example.test", b.base)
}
func TestResultRoundTrip(t *testing.T) {
in := Result[int64]{OK: true, Result: 42}
data, err := DefaultCodec{}.Marshal(in)
require.NoError(t, err)
var out Result[int64]
require.NoError(t, DefaultCodec{}.Unmarshal(data, &out))
require.Equal(t, in, out)
}
+22
View File
@@ -0,0 +1,22 @@
// Package client provides HTTP client primitives for the Telegram Bot API.
package client
import "github.com/goccy/go-json"
// Codec encodes/decodes JSON payloads exchanged with the Telegram Bot API.
// The default implementation wraps goccy/go-json. Users may plug in
// bytedance/sonic or any compatible encoder by passing
// WithCodec to New.
type Codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
// DefaultCodec wraps goccy/go-json. It is the zero-value safe default.
type DefaultCodec struct{}
// Marshal calls json.Marshal.
func (DefaultCodec) Marshal(v any) ([]byte, error) { return json.Marshal(v) }
// Unmarshal calls json.Unmarshal.
func (DefaultCodec) Unmarshal(data []byte, v any) error { return json.Unmarshal(data, v) }
+29
View File
@@ -0,0 +1,29 @@
package client
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultCodec_RoundTrip(t *testing.T) {
c := DefaultCodec{}
type payload struct {
Name string `json:"name"`
N int `json:"n"`
}
in := payload{Name: "x", N: 7}
data, err := c.Marshal(in)
require.NoError(t, err)
require.JSONEq(t, `{"name":"x","n":7}`, string(data))
var out payload
require.NoError(t, c.Unmarshal(data, &out))
require.Equal(t, in, out)
}
func TestDefaultCodec_UnmarshalError(t *testing.T) {
var v map[string]any
err := DefaultCodec{}.Unmarshal([]byte(`not json`), &v)
require.Error(t, err)
}
+107
View File
@@ -0,0 +1,107 @@
package client
import (
"errors"
"fmt"
"strings"
"time"
)
// APIError represents a non-OK Telegram Bot API response.
// It satisfies error and unwraps to a sentinel (ErrUnauthorized, etc.)
// where the description matches a known prefix, enabling errors.Is checks.
type APIError struct {
Code int
Description string
Parameters *ResponseParameters
// sentinel, if non-nil, is the wrapped sentinel error returned by
// Unwrap. It is set by mapAPIError based on Code+Description.
sentinel error
}
// Error implements error.
func (e *APIError) Error() string {
return fmt.Sprintf("telegram: %d %s", e.Code, e.Description)
}
// Unwrap returns the matched sentinel error, if any.
func (e *APIError) Unwrap() error { return e.sentinel }
// IsRetryable returns true for transient HTTP statuses (429, 5xx).
func (e *APIError) IsRetryable() bool {
return e.Code == 429 || (e.Code >= 500 && e.Code < 600)
}
// RetryAfter returns the recommended back-off duration. It honours the
// Telegram-supplied retry_after parameter; if absent, returns 0.
func (e *APIError) RetryAfter() time.Duration {
if e.Parameters == nil {
return 0
}
return time.Duration(e.Parameters.RetryAfter) * time.Second
}
// NetworkError wraps a transport-level failure (DNS, TCP, TLS, timeout
// short of an HTTP response).
type NetworkError struct{ Err error }
func (e *NetworkError) Error() string { return "telegram: network: " + redactToken(e.Err.Error()) }
func (e *NetworkError) Unwrap() error { return e.Err }
// ParseError wraps a JSON decode failure on a response body. Body is
// retained (truncated to 4 KiB); Error() displays up to 256 bytes for diagnostics.
type ParseError struct {
Err error
Body []byte
}
func (e *ParseError) Error() string {
body := e.Body
if len(body) > 256 {
body = body[:256]
}
return fmt.Sprintf("telegram: parse: %s (body=%q)", redactToken(e.Err.Error()), body)
}
func (e *ParseError) Unwrap() error { return e.Err }
// Sentinel errors returned via APIError.Unwrap when the description matches.
// Compare with errors.Is.
var (
ErrUnauthorized = errors.New("telegram: unauthorized")
ErrChatNotFound = errors.New("telegram: chat not found")
ErrMessageNotModified = errors.New("telegram: message is not modified")
ErrTooManyRequests = errors.New("telegram: too many requests")
ErrBadRequest = errors.New("telegram: bad request")
ErrForbidden = errors.New("telegram: forbidden")
ErrUserNotFound = errors.New("telegram: user not found")
ErrMessageNotFound = errors.New("telegram: message not found")
)
// mapAPIError builds an *APIError and attaches the appropriate sentinel
// based on Code+Description. It is the single point where wire-level
// failures are translated into the Go error taxonomy.
func mapAPIError(code int, description string, params *ResponseParameters) *APIError {
e := &APIError{Code: code, Description: description, Parameters: params}
switch {
case code == 401:
e.sentinel = ErrUnauthorized
case code == 403:
e.sentinel = ErrForbidden
case code == 429:
e.sentinel = ErrTooManyRequests
case code == 400 && strings.Contains(description, "user not found"):
e.sentinel = ErrUserNotFound
case code == 400 && strings.Contains(description, "message to") && strings.Contains(description, "not found"):
e.sentinel = ErrMessageNotFound
case code == 400 && strings.Contains(description, "chat not found"):
e.sentinel = ErrChatNotFound
case code == 400 && strings.Contains(description, "message is not modified"):
e.sentinel = ErrMessageNotModified
case code == 400:
e.sentinel = ErrBadRequest
}
return e
}
+58
View File
@@ -0,0 +1,58 @@
package client
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestAPIError_FieldsAndMethods(t *testing.T) {
e := &APIError{
Code: 429,
Description: "Too Many Requests: retry after 5",
Parameters: &ResponseParameters{RetryAfter: 5},
}
require.Equal(t, "telegram: 429 Too Many Requests: retry after 5", e.Error())
require.True(t, e.IsRetryable())
require.Equal(t, 5*time.Second, e.RetryAfter())
}
func TestAPIError_Sentinels(t *testing.T) {
cases := []struct {
code int
desc string
sentinel error
}{
{401, "Unauthorized", ErrUnauthorized},
{400, "Bad Request: chat not found", ErrChatNotFound},
{400, "Bad Request: message is not modified", ErrMessageNotModified},
{429, "Too Many Requests: retry after 1", ErrTooManyRequests},
{400, "Bad Request: user not found", ErrUserNotFound},
{400, "Bad Request: message to delete not found", ErrMessageNotFound},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
e := mapAPIError(c.code, c.desc, nil)
require.True(t, errors.Is(e, c.sentinel), "expected %v to wrap %v", e, c.sentinel)
})
}
}
func TestAPIError_IsRetryable(t *testing.T) {
require.True(t, (&APIError{Code: 500}).IsRetryable())
require.True(t, (&APIError{Code: 502}).IsRetryable())
require.True(t, (&APIError{Code: 429}).IsRetryable())
require.False(t, (&APIError{Code: 400}).IsRetryable())
require.False(t, (&APIError{Code: 401}).IsRetryable())
}
func TestNetworkAndParseErrorWrapping(t *testing.T) {
inner := errors.New("dial tcp: timeout")
ne := &NetworkError{Err: inner}
require.ErrorIs(t, ne, inner)
pe := &ParseError{Err: errors.New("unexpected EOF"), Body: []byte("garbage")}
require.Contains(t, pe.Error(), "garbage")
}
+226
View File
@@ -0,0 +1,226 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// client.go option getters
// ---------------------------------------------------------------------------
func TestBot_Getters(t *testing.T) {
b := New("mytoken",
WithBaseURL("http://localhost:9999"),
WithCodec(DefaultCodec{}),
WithLogger(NoopLogger{}),
)
require.Equal(t, "mytoken", b.Token())
require.Equal(t, "http://localhost:9999", b.BaseURL())
require.NotNil(t, b.HTTP())
require.NotNil(t, b.Codec())
require.NotNil(t, b.Logger())
}
func TestWithLogger_NilBecomesNoop(t *testing.T) {
b := New("t", WithLogger(nil))
require.IsType(t, NoopLogger{}, b.Logger())
}
func TestNoopLogger_AllMethods(t *testing.T) {
l := NoopLogger{}
// None of these should panic.
l.Debug("msg")
l.Info("msg", "k", "v")
l.Warn("msg")
l.Error("msg", "err", "oops")
}
// ---------------------------------------------------------------------------
// RetryOption setters
// ---------------------------------------------------------------------------
func TestRetryOptions_Applied(t *testing.T) {
d := NewRetryDoer(nil,
WithMaxAttempts(7),
WithBaseBackoff(1*time.Second),
WithMaxBackoff(60*time.Second),
WithBackoffFactor(3.0),
WithJitter(0.5),
)
require.Equal(t, 7, d.maxAttempts)
require.Equal(t, 1*time.Second, d.base)
require.Equal(t, 60*time.Second, d.max)
require.Equal(t, 3.0, d.factor)
require.Equal(t, 0.5, d.jitter)
}
// ---------------------------------------------------------------------------
// RetryDoer.delay — override path
// ---------------------------------------------------------------------------
func TestRetryDoer_DelayOverride(t *testing.T) {
d := NewRetryDoer(nil)
got := d.delay(1, 5*time.Second)
require.Equal(t, 5*time.Second, got)
}
func TestRetryDoer_DelayExponential(t *testing.T) {
d := NewRetryDoer(nil,
WithBaseBackoff(100*time.Millisecond),
WithMaxBackoff(10*time.Second),
WithJitter(0), // no jitter for deterministic test
WithBackoffFactor(2.0),
)
d1 := d.delay(1, 0)
d2 := d.delay(2, 0)
require.Greater(t, int64(d2), int64(d1), "backoff should grow")
}
func TestRetryDoer_DelayMaxCap(t *testing.T) {
d := NewRetryDoer(nil,
WithBaseBackoff(1*time.Second),
WithMaxBackoff(2*time.Second),
WithJitter(0),
WithBackoffFactor(100.0),
)
delay := d.delay(10, 0)
require.LessOrEqual(t, delay, 2*time.Second)
}
// ---------------------------------------------------------------------------
// errors.go — RetryAfter nil parameters + ParseError.Unwrap
// ---------------------------------------------------------------------------
func TestAPIError_RetryAfterNilParams(t *testing.T) {
e := &APIError{Code: 429, Description: "Too Many Requests", Parameters: nil}
require.Equal(t, time.Duration(0), e.RetryAfter())
}
func TestParseError_Unwrap(t *testing.T) {
inner := errors.New("decode error")
pe := &ParseError{Err: inner, Body: []byte("body")}
require.ErrorIs(t, pe, inner)
}
func TestParseError_LongBodyTruncated(t *testing.T) {
body := bytes.Repeat([]byte("x"), 1000)
pe := &ParseError{Err: errors.New("e"), Body: body}
msg := pe.Error()
// Error() truncates body to 256 for display — should not include all 1000 chars
require.Less(t, len(msg), 800, "should truncate body in Error()")
}
func TestNetworkError_Unwrap(t *testing.T) {
inner := errors.New("tcp error")
ne := &NetworkError{Err: inner}
require.ErrorIs(t, ne, inner)
}
// ---------------------------------------------------------------------------
// mapAPIError — missing sentinel branches (generic 400, unmapped 500)
// ---------------------------------------------------------------------------
func TestMapAPIError_Generic400(t *testing.T) {
e := mapAPIError(400, "Bad Request: some unknown thing", nil)
require.True(t, errors.Is(e, ErrBadRequest))
}
func TestMapAPIError_Unmapped500(t *testing.T) {
e := mapAPIError(500, "Internal Server Error", nil)
require.Nil(t, e.sentinel)
require.Equal(t, 500, e.Code)
}
func TestMapAPIError_403(t *testing.T) {
e := mapAPIError(403, "Forbidden: bot was blocked", nil)
require.True(t, errors.Is(e, ErrForbidden))
}
// ---------------------------------------------------------------------------
// callMultipart — ctx cancelled
// ---------------------------------------------------------------------------
func TestCallMultipart_ContextCancelled(t *testing.T) {
// A doer that blocks then returns context error.
blocker := &extraBlockingDoer{done: make(chan struct{})}
b := New("t", WithHTTPClient(blocker))
ctx, cancel := context.WithCancel(context.Background())
mp := &extraFakeMultipartReq{
fields: map[string]string{"chat_id": "1"},
files: []MultipartFile{
{FieldName: "document", Filename: "f.txt", Reader: bytes.NewReader([]byte("data"))},
},
}
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
close(blocker.done)
}()
_, err := callMultipart[*struct{}](ctx, b, "sendDocument", mp)
require.Error(t, err)
}
type extraBlockingDoer struct{ done chan struct{} }
func (b *extraBlockingDoer) Do(r *http.Request) (*http.Response, error) {
<-b.done
return nil, r.Context().Err()
}
type extraFakeMultipartReq struct {
fields map[string]string
files []MultipartFile
}
func (f *extraFakeMultipartReq) HasFile() bool { return len(f.files) > 0 }
func (f *extraFakeMultipartReq) MultipartFiles() []MultipartFile { return f.files }
func (f *extraFakeMultipartReq) MultipartFields() map[string]string { return f.fields }
// ---------------------------------------------------------------------------
// copyBody size cap
// ---------------------------------------------------------------------------
func TestCopyBody_LargeBodyCapped(t *testing.T) {
big := bytes.Repeat([]byte("a"), 8000)
out := copyBody(big)
require.Len(t, out, 4096)
}
func TestCopyBody_SmallBody(t *testing.T) {
small := []byte("hello")
out := copyBody(small)
require.Equal(t, small, out)
}
// ---------------------------------------------------------------------------
// Call — 5xx non-200 HTTP status (transport level)
// ---------------------------------------------------------------------------
func TestCall_5xxHTTPStatus(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(&http.Response{
StatusCode: 500,
Body: io.NopCloser(bytes.NewBufferString(`{"ok":false,"error_code":500,"description":"Internal"}`)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}, nil)
b := New("t", WithHTTPClient(m))
_, err := Call[*echoReq, *echoResp](context.Background(), b, "x", &echoReq{})
require.Error(t, err)
var ae *APIError
require.ErrorAs(t, err, &ae)
require.Equal(t, 500, ae.Code)
}
+40
View File
@@ -0,0 +1,40 @@
package client
import (
"net"
"net/http"
"time"
)
// HTTPDoer abstracts the HTTP transport. The default is a net/http client
// tuned for Telegram's long-poll usage. Users may plug in valyala/fasthttp
// (via an adapter), or any custom retry/circuit-breaker client by passing
// WithHTTPClient to New.
type HTTPDoer interface {
Do(req *http.Request) (*http.Response, error)
}
// NewDefaultHTTPDoer returns an *http.Client with sensible defaults for
// Telegram Bot API usage:
// - 60s overall timeout (longer than typical long-poll Timeout=30s).
// - Connection pooling sized for a small number of long-lived hosts.
// - HTTP/2 enabled (default in net/http).
func NewDefaultHTTPDoer() *http.Client {
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 16,
MaxIdleConnsPerHost: 8,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: true,
}
return &http.Client{
Transport: t,
Timeout: 60 * time.Second,
}
}
+24
View File
@@ -0,0 +1,24 @@
package client
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultHTTPClient_Do(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
}))
t.Cleanup(srv.Close)
doer := NewDefaultHTTPDoer()
req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := doer.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusTeapot, resp.StatusCode)
}
+19
View File
@@ -0,0 +1,19 @@
package client
// Logger is a slog-shaped logging interface. Users pass any compatible
// implementation via WithLogger. The default is NoopLogger, which discards
// everything.
type Logger interface {
Debug(msg string, attrs ...any)
Info(msg string, attrs ...any)
Warn(msg string, attrs ...any)
Error(msg string, attrs ...any)
}
// NoopLogger discards all log records. It is the zero-value safe default.
type NoopLogger struct{}
func (NoopLogger) Debug(string, ...any) {}
func (NoopLogger) Info(string, ...any) {}
func (NoopLogger) Warn(string, ...any) {}
func (NoopLogger) Error(string, ...any) {}
+11
View File
@@ -0,0 +1,11 @@
package client
import "testing"
func TestNoopLogger_DoesNotPanic(t *testing.T) {
var l Logger = NoopLogger{}
l.Debug("d", "k", "v")
l.Info("i")
l.Warn("w")
l.Error("e")
}
+146
View File
@@ -0,0 +1,146 @@
package client
import (
"context"
"github.com/goccy/go-json"
"io"
"mime/multipart"
"net/http"
)
// multipartRequest is implemented by request structs that may carry an
// InputFile. The codegen emits this interface for any method whose IR
// MethodDecl.HasFiles is true.
//
// HasFile returns true if at least one file field is set; if false, the
// request is sent as plain JSON via the regular Call path.
//
// MultipartFiles returns one entry per file field that should be uploaded.
// The accompanying scalar/object fields are returned by MultipartFields.
type multipartRequest interface {
HasFile() bool
MultipartFiles() []MultipartFile
MultipartFields() map[string]string
}
// MultipartFile describes a single file part in a multipart upload.
type MultipartFile struct {
FieldName string
Filename string
Reader io.Reader
}
// callMultipart performs a multipart/form-data POST. It is invoked by Call
// when the request implements multipartRequest and HasFile() is true.
func callMultipart[Resp any](ctx context.Context, b *Bot, method string, mp multipartRequest) (Resp, error) {
var zero Resp
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
// Stream-write the multipart body in a goroutine so we don't buffer
// large files in memory.
go func() {
defer func() { _ = pw.Close() }()
defer func() { _ = mw.Close() }()
for k, v := range mp.MultipartFields() {
if err := mw.WriteField(k, v); err != nil {
_ = pw.CloseWithError(err)
return
}
}
for _, f := range mp.MultipartFiles() {
part, err := mw.CreateFormFile(f.FieldName, f.Filename)
if err != nil {
_ = pw.CloseWithError(err)
return
}
if _, err := io.Copy(part, f.Reader); err != nil {
_ = pw.CloseWithError(err)
return
}
}
}()
url := b.base + "/bot" + b.token + "/" + method
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr)
if err != nil {
_ = pr.CloseWithError(err)
return zero, &NetworkError{Err: err}
}
req.Header.Set("Content-Type", mw.FormDataContentType())
req.Header.Set("Accept", "application/json")
resp, err := b.http.Do(req)
if err != nil {
_ = pr.CloseWithError(err)
if ctxErr := ctx.Err(); ctxErr != nil {
return zero, ctxErr
}
return zero, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
_ = pr.CloseWithError(err)
return zero, &NetworkError{Err: err}
}
return decodeResult[Resp](b.codec, raw)
}
// callMultipartRaw is callMultipart's sibling that returns the raw result
// JSON instead of decoding into a typed value. Used by generated method
// wrappers whose return type is a sealed-interface union.
func callMultipartRaw(ctx context.Context, b *Bot, method string, mp multipartRequest) (json.RawMessage, error) {
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
go func() {
defer func() { _ = pw.Close() }()
defer func() { _ = mw.Close() }()
for k, v := range mp.MultipartFields() {
if err := mw.WriteField(k, v); err != nil {
_ = pw.CloseWithError(err)
return
}
}
for _, f := range mp.MultipartFiles() {
part, err := mw.CreateFormFile(f.FieldName, f.Filename)
if err != nil {
_ = pw.CloseWithError(err)
return
}
if _, err := io.Copy(part, f.Reader); err != nil {
_ = pw.CloseWithError(err)
return
}
}
}()
url := b.base + "/bot" + b.token + "/" + method
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, pr)
if err != nil {
_ = pr.CloseWithError(err)
return nil, &NetworkError{Err: err}
}
req.Header.Set("Content-Type", mw.FormDataContentType())
req.Header.Set("Accept", "application/json")
resp, err := b.http.Do(req)
if err != nil {
_ = pr.CloseWithError(err)
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
return nil, &NetworkError{Err: err}
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(resp.Body)
if err != nil {
_ = pr.CloseWithError(err)
return nil, &NetworkError{Err: err}
}
return decodeResultRaw(b.codec, raw)
}
+103
View File
@@ -0,0 +1,103 @@
package client
import (
"context"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type fakeMultipartReq struct {
chatID int64
body string
}
func (f *fakeMultipartReq) HasFile() bool { return true }
func (f *fakeMultipartReq) MultipartFields() map[string]string {
return map[string]string{"chat_id": "42"}
}
func (f *fakeMultipartReq) MultipartFiles() []MultipartFile {
return []MultipartFile{{
FieldName: "document",
Filename: "hello.txt",
Reader: strings.NewReader(f.body),
}}
}
type fileResp struct {
MessageID int64 `json:"message_id"`
}
func TestCallMultipart_Success(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.MatchedBy(func(r *http.Request) bool {
ct := r.Header.Get("Content-Type")
if !strings.HasPrefix(ct, "multipart/form-data") {
return false
}
_, params, err := mime.ParseMediaType(ct)
if err != nil {
return false
}
mr := multipart.NewReader(r.Body, params["boundary"])
seenChat := false
seenFile := false
for {
p, err := mr.NextPart()
if err == io.EOF {
break
}
if err != nil {
return false
}
switch p.FormName() {
case "chat_id":
body, _ := io.ReadAll(p)
seenChat = string(body) == "42"
case "document":
body, _ := io.ReadAll(p)
seenFile = string(body) == "hello world"
}
}
return seenChat && seenFile
})).Return(newResp(200, `{"ok":true,"result":{"message_id":99}}`), nil)
b := New("t", WithHTTPClient(m))
out, err := Call[*fakeMultipartReq, *fileResp](context.Background(), b, "sendDocument", &fakeMultipartReq{chatID: 42, body: "hello world"})
require.NoError(t, err)
require.Equal(t, int64(99), out.MessageID)
}
func TestCallMultipart_NoGoroutineLeakOnError(t *testing.T) {
m := &mockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
b := New("t", WithHTTPClient(m))
before := runtime.NumGoroutine()
for i := 0; i < 50; i++ {
_, _ = Call[*fakeMultipartReq, *fileResp](
context.Background(), b, "sendDocument",
&fakeMultipartReq{chatID: 42, body: strings.Repeat("x", 1<<14)},
)
}
// Allow goroutines to finish exiting after Close propagates.
time.Sleep(50 * time.Millisecond)
runtime.GC()
after := runtime.NumGoroutine()
// A small drift is normal (timers, finalizers); 5 is generous.
if after-before > 5 {
t.Fatalf("goroutine leak: before=%d after=%d", before, after)
}
}
+29
View File
@@ -0,0 +1,29 @@
package client
// Option configures a Bot at construction time. Per-call configuration is
// expressed via typed parameter structs (e.g. SendMessageParams), not options.
type Option func(*Bot)
// WithHTTPClient overrides the HTTP transport. Pass any HTTPDoer
// implementation (e.g. an *http.Client wrapping a custom RoundTripper, or
// a fasthttp adapter).
func WithHTTPClient(c HTTPDoer) Option { return func(b *Bot) { b.http = c } }
// WithCodec overrides the JSON codec. Pass goccy/go-json, sonic, or any
// type implementing Codec to swap out encoding/json.
func WithCodec(c Codec) Option { return func(b *Bot) { b.codec = c } }
// WithBaseURL overrides the API base URL. Useful for testing against a
// local httptest.Server, or for self-hosted Bot API servers.
func WithBaseURL(url string) Option { return func(b *Bot) { b.base = url } }
// WithLogger sets the logger used for diagnostic events. Passing nil
// silently disables logging.
func WithLogger(l Logger) Option {
return func(b *Bot) {
if l == nil {
l = NoopLogger{}
}
b.logger = l
}
}
+15
View File
@@ -0,0 +1,15 @@
package client
import "regexp"
// tokenInURL matches a Telegram bot token segment in a URL path. Tokens
// have the form <bot_id>:<api_key>, where bot_id is digits and api_key
// is 35 base64-url characters. The pattern is conservative: matches
// /bot<id>:<key>/ to avoid false positives.
var tokenInURL = regexp.MustCompile(`/bot(\d{5,15}):([A-Za-z0-9_-]{30,40})/`)
// redactToken replaces any bot token in s with /bot<REDACTED>/. Used by
// error formatters so logs don't leak credentials.
func redactToken(s string) string {
return tokenInURL.ReplaceAllString(s, "/bot<REDACTED>/")
}
+46
View File
@@ -0,0 +1,46 @@
package client
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedactToken(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{"plain bot URL", "https://api.telegram.org/bot123456789:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe",
"https://api.telegram.org/bot<REDACTED>/getMe"},
{"in net/http error", `Post "https://api.telegram.org/bot987654321:Z9YxWvUtSrQpOnMlKjIhGfEdCbA9876543210/sendMessage": dial tcp: lookup api.telegram.org: no such host`,
`Post "https://api.telegram.org/bot<REDACTED>/sendMessage": dial tcp: lookup api.telegram.org: no such host`},
{"no token", "regular error message", "regular error message"},
{"underscore + dash in token", "/bot123456789:abc-def_ghi-jkl_mno-pqr_stu-vwx_yz/sendDocument",
"/bot<REDACTED>/sendDocument"},
{"too short id (no match)", "/bot123:abc/getMe", "/bot123:abc/getMe"},
{"too short key (no match)", "/bot123456789:short/getMe", "/bot123456789:short/getMe"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
require.Equal(t, c.want, redactToken(c.in))
})
}
}
func TestNetworkError_RedactsToken(t *testing.T) {
inner := errors.New(`Post "https://api.telegram.org/bot1234567890:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe": dial tcp: timeout`)
e := &NetworkError{Err: inner}
require.NotContains(t, e.Error(), "ABCdefGHIjklMNOpqrSTUvwxYZ")
require.Contains(t, e.Error(), "<REDACTED>")
}
func TestParseError_RedactsToken(t *testing.T) {
inner := fmt.Errorf(`unexpected response from /bot1234567890:ABCdefGHIjklMNOpqrSTUvwxYZ0123456789/getMe`)
e := &ParseError{Err: inner, Body: []byte("garbage")}
require.NotContains(t, e.Error(), "ABCdefGHI")
require.Contains(t, e.Error(), "<REDACTED>")
}
+27
View File
@@ -0,0 +1,27 @@
package client
// Result is the universal Telegram API response envelope. Every successful
// response is shaped {"ok":true,"result":T,...}; failure responses set ok
// to false and populate ErrorCode / Description / Parameters.
//
// Result is generic over T so generated method wrappers can decode the
// strongly-typed payload directly. Users do not normally construct or
// inspect Result values; method wrappers unwrap them and return either
// the typed payload or a *APIError.
type Result[T any] struct {
OK bool `json:"ok"`
Result T `json:"result,omitempty"`
ErrorCode int `json:"error_code,omitempty"`
Description string `json:"description,omitempty"`
Parameters *ResponseParameters `json:"parameters,omitempty"`
}
// ResponseParameters is the optional metadata Telegram includes on certain
// failures. The most common is RetryAfter (seconds) on 429 responses.
//
// This type is duplicated in package api for users; keeping a copy here
// avoids an import cycle (api imports client, not vice versa).
type ResponseParameters struct {
MigrateToChatID int64 `json:"migrate_to_chat_id,omitempty"`
RetryAfter int `json:"retry_after,omitempty"`
}
+225
View File
@@ -0,0 +1,225 @@
package client
import (
"bytes"
"context"
crand "crypto/rand"
"encoding/binary"
"github.com/goccy/go-json"
"io"
"math"
"net/http"
"time"
)
// RetryDoer is an HTTPDoer that retries transient failures (429, 5xx,
// and network errors) with exponential backoff. It honours the
// retry_after value Telegram supplies on rate-limit responses.
//
// Wrap any HTTPDoer to add retry behaviour:
//
// bot := client.New(token, client.WithHTTPClient(
// client.NewRetryDoer(client.NewDefaultHTTPDoer())))
type RetryDoer struct {
inner HTTPDoer
maxAttempts int
base time.Duration
max time.Duration
factor float64
jitter float64
}
// RetryOption configures a RetryDoer.
type RetryOption func(*RetryDoer)
// WithMaxAttempts sets the maximum number of attempts (including the
// initial one). Default 4 (one initial + three retries).
func WithMaxAttempts(n int) RetryOption {
return func(d *RetryDoer) { d.maxAttempts = n }
}
// WithBaseBackoff sets the initial backoff duration. Default 500ms.
func WithBaseBackoff(d time.Duration) RetryOption {
return func(r *RetryDoer) { r.base = d }
}
// WithMaxBackoff caps the backoff at max. Default 30s.
func WithMaxBackoff(d time.Duration) RetryOption {
return func(r *RetryDoer) { r.max = d }
}
// WithBackoffFactor sets the exponential growth factor. Default 2.0.
func WithBackoffFactor(f float64) RetryOption {
return func(r *RetryDoer) { r.factor = f }
}
// WithJitter sets the jitter fraction (0..1) applied to each backoff.
// Default 0.2.
func WithJitter(j float64) RetryOption {
return func(r *RetryDoer) { r.jitter = j }
}
// NewRetryDoer wraps inner with retry behaviour.
func NewRetryDoer(inner HTTPDoer, opts ...RetryOption) *RetryDoer {
d := &RetryDoer{
inner: inner,
maxAttempts: 4,
base: 500 * time.Millisecond,
max: 30 * time.Second,
factor: 2.0,
jitter: 0.2,
}
for _, o := range opts {
o(d)
}
return d
}
// Do dispatches via the inner HTTPDoer and retries on transient failures.
// The request body is buffered on first attempt so it can be replayed.
func (d *RetryDoer) Do(req *http.Request) (*http.Response, error) {
// Buffer the body so we can replay it across attempts.
var body []byte
if req.Body != nil {
b, err := io.ReadAll(req.Body)
if err != nil {
return nil, &NetworkError{Err: err}
}
_ = req.Body.Close()
body = b
}
var lastResp *http.Response
var lastErr error
for attempt := 1; attempt <= d.maxAttempts; attempt++ {
if body != nil {
req.Body = io.NopCloser(bytes.NewReader(body))
}
resp, err := d.inner.Do(req)
// Network errors: maybe retry.
if err != nil {
// Honour ctx cancellation.
if ctxErr := req.Context().Err(); ctxErr != nil {
return nil, ctxErr
}
lastErr = err
if attempt < d.maxAttempts {
if !d.sleep(req.Context(), d.delay(attempt, 0)) {
return nil, req.Context().Err()
}
continue
}
return nil, err
}
// HTTP 200: Telegram almost always returns 200 even for errors.
// Peek the body to detect retryable Telegram error payloads.
if resp.StatusCode == http.StatusOK {
data, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, &NetworkError{Err: readErr}
}
// Re-attach the buffered body for the caller.
resp.Body = io.NopCloser(bytes.NewReader(data))
if isRetryablePayload(data) && attempt < d.maxAttempts {
lastResp = resp
wait := retryAfterFromPayload(data)
if !d.sleep(req.Context(), d.delay(attempt, wait)) {
return nil, req.Context().Err()
}
continue
}
return resp, nil
}
// Non-200 status (rare with Telegram; usually 200 + ok:false).
// Treat 5xx and 429 as retryable.
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode >= http.StatusInternalServerError) && attempt < d.maxAttempts {
_ = resp.Body.Close()
lastResp = resp
if !d.sleep(req.Context(), d.delay(attempt, 0)) {
return nil, req.Context().Err()
}
continue
}
return resp, nil
}
if lastErr != nil {
return nil, lastErr
}
return lastResp, nil
}
// delay computes the wait duration for the given attempt (1-based).
// override, when non-zero, takes precedence (used to honour Telegram's
// retry_after value).
func (d *RetryDoer) delay(attempt int, override time.Duration) time.Duration {
if override > 0 {
return override
}
delay := float64(d.base) * math.Pow(d.factor, float64(attempt-1))
if d.jitter > 0 {
var b [8]byte
_, _ = crand.Read(b[:])
f := float64(binary.LittleEndian.Uint64(b[:])) / (1 << 64)
delay *= 1 + (f*2-1)*d.jitter
}
if delay > float64(d.max) {
delay = float64(d.max)
}
if delay < 0 {
delay = 0
}
return time.Duration(delay)
}
// sleep waits for dur or ctx cancellation. Returns false if cancelled.
func (d *RetryDoer) sleep(ctx context.Context, dur time.Duration) bool {
if dur <= 0 {
return true
}
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return true
case <-ctx.Done():
return false
}
}
// isRetryablePayload reports whether body is a Telegram error response
// indicating a retryable failure (429 or 5xx error_code).
func isRetryablePayload(body []byte) bool {
var env struct {
OK bool `json:"ok"`
ErrorCode int `json:"error_code"`
}
if err := json.Unmarshal(body, &env); err != nil {
return false
}
if env.OK {
return false
}
return env.ErrorCode == 429 || (env.ErrorCode >= 500 && env.ErrorCode < 600)
}
// retryAfterFromPayload extracts the retry_after value from a Telegram
// error response body and returns it as a duration. Returns 0 if absent.
func retryAfterFromPayload(body []byte) time.Duration {
var env struct {
Parameters struct {
RetryAfter int `json:"retry_after"`
} `json:"parameters"`
}
if err := json.Unmarshal(body, &env); err != nil {
return 0
}
return time.Duration(env.Parameters.RetryAfter) * time.Second
}
+144
View File
@@ -0,0 +1,144 @@
package client
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type retryMockDoer struct{ mock.Mock }
func (m *retryMockDoer) Do(r *http.Request) (*http.Response, error) {
args := m.Called(r)
if v := args.Get(0); v != nil {
return v.(*http.Response), args.Error(1)
}
return nil, args.Error(1)
}
func okResp(body string) *http.Response {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
func TestRetryDoer_HappyPath(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":"hi"}`), nil).Once()
d := NewRetryDoer(m)
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
resp, err := d.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
m.AssertExpectations(t)
}
func TestRetryDoer_RetriesOnNetworkError(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout")).Once()
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":"hi"}`), nil).Once()
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
resp, err := d.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
m.AssertExpectations(t)
}
func TestRetryDoer_HonoursRetryAfter(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(
okResp(`{"ok":false,"error_code":429,"description":"Too Many","parameters":{"retry_after":1}}`), nil).Once()
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":1}`), nil).Once()
// base is 10s — retry_after=1s should override it (much shorter wait).
d := NewRetryDoer(m, WithBaseBackoff(10*time.Second))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
start := time.Now()
resp, err := d.Do(req)
elapsed := time.Since(start)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.GreaterOrEqual(t, elapsed, 900*time.Millisecond, "should honour retry_after=1s")
require.Less(t, elapsed, 3*time.Second, "should NOT use base backoff (10s)")
m.AssertExpectations(t)
}
func TestRetryDoer_Retries5xx(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(
okResp(`{"ok":false,"error_code":500,"description":"Internal Server Error"}`), nil).Once()
m.On("Do", mock.Anything).Return(okResp(`{"ok":true,"result":1}`), nil).Once()
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
resp, err := d.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
m.AssertExpectations(t)
}
func TestRetryDoer_AllAttemptsFail(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(nil, errors.New("dial timeout"))
d := NewRetryDoer(m, WithMaxAttempts(3), WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{}`))
_, err := d.Do(req)
require.Error(t, err)
require.Contains(t, err.Error(), "dial timeout")
require.Equal(t, 3, len(m.Calls))
}
func TestRetryDoer_ContextCancellationAborts(t *testing.T) {
m := &retryMockDoer{}
m.On("Do", mock.Anything).Return(
okResp(`{"ok":false,"error_code":500,"description":"server error"}`), nil).Maybe()
d := NewRetryDoer(m, WithBaseBackoff(100*time.Millisecond))
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "POST", "http://x", strings.NewReader(`{}`))
_, err := d.Do(req)
require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded))
}
func TestRetryDoer_ReplaysBody(t *testing.T) {
m := &retryMockDoer{}
var seen []string
// First call: capture body, return 500 to trigger retry.
m.On("Do", mock.Anything).Return(okResp(`{"ok":false,"error_code":500}`), nil).Once().Run(func(args mock.Arguments) {
r := args.Get(0).(*http.Request)
body, _ := io.ReadAll(r.Body)
seen = append(seen, string(body))
})
// Second call: capture body, return success.
m.On("Do", mock.Anything).Return(okResp(`{"ok":true}`), nil).Once().Run(func(args mock.Arguments) {
r := args.Get(0).(*http.Request)
body, _ := io.ReadAll(r.Body)
seen = append(seen, string(body))
})
d := NewRetryDoer(m, WithBaseBackoff(time.Millisecond))
req, _ := http.NewRequest("POST", "http://x", strings.NewReader(`{"chat_id":42}`))
_, err := d.Do(req)
require.NoError(t, err)
require.Len(t, seen, 2)
require.Equal(t, seen[0], seen[1])
require.Equal(t, `{"chat_id":42}`, seen[0])
m.AssertExpectations(t)
}