Files
traefikoidc/google_session_test.go
T

600 lines
20 KiB
Go

package traefikoidc
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"math/big"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
// MockTokenVerifier implements the TokenVerifier interface for testing
type MockTokenVerifier struct {
VerifyFunc func(token string) error
}
// MockJWTVerifier implements the JWTVerifier interface for testing
type MockJWTVerifier struct {
VerifyJWTFunc func(jwt *JWT, token string) error
}
func (m *MockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
if m.VerifyJWTFunc != nil {
return m.VerifyJWTFunc(jwt, token)
}
return nil
}
func (m *MockTokenVerifier) VerifyToken(token string) error {
if m.VerifyFunc != nil {
return m.VerifyFunc(token)
}
return nil
}
func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
// Create a mocked TraefikOidc instance that simulates Google provider behavior
mockLogger := NewLogger("debug")
// Create a test instance with a Google-like issuer URL
tOidc := &TraefikOidc{
issuerURL: "https://accounts.google.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
logger: mockLogger,
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60,
}
// Create a session manager
sessionManager, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, mockLogger)
tOidc.sessionManager = sessionManager
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that access_type=offline was added (not offline_access scope for Google)
if !strings.Contains(authURL, "access_type=offline") {
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
}
// Verify offline_access scope is NOT included for Google providers
if strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
}
// Check that prompt=consent was added
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
}
})
t.Run("Non-Google provider doesn't add Google-specific params", func(t *testing.T) {
// Create a test instance with a non-Google issuer URL
nonGoogleOidc := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
logger: mockLogger,
scopes: []string{"openid", "profile", "email"},
}
// Test buildAuthURL without Google-specific parameters
authURL := nonGoogleOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that prompt=consent is not automatically added
if strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent added to non-Google auth URL: %s", authURL)
}
})
t.Run("Session refresh with Google provider", func(t *testing.T) {
// Create a request and response recorder
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
// Create a session and set a refresh token
session, _ := sessionManager.GetSession(req)
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("valid-refresh-token")
// Create a mock token exchanger that simulates Google's behavior
mockTokenExchanger := &MockTokenExchanger{
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
// Check that the refresh token is passed correctly
if refreshToken != "valid-refresh-token" {
t.Errorf("Incorrect refresh token passed: %s", refreshToken)
return nil, fmt.Errorf("invalid token")
}
// Return a simulated Google token response with a new access token
// but without a new refresh token (Google doesn't always return a new refresh token)
return &TokenResponse{
IDToken: "new-id-token-from-google",
AccessToken: "new-access-token-from-google",
RefreshToken: "", // Google often doesn't return a new refresh token
ExpiresIn: 3600,
}, nil
},
}
// Set the mock token exchanger
tOidc.tokenExchanger = mockTokenExchanger
// Create a struct that implements the TokenVerifier interface
tOidc.tokenVerifier = &MockTokenVerifier{
VerifyFunc: func(token string) error {
return nil
},
}
tOidc.extractClaimsFunc = func(token string) (map[string]interface{}, error) {
// Return mock claims
return map[string]interface{}{
"email": "test@example.com",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
}, nil
}
// Attempt to refresh the token
refreshed := tOidc.refreshToken(rw, req, session)
// Verify the refresh was successful
if !refreshed {
t.Error("Token refresh failed for Google provider")
}
// Check that we kept the original refresh token since Google didn't provide a new one
if session.GetRefreshToken() != "valid-refresh-token" {
t.Errorf("Original refresh token not preserved: got %s, expected 'valid-refresh-token'",
session.GetRefreshToken())
}
// Check that the access token was updated
if session.GetAccessToken() != "new-id-token-from-google" {
t.Errorf("Access token not updated: got %s, expected 'new-id-token-from-google'",
session.GetAccessToken())
}
})
// Test that our fix specifically addresses the reported Google error
t.Run("Google provider handles offline access correctly", func(t *testing.T) {
// Build the auth URL with Google provider detection
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Parse the URL to examine its parameters
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
params := parsedURL.Query()
// Verify that access_type=offline is set (Google's way of requesting refresh tokens)
if params.Get("access_type") != "offline" {
t.Errorf("access_type=offline not set in Google auth URL")
}
// Verify that the scope parameter doesn't contain offline_access
// (which Google reports as invalid: {invalid=[offline_access]})
scope := params.Get("scope")
if strings.Contains(scope, "offline_access") {
t.Errorf("offline_access incorrectly included in scope for Google provider: %s", scope)
}
// Verify that the necessary scopes are still included
for _, requiredScope := range []string{"openid", "profile", "email"} {
if !strings.Contains(scope, requiredScope) {
t.Errorf("Required scope '%s' missing from auth URL", requiredScope)
}
}
})
// Enhanced test for verifying non-Google provider includes offline_access scope
t.Run("Non-Google provider includes offline_access scope", func(t *testing.T) {
// Create a test instance with a non-Google issuer URL
nonGoogleOidc := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
logger: mockLogger,
scopes: []string{"openid", "profile", "email"},
}
// Test buildAuthURL for a non-Google provider
authURL := nonGoogleOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Parse the URL to examine its parameters
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
params := parsedURL.Query()
// Verify that access_type=offline is NOT set for non-Google providers
if params.Get("access_type") == "offline" {
t.Errorf("access_type=offline incorrectly added to non-Google auth URL")
}
// Verify that offline_access scope IS included for non-Google providers
scope := params.Get("scope")
if !strings.Contains(scope, "offline_access") {
t.Errorf("offline_access scope missing from non-Google auth URL scope: %s", scope)
}
// Verify that the necessary scopes are still included
for _, requiredScope := range []string{"openid", "profile", "email"} {
if !strings.Contains(scope, requiredScope) {
t.Errorf("Required scope '%s' missing from non-Google auth URL", requiredScope)
}
}
})
// Additional test for complete URL construction for Google provider
t.Run("Complete Google auth URL construction", func(t *testing.T) {
// Build the auth URL with additional parameters
redirectURL := "https://example.com/callback"
state := "state123"
nonce := "nonce123"
codeChallenge := "code_challenge_value" // For PKCE
// Enable PKCE for this test
tOidc.enablePKCE = true
// Build auth URL
authURL := tOidc.buildAuthURL(redirectURL, state, nonce, codeChallenge)
// Parse the URL to examine its structure and parameters
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
// Verify the base URL
expectedBaseURL := "https://accounts.google.com/o/oauth2/v2/auth"
if !strings.HasPrefix(authURL, expectedBaseURL) && !strings.Contains(authURL, "accounts.google.com") {
t.Errorf("Auth URL doesn't start with expected Google OAuth endpoint: %s", authURL)
}
// Check all required parameters
params := parsedURL.Query()
expectedParams := map[string]string{
"client_id": "test-client-id",
"response_type": "code",
"redirect_uri": redirectURL,
"state": state,
"nonce": nonce,
"access_type": "offline",
"prompt": "consent",
}
// Also check PKCE parameters if enabled
if tOidc.enablePKCE {
expectedParams["code_challenge"] = codeChallenge
expectedParams["code_challenge_method"] = "S256"
}
for key, expectedValue := range expectedParams {
if value := params.Get(key); value != expectedValue {
t.Errorf("Parameter %s has incorrect value. Expected: %s, Got: %s",
key, expectedValue, value)
}
}
// Verify scope parameter separately due to it being space-separated values
scope := params.Get("scope")
if scope == "" {
t.Error("Scope parameter missing from Google auth URL")
}
// Check that all required scopes are present
scopeList := strings.Split(scope, " ")
expectedScopes := []string{"openid", "profile", "email"}
for _, expectedScope := range expectedScopes {
found := false
for _, actualScope := range scopeList {
if actualScope == expectedScope {
found = true
break
}
}
if !found {
t.Errorf("Expected scope '%s' not found in scope parameter: %s", expectedScope, scope)
}
}
// Verify offline_access is NOT in the scope list
for _, actualScope := range scopeList {
if actualScope == "offline_access" {
t.Errorf("offline_access scope incorrectly included in Google auth URL: %s", scope)
}
}
})
// Integration test with mocked Google provider
t.Run("Integration test with mocked Google provider", func(t *testing.T) {
// Generate an RSA key for signing the test JWTs
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
// Create JWK for the RSA public key
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPrivateKey.PublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(rsaPrivateKey.PublicKey.E)))),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
// Create a mock JWK cache
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
// Create a complete test instance with all required fields
mockLogger := NewLogger("debug")
googleTOidc := &TraefikOidc{
issuerURL: "https://accounts.google.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
logger: mockLogger,
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60,
tokenCache: NewTokenCache(), // Initialize tokenCache
tokenBlacklist: NewCache(), // Initialize tokenBlacklist
enablePKCE: false,
limiter: rate.NewLimiter(rate.Inf, 0), // No rate limiting for tests
jwkCache: mockJWKCache,
jwksURL: "https://accounts.google.com/jwks",
}
// Create a session manager
sessionManager, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, mockLogger)
googleTOidc.sessionManager = sessionManager
// Create a mock token verifier
mockTokenVerifier := &MockTokenVerifier{
VerifyFunc: func(token string) error {
return nil // Always verify successfully for this test
},
}
googleTOidc.tokenVerifier = mockTokenVerifier
// Create JWT tokens for the test
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
// Create initial ID token
initialIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://accounts.google.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "nonce123", // For initial authentication verification
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create test ID token: %v", err)
}
// Create refresh ID token
refreshedIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://accounts.google.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create refreshed test ID token: %v", err)
}
// Set up token verifier with mock
googleTOidc.tokenVerifier = &MockTokenVerifier{
VerifyFunc: func(token string) error {
return nil // Always verify successfully for this test
},
}
// Set up JWT verifier with mock
googleTOidc.jwtVerifier = &MockJWTVerifier{
VerifyJWTFunc: func(jwt *JWT, token string) error {
return nil // Always verify successfully for this test
},
}
// Create a mock token exchanger that simulates Google's OAuth behavior
mockTokenExchanger := &MockTokenExchanger{
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
// Verify the correct parameters are passed
if grantType != "authorization_code" {
t.Errorf("Expected grant_type=authorization_code, got %s", grantType)
}
if codeOrToken != "test_auth_code" {
t.Errorf("Expected code=test_auth_code, got %s", codeOrToken)
}
if redirectURL != "https://example.com/callback" {
t.Errorf("Expected redirect_uri=https://example.com/callback, got %s", redirectURL)
}
// Return a successful token response with a proper JWT
return &TokenResponse{
IDToken: initialIDToken,
AccessToken: "google_access_token",
RefreshToken: "google_refresh_token",
ExpiresIn: 3600,
}, nil
},
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
// Verify the correct refresh token is passed
if refreshToken != "google_refresh_token" {
t.Errorf("Expected refresh_token=google_refresh_token, got %s", refreshToken)
}
// Return a successful refresh response with a proper JWT
return &TokenResponse{
IDToken: refreshedIDToken,
AccessToken: "new_google_access_token",
RefreshToken: "", // Google doesn't always return a new refresh token
ExpiresIn: 3600,
}, nil
},
}
googleTOidc.tokenExchanger = mockTokenExchanger
// Use the real extractClaimsFunc to parse the proper JWT tokens
googleTOidc.extractClaimsFunc = extractClaims
// 1. Test building the authorization URL
authURL := googleTOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Verify Google-specific parameters
if !strings.Contains(authURL, "access_type=offline") {
t.Errorf("Google auth URL missing access_type=offline: %s", authURL)
}
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("Google auth URL missing prompt=consent: %s", authURL)
}
if strings.Contains(authURL, "offline_access") {
t.Errorf("Google auth URL incorrectly includes offline_access scope: %s", authURL)
}
// 2. Test handling the callback and token exchange
// Create a request and response recorder for the callback
req := httptest.NewRequest("GET", "/callback?code=test_auth_code&state=state123", nil)
rw := httptest.NewRecorder()
// Create a session and set the necessary values
session, _ := googleTOidc.sessionManager.GetSession(req)
session.SetCSRF("state123") // Must match the state parameter
session.SetNonce("nonce123")
// Save the session to the request
if err := session.Save(req, rw); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies from the response and add them to a new request
cookies := rw.Result().Cookies()
callbackReq := httptest.NewRequest("GET", "/callback?code=test_auth_code&state=state123", nil)
for _, cookie := range cookies {
callbackReq.AddCookie(cookie)
}
callbackRw := httptest.NewRecorder()
// Handle the callback
googleTOidc.handleCallback(callbackRw, callbackReq, "https://example.com/callback")
// Verify the response is a redirect (302 Found)
if callbackRw.Code != 302 {
t.Errorf("Expected 302 redirect, got %d", callbackRw.Code)
}
// Create a new request to get the updated session
newReq := httptest.NewRequest("GET", "/", nil)
for _, cookie := range callbackRw.Result().Cookies() {
newReq.AddCookie(cookie)
}
// Get the updated session
newSession, err := googleTOidc.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get session after callback: %v", err)
}
// Verify the session contains the expected values
if !newSession.GetAuthenticated() {
t.Error("Session not marked as authenticated after callback")
}
if newSession.GetEmail() != "user@example.com" {
t.Errorf("Session email incorrect: got %s, expected user@example.com",
newSession.GetEmail())
}
// Check for non-empty access token that can be parsed as JWT
accessToken := newSession.GetAccessToken()
if accessToken == "" {
t.Error("Session access token is empty")
} else {
claims, err := extractClaims(accessToken)
if err != nil {
t.Errorf("Failed to parse access token as JWT: %v", err)
} else if email, ok := claims["email"].(string); !ok || email != "user@example.com" {
t.Errorf("Access token JWT doesn't contain expected email claim")
}
}
// Check refresh token
if newSession.GetRefreshToken() != "google_refresh_token" {
t.Errorf("Session refresh token incorrect: got %s, expected google_refresh_token",
newSession.GetRefreshToken())
}
// 3. Test token refresh
refreshReq := httptest.NewRequest("GET", "/", nil)
for _, cookie := range callbackRw.Result().Cookies() {
refreshReq.AddCookie(cookie)
}
refreshRw := httptest.NewRecorder()
// Get the session for refresh
refreshSession, _ := googleTOidc.sessionManager.GetSession(refreshReq)
// Refresh the token
refreshed := googleTOidc.refreshToken(refreshRw, refreshReq, refreshSession)
// Verify refresh was successful
if !refreshed {
t.Error("Token refresh failed")
}
// Verify the session data after refresh
// Check for non-empty refreshed access token that can be parsed as JWT
refreshedAccessToken := refreshSession.GetAccessToken()
if refreshedAccessToken == "" {
t.Error("Session access token is empty after refresh")
} else {
claims, err := extractClaims(refreshedAccessToken)
if err != nil {
t.Errorf("Failed to parse refreshed access token as JWT: %v", err)
} else if email, ok := claims["email"].(string); !ok || email != "user@example.com" {
t.Errorf("Refreshed access token JWT doesn't contain expected email claim")
}
}
// Since Google didn't return a new refresh token, the original should be preserved
if refreshSession.GetRefreshToken() != "google_refresh_token" {
t.Errorf("Original refresh token not preserved: got %s, expected google_refresh_token",
refreshSession.GetRefreshToken())
}
})
}
// No need to redefine MockTokenExchanger - it's already defined in main_test.go