mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
bde1db1c3b
* Automatic discovery of the scopes. Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider. This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes. Before: User configures: scopes: ["openid", "profile", "email", "offline_access"] Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed" Authentication: ❌ FAILS After: User configures: scopes: ["openid", "profile", "email", "offline_access"] Middleware checks discovery doc → offline_access not supported Automatically filters to: ["openid", "profile", "email"] Authentication: ✅ SUCCEEDS * Resolves issue #74 by enabling user to specify expected audience in the configuration. * Fix flaky tests.
687 lines
20 KiB
Go
687 lines
20 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestGetNewTokenWithRefreshToken tests the GetNewTokenWithRefreshToken function
|
|
func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
refreshToken string
|
|
setupMock func(*httptest.Server) *TraefikOidc
|
|
validateFunc func(*testing.T, *TokenResponse, error)
|
|
wantErr bool
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "successful token refresh",
|
|
refreshToken: "valid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
if resp == nil {
|
|
t.Error("expected token response, got nil")
|
|
return
|
|
}
|
|
if resp.AccessToken != "refreshed_access_token" {
|
|
t.Errorf("expected refreshed_access_token, got %s", resp.AccessToken)
|
|
}
|
|
if resp.IDToken != "refreshed_id_token" {
|
|
t.Errorf("expected refreshed_id_token, got %s", resp.IDToken)
|
|
}
|
|
if resp.RefreshToken != "new_refresh_token" {
|
|
t.Errorf("expected new_refresh_token, got %s", resp.RefreshToken)
|
|
}
|
|
if resp.TokenType != "Bearer" {
|
|
t.Errorf("expected token type Bearer, got %s", resp.TokenType)
|
|
}
|
|
if resp.ExpiresIn != 3600 {
|
|
t.Errorf("expected expires_in 3600, got %d", resp.ExpiresIn)
|
|
}
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "expired refresh token",
|
|
refreshToken: "expired_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/expired",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected error for expired refresh token, got nil")
|
|
return
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "expired") {
|
|
t.Errorf("expected invalid_grant or expired error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "invalid_grant",
|
|
},
|
|
{
|
|
name: "invalid refresh token",
|
|
refreshToken: "invalid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/invalid",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected error for invalid refresh token, got nil")
|
|
return
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid_grant") {
|
|
t.Errorf("expected invalid_grant error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "invalid_grant",
|
|
},
|
|
{
|
|
name: "revoked refresh token",
|
|
refreshToken: "revoked_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/revoked",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected error for revoked refresh token, got nil")
|
|
return
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "revoked") {
|
|
t.Errorf("expected invalid_grant or revoked error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "invalid_grant",
|
|
},
|
|
{
|
|
name: "network timeout during refresh",
|
|
refreshToken: "valid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/timeout",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 100 * time.Millisecond,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected timeout error, got nil")
|
|
return
|
|
}
|
|
if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "deadline") {
|
|
t.Errorf("expected timeout error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "timeout",
|
|
},
|
|
{
|
|
name: "server error during refresh",
|
|
refreshToken: "valid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/error",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected server error, got nil")
|
|
return
|
|
}
|
|
if !strings.Contains(err.Error(), "500") && !strings.Contains(err.Error(), "server_error") {
|
|
t.Errorf("expected server error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "server_error",
|
|
},
|
|
{
|
|
name: "malformed JSON response",
|
|
refreshToken: "valid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/malformed",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected JSON parse error, got nil")
|
|
return
|
|
}
|
|
// Accept various JSON parsing error messages
|
|
if !strings.Contains(err.Error(), "json") && !strings.Contains(err.Error(), "unmarshal") && !strings.Contains(err.Error(), "invalid character") {
|
|
t.Errorf("expected JSON error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "json",
|
|
},
|
|
{
|
|
name: "partial token response (missing ID token)",
|
|
refreshToken: "valid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/partial",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err != nil {
|
|
t.Logf("got error: %v", err)
|
|
}
|
|
if resp == nil {
|
|
t.Error("expected partial token response, got nil")
|
|
return
|
|
}
|
|
if resp.AccessToken != "partial_access_token" {
|
|
t.Errorf("expected partial_access_token, got %s", resp.AccessToken)
|
|
}
|
|
if resp.IDToken != "" {
|
|
t.Errorf("expected empty ID token, got %s", resp.IDToken)
|
|
}
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "rate limited refresh request",
|
|
refreshToken: "valid_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/ratelimit",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected rate limit error, got nil")
|
|
return
|
|
}
|
|
if !strings.Contains(err.Error(), "429") && !strings.Contains(err.Error(), "rate") {
|
|
t.Errorf("expected rate limit error, got: %v", err)
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "rate",
|
|
},
|
|
{
|
|
name: "empty refresh token",
|
|
refreshToken: "",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err == nil {
|
|
t.Error("expected error for empty refresh token, got nil")
|
|
return
|
|
}
|
|
// The actual error should contain invalid_request
|
|
if !strings.Contains(err.Error(), "invalid_request") && !strings.Contains(err.Error(), "missing") {
|
|
t.Errorf("expected invalid_request or missing error, got: %v", err)
|
|
}
|
|
if resp != nil {
|
|
t.Error("expected nil response for empty refresh token")
|
|
}
|
|
},
|
|
wantErr: true,
|
|
expectedError: "invalid_request",
|
|
},
|
|
{
|
|
name: "refresh with rotating tokens",
|
|
refreshToken: "rotating_refresh_token",
|
|
setupMock: func(server *httptest.Server) *TraefikOidc {
|
|
return &TraefikOidc{
|
|
tokenURL: server.URL + "/token/rotating",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
},
|
|
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
if resp == nil {
|
|
t.Error("expected token response, got nil")
|
|
return
|
|
}
|
|
// Verify we got a different refresh token (rotation)
|
|
if resp.RefreshToken == "rotating_refresh_token" {
|
|
t.Error("expected new refresh token (rotation), got same token")
|
|
}
|
|
if resp.RefreshToken == "" {
|
|
t.Error("expected new refresh token, got empty")
|
|
}
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create test server with various endpoints
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify request method
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Parse request body
|
|
body, _ := io.ReadAll(r.Body)
|
|
values, _ := url.ParseQuery(string(body))
|
|
|
|
// Verify grant type for refresh
|
|
if values.Get("grant_type") != "refresh_token" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "unsupported_grant_type",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Handle different test scenarios based on path
|
|
switch r.URL.Path {
|
|
case "/token":
|
|
// Check for empty refresh token
|
|
if values.Get("refresh_token") == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "invalid_request",
|
|
"error_description": "The refresh token is missing",
|
|
})
|
|
return
|
|
}
|
|
// Successful refresh
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
AccessToken: "refreshed_access_token",
|
|
IDToken: "refreshed_id_token",
|
|
RefreshToken: "new_refresh_token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
|
|
case "/token/expired":
|
|
// Expired refresh token
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "invalid_grant",
|
|
"error_description": "The refresh token has expired",
|
|
})
|
|
|
|
case "/token/invalid":
|
|
// Invalid refresh token
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "invalid_grant",
|
|
"error_description": "The refresh token is invalid",
|
|
})
|
|
|
|
case "/token/revoked":
|
|
// Revoked refresh token
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "invalid_grant",
|
|
"error_description": "The refresh token has been revoked",
|
|
})
|
|
|
|
case "/token/timeout":
|
|
// Simulate timeout
|
|
time.Sleep(200 * time.Millisecond)
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
case "/token/error":
|
|
// Server error
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "server_error",
|
|
})
|
|
|
|
case "/token/malformed":
|
|
// Malformed JSON
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(`{"access_token": "test", invalid json`))
|
|
|
|
case "/token/partial":
|
|
// Partial response (missing ID token)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"access_token": "partial_access_token",
|
|
"refresh_token": "partial_refresh_token",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
// ID token intentionally missing
|
|
})
|
|
|
|
case "/token/ratelimit":
|
|
// Rate limiting
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "rate_limit_exceeded",
|
|
})
|
|
|
|
case "/token/rotating":
|
|
// Token rotation - return different refresh token
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
AccessToken: "rotated_access_token",
|
|
IDToken: "rotated_id_token",
|
|
RefreshToken: fmt.Sprintf("rotated_refresh_token_%d", time.Now().UnixNano()),
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Setup TraefikOidc instance
|
|
oidc := tt.setupMock(server)
|
|
|
|
// Execute the function
|
|
resp, err := oidc.GetNewTokenWithRefreshToken(tt.refreshToken)
|
|
|
|
// Validate results
|
|
if tt.wantErr && err == nil {
|
|
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
|
} else if !tt.wantErr && err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
} else if tt.wantErr && err != nil && tt.expectedError != "" {
|
|
// Check if error message contains expected string
|
|
if !strings.Contains(err.Error(), tt.expectedError) {
|
|
t.Logf("Error doesn't contain expected string %q: %v", tt.expectedError, err)
|
|
}
|
|
}
|
|
|
|
// Run custom validation
|
|
if tt.validateFunc != nil {
|
|
tt.validateFunc(t, resp, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetNewTokenWithRefreshToken_Concurrency tests concurrent refresh scenarios
|
|
func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
|
|
t.Run("multiple concurrent refreshes with same token", func(t *testing.T) {
|
|
refreshCount := 0
|
|
var mu sync.Mutex
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
refreshCount++
|
|
count := refreshCount
|
|
mu.Unlock()
|
|
|
|
// Simulate processing time
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
AccessToken: fmt.Sprintf("access_token_%d", count),
|
|
IDToken: fmt.Sprintf("id_token_%d", count),
|
|
RefreshToken: fmt.Sprintf("refresh_token_%d", count),
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
oidc := &TraefikOidc{
|
|
tokenURL: server.URL + "/token",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
|
|
// Run multiple concurrent refreshes with the same token
|
|
const numRequests = 5
|
|
results := make(chan *TokenResponse, numRequests)
|
|
errors := make(chan error, numRequests)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(numRequests)
|
|
|
|
for i := 0; i < numRequests; i++ {
|
|
go func() {
|
|
defer wg.Done()
|
|
resp, err := oidc.GetNewTokenWithRefreshToken("same_refresh_token")
|
|
if err != nil {
|
|
errors <- err
|
|
} else {
|
|
results <- resp
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(results)
|
|
close(errors)
|
|
|
|
// Verify all requests completed
|
|
successCount := len(results)
|
|
errorCount := len(errors)
|
|
|
|
if successCount != numRequests {
|
|
t.Errorf("expected %d successful refreshes, got %d", numRequests, successCount)
|
|
}
|
|
if errorCount > 0 {
|
|
t.Errorf("got %d errors in concurrent refreshes", errorCount)
|
|
}
|
|
|
|
// Verify we actually made concurrent requests
|
|
mu.Lock()
|
|
finalCount := refreshCount
|
|
mu.Unlock()
|
|
|
|
if finalCount != numRequests {
|
|
t.Errorf("expected %d refresh calls, got %d", numRequests, finalCount)
|
|
}
|
|
})
|
|
|
|
t.Run("race condition detection", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Return successful response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
AccessToken: "race_test_access_token",
|
|
IDToken: "race_test_id_token",
|
|
RefreshToken: "race_test_refresh_token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
oidc := &TraefikOidc{
|
|
tokenURL: server.URL + "/token",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
|
|
// Run with race detector (go test -race will catch issues)
|
|
const numGoroutines = 10
|
|
var wg sync.WaitGroup
|
|
wg.Add(numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
token := fmt.Sprintf("refresh_token_%d", id)
|
|
_, _ = oidc.GetNewTokenWithRefreshToken(token)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
})
|
|
}
|
|
|
|
// TestGetNewTokenWithRefreshToken_ErrorRecovery tests error recovery scenarios
|
|
func TestGetNewTokenWithRefreshToken_ErrorRecovery(t *testing.T) {
|
|
t.Run("recovery after temporary failure", func(t *testing.T) {
|
|
attemptCount := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
attemptCount++
|
|
|
|
// Fail first two attempts, succeed on third
|
|
if attemptCount <= 2 {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": "temporarily_unavailable",
|
|
})
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
AccessToken: "recovered_access_token",
|
|
IDToken: "recovered_id_token",
|
|
RefreshToken: "recovered_refresh_token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
oidc := &TraefikOidc{
|
|
tokenURL: server.URL + "/token",
|
|
clientID: "test_client",
|
|
audience: "test_client",
|
|
clientSecret: "test_secret",
|
|
tokenHTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
logger: NewLogger("debug"),
|
|
}
|
|
|
|
// First two attempts should fail
|
|
for i := 0; i < 2; i++ {
|
|
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
|
|
if err == nil {
|
|
t.Errorf("expected error on attempt %d, got success", i+1)
|
|
}
|
|
if resp != nil {
|
|
t.Errorf("expected nil response on attempt %d", i+1)
|
|
}
|
|
}
|
|
|
|
// Third attempt should succeed
|
|
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
|
|
if err != nil {
|
|
t.Errorf("unexpected error on recovery attempt: %v", err)
|
|
}
|
|
if resp == nil || resp.AccessToken != "recovered_access_token" {
|
|
t.Error("expected successful recovery")
|
|
}
|
|
})
|
|
}
|