mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
6efb78b7a8
* Smarter approach to the cookies - Single maxCookieSize = 1400 constant with clear documentation - Combined cookie storage for ~40-45% size reduction - Backward compatible migration from legacy cookies * Tuneup the code.
2117 lines
57 KiB
Go
2117 lines
57 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"text/template"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// =============================================================================
|
|
// TOKEN TEST CONSTANTS AND TYPES
|
|
// =============================================================================
|
|
|
|
// Test tokens used across multiple test files
|
|
var (
|
|
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjozMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU" // trufflehog:ignore
|
|
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjozMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU" // trufflehog:ignore
|
|
ValidRefreshToken = "refresh_token_abc123"
|
|
MinimalValidJWT = "eyJhbGciOiJub25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0." // trufflehog:ignore
|
|
InvalidTokenOneDot = "invalid.token"
|
|
InvalidTokenNoDots = "invalidtoken"
|
|
InvalidTokenThreeDots = "invalid..token"
|
|
)
|
|
|
|
// TestTokens provides test JWT tokens
|
|
type TestTokens struct {
|
|
validJWT string
|
|
expiredJWT string
|
|
}
|
|
|
|
func NewTestTokens() *TestTokens {
|
|
return &TestTokens{
|
|
validJWT: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjozMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU", // trufflehog:ignore
|
|
expiredJWT: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIn0.eyJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjoxMDAwMDAwMDAwLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.dGVzdC1zaWduYXR1cmU", // trufflehog:ignore
|
|
}
|
|
}
|
|
|
|
func (tt *TestTokens) CreateValidJWT() string {
|
|
return tt.validJWT
|
|
}
|
|
|
|
// TokenSet represents a complete set of tokens with proper field names
|
|
type TokenSet struct {
|
|
AccessToken string
|
|
IDToken string
|
|
RefreshToken string
|
|
}
|
|
|
|
func (tt *TestTokens) GetValidTokenSet() *TokenSet {
|
|
return &TokenSet{
|
|
AccessToken: tt.validJWT,
|
|
IDToken: tt.validJWT,
|
|
RefreshToken: ValidRefreshToken,
|
|
}
|
|
}
|
|
|
|
func (tt *TestTokens) CreateIncompressibleToken(size int) string {
|
|
return "incompressible." + generateRandomString(size) + ".signature"
|
|
}
|
|
|
|
func (tt *TestTokens) CreateUniqueValidJWT(suffix string) string {
|
|
return tt.validJWT + "_" + suffix
|
|
}
|
|
|
|
func (tt *TestTokens) GetLargeTokenSet() *TokenSet {
|
|
return &TokenSet{
|
|
AccessToken: tt.CreateIncompressibleToken(2000),
|
|
IDToken: tt.CreateIncompressibleToken(2000),
|
|
RefreshToken: ValidRefreshToken,
|
|
}
|
|
}
|
|
|
|
func (tt *TestTokens) CreateExpiredJWT() string {
|
|
return tt.expiredJWT
|
|
}
|
|
|
|
func (tt *TestTokens) CreateLargeValidJWT(claimSize int) string {
|
|
largeClaim := generateRandomString(claimSize)
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","kid":"test-key-id"}`))
|
|
payload := fmt.Sprintf(`{"iss":"https://test-issuer.com","aud":"test-client-id","exp":3000000000,"sub":"test-subject","email":"test@example.com","large_claim":"%s"}`, largeClaim)
|
|
encodedPayload := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
|
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
|
|
return fmt.Sprintf("%s.%s.%s", header, encodedPayload, signature)
|
|
}
|
|
|
|
// TestCache is a simple in-memory cache for testing
|
|
type TestCache struct {
|
|
data map[string]interface{}
|
|
}
|
|
|
|
func NewTestCache() *TestCache {
|
|
return &TestCache{
|
|
data: make(map[string]interface{}),
|
|
}
|
|
}
|
|
|
|
func (c *TestCache) Set(key string, value interface{}, ttl time.Duration) {
|
|
c.data[key] = value
|
|
}
|
|
|
|
func (c *TestCache) Get(key string) (interface{}, bool) {
|
|
val, ok := c.data[key]
|
|
return val, ok
|
|
}
|
|
|
|
func (c *TestCache) Delete(key string) {
|
|
delete(c.data, key)
|
|
}
|
|
|
|
func (c *TestCache) SetMaxSize(size int) {}
|
|
func (c *TestCache) Size() int { return len(c.data) }
|
|
func (c *TestCache) Clear() { c.data = make(map[string]interface{}) }
|
|
func (c *TestCache) Cleanup() {}
|
|
func (c *TestCache) Close() {}
|
|
func (c *TestCache) GetStats() map[string]interface{} {
|
|
return map[string]interface{}{"size": len(c.data)}
|
|
}
|
|
|
|
// =============================================================================
|
|
// OPAQUE TOKEN TESTS
|
|
// =============================================================================
|
|
|
|
func TestOpaqueTokenDetection(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
description string
|
|
isOpaque bool
|
|
}{
|
|
{
|
|
name: "JWT token with 3 parts",
|
|
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", // trufflehog:ignore
|
|
isOpaque: false,
|
|
description: "Standard JWT with header.payload.signature",
|
|
},
|
|
{
|
|
name: "Auth0 opaque token",
|
|
token: "8n3d84nd92nf92nf92nf92nf923nf923nf923nf9",
|
|
isOpaque: true,
|
|
description: "Auth0 opaque access token",
|
|
},
|
|
{
|
|
name: "Okta opaque token",
|
|
token: "00Otkjhgt5Rfasde12345678901234567890",
|
|
isOpaque: true,
|
|
description: "Okta opaque access token",
|
|
},
|
|
{
|
|
name: "AWS Cognito opaque token",
|
|
token: "AGPAYJhZmU3NzI5YTQtNGQ0Yy00YTU5LWJjYTQtYzdlMzQ0MmQ3ZDJl",
|
|
isOpaque: true,
|
|
description: "AWS Cognito opaque access token",
|
|
},
|
|
{
|
|
name: "Invalid single dot token",
|
|
token: "invalid.token",
|
|
isOpaque: true,
|
|
description: "Invalid format with single dot",
|
|
},
|
|
{
|
|
name: "Token with no dots",
|
|
token: "opaquetoken1234567890abcdefghijklmnop",
|
|
isOpaque: true,
|
|
description: "Pure opaque token with no dots",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
dotCount := strings.Count(tt.token, ".")
|
|
isOpaqueToken := dotCount != 2
|
|
|
|
if isOpaqueToken != tt.isOpaque {
|
|
t.Errorf("Token detection failed for %s: expected opaque=%v, got opaque=%v (dots=%d)",
|
|
tt.name, tt.isOpaque, isOpaqueToken, dotCount)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpaqueTokenValidation(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cm := NewChunkManager(logger)
|
|
defer cm.Shutdown()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
wantError bool
|
|
}{
|
|
{
|
|
name: "Valid opaque token",
|
|
token: "opaquetoken1234567890abcdefghijklmnop",
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "Too short opaque token",
|
|
token: "short",
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "Opaque token with spaces",
|
|
token: "opaque token with spaces 1234567890",
|
|
wantError: true,
|
|
},
|
|
{
|
|
name: "Valid JWT token",
|
|
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", // trufflehog:ignore
|
|
wantError: false,
|
|
},
|
|
}
|
|
|
|
config := TokenConfig{
|
|
Type: "access",
|
|
MinLength: 5,
|
|
MaxLength: 100 * 1024,
|
|
MaxChunks: 25,
|
|
MaxChunkSize: maxCookieSize,
|
|
AllowOpaqueTokens: true,
|
|
RequireJWTFormat: false,
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := cm.validateToken(tt.token, config)
|
|
hasError := result.Error != nil
|
|
|
|
if hasError != tt.wantError {
|
|
if tt.wantError {
|
|
t.Errorf("Expected error for %s but got none", tt.name)
|
|
} else {
|
|
t.Errorf("Unexpected error for %s: %v", tt.name, result.Error)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpaqueTokenStorage(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
description string
|
|
shouldStore bool
|
|
}{
|
|
{
|
|
name: "Valid opaque token",
|
|
token: "auth0_opaque_token_1234567890abcdefghijklmnop",
|
|
shouldStore: true,
|
|
description: "Opaque token with sufficient length and no dots",
|
|
},
|
|
{
|
|
name: "Valid JWT token",
|
|
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", // trufflehog:ignore
|
|
shouldStore: true,
|
|
description: "Standard JWT with three parts",
|
|
},
|
|
{
|
|
name: "Invalid single-dot token",
|
|
token: "invalid.token",
|
|
shouldStore: false,
|
|
description: "Token with single dot - invalid format",
|
|
},
|
|
{
|
|
name: "Too short opaque token",
|
|
token: "short",
|
|
shouldStore: false,
|
|
description: "Opaque token too short (less than 20 chars)",
|
|
},
|
|
{
|
|
name: "Multi-dot invalid token",
|
|
token: "too.many.dots.here",
|
|
shouldStore: false,
|
|
description: "Token with more than 2 dots - invalid format",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
shouldStore := true
|
|
if tt.token != "" {
|
|
dotCount := strings.Count(tt.token, ".")
|
|
if dotCount == 1 {
|
|
shouldStore = false
|
|
}
|
|
if dotCount == 0 && len(tt.token) < 20 {
|
|
shouldStore = false
|
|
}
|
|
if dotCount > 2 {
|
|
shouldStore = false
|
|
}
|
|
}
|
|
|
|
if shouldStore != tt.shouldStore {
|
|
t.Errorf("Token storage decision failed for %s: expected store=%v, got store=%v",
|
|
tt.name, tt.shouldStore, shouldStore)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// TOKEN INTROSPECTION TESTS
|
|
// =============================================================================
|
|
|
|
func TestIntrospectToken_Success(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" {
|
|
t.Errorf("Expected application/x-www-form-urlencoded, got %s", r.Header.Get("Content-Type"))
|
|
}
|
|
|
|
username, password, ok := r.BasicAuth()
|
|
if !ok || username != "test-client" || password != "test-secret" {
|
|
t.Errorf("Invalid basic auth: username=%s, password=%s, ok=%v", username, password, ok)
|
|
}
|
|
|
|
body, _ := io.ReadAll(r.Body)
|
|
values, _ := url.ParseQuery(string(body))
|
|
|
|
if values.Get("token") != "test-opaque-token" {
|
|
t.Errorf("Expected token=test-opaque-token, got %s", values.Get("token"))
|
|
}
|
|
if values.Get("token_type_hint") != "access_token" {
|
|
t.Errorf("Expected token_type_hint=access_token, got %s", values.Get("token_type_hint"))
|
|
}
|
|
|
|
resp := IntrospectionResponse{
|
|
Active: true,
|
|
Scope: "openid profile email",
|
|
ClientID: "test-client",
|
|
Username: "testuser",
|
|
TokenType: "Bearer",
|
|
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
|
Iat: time.Now().Add(-5 * time.Minute).Unix(),
|
|
Nbf: time.Now().Add(-5 * time.Minute).Unix(),
|
|
Sub: "user123",
|
|
Aud: "test-audience",
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
resp, err := tOidc.introspectToken("test-opaque-token")
|
|
if err != nil {
|
|
t.Fatalf("introspectToken failed: %v", err)
|
|
}
|
|
|
|
if !resp.Active {
|
|
t.Error("Expected token to be active")
|
|
}
|
|
if resp.ClientID != "test-client" {
|
|
t.Errorf("Expected clientID=test-client, got %s", resp.ClientID)
|
|
}
|
|
if resp.Username != "testuser" {
|
|
t.Errorf("Expected username=testuser, got %s", resp.Username)
|
|
}
|
|
if resp.Scope != "openid profile email" {
|
|
t.Errorf("Expected scope='openid profile email', got %s", resp.Scope)
|
|
}
|
|
}
|
|
|
|
func TestIntrospectToken_CachedResult(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
requestCount := 0
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount++
|
|
resp := IntrospectionResponse{
|
|
Active: true,
|
|
ClientID: "test-client",
|
|
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
resp1, err := tOidc.introspectToken("cached-token")
|
|
if err != nil {
|
|
t.Fatalf("First introspectToken failed: %v", err)
|
|
}
|
|
if !resp1.Active {
|
|
t.Error("Expected first token to be active")
|
|
}
|
|
if requestCount != 1 {
|
|
t.Errorf("Expected 1 request after first call, got %d", requestCount)
|
|
}
|
|
|
|
resp2, err := tOidc.introspectToken("cached-token")
|
|
if err != nil {
|
|
t.Fatalf("Second introspectToken failed: %v", err)
|
|
}
|
|
if !resp2.Active {
|
|
t.Error("Expected second token to be active")
|
|
}
|
|
if requestCount != 1 {
|
|
t.Errorf("Expected 1 request after cache hit, got %d", requestCount)
|
|
}
|
|
}
|
|
|
|
func TestIntrospectToken_MissingEndpoint(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
tOidc := &TraefikOidc{
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: "",
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
_, err := tOidc.introspectToken("test-token")
|
|
if err == nil {
|
|
t.Error("Expected error for missing introspection endpoint")
|
|
}
|
|
if !strings.Contains(err.Error(), "introspection endpoint not available") {
|
|
t.Errorf("Expected 'introspection endpoint not available' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestIntrospectToken_HTTPError(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "invalid_client"}`))
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
_, err := tOidc.introspectToken("test-token")
|
|
if err == nil {
|
|
t.Error("Expected error for HTTP 401 response")
|
|
}
|
|
if !strings.Contains(err.Error(), "401") {
|
|
t.Errorf("Expected error mentioning status 401, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestIntrospectToken_InvalidJSON(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{invalid json`))
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
_, err := tOidc.introspectToken("test-token")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid JSON response")
|
|
}
|
|
if !strings.Contains(err.Error(), "failed to decode") {
|
|
t.Errorf("Expected 'failed to decode' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueToken_OpaqueTokensDisabled(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
tOidc := &TraefikOidc{
|
|
allowOpaqueTokens: false,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
err := tOidc.validateOpaqueToken("test-token")
|
|
if err == nil {
|
|
t.Error("Expected error when opaque tokens are disabled")
|
|
}
|
|
if !strings.Contains(err.Error(), "opaque tokens are not enabled") {
|
|
t.Errorf("Expected 'opaque tokens are not enabled' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueToken_InactiveToken(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := IntrospectionResponse{
|
|
Active: false,
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
allowOpaqueTokens: true,
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
err := tOidc.validateOpaqueToken("inactive-token")
|
|
if err == nil {
|
|
t.Error("Expected error for inactive token")
|
|
}
|
|
if !strings.Contains(err.Error(), "not active") {
|
|
t.Errorf("Expected 'not active' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueToken_ExpiredToken(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := IntrospectionResponse{
|
|
Active: true,
|
|
Exp: time.Now().Add(-1 * time.Hour).Unix(),
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
allowOpaqueTokens: true,
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
err := tOidc.validateOpaqueToken("expired-token")
|
|
if err == nil {
|
|
t.Error("Expected error for expired token")
|
|
}
|
|
if !strings.Contains(err.Error(), "expired") {
|
|
t.Errorf("Expected 'expired' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueToken_InvalidAudience(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := IntrospectionResponse{
|
|
Active: true,
|
|
Aud: "wrong-audience",
|
|
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
allowOpaqueTokens: true,
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
audience: "expected-audience",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
err := tOidc.validateOpaqueToken("wrong-aud-token")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid audience")
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid audience") {
|
|
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueToken_SuccessfulValidation(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resp := IntrospectionResponse{
|
|
Active: true,
|
|
ClientID: "test-client",
|
|
Aud: "test-audience",
|
|
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
|
Nbf: time.Now().Add(-5 * time.Minute).Unix(),
|
|
Scope: "openid profile",
|
|
Sub: "user123",
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
allowOpaqueTokens: true,
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
audience: "test-audience",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
err := tOidc.validateOpaqueToken("valid-token")
|
|
if err != nil {
|
|
t.Errorf("Expected successful validation, got error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestIntrospectToken_ConcurrentCalls(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
cacheManager := GetUniversalCacheManager(logger)
|
|
defer ResetUniversalCacheManagerForTesting()
|
|
|
|
var requestCount int
|
|
var mu sync.Mutex
|
|
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
requestCount++
|
|
mu.Unlock()
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
resp := IntrospectionResponse{
|
|
Active: true,
|
|
ClientID: "test-client",
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tOidc := &TraefikOidc{
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
introspectionURL: mockServer.URL,
|
|
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
|
logger: logger,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
concurrency := 10
|
|
wg.Add(concurrency)
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
token := fmt.Sprintf("concurrent-token-%d", id)
|
|
_, err := tOidc.introspectToken(token)
|
|
if err != nil {
|
|
t.Errorf("Concurrent introspection %d failed: %v", id, err)
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
mu.Lock()
|
|
finalCount := requestCount
|
|
mu.Unlock()
|
|
|
|
if finalCount != concurrency {
|
|
t.Errorf("Expected %d requests for %d concurrent calls, got %d", concurrency, concurrency, finalCount)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// TOKEN TYPE DETECTION TESTS
|
|
// =============================================================================
|
|
|
|
func TestDetectTokenType(t *testing.T) {
|
|
tr := &TraefikOidc{
|
|
clientID: "test-client-id",
|
|
suppressDiagnosticLogs: true,
|
|
tokenTypeCache: NewTestCache(),
|
|
}
|
|
|
|
testCases := []struct {
|
|
jwt *JWT
|
|
name string
|
|
token string
|
|
description string
|
|
expectedID bool
|
|
}{
|
|
{
|
|
name: "ID token with nonce",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"nonce": "test-nonce",
|
|
"aud": "test-client-id",
|
|
},
|
|
},
|
|
token: "test-token-with-nonce",
|
|
expectedID: true,
|
|
description: "Should detect ID token via nonce claim",
|
|
},
|
|
{
|
|
name: "RFC 9068 access token",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{
|
|
"alg": "RS256",
|
|
"typ": "at+jwt",
|
|
},
|
|
Claims: map[string]interface{}{
|
|
"scope": "openid profile",
|
|
},
|
|
},
|
|
token: "test-access-token-rfc9068",
|
|
expectedID: false,
|
|
description: "Should detect access token via typ=at+jwt header",
|
|
},
|
|
{
|
|
name: "Token with token_use=id",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"token_use": "id",
|
|
"aud": "test-client-id",
|
|
},
|
|
},
|
|
token: "test-token-use-id",
|
|
expectedID: true,
|
|
description: "Should detect ID token via token_use claim",
|
|
},
|
|
{
|
|
name: "Token with token_use=access",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"token_use": "access",
|
|
"scope": "read write",
|
|
},
|
|
},
|
|
token: "test-token-use-access",
|
|
expectedID: false,
|
|
description: "Should detect access token via token_use claim",
|
|
},
|
|
{
|
|
name: "Access token with scope",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"scope": "openid profile email",
|
|
"aud": "some-api-audience",
|
|
},
|
|
},
|
|
token: "test-access-token-with-scope",
|
|
expectedID: false,
|
|
description: "Should detect access token via scope claim",
|
|
},
|
|
{
|
|
name: "ID token with client_id audience",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"aud": "test-client-id",
|
|
"sub": "user123",
|
|
},
|
|
},
|
|
token: "test-id-token-client-aud",
|
|
expectedID: true,
|
|
description: "Should detect ID token via audience matching client_id",
|
|
},
|
|
{
|
|
name: "Default to access token",
|
|
jwt: &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"aud": "different-audience",
|
|
"sub": "user123",
|
|
},
|
|
},
|
|
token: "test-default-access-token",
|
|
expectedID: false,
|
|
description: "Should default to access token when no clear indicators",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result := tr.detectTokenType(tc.jwt, tc.token)
|
|
if result != tc.expectedID {
|
|
t.Errorf("%s: expected isIDToken=%v, got %v", tc.description, tc.expectedID, result)
|
|
}
|
|
|
|
result2 := tr.detectTokenType(tc.jwt, tc.token)
|
|
if result2 != tc.expectedID {
|
|
t.Errorf("%s (cached): expected isIDToken=%v, got %v", tc.description, tc.expectedID, result2)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDetectTokenTypeCaching(t *testing.T) {
|
|
cache := NewTestCache()
|
|
tr := &TraefikOidc{
|
|
clientID: "test-client-id",
|
|
suppressDiagnosticLogs: true,
|
|
tokenTypeCache: cache,
|
|
}
|
|
|
|
jwt := &JWT{
|
|
Header: map[string]interface{}{"alg": "RS256"},
|
|
Claims: map[string]interface{}{
|
|
"nonce": "test-nonce",
|
|
},
|
|
}
|
|
token := "test-token-for-caching-with-enough-characters-for-key"
|
|
cacheKey := token
|
|
if len(token) > 32 {
|
|
cacheKey = token[:32]
|
|
}
|
|
|
|
result := tr.detectTokenType(jwt, token)
|
|
if !result {
|
|
t.Error("Expected ID token detection via nonce")
|
|
}
|
|
|
|
if cached, found := cache.Get(cacheKey); !found {
|
|
t.Error("Expected token type to be cached")
|
|
} else if cachedBool, ok := cached.(bool); !ok || !cachedBool {
|
|
t.Error("Expected cached value to be true (ID token)")
|
|
}
|
|
|
|
jwt.Claims = map[string]interface{}{
|
|
"scope": "openid profile",
|
|
}
|
|
|
|
result2 := tr.detectTokenType(jwt, token)
|
|
if !result2 {
|
|
t.Error("Expected cached ID token result, ignoring modified JWT")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// TOKEN VALIDATOR TESTS
|
|
// =============================================================================
|
|
|
|
func TestNewTokenValidator(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
if validator == nil {
|
|
t.Fatal("Expected non-nil token validator")
|
|
}
|
|
|
|
if validator.logger == nil {
|
|
t.Error("Expected logger to be initialized")
|
|
}
|
|
}
|
|
|
|
func TestNewTokenValidatorWithLogger(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
validator := NewTokenValidator(logger)
|
|
|
|
if validator == nil {
|
|
t.Fatal("Expected non-nil token validator")
|
|
}
|
|
|
|
if validator.logger != logger {
|
|
t.Error("Expected provided logger to be used")
|
|
}
|
|
}
|
|
|
|
func TestValidateTokenEmpty(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
result := validator.ValidateToken("", false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for empty token")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for empty token")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "empty") {
|
|
t.Errorf("Expected 'empty' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateTokenRequireJWT(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
result := validator.ValidateToken("opaque_token_value_here", true)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for opaque token when JWT required")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error when JWT required but opaque token provided")
|
|
}
|
|
}
|
|
|
|
func TestValidateJWTValidFormat(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
}
|
|
|
|
token := createTestJWTSimple(claims)
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if !result.Valid {
|
|
t.Errorf("Expected valid result, got error: %v", result.Error)
|
|
}
|
|
|
|
if result.TokenType != "JWT" {
|
|
t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType)
|
|
}
|
|
|
|
if result.Claims == nil {
|
|
t.Error("Expected claims to be parsed")
|
|
}
|
|
|
|
if result.Expiry == nil {
|
|
t.Error("Expected expiry to be extracted")
|
|
}
|
|
|
|
if result.IssuedAt == nil {
|
|
t.Error("Expected issued at to be extracted")
|
|
}
|
|
}
|
|
|
|
func TestValidateJWTExpiredToken(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(-1 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
|
}
|
|
|
|
token := createTestJWTSimple(claims)
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for expired token")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for expired token")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "expired") {
|
|
t.Errorf("Expected 'expired' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateJWTFutureIssuedAt(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(10 * time.Minute).Unix(),
|
|
}
|
|
|
|
token := createTestJWTSimple(claims)
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for future iat")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for future iat")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "future") {
|
|
t.Errorf("Expected 'future' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateJWTNotBeforeClaim(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": "user123",
|
|
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"nbf": time.Now().Add(1 * time.Hour).Unix(),
|
|
}
|
|
|
|
token := createTestJWTSimple(claims)
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for nbf in future")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for nbf in future")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "not yet valid") {
|
|
t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateJWTInvalidFormat(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
}{
|
|
{"single part", "eyJhbGciOiJIUzI1NiJ9"},
|
|
{"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"},
|
|
{"four parts", "part1.part2.part3.part4"},
|
|
{"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := validator.ValidateToken(tt.token, true)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for malformed JWT")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for malformed JWT")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueTokenValid(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token := "sk_live_abcdef123456GHIJKL789"
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if !result.Valid {
|
|
t.Errorf("Expected valid result, got error: %v", result.Error)
|
|
}
|
|
|
|
if result.TokenType != "Opaque" {
|
|
t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueTokenTooShort(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token := "short"
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for short token")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for short token")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "too short") {
|
|
t.Errorf("Expected 'too short' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueTokenWithSpaces(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token := "this token has spaces in it"
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for token with spaces")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for token with spaces")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "spaces") {
|
|
t.Errorf("Expected 'spaces' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueTokenControlCharacters(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token := "token_with\x00control_char"
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for token with control characters")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for token with control characters")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "control character") {
|
|
t.Errorf("Expected 'control character' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token := "aaaaaabbbbbbccccccdddd"
|
|
result := validator.ValidateToken(token, false)
|
|
|
|
if result.Valid {
|
|
t.Error("Expected invalid result for low entropy token")
|
|
}
|
|
|
|
if result.Error == nil {
|
|
t.Error("Expected error for low entropy token")
|
|
}
|
|
|
|
if !strings.Contains(result.Error.Error(), "entropy") {
|
|
t.Errorf("Expected 'entropy' in error, got: %v", result.Error)
|
|
}
|
|
}
|
|
|
|
func TestIsValidBase64URL(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected bool
|
|
}{
|
|
{"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true},
|
|
{"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true},
|
|
{"valid numbers", "0123456789", true},
|
|
{"valid dash", "abc-def", true},
|
|
{"valid underscore", "abc_def", true},
|
|
{"valid equals", "abc=", true},
|
|
{"invalid at sign", "abc@def", false},
|
|
{"invalid space", "abc def", false},
|
|
{"invalid plus", "abc+def", false},
|
|
{"invalid slash", "abc/def", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := validator.isValidBase64URL(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractTime(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
tests := []struct {
|
|
claim interface{}
|
|
name string
|
|
expected bool
|
|
}{
|
|
{name: "float64", claim: float64(1609459200), expected: true},
|
|
{name: "int64", claim: int64(1609459200), expected: true},
|
|
{name: "int", claim: int(1609459200), expected: true},
|
|
{name: "string", claim: "not a timestamp", expected: false},
|
|
{name: "nil", claim: nil, expected: false},
|
|
{name: "map", claim: map[string]interface{}{}, expected: false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := validator.extractTime(tt.claim)
|
|
|
|
if tt.expected && result == nil {
|
|
t.Error("Expected non-nil time")
|
|
}
|
|
|
|
if !tt.expected && result != nil {
|
|
t.Error("Expected nil time")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateTokenSize(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
maxSize int
|
|
expectError bool
|
|
}{
|
|
{"within limit", "short_token", 20, false},
|
|
{"at limit", "exactly_twenty_c", 16, false},
|
|
{"exceeds limit", "this_token_is_too_long", 10, true},
|
|
{"empty token", "", 10, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validator.ValidateTokenSize(tt.token, tt.maxSize)
|
|
|
|
if tt.expectError && err == nil {
|
|
t.Error("Expected error for oversized token")
|
|
}
|
|
|
|
if !tt.expectError && err != nil {
|
|
t.Errorf("Expected no error, got: %v", err)
|
|
}
|
|
|
|
if err != nil && !strings.Contains(err.Error(), "exceeds") {
|
|
t.Errorf("Expected 'exceeds' in error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractClaims(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": "user123",
|
|
"email": "user@example.com",
|
|
"exp": float64(1609459200),
|
|
}
|
|
|
|
token := createTestJWTSimple(claims)
|
|
extracted, err := validator.ExtractClaims(token)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got: %v", err)
|
|
}
|
|
|
|
if extracted == nil {
|
|
t.Fatal("Expected non-nil claims")
|
|
}
|
|
|
|
if extracted["sub"] != "user123" {
|
|
t.Errorf("Expected sub 'user123', got %v", extracted["sub"])
|
|
}
|
|
|
|
if extracted["email"] != "user@example.com" {
|
|
t.Errorf("Expected email 'user@example.com', got %v", extracted["email"])
|
|
}
|
|
}
|
|
|
|
func TestExtractClaimsInvalidFormat(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
}{
|
|
{"single part", "onlyonepart"},
|
|
{"two parts", "two.parts"},
|
|
{"four parts", "one.two.three.four"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := validator.ExtractClaims(tt.token)
|
|
|
|
if err == nil {
|
|
t.Error("Expected error for invalid format")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "invalid JWT format") {
|
|
t.Errorf("Expected 'invalid JWT format' in error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCompareTokensEqual(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token1 := "secret_token_12345"
|
|
token2 := "secret_token_12345"
|
|
|
|
if !validator.CompareTokens(token1, token2) {
|
|
t.Error("Expected tokens to be equal")
|
|
}
|
|
}
|
|
|
|
func TestCompareTokensDifferent(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token1 := "secret_token_12345"
|
|
token2 := "secret_token_54321"
|
|
|
|
if validator.CompareTokens(token1, token2) {
|
|
t.Error("Expected tokens to be different")
|
|
}
|
|
}
|
|
|
|
func TestCompareTokensDifferentLength(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token1 := "short"
|
|
token2 := "much_longer_token"
|
|
|
|
if validator.CompareTokens(token1, token2) {
|
|
t.Error("Expected tokens to be different (different lengths)")
|
|
}
|
|
}
|
|
|
|
func TestCompareTokensEmpty(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
token1 := ""
|
|
token2 := ""
|
|
|
|
if !validator.CompareTokens(token1, token2) {
|
|
t.Error("Expected empty tokens to be equal")
|
|
}
|
|
}
|
|
|
|
func TestValidateTokenMaliciousPayloads(t *testing.T) {
|
|
validator := NewTokenValidator(nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
}{
|
|
{"sql injection attempt", "'; DROP TABLE users; --"},
|
|
{"xss attempt", "<script>alert('xss')</script>"},
|
|
{"path traversal", "../../../etc/passwd"},
|
|
{"null bytes", "token\x00with\x00nulls"},
|
|
{"unicode exploit", "token\u0000\u0001\u0002"},
|
|
{"extremely long", strings.Repeat("a", 100000)},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := validator.ValidateToken(tt.token, false)
|
|
|
|
if result.Valid {
|
|
if result.Claims != nil {
|
|
t.Logf("Token considered valid: %s", tt.name)
|
|
}
|
|
} else {
|
|
if result.Error == nil {
|
|
t.Error("Expected error for malicious payload")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// CONSOLIDATED TOKEN TESTS
|
|
// =============================================================================
|
|
|
|
func TestTokenTypes(t *testing.T) {
|
|
t.Run("TokenTypeDistinction", func(t *testing.T) {
|
|
type templateData struct {
|
|
Claims map[string]interface{}
|
|
AccessToken string
|
|
IDToken string
|
|
RefreshToken string
|
|
}
|
|
|
|
testData := templateData{
|
|
AccessToken: "test-access-token-abc123",
|
|
IDToken: "test-id-token-xyz789",
|
|
RefreshToken: "test-refresh-token",
|
|
Claims: map[string]interface{}{
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
templateText string
|
|
expectedValue string
|
|
}{
|
|
{
|
|
name: "Access Token Only",
|
|
templateText: "Bearer {{.AccessToken}}",
|
|
expectedValue: "Bearer test-access-token-abc123",
|
|
},
|
|
{
|
|
name: "ID Token Only",
|
|
templateText: "ID: {{.IDToken}}",
|
|
expectedValue: "ID: test-id-token-xyz789",
|
|
},
|
|
{
|
|
name: "Both Tokens",
|
|
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
|
expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
tmpl, err := template.New("test").Parse(tc.templateText)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse template: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
err = tmpl.Execute(&buf, testData)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute template: %v", err)
|
|
}
|
|
|
|
result := buf.String()
|
|
if result != tc.expectedValue {
|
|
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("TokenTypeIntegration", func(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(3000000000),
|
|
"sub": "id-token-subject",
|
|
"email": "id@example.com",
|
|
"nonce": "test-nonce",
|
|
"token_type": "id",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create ID token: %v", err)
|
|
}
|
|
|
|
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": float64(3000000000),
|
|
"sub": "access-token-subject",
|
|
"email": "access@example.com",
|
|
"scope": "openid email profile",
|
|
"token_type": "access",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create access token: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
session, err := ts.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
session.SetIDToken(idToken)
|
|
session.SetAccessToken(accessToken)
|
|
|
|
retrievedID := session.GetIDToken()
|
|
retrievedAccess := session.GetAccessToken()
|
|
|
|
if retrievedID != idToken {
|
|
t.Errorf("ID token mismatch: expected %q, got %q", idToken, retrievedID)
|
|
}
|
|
if retrievedAccess != accessToken {
|
|
t.Errorf("Access token mismatch: expected %q, got %q", accessToken, retrievedAccess)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTokenCorruption(t *testing.T) {
|
|
t.Run("TokenCorruptionScenario", func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
validJWT := testTokens.CreateLargeValidJWT(100)
|
|
|
|
tests := []struct {
|
|
corruptionScenario func(*SessionData)
|
|
name string
|
|
tokenSize int
|
|
iterations int
|
|
expectConsistent bool
|
|
}{
|
|
{
|
|
name: "Small token - multiple retrievals",
|
|
tokenSize: len(validJWT),
|
|
iterations: 10,
|
|
expectConsistent: true,
|
|
},
|
|
{
|
|
name: "Large chunked token - multiple retrievals",
|
|
tokenSize: 5000,
|
|
iterations: 10,
|
|
expectConsistent: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
token := createTokenOfSize(validJWT, tt.tokenSize)
|
|
session.SetAccessToken(token)
|
|
|
|
var retrievedTokens []string
|
|
for i := 0; i < tt.iterations; i++ {
|
|
retrieved := session.GetAccessToken()
|
|
retrievedTokens = append(retrievedTokens, retrieved)
|
|
|
|
if tt.expectConsistent && retrieved != token {
|
|
t.Errorf("Iteration %d: Token changed unexpectedly", i)
|
|
}
|
|
}
|
|
|
|
if tt.expectConsistent {
|
|
for i, retrievedToken := range retrievedTokens {
|
|
if retrievedToken != token {
|
|
t.Errorf("Iteration %d: Token mismatch", i)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Base64CorruptionHandling", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expectError bool
|
|
}{
|
|
{"Valid base64", "eyJhbGciOiJSUzI1NiJ9", false},
|
|
{"Invalid characters", "eyJ!@#$%^&*()", true},
|
|
{"Missing padding", "eyJhbGc", false},
|
|
{"Empty string", "", false},
|
|
{"Spaces in base64", "eyJ hbG ciOi JSU zI1 NiJ9", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(tt.input))
|
|
hasError := err != nil
|
|
if hasError != tt.expectError {
|
|
t.Errorf("Expected error=%v, got error=%v (err: %v)", tt.expectError, hasError, err)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTokenResilience(t *testing.T) {
|
|
t.Run("ConcurrentTokenAccess", func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
testToken := "test-token-" + generateRandomString(100)
|
|
session.SetAccessToken(testToken)
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, 100)
|
|
successCount := int32(0)
|
|
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
retrieved := session.GetAccessToken()
|
|
if retrieved == testToken {
|
|
atomic.AddInt32(&successCount, 1)
|
|
} else {
|
|
errors <- fmt.Errorf("token mismatch: expected %q, got %q", testToken, retrieved)
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
for err := range errors {
|
|
t.Error(err)
|
|
}
|
|
|
|
if successCount != 100 {
|
|
t.Errorf("Expected 100 successful retrievals, got %d", successCount)
|
|
}
|
|
})
|
|
|
|
t.Run("TokenSizeHandling", func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
sizes := []int{
|
|
100,
|
|
1000,
|
|
4000,
|
|
5000,
|
|
10000,
|
|
}
|
|
|
|
for _, size := range sizes {
|
|
t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
token := createTokenOfSize(ValidAccessToken, size)
|
|
session.SetAccessToken(token)
|
|
|
|
retrieved := session.GetAccessToken()
|
|
if size > 15000 && retrieved == "" {
|
|
t.Logf("Token size %d exceeds chunk limits (expected)", size)
|
|
} else if retrieved != token {
|
|
t.Errorf("Token mismatch for size %d", size)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("RateLimitedTokenRefresh", func(t *testing.T) {
|
|
limiter := rate.NewLimiter(rate.Limit(10), 1)
|
|
|
|
var wg sync.WaitGroup
|
|
successCount := int32(0)
|
|
deniedCount := int32(0)
|
|
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if limiter.Allow() {
|
|
atomic.AddInt32(&successCount, 1)
|
|
} else {
|
|
atomic.AddInt32(&deniedCount, 1)
|
|
}
|
|
}()
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
t.Logf("Allowed: %d, Denied: %d", successCount, deniedCount)
|
|
if successCount == 0 {
|
|
t.Error("No requests were allowed")
|
|
}
|
|
if successCount == 50 {
|
|
t.Error("All requests were allowed, rate limiting not working")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTokenValidation(t *testing.T) {
|
|
t.Run("JWTStructureValidation", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
expectValid bool
|
|
}{
|
|
{
|
|
name: "Valid JWT structure",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature",
|
|
expectValid: true,
|
|
},
|
|
{
|
|
name: "Missing signature",
|
|
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
|
|
expectValid: false,
|
|
},
|
|
{
|
|
name: "Missing payload",
|
|
token: "eyJhbGciOiJSUzI1NiJ9..signature",
|
|
expectValid: true,
|
|
},
|
|
{
|
|
name: "Only header",
|
|
token: "eyJhbGciOiJSUzI1NiJ9",
|
|
expectValid: false,
|
|
},
|
|
{
|
|
name: "Too many parts",
|
|
token: "header.payload.signature.extra",
|
|
expectValid: false,
|
|
},
|
|
{
|
|
name: "Empty token",
|
|
token: "",
|
|
expectValid: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parts := strings.Split(tt.token, ".")
|
|
isValid := len(parts) == 3
|
|
if isValid != tt.expectValid {
|
|
t.Errorf("Expected valid=%v, got %v", tt.expectValid, isValid)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("TokenExpiryValidation", func(t *testing.T) {
|
|
now := time.Now()
|
|
tests := []struct {
|
|
exp time.Time
|
|
name string
|
|
expectValid bool
|
|
}{
|
|
{name: "Future expiry", exp: now.Add(time.Hour), expectValid: true},
|
|
{name: "Just expired", exp: now.Add(-time.Second), expectValid: false},
|
|
{name: "Long expired", exp: now.Add(-24 * time.Hour), expectValid: false},
|
|
{name: "Far future", exp: now.Add(365 * 24 * time.Hour), expectValid: true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
isValid := tt.exp.After(now)
|
|
if isValid != tt.expectValid {
|
|
t.Errorf("Expected valid=%v, got %v", tt.expectValid, isValid)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTokenChunking(t *testing.T) {
|
|
t.Run("ChunkSplitting", func(t *testing.T) {
|
|
chunkSize := 4000
|
|
tests := []struct {
|
|
name string
|
|
tokenSize int
|
|
expectedChunks int
|
|
}{
|
|
{"Small token", 100, 1},
|
|
{"Just under chunk size", 3999, 1},
|
|
{"Exactly chunk size", 4000, 1},
|
|
{"Just over chunk size", 4100, 2},
|
|
{"Multiple chunks", 10000, 3},
|
|
{"Large token", 50000, 13},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
token := generateRandomString(tt.tokenSize)
|
|
chunks := (len(token) + chunkSize - 1) / chunkSize
|
|
if chunks != tt.expectedChunks {
|
|
t.Errorf("Expected %d chunks, got %d", tt.expectedChunks, chunks)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("ChunkReassembly", func(t *testing.T) {
|
|
originalToken := generateRandomString(10000)
|
|
chunkSize := 4000
|
|
|
|
var chunks []string
|
|
for i := 0; i < len(originalToken); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(originalToken) {
|
|
end = len(originalToken)
|
|
}
|
|
chunks = append(chunks, originalToken[i:end])
|
|
}
|
|
|
|
var reassembled strings.Builder
|
|
for _, chunk := range chunks {
|
|
reassembled.WriteString(chunk)
|
|
}
|
|
|
|
if reassembled.String() != originalToken {
|
|
t.Error("Token reassembly failed")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTokenCompression(t *testing.T) {
|
|
t.Run("CompressionEfficiency", func(t *testing.T) {
|
|
repetitiveToken := strings.Repeat("AAAA", 1000)
|
|
|
|
var compressed bytes.Buffer
|
|
gz := gzip.NewWriter(&compressed)
|
|
_, err := gz.Write([]byte(repetitiveToken))
|
|
if err != nil {
|
|
t.Fatalf("Compression failed: %v", err)
|
|
}
|
|
gz.Close()
|
|
|
|
compressionRatio := float64(len(repetitiveToken)) / float64(compressed.Len())
|
|
t.Logf("Compression ratio: %.2fx (original: %d, compressed: %d)",
|
|
compressionRatio, len(repetitiveToken), compressed.Len())
|
|
|
|
if compressionRatio < 10 {
|
|
t.Error("Expected better compression for repetitive data")
|
|
}
|
|
})
|
|
|
|
t.Run("CompressionDecompression", func(t *testing.T) {
|
|
tokens := []string{
|
|
generateRandomString(100),
|
|
generateRandomString(1000),
|
|
generateRandomString(10000),
|
|
strings.Repeat("A", 5000),
|
|
}
|
|
|
|
for i, token := range tokens {
|
|
t.Run(fmt.Sprintf("Token_%d", i), func(t *testing.T) {
|
|
var compressed bytes.Buffer
|
|
gz := gzip.NewWriter(&compressed)
|
|
_, err := gz.Write([]byte(token))
|
|
if err != nil {
|
|
t.Fatalf("Compression failed: %v", err)
|
|
}
|
|
gz.Close()
|
|
|
|
reader, err := gzip.NewReader(&compressed)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create decompressor: %v", err)
|
|
}
|
|
var decompressed bytes.Buffer
|
|
_, err = decompressed.ReadFrom(reader)
|
|
if err != nil {
|
|
t.Fatalf("Decompression failed: %v", err)
|
|
}
|
|
reader.Close()
|
|
|
|
if decompressed.String() != token {
|
|
t.Error("Token changed after compression/decompression")
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAjaxTokenExpiry(t *testing.T) {
|
|
t.Run("AjaxExpiryDetection", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
isAjax bool
|
|
tokenExpired bool
|
|
expectedStatus int
|
|
}{
|
|
{"Regular request, valid token", false, false, http.StatusOK},
|
|
{"Regular request, expired token", false, true, http.StatusFound},
|
|
{"Ajax request, valid token", true, false, http.StatusOK},
|
|
{"Ajax request, expired token", true, true, http.StatusUnauthorized},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
if tt.isAjax {
|
|
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
if tt.tokenExpired {
|
|
if tt.isAjax {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error": "token_expired", "message": "Your session has expired"}`))
|
|
} else {
|
|
w.WriteHeader(http.StatusFound)
|
|
w.Header().Set("Location", "/auth/login")
|
|
}
|
|
} else {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("Success"))
|
|
}
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
|
}
|
|
|
|
if tt.isAjax && tt.tokenExpired {
|
|
body := w.Body.String()
|
|
if !strings.Contains(body, "token_expired") {
|
|
t.Error("Expected token_expired error in response")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTestTokens_CreateValidJWT(t *testing.T) {
|
|
tokens := NewTestTokens()
|
|
jwt := tokens.CreateValidJWT()
|
|
|
|
parts := strings.Split(jwt, ".")
|
|
if len(parts) != 3 {
|
|
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
|
|
}
|
|
|
|
headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode header: %v", err)
|
|
}
|
|
|
|
var header map[string]interface{}
|
|
if err := json.Unmarshal(headerJSON, &header); err != nil {
|
|
t.Fatalf("Failed to parse header: %v", err)
|
|
}
|
|
|
|
if header["alg"] != "RS256" {
|
|
t.Errorf("Expected RS256 algorithm, got %v", header["alg"])
|
|
}
|
|
}
|
|
|
|
func TestTestTokens_CreateLargeValidJWT(t *testing.T) {
|
|
tokens := NewTestTokens()
|
|
sizes := []int{10, 100, 1000}
|
|
|
|
for _, size := range sizes {
|
|
t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) {
|
|
jwt := tokens.CreateLargeValidJWT(size)
|
|
|
|
parts := strings.Split(jwt, ".")
|
|
if len(parts) != 3 {
|
|
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
|
|
}
|
|
|
|
minExpectedSize := size + 200
|
|
if len(jwt) < minExpectedSize {
|
|
t.Errorf("JWT seems too small for requested claim size: got %d, expected at least %d", len(jwt), minExpectedSize)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTestTokens_CreateExpiredJWT(t *testing.T) {
|
|
tokens := NewTestTokens()
|
|
jwt := tokens.CreateExpiredJWT()
|
|
|
|
parts := strings.Split(jwt, ".")
|
|
if len(parts) != 3 {
|
|
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
|
|
}
|
|
|
|
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode payload: %v", err)
|
|
}
|
|
|
|
var payload map[string]interface{}
|
|
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
|
t.Fatalf("Failed to parse payload: %v", err)
|
|
}
|
|
|
|
exp, ok := payload["exp"].(float64)
|
|
if !ok {
|
|
t.Fatal("Expected exp claim in payload")
|
|
}
|
|
|
|
if exp >= float64(time.Now().Unix()) {
|
|
t.Error("Token should be expired")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// HELPER FUNCTIONS
|
|
// =============================================================================
|
|
|
|
// Mock implementations for testing
|
|
type MockJWTVerifier struct {
|
|
valid bool
|
|
}
|
|
|
|
func (v *MockJWTVerifier) Verify(token string) error {
|
|
if !v.valid {
|
|
return fmt.Errorf("invalid token")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func equalSlices(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i, v := range a {
|
|
if v != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func createTokenOfSize(baseToken string, targetSize int) string {
|
|
if targetSize > 1000 {
|
|
testTokens := NewTestTokens()
|
|
claimSize := targetSize - 230
|
|
if claimSize < 0 {
|
|
claimSize = 10
|
|
}
|
|
return testTokens.CreateLargeValidJWT(claimSize)
|
|
}
|
|
|
|
return baseToken
|
|
}
|
|
|
|
func createTestJWTSimple(claims map[string]interface{}) string {
|
|
header := map[string]interface{}{
|
|
"alg": "HS256",
|
|
"typ": "JWT",
|
|
}
|
|
|
|
headerJSON, _ := json.Marshal(header)
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
|
|
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature"))
|
|
|
|
return headerB64 + "." + claimsB64 + "." + signature
|
|
}
|