Fix OIDC logout issue, improve test coverage, load provider once.

This commit is contained in:
2024-11-06 10:28:35 +00:00
parent 555164160d
commit 8ca669105b
4 changed files with 616 additions and 70 deletions
+55 -28
View File
@@ -97,47 +97,74 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
// handleLogout handles the logout process
func (t *TraefikOidc) handleLogout(w http.ResponseWriter, r *http.Request) {
session, err := t.store.Get(r, cookieName)
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
handleError(w, fmt.Sprintf("Error getting session: %v", err), http.StatusInternalServerError, t.logger)
return
}
// Revoke tokens if available
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil {
t.logger.Errorf("Failed to revoke refresh token: %v", err)
}
// Get tokens from session
idToken, _ := session.Values["id_token"].(string)
refreshToken, _ := session.Values["refresh_token"].(string)
accessToken, _ := session.Values["access_token"].(string)
// Revoke tokens if they exist
if refreshToken != "" {
t.RevokeTokenWithProvider(refreshToken, "refresh_token")
t.RevokeToken(refreshToken)
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil {
t.logger.Errorf("Failed to revoke access token: %v", err)
}
if accessToken != "" {
t.RevokeTokenWithProvider(accessToken, "access_token")
t.RevokeToken(accessToken)
}
// Remove tokens from session
delete(session.Values, "id_token")
delete(session.Values, "refresh_token")
delete(session.Values, "access_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = defaultSessionOptions
// Clear session
session.Options.MaxAge = -1
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
session.Values = make(map[interface{}]interface{})
if err := session.Save(r, w); err != nil {
handleError(w, fmt.Sprintf("Error saving session: %v", err), http.StatusInternalServerError, t.logger)
return
}
// Redirect or display logout message
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
// Determine redirect URL
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
scheme := "http"
if r.Header.Get("X-Forwarded-Proto") == "https" || t.forceHTTPS {
scheme = "https"
}
baseURL := fmt.Sprintf("%s://%s/", scheme, host)
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, baseURL)
if err != nil {
handleError(w, fmt.Sprintf("Invalid end session URL: %v", err), http.StatusInternalServerError, t.logger)
return
}
http.Redirect(w, r, logoutURL, http.StatusFound)
return
}
http.Redirect(w, r, baseURL, http.StatusFound)
}
// BuildLogoutURL constructs the logout URL with proper encoding
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("invalid end session URL: %v", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
u.RawQuery = q.Encode()
return u.String(), nil
}
// handleExpiredToken handles the case when a token has expired
+57 -42
View File
@@ -62,17 +62,19 @@ type TraefikOidc struct {
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
endSessionURL string
baseURL string
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
@@ -82,6 +84,14 @@ var defaultExcludedURLs = map[string]struct{}{
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -254,20 +264,26 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.initOnce.Do(func() {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
}
close(t.initComplete)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
})
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -620,14 +636,9 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
}
// RevokeTokenWithProvider revokes the token with the provider
@@ -726,26 +737,30 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
var groups []string
var roles []string
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
+503
View File
@@ -820,3 +820,506 @@ func TestOIDCHandler(t *testing.T) {
})
}
}
// TestHandleLogout tests the logout functionality
func TestHandleLogout(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create mock revocation endpoint server
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify the required parameters are present
if r.Form.Get("token") == "" {
t.Error("Missing token parameter")
}
if r.Form.Get("token_type_hint") == "" {
t.Error("Missing token_type_hint parameter")
}
w.WriteHeader(http.StatusOK)
}))
defer mockRevocationServer.Close()
tests := []struct {
name string
setupSession func(*sessions.Session)
endSessionURL string
expectedStatus int
expectedURL string
host string
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
// Fix: The entire URL should be URL-encoded
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *sessions.Session) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
host: "test-host",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: NewLogger("info"),
tokenBlacklist: NewTokenBlacklist(),
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
}
// Create request with proper headers
req := httptest.NewRequest("GET", "/logout", nil)
req.Header.Set("Host", tc.host)
// Create a response recorder
rr := httptest.NewRecorder()
// Get a session
session, err := tOidc.store.Get(req, cookieName)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session
tc.setupSession(session)
session.Save(req, rr)
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Handle logout
tOidc.handleLogout(rr, req)
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check redirect URL if expected
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
}
}
// Verify session is cleared
newSession, _ := tOidc.store.Get(req, cookieName)
if len(newSession.Values) > 0 {
t.Error("Session was not cleared")
}
if newSession.Options.MaxAge != -1 {
t.Error("Session MaxAge was not set to -1")
}
// Check token blacklist
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
t.Error("Refresh token was not blacklisted")
}
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
t.Error("Access token was not blacklisted")
}
}
})
}
}
// TestRevokeTokenWithProvider tests the token revocation with provider
func TestRevokeTokenWithProvider(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
token string
tokenType string
statusCode int
expectError bool
}{
{
name: "Successful token revocation",
token: "valid-token",
tokenType: "refresh_token",
statusCode: http.StatusOK,
expectError: false,
},
{
name: "Failed token revocation",
token: "invalid-token",
tokenType: "refresh_token",
statusCode: http.StatusBadRequest,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method and content type
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
}
// Verify form values
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
if got := r.Form.Get("token"); got != tc.token {
t.Errorf("Expected token %s, got %s", tc.token, got)
}
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
}
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
}
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
}
w.WriteHeader(tc.statusCode)
}))
defer server.Close()
// Set revocation URL to test server
ts.tOidc.revocationURL = server.URL
// Test token revocation
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
if tc.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !tc.expectError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
}
// TestRevokeToken tests the token revocation functionality
func TestRevokeToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
token := "test.token.with.claims"
claims := map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
// Test token revocation
t.Run("Token revocation", func(t *testing.T) {
// Create a new instance for this specific test
tOidc := &TraefikOidc{
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
}
// Cache the token
tOidc.tokenCache.Set(token, claims, time.Hour)
// Revoke the token
tOidc.RevokeToken(token)
// Verify token was removed from cache
if _, exists := tOidc.tokenCache.Get(token); exists {
t.Error("Token was not removed from cache")
}
// Verify token was added to blacklist
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Token was not added to blacklist")
}
})
}
// Add this new test function
func TestBuildLogoutURL(t *testing.T) {
tests := []struct {
name string
endSessionURL string
idToken string
postLogoutRedirect string
expectedURL string
expectError bool
}{
{
name: "Valid URL",
endSessionURL: "https://provider/end-session",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
{
name: "Invalid URL",
endSessionURL: "://invalid-url",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectError: true,
},
{
name: "URL with existing query parameters",
endSessionURL: "https://provider/end-session?existing=param",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if url != tc.expectedURL {
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
}
}
})
}
}
// Add this new test function
func TestHandleExpiredToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupSession func(*sessions.Session)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["email"] = "test@example.com"
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: NewLogger("info"),
redirectURL: "http://example.com/callback",
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
// Add this initialization of initiateAuthenticationFunc
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Mock implementation for test
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
close(tOidc.initComplete)
// Create request
req := httptest.NewRequest("GET", tc.expectedPath, nil)
rr := httptest.NewRecorder()
// Get session
session, _ := tOidc.store.New(req, cookieName)
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session)
// Verify session is cleaned
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
t.Errorf("Expected 3 session values, got %d", len(session.Values))
}
// Verify required values are set
if _, ok := session.Values["csrf"].(string); !ok {
t.Error("CSRF token not set")
}
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if _, ok := session.Values["nonce"].(string); !ok {
t.Error("Nonce not set")
}
// Verify session options
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
t.Error("Session MaxAge not set correctly")
}
// Verify redirect status
if rr.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
}
})
}
}
// Add this new test function
func TestExtractGroupsAndRoles(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
claims map[string]interface{}
expectGroups []string
expectRoles []string
expectError bool
}{
{
name: "Valid groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{"group1", "group2"},
"roles": []interface{}{"role1", "role2"},
},
expectGroups: []string{"group1", "group2"},
expectRoles: []string{"role1", "role2"},
expectError: false,
},
{
name: "Empty groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{},
"roles": []interface{}{},
},
expectGroups: []string{},
expectRoles: []string{},
expectError: false,
},
{
name: "Invalid groups format",
claims: map[string]interface{}{
"groups": "not-an-array",
"roles": []interface{}{"role1"},
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a test token with the claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Compare groups
if !stringSliceEqual(groups, tc.expectGroups) {
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
}
// Compare roles
if !stringSliceEqual(roles, tc.expectRoles) {
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
+1
View File
@@ -30,6 +30,7 @@ type Config struct {
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
HTTPClient *http.Client
}