package traefikoidc import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/http/cookiejar" "net/url" "strings" "time" ) // generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce. // The nonce is used during the authentication flow to mitigate replay attacks by associating // the ID token with the specific authentication request. // It generates 32 random bytes and encodes them using base64 URL encoding. // // Returns: // - A base64 URL encoded random string (nonce). // - An error if the random byte generation fails. func generateNonce() (string, error) { nonceBytes := make([]byte, 32) _, err := rand.Read(nonceBytes) if err != nil { return "", fmt.Errorf("could not generate nonce: %w", err) } return base64.URLEncoding.EncodeToString(nonceBytes), nil } // generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier. // According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long. // This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string. // // Returns: // - A base64 URL encoded random string (code verifier). // - An error if the random byte generation fails. func generateCodeVerifier() (string, error) { // Using 32 bytes (256 bits) will produce a 43 character base64url string verifierBytes := make([]byte, 32) _, err := rand.Read(verifierBytes) if err != nil { return "", fmt.Errorf("could not generate code verifier: %w", err) } return base64.RawURLEncoding.EncodeToString(verifierBytes), nil } // deriveCodeChallenge computes the PKCE code challenge from a given code verifier. // It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding) // as defined in RFC 7636. // // Parameters: // - codeVerifier: The high-entropy string generated by generateCodeVerifier. // // Returns: // - The base64 URL encoded SHA-256 hash of the code verifier (code challenge). func deriveCodeChallenge(codeVerifier string) string { // Calculate SHA-256 hash of the code verifier hasher := sha256.New() hasher.Write([]byte(codeVerifier)) hash := hasher.Sum(nil) // Base64url encode the hash to get the code challenge return base64.RawURLEncoding.EncodeToString(hash) } // TokenResponse represents the response from the OIDC token endpoint. // It contains the various tokens and metadata returned after successful // code exchange or token refresh operations. type TokenResponse struct { // IDToken is the OIDC ID token containing user claims IDToken string `json:"id_token"` // AccessToken is the OAuth 2.0 access token for API access AccessToken string `json:"access_token"` // RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens RefreshToken string `json:"refresh_token"` // ExpiresIn is the lifetime in seconds of the access token ExpiresIn int `json:"expires_in"` // TokenType is the type of token, typically "Bearer" TokenType string `json:"token_type"` } // exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint. // It handles both the "authorization_code" grant type (exchanging an authorization code for tokens) // and the "refresh_token" grant type (using a refresh token to obtain new tokens). // It includes necessary parameters like client credentials and handles PKCE verification if applicable. // The function follows redirects and handles potential errors during the exchange. // // Parameters: // - ctx: The context for the outgoing HTTP request. // - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token"). // - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant). // - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant). // - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used). // // Returns: // - A TokenResponse containing the obtained tokens (ID, access, refresh). // - An error if the token exchange fails (e.g., network error, provider error, invalid grant). func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, "client_id": {t.clientID}, "client_secret": {t.clientSecret}, } if grantType == "authorization_code" { data.Set("code", codeOrToken) data.Set("redirect_uri", redirectURL) // Add code_verifier if PKCE is being used if codeVerifier != "" { data.Set("code_verifier", codeVerifier) } } else if grantType == "refresh_token" { data.Set("refresh_token", codeOrToken) } // Create a cookie jar for this request to handle redirects with cookies jar, _ := cookiejar.New(nil) client := &http.Client{ Transport: t.httpClient.Transport, Timeout: t.httpClient.Timeout, CheckRedirect: func(req *http.Request, via []*http.Request) error { // Always follow redirects for OIDC endpoints if len(via) >= 50 { return fmt.Errorf("stopped after 50 redirects") } return nil }, Jar: jar, } req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create token request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to exchange tokens: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes)) } var tokenResponse TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return nil, fmt.Errorf("failed to decode token response: %w", err) } return &tokenResponse, nil } // getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh) // from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the // "refresh_token" grant type. // // Parameters: // - refreshToken: The refresh token previously obtained during authentication or a prior refresh. // // Returns: // - A TokenResponse containing the newly obtained tokens. // - An error if the refresh operation fails. func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { ctx := context.Background() tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "") if err != nil { return nil, fmt.Errorf("failed to refresh token: %w", err) } t.logger.Debugf("Token response: %+v", tokenResponse) return tokenResponse, nil } // extractClaims decodes the payload (claims set) part of a JWT string. // It splits the JWT into its three parts, base64 URL decodes the second part (payload), // and unmarshals the resulting JSON into a map. // Note: This function does *not* validate the token's signature or claims. // // Parameters: // - tokenString: The raw JWT string. // // Returns: // - A map representing the JSON claims extracted from the token payload. // - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails. func extractClaims(tokenString string) (map[string]interface{}, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid token format") } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, fmt.Errorf("failed to decode token payload: %w", err) } var claims map[string]interface{} if err := json.Unmarshal(payload, &claims); err != nil { return nil, fmt.Errorf("failed to unmarshal claims: %w", err) } return claims, nil } // TokenCache provides a caching mechanism for validated tokens. // It stores token claims to avoid repeated validation of the // same token, improving performance for frequently used tokens. type TokenCache struct { // cache is the underlying cache implementation cache *Cache } // NewTokenCache creates and initializes a new TokenCache. // It internally creates a new generic Cache instance for storage. func NewTokenCache() *TokenCache { return &TokenCache{ cache: NewCache(), } } // Set stores the claims associated with a specific token string in the cache. // It prefixes the token string to avoid potential collisions with other cache types // and sets the provided expiration duration. // // Parameters: // - token: The raw token string (used as the key). // - claims: The map of claims associated with the token. // - expiration: The duration for which the cache entry should be valid. func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) { token = "t-" + token tc.cache.Set(token, claims, expiration) } // Get retrieves the cached claims for a given token string. // It prefixes the token string before querying the underlying cache. // // Parameters: // - token: The raw token string to look up. // // Returns: // - The cached claims map if found and valid. // - A boolean indicating whether the token was found in the cache (true if found, false otherwise). func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { token = "t-" + token value, found := tc.cache.Get(token) if !found { return nil, false } claims, ok := value.(map[string]interface{}) return claims, ok } // Delete removes the cached entry for a specific token string. // It prefixes the token string before calling the underlying cache's Delete method. // // Parameters: // - token: The raw token string to remove from the cache. func (tc *TokenCache) Delete(token string) { token = "t-" + token tc.cache.Delete(token) } // Cleanup triggers the cleanup process for the underlying generic cache, // removing expired token entries. func (tc *TokenCache) Cleanup() { tc.cache.Cleanup() } // exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically // for the "authorization_code" grant type. It handles the conditional inclusion of the // PKCE code verifier based on the middleware's configuration (t.enablePKCE). // // Parameters: // - code: The authorization code received from the OIDC provider. // - redirectURL: The redirect URI used in the initial authorization request. // - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled). // // Returns: // - A TokenResponse containing the obtained tokens. // - An error if the code exchange fails. func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { ctx := context.Background() // Only include code verifier if PKCE is enabled effectiveCodeVerifier := "" if t.enablePKCE && codeVerifier != "" { effectiveCodeVerifier = codeVerifier } tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier) if err != nil { return nil, fmt.Errorf("failed to exchange code for token: %w", err) } return tokenResponse, nil } // createStringMap converts a slice of strings into a map[string]struct{} (a set). // This is useful for creating efficient lookups (O(1) average time complexity) // for checking the presence of items like allowed domains, roles, or groups. // // Parameters: // - keys: A slice of strings to be added to the set. // // Returns: // - A map where the keys are the strings from the input slice and the values are empty structs. func createStringMap(keys []string) map[string]struct{} { result := make(map[string]struct{}) for _, key := range keys { result[key] = struct{}{} } return result } // handleLogout processes requests to the configured logout path. // It performs the following steps: // 1. Retrieves the current user session. // 2. Gets the access token (ID token hint) from the session. // 3. Clears all authentication-related data from the session cookies. // 4. Determines the final post-logout redirect URI. // 5. If an OIDC end_session_endpoint is configured and an ID token hint is available, // it builds the OIDC logout URL and redirects the user agent to the provider for logout. // 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI. // // It handles potential errors during session retrieval or clearing. func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { session, err := t.sessionManager.GetSession(req) if err != nil { t.logger.Errorf("Error getting session: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) return } accessToken := session.GetAccessToken() if err := session.Clear(req, rw); err != nil { t.logger.Errorf("Error clearing session: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) return } host := t.determineHost(req) scheme := t.determineScheme(req) baseURL := fmt.Sprintf("%s://%s", scheme, host) postLogoutRedirectURI := t.postLogoutRedirectURI if postLogoutRedirectURI == "" { postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL) } else if !strings.HasPrefix(postLogoutRedirectURI, "http") { postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI) } if t.endSessionURL != "" && accessToken != "" { logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI) if err != nil { t.logger.Errorf("Failed to build logout URL: %v", err) http.Error(rw, "Logout error", http.StatusInternalServerError) return } http.Redirect(rw, req, logoutURL, http.StatusFound) return } http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound) } // BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's // end_session_endpoint, including the required id_token_hint and optional // post_logout_redirect_uri parameters as query arguments. // // Parameters: // - endSessionURL: The URL of the OIDC provider's end session endpoint. // - idToken: The ID token previously issued to the user (used as id_token_hint). // - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout. // // Returns: // - The fully constructed logout URL string. // - An error if the provided endSessionURL is invalid. func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) { u, err := url.Parse(endSessionURL) if err != nil { return "", fmt.Errorf("failed to parse end session URL: %w", err) } q := u.Query() q.Set("id_token_hint", idToken) if postLogoutRedirectURI != "" { q.Set("post_logout_redirect_uri", postLogoutRedirectURI) } u.RawQuery = q.Encode() return u.String(), nil }