mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Fix: Wrong IdToken passed when AccessToken was configured
This commit is contained in:
+11
-6
@@ -161,9 +161,14 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
session.GetRefreshToken())
|
||||
}
|
||||
|
||||
// Check that the access token was updated
|
||||
if session.GetAccessToken() != "new-id-token-from-google" {
|
||||
t.Errorf("Access token not updated: got %s, expected 'new-id-token-from-google'",
|
||||
// Check that the tokens were updated correctly
|
||||
if session.GetIDToken() != "new-id-token-from-google" {
|
||||
t.Errorf("ID token not updated: got %s, expected 'new-id-token-from-google'",
|
||||
session.GetIDToken())
|
||||
}
|
||||
|
||||
if session.GetAccessToken() != "new-access-token-from-google" {
|
||||
t.Errorf("Access token not updated: got %s, expected 'new-access-token-from-google'",
|
||||
session.GetAccessToken())
|
||||
}
|
||||
})
|
||||
@@ -445,7 +450,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
// Return a successful token response with a proper JWT
|
||||
return &TokenResponse{
|
||||
IDToken: initialIDToken,
|
||||
AccessToken: "google_access_token",
|
||||
AccessToken: initialIDToken, // Use a valid JWT as the access token too
|
||||
RefreshToken: "google_refresh_token",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
@@ -459,8 +464,8 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
// Return a successful refresh response with a proper JWT
|
||||
return &TokenResponse{
|
||||
IDToken: refreshedIDToken,
|
||||
AccessToken: "new_google_access_token",
|
||||
RefreshToken: "", // Google doesn't always return a new refresh token
|
||||
AccessToken: refreshedIDToken, // Use a valid JWT as the access token
|
||||
RefreshToken: "", // Google doesn't always return a new refresh token
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
|
||||
@@ -769,7 +769,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken()) // Using the actual access token
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Continue without group/role headers if extraction fails
|
||||
@@ -805,7 +805,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
// Set OIDC-specific headers
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetAccessToken(); idToken != "" {
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
|
||||
@@ -826,8 +826,8 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
RefreshToken string
|
||||
Claims map[string]interface{}
|
||||
}{
|
||||
AccessToken: accessToken,
|
||||
IdToken: accessToken, // Using access token as ID token
|
||||
AccessToken: session.GetAccessToken(),
|
||||
IdToken: session.GetIDToken(),
|
||||
RefreshToken: refreshToken,
|
||||
Claims: claims,
|
||||
}
|
||||
@@ -887,6 +887,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
|
||||
// Clear authentication data but preserve CSRF state if possible (though Clear might remove it)
|
||||
session.SetAuthenticated(false)
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
@@ -983,7 +984,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify tokens and claims
|
||||
// Verify ID token and claims
|
||||
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
@@ -1039,7 +1040,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
// Clear CSRF, Nonce, CodeVerifier after use
|
||||
@@ -1569,13 +1571,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the new access token (ID token)
|
||||
// Verify the new ID token
|
||||
if err := t.verifyToken(newToken.IDToken); err != nil {
|
||||
truncatedNewToken := newToken.IDToken
|
||||
truncatedToken := newToken.IDToken
|
||||
if len(newToken.IDToken) > 10 {
|
||||
truncatedNewToken = newToken.IDToken[:10]
|
||||
truncatedToken = newToken.IDToken[:10]
|
||||
}
|
||||
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err)
|
||||
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedToken, err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1614,8 +1616,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Debugf("New token expires at: %v (in %v)", expiryTime, time.Until(expiryTime))
|
||||
}
|
||||
|
||||
// Set the new access token
|
||||
session.SetAccessToken(newToken.IDToken)
|
||||
// Set the new tokens
|
||||
session.SetIDToken(newToken.IDToken)
|
||||
session.SetAccessToken(newToken.AccessToken)
|
||||
|
||||
// Handle the refresh token
|
||||
if newToken.RefreshToken != "" {
|
||||
|
||||
+36
@@ -757,3 +757,39 @@ func (sd *SessionData) GetIncomingPath() string {
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
// GetIDToken retrieves the ID token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed ID token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetIDToken() string {
|
||||
token, _ := sd.mainSession.Values["id_token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetIDToken stores the provided ID token in the session.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The ID token string to store.
|
||||
func (sd *SessionData) SetIDToken(token string) {
|
||||
if token == "" {
|
||||
sd.mainSession.Values["id_token"] = ""
|
||||
sd.mainSession.Values["id_token_compressed"] = false
|
||||
return
|
||||
}
|
||||
|
||||
// Compress token
|
||||
compressed := compressToken(token)
|
||||
|
||||
sd.mainSession.Values["id_token"] = compressed
|
||||
sd.mainSession.Values["id_token_compressed"] = true
|
||||
}
|
||||
|
||||
@@ -192,14 +192,14 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token identity",
|
||||
name: "Access and ID token distinction",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "access-token", // Same as AccessToken in processAuthorizedRequest
|
||||
AccessToken: "access-token-value",
|
||||
IdToken: "id-token-value", // Now these should be distinct values
|
||||
Claims: map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "Access: access-token ID: access-token",
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims",
|
||||
|
||||
@@ -66,6 +66,28 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
"Authorization": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ID Token Header",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update this dynamically after generating the token
|
||||
"X-ID-Token": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both Token Types",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update these dynamically after generating the tokens
|
||||
"X-Access-Token": "",
|
||||
"X-ID-Token": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing Claim",
|
||||
headers: []TemplatedHeader{
|
||||
@@ -210,9 +232,78 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
// Setup the session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
// For most tests, set the same token for both ID and Access token for backward compatibility
|
||||
session.SetIDToken(token)
|
||||
session.SetAccessToken(token)
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
// For tests specifically testing token distinction, set different tokens
|
||||
if tc.name == "ID Token Header" || tc.name == "Both Token Types" {
|
||||
// Both tokens need to use the same key ID to avoid verification issues
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
"type": "id_token",
|
||||
"email": "user@example.com", // Use the standard test email directly
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"jti": generateRandomString(16),
|
||||
"type": "access_token",
|
||||
"scope": "openid email profile",
|
||||
"email": "user@example.com", // Include email in access token too
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test access JWT: %v", err)
|
||||
}
|
||||
|
||||
// Create a proper token exchanger that won't cause nil pointer issues
|
||||
tOidc.tokenExchanger = &MockTokenExchanger{
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Also add a mock token verifier to skip verification
|
||||
tOidc.tokenVerifier = &MockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Set both tokens in the session
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
|
||||
// Update expectedHeaders for the token tests
|
||||
if tc.name == "ID Token Header" {
|
||||
tc.expectedHeaders["X-ID-Token"] = idToken
|
||||
} else if tc.name == "Both Token Types" {
|
||||
tc.expectedHeaders["X-ID-Token"] = idToken
|
||||
tc.expectedHeaders["X-Access-Token"] = accessToken
|
||||
}
|
||||
}
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
@@ -361,9 +452,23 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
// Setup the session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(token)
|
||||
session.SetIDToken(token) // Use the new method
|
||||
session.SetAccessToken(token) // Also set access token to match
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
// Make sure these properties are set so refreshToken won't panic
|
||||
tOidc.extractClaimsFunc = extractClaims // Ensure claims extraction works
|
||||
tOidc.tokenExchanger = &MockTokenExchanger{ // Add a mock token exchanger
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: token,
|
||||
AccessToken: token,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestTokenTypeDistinction tests that AccessToken and IdToken are correctly distinguished in templates
|
||||
func TestTokenTypeDistinction(t *testing.T) {
|
||||
// Define test data where AccessToken and IdToken are deliberately different
|
||||
type templateData struct {
|
||||
AccessToken string
|
||||
IdToken string
|
||||
RefreshToken string
|
||||
Claims map[string]interface{}
|
||||
}
|
||||
|
||||
testData := templateData{
|
||||
AccessToken: "test-access-token-abc123",
|
||||
IdToken: "test-id-token-xyz789",
|
||||
RefreshToken: "test-refresh-token",
|
||||
Claims: map[string]interface{}{
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access Token Only",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
expectedValue: "Bearer test-access-token-abc123",
|
||||
},
|
||||
{
|
||||
name: "ID Token Only",
|
||||
templateText: "ID: {{.IdToken}}",
|
||||
expectedValue: "ID: test-id-token-xyz789",
|
||||
},
|
||||
{
|
||||
name: "Both Tokens",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789",
|
||||
},
|
||||
{
|
||||
name: "Both Tokens in Authorization Format",
|
||||
templateText: "Bearer {{.AccessToken}} and Bearer {{.IdToken}}",
|
||||
expectedValue: "Bearer test-access-token-abc123 and Bearer test-id-token-xyz789",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenTypeIntegration tests the integration of ID and access tokens with the middleware
|
||||
func TestTokenTypeIntegration(t *testing.T) {
|
||||
// Create a TestSuite to use its helper methods and fields
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create different tokens for ID and access tokens
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
"token_type": "id_token",
|
||||
"email": "user@example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"jti": generateRandomString(16),
|
||||
"token_type": "access_token",
|
||||
"scope": "openid profile email",
|
||||
"email": "user@example.com", // Add email to access token so it's available in claims
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test access JWT: %v", err)
|
||||
}
|
||||
|
||||
// Define test headers that use both token types
|
||||
headers := []TemplatedHeader{
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-Email-From-Claims", Value: "{{.Claims.email}}"},
|
||||
}
|
||||
|
||||
// Store intercepted headers for verification
|
||||
interceptedHeaders := make(map[string]string)
|
||||
|
||||
// Create a test next handler that captures the headers
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture headers for verification
|
||||
for _, header := range headers {
|
||||
if value := r.Header.Get(header.Name); value != "" {
|
||||
interceptedHeaders[header.Name] = value
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create the TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: NewLogger("debug"),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
extractClaimsFunc: extractClaims,
|
||||
headerTemplates: make(map[string]*template.Template),
|
||||
}
|
||||
|
||||
// Initialize and parse header templates
|
||||
for _, header := range headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse header template for %s: %v", header.Name, err)
|
||||
}
|
||||
tOidc.headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
// Close the initComplete channel to bypass the waiting
|
||||
close(tOidc.initComplete)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Create a session
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Setup the session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetIDToken(idToken) // Set the ID token
|
||||
session.SetAccessToken(accessToken) // Set the access token
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Add session cookies to the request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset the response recorder for the main test
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Process the request
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Verify headers were set correctly
|
||||
expectedHeaders := map[string]string{
|
||||
"X-ID-Token": idToken,
|
||||
"X-Access-Token": accessToken,
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
"X-Email-From-Claims": "user@example.com",
|
||||
}
|
||||
|
||||
for name, expectedValue := range expectedHeaders {
|
||||
if value, exists := interceptedHeaders[name]; !exists {
|
||||
t.Errorf("Expected header %s was not set", name)
|
||||
} else if value != expectedValue {
|
||||
t.Errorf("Header %s expected value %q, got %q", name, expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionIDTokenAccessToken tests that the SessionData correctly stores and retrieves
|
||||
// both ID tokens and access tokens separately
|
||||
func TestSessionIDTokenAccessToken(t *testing.T) {
|
||||
// Create a logger for the session manager
|
||||
logger := NewLogger("debug")
|
||||
|
||||
// Create a session manager
|
||||
sessionManager, err := NewSessionManager("test-session-encryption-key-at-least-32-bytes", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get a session
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set test tokens
|
||||
idToken := "test-id-token-123"
|
||||
accessToken := "test-access-token-456"
|
||||
refreshToken := "test-refresh-token-789"
|
||||
|
||||
// Store tokens in session
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetRefreshToken(refreshToken)
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from response
|
||||
cookies := rr.Result().Cookies()
|
||||
|
||||
// Create a new request with those cookies
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again
|
||||
session2, err := sessionManager.GetSession(req2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session from request with cookies: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the tokens were correctly stored and retrieved
|
||||
retrievedIDToken := session2.GetIDToken()
|
||||
retrievedAccessToken := session2.GetAccessToken()
|
||||
retrievedRefreshToken := session2.GetRefreshToken()
|
||||
|
||||
if retrievedIDToken != idToken {
|
||||
t.Errorf("ID token mismatch: expected %q, got %q", idToken, retrievedIDToken)
|
||||
}
|
||||
|
||||
if retrievedAccessToken != accessToken {
|
||||
t.Errorf("Access token mismatch: expected %q, got %q", accessToken, retrievedAccessToken)
|
||||
}
|
||||
|
||||
if retrievedRefreshToken != refreshToken {
|
||||
t.Errorf("Refresh token mismatch: expected %q, got %q", refreshToken, retrievedRefreshToken)
|
||||
}
|
||||
|
||||
// Verify that the tokens are distinct
|
||||
if retrievedIDToken == retrievedAccessToken {
|
||||
t.Errorf("ID token and Access token should be different, but both are %q", retrievedIDToken)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user