mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly * Increase test coverage * Further improvements to test coverage and code quality * Add new providers. * fixup! Add new providers. * Cleanup. * fixup! Cleanup. * fixup! fixup! Cleanup. * fixup! fixup! fixup! Cleanup. * fixup! fixup! fixup! fixup! Cleanup. * Memory management optimisation 24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference) * Pooling cleanup.
This commit is contained in:
@@ -0,0 +1,672 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestGetNewTokenWithRefreshToken tests the GetNewTokenWithRefreshToken function
|
||||
func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
setupMock func(*httptest.Server) *TraefikOidc
|
||||
validateFunc func(*testing.T, *TokenResponse, error)
|
||||
wantErr bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "successful token refresh",
|
||||
refreshToken: "valid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
t.Error("expected token response, got nil")
|
||||
return
|
||||
}
|
||||
if resp.AccessToken != "refreshed_access_token" {
|
||||
t.Errorf("expected refreshed_access_token, got %s", resp.AccessToken)
|
||||
}
|
||||
if resp.IDToken != "refreshed_id_token" {
|
||||
t.Errorf("expected refreshed_id_token, got %s", resp.IDToken)
|
||||
}
|
||||
if resp.RefreshToken != "new_refresh_token" {
|
||||
t.Errorf("expected new_refresh_token, got %s", resp.RefreshToken)
|
||||
}
|
||||
if resp.TokenType != "Bearer" {
|
||||
t.Errorf("expected token type Bearer, got %s", resp.TokenType)
|
||||
}
|
||||
if resp.ExpiresIn != 3600 {
|
||||
t.Errorf("expected expires_in 3600, got %d", resp.ExpiresIn)
|
||||
}
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired refresh token",
|
||||
refreshToken: "expired_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/expired",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected error for expired refresh token, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "expired") {
|
||||
t.Errorf("expected invalid_grant or expired error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "invalid_grant",
|
||||
},
|
||||
{
|
||||
name: "invalid refresh token",
|
||||
refreshToken: "invalid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/invalid",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid refresh token, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid_grant") {
|
||||
t.Errorf("expected invalid_grant error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "invalid_grant",
|
||||
},
|
||||
{
|
||||
name: "revoked refresh token",
|
||||
refreshToken: "revoked_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/revoked",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected error for revoked refresh token, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "revoked") {
|
||||
t.Errorf("expected invalid_grant or revoked error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "invalid_grant",
|
||||
},
|
||||
{
|
||||
name: "network timeout during refresh",
|
||||
refreshToken: "valid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/timeout",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 100 * time.Millisecond,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected timeout error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "deadline") {
|
||||
t.Errorf("expected timeout error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "timeout",
|
||||
},
|
||||
{
|
||||
name: "server error during refresh",
|
||||
refreshToken: "valid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/error",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected server error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") && !strings.Contains(err.Error(), "server_error") {
|
||||
t.Errorf("expected server error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "server_error",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON response",
|
||||
refreshToken: "valid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/malformed",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected JSON parse error, got nil")
|
||||
return
|
||||
}
|
||||
// Accept various JSON parsing error messages
|
||||
if !strings.Contains(err.Error(), "json") && !strings.Contains(err.Error(), "unmarshal") && !strings.Contains(err.Error(), "invalid character") {
|
||||
t.Errorf("expected JSON error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "json",
|
||||
},
|
||||
{
|
||||
name: "partial token response (missing ID token)",
|
||||
refreshToken: "valid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/partial",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err != nil {
|
||||
t.Logf("got error: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Error("expected partial token response, got nil")
|
||||
return
|
||||
}
|
||||
if resp.AccessToken != "partial_access_token" {
|
||||
t.Errorf("expected partial_access_token, got %s", resp.AccessToken)
|
||||
}
|
||||
if resp.IDToken != "" {
|
||||
t.Errorf("expected empty ID token, got %s", resp.IDToken)
|
||||
}
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "rate limited refresh request",
|
||||
refreshToken: "valid_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/ratelimit",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected rate limit error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), "429") && !strings.Contains(err.Error(), "rate") {
|
||||
t.Errorf("expected rate limit error, got: %v", err)
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "rate",
|
||||
},
|
||||
{
|
||||
name: "empty refresh token",
|
||||
refreshToken: "",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err == nil {
|
||||
t.Error("expected error for empty refresh token, got nil")
|
||||
return
|
||||
}
|
||||
// The actual error should contain invalid_request
|
||||
if !strings.Contains(err.Error(), "invalid_request") && !strings.Contains(err.Error(), "missing") {
|
||||
t.Errorf("expected invalid_request or missing error, got: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Error("expected nil response for empty refresh token")
|
||||
}
|
||||
},
|
||||
wantErr: true,
|
||||
expectedError: "invalid_request",
|
||||
},
|
||||
{
|
||||
name: "refresh with rotating tokens",
|
||||
refreshToken: "rotating_refresh_token",
|
||||
setupMock: func(server *httptest.Server) *TraefikOidc {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/rotating",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
},
|
||||
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
t.Error("expected token response, got nil")
|
||||
return
|
||||
}
|
||||
// Verify we got a different refresh token (rotation)
|
||||
if resp.RefreshToken == "rotating_refresh_token" {
|
||||
t.Error("expected new refresh token (rotation), got same token")
|
||||
}
|
||||
if resp.RefreshToken == "" {
|
||||
t.Error("expected new refresh token, got empty")
|
||||
}
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test server with various endpoints
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request method
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
|
||||
// Verify grant type for refresh
|
||||
if values.Get("grant_type") != "refresh_token" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unsupported_grant_type",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle different test scenarios based on path
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
// Check for empty refresh token
|
||||
if values.Get("refresh_token") == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_request",
|
||||
"error_description": "The refresh token is missing",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Successful refresh
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "refreshed_access_token",
|
||||
IDToken: "refreshed_id_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
|
||||
case "/token/expired":
|
||||
// Expired refresh token
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "The refresh token has expired",
|
||||
})
|
||||
|
||||
case "/token/invalid":
|
||||
// Invalid refresh token
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "The refresh token is invalid",
|
||||
})
|
||||
|
||||
case "/token/revoked":
|
||||
// Revoked refresh token
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "invalid_grant",
|
||||
"error_description": "The refresh token has been revoked",
|
||||
})
|
||||
|
||||
case "/token/timeout":
|
||||
// Simulate timeout
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
case "/token/error":
|
||||
// Server error
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "server_error",
|
||||
})
|
||||
|
||||
case "/token/malformed":
|
||||
// Malformed JSON
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token": "test", invalid json`))
|
||||
|
||||
case "/token/partial":
|
||||
// Partial response (missing ID token)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"access_token": "partial_access_token",
|
||||
"refresh_token": "partial_refresh_token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
// ID token intentionally missing
|
||||
})
|
||||
|
||||
case "/token/ratelimit":
|
||||
// Rate limiting
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "rate_limit_exceeded",
|
||||
})
|
||||
|
||||
case "/token/rotating":
|
||||
// Token rotation - return different refresh token
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "rotated_access_token",
|
||||
IDToken: "rotated_id_token",
|
||||
RefreshToken: fmt.Sprintf("rotated_refresh_token_%d", time.Now().UnixNano()),
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Setup TraefikOidc instance
|
||||
oidc := tt.setupMock(server)
|
||||
|
||||
// Execute the function
|
||||
resp, err := oidc.GetNewTokenWithRefreshToken(tt.refreshToken)
|
||||
|
||||
// Validate results
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.expectedError)
|
||||
} else if !tt.wantErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tt.wantErr && err != nil && tt.expectedError != "" {
|
||||
// Check if error message contains expected string
|
||||
if !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Logf("Error doesn't contain expected string %q: %v", tt.expectedError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run custom validation
|
||||
if tt.validateFunc != nil {
|
||||
tt.validateFunc(t, resp, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetNewTokenWithRefreshToken_Concurrency tests concurrent refresh scenarios
|
||||
func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
|
||||
t.Run("multiple concurrent refreshes with same token", func(t *testing.T) {
|
||||
refreshCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
refreshCount++
|
||||
count := refreshCount
|
||||
mu.Unlock()
|
||||
|
||||
// Simulate processing time
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: fmt.Sprintf("access_token_%d", count),
|
||||
IDToken: fmt.Sprintf("id_token_%d", count),
|
||||
RefreshToken: fmt.Sprintf("refresh_token_%d", count),
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
// Run multiple concurrent refreshes with the same token
|
||||
const numRequests = 5
|
||||
results := make(chan *TokenResponse, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := oidc.GetNewTokenWithRefreshToken("same_refresh_token")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
} else {
|
||||
results <- resp
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
close(errors)
|
||||
|
||||
// Verify all requests completed
|
||||
successCount := len(results)
|
||||
errorCount := len(errors)
|
||||
|
||||
if successCount != numRequests {
|
||||
t.Errorf("expected %d successful refreshes, got %d", numRequests, successCount)
|
||||
}
|
||||
if errorCount > 0 {
|
||||
t.Errorf("got %d errors in concurrent refreshes", errorCount)
|
||||
}
|
||||
|
||||
// Verify we actually made concurrent requests
|
||||
mu.Lock()
|
||||
finalCount := refreshCount
|
||||
mu.Unlock()
|
||||
|
||||
if finalCount != numRequests {
|
||||
t.Errorf("expected %d refresh calls, got %d", numRequests, finalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("race condition detection", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return successful response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "race_test_access_token",
|
||||
IDToken: "race_test_id_token",
|
||||
RefreshToken: "race_test_refresh_token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
// Run with race detector (go test -race will catch issues)
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
token := fmt.Sprintf("refresh_token_%d", id)
|
||||
_, _ = oidc.GetNewTokenWithRefreshToken(token)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetNewTokenWithRefreshToken_ErrorRecovery tests error recovery scenarios
|
||||
func TestGetNewTokenWithRefreshToken_ErrorRecovery(t *testing.T) {
|
||||
t.Run("recovery after temporary failure", func(t *testing.T) {
|
||||
attemptCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount++
|
||||
|
||||
// Fail first two attempts, succeed on third
|
||||
if attemptCount <= 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "temporarily_unavailable",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "recovered_access_token",
|
||||
IDToken: "recovered_id_token",
|
||||
RefreshToken: "recovered_refresh_token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
// First two attempts should fail
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
|
||||
if err == nil {
|
||||
t.Errorf("expected error on attempt %d, got success", i+1)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Errorf("expected nil response on attempt %d", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Third attempt should succeed
|
||||
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error on recovery attempt: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "recovered_access_token" {
|
||||
t.Error("expected successful recovery")
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user