Files
traefikoidc/main_refresh_test.go
T
lukaszraczylo bde1db1c3b traefik plugin 0.7.7 (#73)
* Automatic discovery of the scopes.

Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider.
This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes.

Before:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed"
Authentication:  FAILS

After:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Middleware checks discovery doc → offline_access not supported
Automatically filters to: ["openid", "profile", "email"]
Authentication:  SUCCEEDS

* Resolves issue #74 by enabling user to specify expected audience in the configuration.

* Fix flaky tests.
2025-10-08 11:44:00 +01:00

687 lines
20 KiB
Go

package traefikoidc
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
)
// TestGetNewTokenWithRefreshToken tests the GetNewTokenWithRefreshToken function
func TestGetNewTokenWithRefreshToken(t *testing.T) {
tests := []struct {
name string
refreshToken string
setupMock func(*httptest.Server) *TraefikOidc
validateFunc func(*testing.T, *TokenResponse, error)
wantErr bool
expectedError string
}{
{
name: "successful token refresh",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if resp == nil {
t.Error("expected token response, got nil")
return
}
if resp.AccessToken != "refreshed_access_token" {
t.Errorf("expected refreshed_access_token, got %s", resp.AccessToken)
}
if resp.IDToken != "refreshed_id_token" {
t.Errorf("expected refreshed_id_token, got %s", resp.IDToken)
}
if resp.RefreshToken != "new_refresh_token" {
t.Errorf("expected new_refresh_token, got %s", resp.RefreshToken)
}
if resp.TokenType != "Bearer" {
t.Errorf("expected token type Bearer, got %s", resp.TokenType)
}
if resp.ExpiresIn != 3600 {
t.Errorf("expected expires_in 3600, got %d", resp.ExpiresIn)
}
},
wantErr: false,
},
{
name: "expired refresh token",
refreshToken: "expired_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/expired",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for expired refresh token, got nil")
return
}
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "expired") {
t.Errorf("expected invalid_grant or expired error, got: %v", err)
}
},
wantErr: true,
expectedError: "invalid_grant",
},
{
name: "invalid refresh token",
refreshToken: "invalid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/invalid",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for invalid refresh token, got nil")
return
}
if !strings.Contains(err.Error(), "invalid_grant") {
t.Errorf("expected invalid_grant error, got: %v", err)
}
},
wantErr: true,
expectedError: "invalid_grant",
},
{
name: "revoked refresh token",
refreshToken: "revoked_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/revoked",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for revoked refresh token, got nil")
return
}
if !strings.Contains(err.Error(), "invalid_grant") && !strings.Contains(err.Error(), "revoked") {
t.Errorf("expected invalid_grant or revoked error, got: %v", err)
}
},
wantErr: true,
expectedError: "invalid_grant",
},
{
name: "network timeout during refresh",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/timeout",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 100 * time.Millisecond,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected timeout error, got nil")
return
}
if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "deadline") {
t.Errorf("expected timeout error, got: %v", err)
}
},
wantErr: true,
expectedError: "timeout",
},
{
name: "server error during refresh",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/error",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected server error, got nil")
return
}
if !strings.Contains(err.Error(), "500") && !strings.Contains(err.Error(), "server_error") {
t.Errorf("expected server error, got: %v", err)
}
},
wantErr: true,
expectedError: "server_error",
},
{
name: "malformed JSON response",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/malformed",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected JSON parse error, got nil")
return
}
// Accept various JSON parsing error messages
if !strings.Contains(err.Error(), "json") && !strings.Contains(err.Error(), "unmarshal") && !strings.Contains(err.Error(), "invalid character") {
t.Errorf("expected JSON error, got: %v", err)
}
},
wantErr: true,
expectedError: "json",
},
{
name: "partial token response (missing ID token)",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/partial",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err != nil {
t.Logf("got error: %v", err)
}
if resp == nil {
t.Error("expected partial token response, got nil")
return
}
if resp.AccessToken != "partial_access_token" {
t.Errorf("expected partial_access_token, got %s", resp.AccessToken)
}
if resp.IDToken != "" {
t.Errorf("expected empty ID token, got %s", resp.IDToken)
}
},
wantErr: false,
},
{
name: "rate limited refresh request",
refreshToken: "valid_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/ratelimit",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected rate limit error, got nil")
return
}
if !strings.Contains(err.Error(), "429") && !strings.Contains(err.Error(), "rate") {
t.Errorf("expected rate limit error, got: %v", err)
}
},
wantErr: true,
expectedError: "rate",
},
{
name: "empty refresh token",
refreshToken: "",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err == nil {
t.Error("expected error for empty refresh token, got nil")
return
}
// The actual error should contain invalid_request
if !strings.Contains(err.Error(), "invalid_request") && !strings.Contains(err.Error(), "missing") {
t.Errorf("expected invalid_request or missing error, got: %v", err)
}
if resp != nil {
t.Error("expected nil response for empty refresh token")
}
},
wantErr: true,
expectedError: "invalid_request",
},
{
name: "refresh with rotating tokens",
refreshToken: "rotating_refresh_token",
setupMock: func(server *httptest.Server) *TraefikOidc {
return &TraefikOidc{
tokenURL: server.URL + "/token/rotating",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
},
validateFunc: func(t *testing.T, resp *TokenResponse, err error) {
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if resp == nil {
t.Error("expected token response, got nil")
return
}
// Verify we got a different refresh token (rotation)
if resp.RefreshToken == "rotating_refresh_token" {
t.Error("expected new refresh token (rotation), got same token")
}
if resp.RefreshToken == "" {
t.Error("expected new refresh token, got empty")
}
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test server with various endpoints
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Parse request body
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
// Verify grant type for refresh
if values.Get("grant_type") != "refresh_token" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "unsupported_grant_type",
})
return
}
// Handle different test scenarios based on path
switch r.URL.Path {
case "/token":
// Check for empty refresh token
if values.Get("refresh_token") == "" {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_request",
"error_description": "The refresh token is missing",
})
return
}
// Successful refresh
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "refreshed_access_token",
IDToken: "refreshed_id_token",
RefreshToken: "new_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
case "/token/expired":
// Expired refresh token
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token has expired",
})
case "/token/invalid":
// Invalid refresh token
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token is invalid",
})
case "/token/revoked":
// Revoked refresh token
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": "The refresh token has been revoked",
})
case "/token/timeout":
// Simulate timeout
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
case "/token/error":
// Server error
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{
"error": "server_error",
})
case "/token/malformed":
// Malformed JSON
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token": "test", invalid json`))
case "/token/partial":
// Partial response (missing ID token)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": "partial_access_token",
"refresh_token": "partial_refresh_token",
"token_type": "Bearer",
"expires_in": 3600,
// ID token intentionally missing
})
case "/token/ratelimit":
// Rate limiting
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "rate_limit_exceeded",
})
case "/token/rotating":
// Token rotation - return different refresh token
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "rotated_access_token",
IDToken: "rotated_id_token",
RefreshToken: fmt.Sprintf("rotated_refresh_token_%d", time.Now().UnixNano()),
TokenType: "Bearer",
ExpiresIn: 3600,
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
// Setup TraefikOidc instance
oidc := tt.setupMock(server)
// Execute the function
resp, err := oidc.GetNewTokenWithRefreshToken(tt.refreshToken)
// Validate results
if tt.wantErr && err == nil {
t.Errorf("expected error containing %q, got nil", tt.expectedError)
} else if !tt.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
} else if tt.wantErr && err != nil && tt.expectedError != "" {
// Check if error message contains expected string
if !strings.Contains(err.Error(), tt.expectedError) {
t.Logf("Error doesn't contain expected string %q: %v", tt.expectedError, err)
}
}
// Run custom validation
if tt.validateFunc != nil {
tt.validateFunc(t, resp, err)
}
})
}
}
// TestGetNewTokenWithRefreshToken_Concurrency tests concurrent refresh scenarios
func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
t.Run("multiple concurrent refreshes with same token", func(t *testing.T) {
refreshCount := 0
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
refreshCount++
count := refreshCount
mu.Unlock()
// Simulate processing time
time.Sleep(50 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: fmt.Sprintf("access_token_%d", count),
IDToken: fmt.Sprintf("id_token_%d", count),
RefreshToken: fmt.Sprintf("refresh_token_%d", count),
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
// Run multiple concurrent refreshes with the same token
const numRequests = 5
results := make(chan *TokenResponse, numRequests)
errors := make(chan error, numRequests)
var wg sync.WaitGroup
wg.Add(numRequests)
for i := 0; i < numRequests; i++ {
go func() {
defer wg.Done()
resp, err := oidc.GetNewTokenWithRefreshToken("same_refresh_token")
if err != nil {
errors <- err
} else {
results <- resp
}
}()
}
wg.Wait()
close(results)
close(errors)
// Verify all requests completed
successCount := len(results)
errorCount := len(errors)
if successCount != numRequests {
t.Errorf("expected %d successful refreshes, got %d", numRequests, successCount)
}
if errorCount > 0 {
t.Errorf("got %d errors in concurrent refreshes", errorCount)
}
// Verify we actually made concurrent requests
mu.Lock()
finalCount := refreshCount
mu.Unlock()
if finalCount != numRequests {
t.Errorf("expected %d refresh calls, got %d", numRequests, finalCount)
}
})
t.Run("race condition detection", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return successful response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "race_test_access_token",
IDToken: "race_test_id_token",
RefreshToken: "race_test_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
// Run with race detector (go test -race will catch issues)
const numGoroutines = 10
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
token := fmt.Sprintf("refresh_token_%d", id)
_, _ = oidc.GetNewTokenWithRefreshToken(token)
}(i)
}
wg.Wait()
})
}
// TestGetNewTokenWithRefreshToken_ErrorRecovery tests error recovery scenarios
func TestGetNewTokenWithRefreshToken_ErrorRecovery(t *testing.T) {
t.Run("recovery after temporary failure", func(t *testing.T) {
attemptCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
// Fail first two attempts, succeed on third
if attemptCount <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(map[string]string{
"error": "temporarily_unavailable",
})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "recovered_access_token",
IDToken: "recovered_id_token",
RefreshToken: "recovered_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
logger: NewLogger("debug"),
}
// First two attempts should fail
for i := 0; i < 2; i++ {
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
if err == nil {
t.Errorf("expected error on attempt %d, got success", i+1)
}
if resp != nil {
t.Errorf("expected nil response on attempt %d", i+1)
}
}
// Third attempt should succeed
resp, err := oidc.GetNewTokenWithRefreshToken("test_refresh_token")
if err != nil {
t.Errorf("unexpected error on recovery attempt: %v", err)
}
if resp == nil || resp.AccessToken != "recovered_access_token" {
t.Error("expected successful recovery")
}
})
}