Fix: Wrong IdToken passed when AccessToken was configured

This commit is contained in:
2025-05-06 20:21:00 +01:00
parent 2583266738
commit 075476792f
6 changed files with 481 additions and 23 deletions
+11 -6
View File
@@ -161,9 +161,14 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
session.GetRefreshToken())
}
// Check that the access token was updated
if session.GetAccessToken() != "new-id-token-from-google" {
t.Errorf("Access token not updated: got %s, expected 'new-id-token-from-google'",
// Check that the tokens were updated correctly
if session.GetIDToken() != "new-id-token-from-google" {
t.Errorf("ID token not updated: got %s, expected 'new-id-token-from-google'",
session.GetIDToken())
}
if session.GetAccessToken() != "new-access-token-from-google" {
t.Errorf("Access token not updated: got %s, expected 'new-access-token-from-google'",
session.GetAccessToken())
}
})
@@ -445,7 +450,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
// Return a successful token response with a proper JWT
return &TokenResponse{
IDToken: initialIDToken,
AccessToken: "google_access_token",
AccessToken: initialIDToken, // Use a valid JWT as the access token too
RefreshToken: "google_refresh_token",
ExpiresIn: 3600,
}, nil
@@ -459,8 +464,8 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
// Return a successful refresh response with a proper JWT
return &TokenResponse{
IDToken: refreshedIDToken,
AccessToken: "new_google_access_token",
RefreshToken: "", // Google doesn't always return a new refresh token
AccessToken: refreshedIDToken, // Use a valid JWT as the access token
RefreshToken: "", // Google doesn't always return a new refresh token
ExpiresIn: 3600,
}, nil
},
+15 -12
View File
@@ -769,7 +769,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
return
}
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken()) // Using the actual access token
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Continue without group/role headers if extraction fails
@@ -805,7 +805,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
@@ -826,8 +826,8 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
RefreshToken string
Claims map[string]interface{}
}{
AccessToken: accessToken,
IdToken: accessToken, // Using access token as ID token
AccessToken: session.GetAccessToken(),
IdToken: session.GetIDToken(),
RefreshToken: refreshToken,
Claims: claims,
}
@@ -887,6 +887,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
// Clear authentication data but preserve CSRF state if possible (though Clear might remove it)
session.SetAuthenticated(false)
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
@@ -983,7 +984,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Verify tokens and claims
// Verify ID token and claims
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
@@ -1039,7 +1040,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
// Clear CSRF, Nonce, CodeVerifier after use
@@ -1569,13 +1571,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new access token (ID token)
// Verify the new ID token
if err := t.verifyToken(newToken.IDToken); err != nil {
truncatedNewToken := newToken.IDToken
truncatedToken := newToken.IDToken
if len(newToken.IDToken) > 10 {
truncatedNewToken = newToken.IDToken[:10]
truncatedToken = newToken.IDToken[:10]
}
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err)
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedToken, err)
return false
}
@@ -1614,8 +1616,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
t.logger.Debugf("New token expires at: %v (in %v)", expiryTime, time.Until(expiryTime))
}
// Set the new access token
session.SetAccessToken(newToken.IDToken)
// Set the new tokens
session.SetIDToken(newToken.IDToken)
session.SetAccessToken(newToken.AccessToken)
// Handle the refresh token
if newToken.RefreshToken != "" {
+36
View File
@@ -757,3 +757,39 @@ func (sd *SessionData) GetIncomingPath() string {
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
// GetIDToken retrieves the ID token stored in the session.
// It handles reassembling the token from multiple cookie chunks if necessary
// and decompresses it if it was stored compressed.
//
// Returns:
// - The complete, decompressed ID token string, or an empty string if not found.
func (sd *SessionData) GetIDToken() string {
token, _ := sd.mainSession.Values["id_token"].(string)
if token != "" {
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
return ""
}
// SetIDToken stores the provided ID token in the session.
//
// Parameters:
// - token: The ID token string to store.
func (sd *SessionData) SetIDToken(token string) {
if token == "" {
sd.mainSession.Values["id_token"] = ""
sd.mainSession.Values["id_token_compressed"] = false
return
}
// Compress token
compressed := compressToken(token)
sd.mainSession.Values["id_token"] = compressed
sd.mainSession.Values["id_token_compressed"] = true
}
+4 -4
View File
@@ -192,14 +192,14 @@ func TestTemplateExecutionContext(t *testing.T) {
expectedValue string
}{
{
name: "Access and ID token identity",
name: "Access and ID token distinction",
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
data: templateData{
AccessToken: "access-token",
IdToken: "access-token", // Same as AccessToken in processAuthorizedRequest
AccessToken: "access-token-value",
IdToken: "id-token-value", // Now these should be distinct values
Claims: map[string]interface{}{},
},
expectedValue: "Access: access-token ID: access-token",
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims",
+106 -1
View File
@@ -66,6 +66,28 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
"Authorization": "",
},
},
{
name: "ID Token Header",
headers: []TemplatedHeader{
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
},
expectedHeaders: map[string]string{
// We'll update this dynamically after generating the token
"X-ID-Token": "",
},
},
{
name: "Both Token Types",
headers: []TemplatedHeader{
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
},
expectedHeaders: map[string]string{
// We'll update these dynamically after generating the tokens
"X-Access-Token": "",
"X-ID-Token": "",
},
},
{
name: "Missing Claim",
headers: []TemplatedHeader{
@@ -210,9 +232,78 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
// Setup the session with authentication data
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
// For most tests, set the same token for both ID and Access token for backward compatibility
session.SetIDToken(token)
session.SetAccessToken(token)
session.SetRefreshToken("test-refresh-token")
// For tests specifically testing token distinction, set different tokens
if tc.name == "ID Token Header" || tc.name == "Both Token Types" {
// Both tokens need to use the same key ID to avoid verification issues
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(3000000000),
"iat": float64(1000000000),
"nbf": float64(1000000000),
"sub": "test-subject",
"nonce": "test-nonce",
"jti": generateRandomString(16),
"type": "id_token",
"email": "user@example.com", // Use the standard test email directly
})
if err != nil {
t.Fatalf("Failed to create test ID JWT: %v", err)
}
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(3000000000),
"iat": float64(1000000000),
"nbf": float64(1000000000),
"sub": "test-subject",
"jti": generateRandomString(16),
"type": "access_token",
"scope": "openid email profile",
"email": "user@example.com", // Include email in access token too
})
if err != nil {
t.Fatalf("Failed to create test access JWT: %v", err)
}
// Create a proper token exchanger that won't cause nil pointer issues
tOidc.tokenExchanger = &MockTokenExchanger{
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: 3600,
}, nil
},
}
// Also add a mock token verifier to skip verification
tOidc.tokenVerifier = &MockTokenVerifier{
VerifyFunc: func(token string) error {
return nil
},
}
// Set both tokens in the session
session.SetIDToken(idToken)
session.SetAccessToken(accessToken)
// Update expectedHeaders for the token tests
if tc.name == "ID Token Header" {
tc.expectedHeaders["X-ID-Token"] = idToken
} else if tc.name == "Both Token Types" {
tc.expectedHeaders["X-ID-Token"] = idToken
tc.expectedHeaders["X-Access-Token"] = accessToken
}
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
@@ -361,9 +452,23 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
// Setup the session with authentication data
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(token)
session.SetIDToken(token) // Use the new method
session.SetAccessToken(token) // Also set access token to match
session.SetRefreshToken("test-refresh-token")
// Make sure these properties are set so refreshToken won't panic
tOidc.extractClaimsFunc = extractClaims // Ensure claims extraction works
tOidc.tokenExchanger = &MockTokenExchanger{ // Add a mock token exchanger
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: token,
AccessToken: token,
RefreshToken: refreshToken,
ExpiresIn: 3600,
}, nil
},
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
+309
View File
@@ -0,0 +1,309 @@
package traefikoidc
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"text/template"
"time"
"golang.org/x/time/rate"
)
// TestTokenTypeDistinction tests that AccessToken and IdToken are correctly distinguished in templates
func TestTokenTypeDistinction(t *testing.T) {
// Define test data where AccessToken and IdToken are deliberately different
type templateData struct {
AccessToken string
IdToken string
RefreshToken string
Claims map[string]interface{}
}
testData := templateData{
AccessToken: "test-access-token-abc123",
IdToken: "test-id-token-xyz789",
RefreshToken: "test-refresh-token",
Claims: map[string]interface{}{
"sub": "test-subject",
"email": "user@example.com",
},
}
// Test cases
tests := []struct {
name string
templateText string
expectedValue string
}{
{
name: "Access Token Only",
templateText: "Bearer {{.AccessToken}}",
expectedValue: "Bearer test-access-token-abc123",
},
{
name: "ID Token Only",
templateText: "ID: {{.IdToken}}",
expectedValue: "ID: test-id-token-xyz789",
},
{
name: "Both Tokens",
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789",
},
{
name: "Both Tokens in Authorization Format",
templateText: "Bearer {{.AccessToken}} and Bearer {{.IdToken}}",
expectedValue: "Bearer test-access-token-abc123 and Bearer test-id-token-xyz789",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, testData)
if err != nil {
t.Fatalf("Failed to execute template: %v", err)
}
result := buf.String()
if result != tc.expectedValue {
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
}
})
}
}
// TestTokenTypeIntegration tests the integration of ID and access tokens with the middleware
func TestTokenTypeIntegration(t *testing.T) {
// Create a TestSuite to use its helper methods and fields
ts := &TestSuite{t: t}
ts.Setup()
// Create different tokens for ID and access tokens
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(3000000000),
"iat": float64(1000000000),
"nbf": float64(1000000000),
"sub": "test-subject",
"nonce": "test-nonce",
"jti": generateRandomString(16),
"token_type": "id_token",
"email": "user@example.com",
})
if err != nil {
t.Fatalf("Failed to create test ID JWT: %v", err)
}
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": float64(3000000000),
"iat": float64(1000000000),
"nbf": float64(1000000000),
"sub": "test-subject",
"jti": generateRandomString(16),
"token_type": "access_token",
"scope": "openid profile email",
"email": "user@example.com", // Add email to access token so it's available in claims
})
if err != nil {
t.Fatalf("Failed to create test access JWT: %v", err)
}
// Define test headers that use both token types
headers := []TemplatedHeader{
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-Email-From-Claims", Value: "{{.Claims.email}}"},
}
// Store intercepted headers for verification
interceptedHeaders := make(map[string]string)
// Create a test next handler that captures the headers
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture headers for verification
for _, header := range headers {
if value := r.Header.Get(header.Name); value != "" {
interceptedHeaders[header.Name] = value
}
}
w.WriteHeader(http.StatusOK)
})
// Create the TraefikOidc instance
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: NewCache(),
tokenCache: NewTokenCache(),
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: NewLogger("debug"),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
extractClaimsFunc: extractClaims,
headerTemplates: make(map[string]*template.Template),
}
// Initialize and parse header templates
for _, header := range headers {
tmpl, err := template.New(header.Name).Parse(header.Value)
if err != nil {
t.Fatalf("Failed to parse header template for %s: %v", header.Name, err)
}
tOidc.headerTemplates[header.Name] = tmpl
}
// Close the initComplete channel to bypass the waiting
close(tOidc.initComplete)
// Create a test request
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
rr := httptest.NewRecorder()
// Create a session
session, err := tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup the session with authentication data
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetIDToken(idToken) // Set the ID token
session.SetAccessToken(accessToken) // Set the access token
session.SetRefreshToken("test-refresh-token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Add session cookies to the request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset the response recorder for the main test
rr = httptest.NewRecorder()
// Process the request
tOidc.ServeHTTP(rr, req)
// Check status code
if rr.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
}
// Verify headers were set correctly
expectedHeaders := map[string]string{
"X-ID-Token": idToken,
"X-Access-Token": accessToken,
"Authorization": "Bearer " + accessToken,
"X-Email-From-Claims": "user@example.com",
}
for name, expectedValue := range expectedHeaders {
if value, exists := interceptedHeaders[name]; !exists {
t.Errorf("Expected header %s was not set", name)
} else if value != expectedValue {
t.Errorf("Header %s expected value %q, got %q", name, expectedValue, value)
}
}
}
// TestSessionIDTokenAccessToken tests that the SessionData correctly stores and retrieves
// both ID tokens and access tokens separately
func TestSessionIDTokenAccessToken(t *testing.T) {
// Create a logger for the session manager
logger := NewLogger("debug")
// Create a session manager
sessionManager, err := NewSessionManager("test-session-encryption-key-at-least-32-bytes", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a test request
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
// Get a session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set test tokens
idToken := "test-id-token-123"
accessToken := "test-access-token-456"
refreshToken := "test-refresh-token-789"
// Store tokens in session
session.SetIDToken(idToken)
session.SetAccessToken(accessToken)
session.SetRefreshToken(refreshToken)
// Save the session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies from response
cookies := rr.Result().Cookies()
// Create a new request with those cookies
req2 := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
req2.AddCookie(cookie)
}
// Get the session again
session2, err := sessionManager.GetSession(req2)
if err != nil {
t.Fatalf("Failed to get session from request with cookies: %v", err)
}
// Verify that the tokens were correctly stored and retrieved
retrievedIDToken := session2.GetIDToken()
retrievedAccessToken := session2.GetAccessToken()
retrievedRefreshToken := session2.GetRefreshToken()
if retrievedIDToken != idToken {
t.Errorf("ID token mismatch: expected %q, got %q", idToken, retrievedIDToken)
}
if retrievedAccessToken != accessToken {
t.Errorf("Access token mismatch: expected %q, got %q", accessToken, retrievedAccessToken)
}
if retrievedRefreshToken != refreshToken {
t.Errorf("Refresh token mismatch: expected %q, got %q", refreshToken, retrievedRefreshToken)
}
// Verify that the tokens are distinct
if retrievedIDToken == retrievedAccessToken {
t.Errorf("ID token and Access token should be different, but both are %q", retrievedIDToken)
}
}