Files
traefikoidc/main_refresh_test.go
T
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00

673 lines
19 KiB
Go

package traefikoidc
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
)
// TestGetNewTokenWithRefreshToken tests the GetNewTokenWithRefreshToken function
func TestGetNewTokenWithRefreshToken(t *testing.T) {
tests := []struct {
name string
refreshToken string
setupMock func(*httptest.Server) *TraefikOidc
validateFunc func(*testing.T, *TokenResponse, error)
wantErr bool
expectedError string
}{
{
name: "successful token refresh",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if resp == nil {
t.Error("expected token response, got nil")
return
}
if resp.AccessToken != "refreshed_access_token" {
t.Errorf("expected refreshed_access_token, got %s", resp.AccessToken)
}
if resp.IDToken != "refreshed_id_token" {
t.Errorf("expected refreshed_id_token, got %s", resp.IDToken)
}
if resp.RefreshToken != "new_refresh_token" {
t.Errorf("expected new_refresh_token, got %s", resp.RefreshToken)
}
if resp.TokenType != "Bearer" {
t.Errorf("expected token type Bearer, got %s", resp.TokenType)
}
if resp.ExpiresIn != 3600 {
t.Errorf("expected expires_in 3600, got %d", resp.ExpiresIn)
}
},
wantErr: false,
},
{
name: "expired refresh token",
refreshToken: "expired_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/expired",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for expired refresh token, got nil")
return
}
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "expired") {
t.Errorf("expected invalid_grant or expired error, got: %v", err)
}
},
wantErr: true,
expectedError: "invalid_grant",
},
{
name: "invalid refresh token",
refreshToken: "invalid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/invalid",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for invalid refresh token, got nil")
return
}
if !strings.Contains(err.Error(), "invalid_grant") {
t.Errorf("expected invalid_grant error, got: %v", err)
}
},
wantErr: true,
expectedError: "invalid_grant",
},
{
name: "revoked refresh token",
refreshToken: "revoked_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/revoked",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for revoked refresh token, got nil")
return
}
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "revoked") {
t.Errorf("expected invalid_grant or revoked error, got: %v", err)
}
},
wantErr: true,
expectedError: "invalid_grant",
},
{
name: "network timeout during refresh",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/timeout",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 100 * time.Millisecond,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected timeout error, got nil")
return
}
if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "deadline") {
t.Errorf("expected timeout error, got: %v", err)
}
},
wantErr: true,
expectedError: "timeout",
},
{
name: "server error during refresh",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/error",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected server error, got nil")
return
}
if !strings.Contains(err.Error(), "500") && !strings.Contains(err.Error(), "server_error") {
t.Errorf("expected server error, got: %v", err)
}
},
wantErr: true,
expectedError: "server_error",
},
{
name: "malformed JSON response",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/malformed",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected JSON parse error, got nil")
return
}
// Accept various JSON parsing error messages
if !strings.Contains(err.Error(), "json") && !strings.Contains(err.Error(), "unmarshal") && !strings.Contains(err.Error(), "invalid character") {
t.Errorf("expected JSON error, got: %v", err)
}
},
wantErr: true,
expectedError: "json",
},
{
name: "partial token response (missing ID token)",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/partial",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err != nil {
t.Logf("got error: %v", err)
}
if resp == nil {
t.Error("expected partial token response, got nil")
return
}
if resp.AccessToken != "partial_access_token" {
t.Errorf("expected partial_access_token, got %s", resp.AccessToken)
}
if resp.IDToken != "" {
t.Errorf("expected empty ID token, got %s", resp.IDToken)
}
},
wantErr: false,
},
{
name: "rate limited refresh request",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/ratelimit",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected rate limit error, got nil")
return
}
if !strings.Contains(err.Error(), "429") && !strings.Contains(err.Error(), "rate") {
t.Errorf("expected rate limit error, got: %v", err)
}
},
wantErr: true,
expectedError: "rate",
},
{
name: "empty refresh token",
refreshToken: "",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for empty refresh token, got nil")
return
}
// The actual error should contain invalid_request
if !strings.Contains(err.Error(), "invalid_request") && !strings.Contains(err.Error(), "missing") {
t.Errorf("expected invalid_request or missing error, got: %v", err)
}
if resp != nil {
t.Error("expected nil response for empty refresh token")
}
},
wantErr: true,
expectedError: "invalid_request",
},
{
name: "refresh with rotating tokens",
refreshToken: "rotating_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/rotating",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if resp == nil {
t.Error("expected token response, got nil")
return
}
// Verify we got a different refresh token (rotation)
if resp.RefreshToken == "rotating_refresh_token" {
t.Error("expected new refresh token (rotation), got same token")
}
if resp.RefreshToken == "" {
t.Error("expected new refresh token, got empty")
}
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test server with various endpoints
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Parse request body
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
// Verify grant type for refresh
if values.Get("grant_type") != "refresh_token" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "unsupported_grant_type",
})
return
}
// Handle different test scenarios based on path
switch r.URL.Path {
case "/token":
// Check for empty refresh token
if values.Get("refresh_token") == "" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_request",
"error_description": "The refresh token is missing",
})
return
}
// Successful refresh
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "refreshed_access_token",
IDToken: "refreshed_id_token",
RefreshToken: "new_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
case "/token/expired":
// Expired refresh token
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token has expired",
})
case "/token/invalid":
// Invalid refresh token
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token is invalid",
})
case "/token/revoked":
// Revoked refresh token
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token has been revoked",
})
case "/token/timeout":
// Simulate timeout
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
case "/token/error":
// Server error
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{
"error": "server_error",
})
case "/token/malformed":
// Malformed JSON
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token": "test", invalid json`))
case "/token/partial":
// Partial response (missing ID token)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": "partial_access_token",
"refresh_token": "partial_refresh_token",
"token_type": "Bearer",
"expires_in": 3600,
// ID token intentionally missing
})
case "/token/ratelimit":
// Rate limiting
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limit_exceeded",
})
case "/token/rotating":
// Token rotation - return different refresh token
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "rotated_access_token",
IDToken: "rotated_id_token",
RefreshToken: fmt.Sprintf("rotated_refresh_token_%d", time.Now().UnixNano()),
TokenType: "Bearer",
ExpiresIn: 3600,
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
// Setup TraefikOidc instance
oidc := tt.setupMock(server)
// Execute the function
resp, err := oidc.GetNewTokenWithRefreshToken(tt.refreshToken)
// Validate results
if tt.wantErr && err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else if !tt.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
} else if tt.wantErr && err != nil && tt.expectedError != "" {
// Check if error message contains expected string
if !strings.Contains(err.Error(), tt.expectedError) {
t.Logf("Error doesn't contain expected string %q: %v", tt.expectedError, err)
}
}
// Run custom validation
if tt.validateFunc != nil {
tt.validateFunc(t, resp, err)
}
})
}
}
// TestGetNewTokenWithRefreshToken_Concurrency tests concurrent refresh scenarios
func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
t.Run("multiple concurrent refreshes with same token", func(t *testing.T) {
refreshCount := 0
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
refreshCount++
count := refreshCount
mu.Unlock()
// Simulate processing time
time.Sleep(50 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: fmt.Sprintf("access_token_%d", count),
IDToken: fmt.Sprintf("id_token_%d", count),
RefreshToken: fmt.Sprintf("refresh_token_%d", count),
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
// Run multiple concurrent refreshes with the same token
const numRequests = 5
results := make(chan *TokenResponse, numRequests)
errors := make(chan error, numRequests)
var wg sync.WaitGroup
wg.Add(numRequests)
for i := 0; i < numRequests; i++ {
go func() {
defer wg.Done()
resp, err := oidc.GetNewTokenWithRefreshToken("same_refresh_token")
if err != nil {
errors <- err
} else {
results <- resp
}
}()
}
wg.Wait()
close(results)
close(errors)
// Verify all requests completed
successCount := len(results)
errorCount := len(errors)
if successCount != numRequests {
t.Errorf("expected %d successful refreshes, got %d", numRequests, successCount)
}
if errorCount > 0 {
t.Errorf("got %d errors in concurrent refreshes", errorCount)
}
// Verify we actually made concurrent requests
mu.Lock()
finalCount := refreshCount
mu.Unlock()
if finalCount != numRequests {
t.Errorf("expected %d refresh calls, got %d", numRequests, finalCount)
}
})
t.Run("race condition detection", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return successful response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "race_test_access_token",
IDToken: "race_test_id_token",
RefreshToken: "race_test_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
// Run with race detector (go test -race will catch issues)
const numGoroutines = 10
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
token := fmt.Sprintf("refresh_token_%d", id)
_, _ = oidc.GetNewTokenWithRefreshToken(token)
}(i)
}
wg.Wait()
})
}
// TestGetNewTokenWithRefreshToken_ErrorRecovery tests error recovery scenarios
func TestGetNewTokenWithRefreshToken_ErrorRecovery(t *testing.T) {
t.Run("recovery after temporary failure", func(t *testing.T) {
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
// Fail first two attempts, succeed on third
if attemptCount <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(map[string]string{
"error": "temporarily_unavailable",
})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "recovered_access_token",
IDToken: "recovered_id_token",
RefreshToken: "recovered_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
// First two attempts should fail
for i := 0; i < 2; i++ {
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
if err == nil {
t.Errorf("expected error on attempt %d, got success", i+1)
}
if resp != nil {
t.Errorf("expected nil response on attempt %d", i+1)
}
}
// Third attempt should succeed
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
if err != nil {
t.Errorf("unexpected error on recovery attempt: %v", err)
}
if resp == nil || resp.AccessToken != "recovered_access_token" {
t.Error("expected successful recovery")
}
})
}