mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
405 lines
15 KiB
Go
405 lines
15 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce.
|
|
// The nonce is used during the authentication flow to mitigate replay attacks by associating
|
|
// the ID token with the specific authentication request.
|
|
// It generates 32 random bytes and encodes them using base64 URL encoding.
|
|
//
|
|
// Returns:
|
|
// - A base64 URL encoded random string (nonce).
|
|
// - An error if the random byte generation fails.
|
|
func generateNonce() (string, error) {
|
|
nonceBytes := make([]byte, 32)
|
|
_, err := rand.Read(nonceBytes)
|
|
if err != nil {
|
|
return "", fmt.Errorf("could not generate nonce: %w", err)
|
|
}
|
|
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
|
}
|
|
|
|
// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier.
|
|
// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long.
|
|
// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string.
|
|
//
|
|
// Returns:
|
|
// - A base64 URL encoded random string (code verifier).
|
|
// - An error if the random byte generation fails.
|
|
func generateCodeVerifier() (string, error) {
|
|
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
|
verifierBytes := make([]byte, 32)
|
|
_, err := rand.Read(verifierBytes)
|
|
if err != nil {
|
|
return "", fmt.Errorf("could not generate code verifier: %w", err)
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
|
}
|
|
|
|
// deriveCodeChallenge computes the PKCE code challenge from a given code verifier.
|
|
// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding)
|
|
// as defined in RFC 7636.
|
|
//
|
|
// Parameters:
|
|
// - codeVerifier: The high-entropy string generated by generateCodeVerifier.
|
|
//
|
|
// Returns:
|
|
// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge).
|
|
func deriveCodeChallenge(codeVerifier string) string {
|
|
// Calculate SHA-256 hash of the code verifier
|
|
hasher := sha256.New()
|
|
hasher.Write([]byte(codeVerifier))
|
|
hash := hasher.Sum(nil)
|
|
|
|
// Base64url encode the hash to get the code challenge
|
|
return base64.RawURLEncoding.EncodeToString(hash)
|
|
}
|
|
|
|
// TokenResponse represents the response from the OIDC token endpoint.
|
|
// It contains the various tokens and metadata returned after successful
|
|
// code exchange or token refresh operations.
|
|
type TokenResponse struct {
|
|
// IDToken is the OIDC ID token containing user claims
|
|
IDToken string `json:"id_token"`
|
|
|
|
// AccessToken is the OAuth 2.0 access token for API access
|
|
AccessToken string `json:"access_token"`
|
|
|
|
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
|
|
RefreshToken string `json:"refresh_token"`
|
|
|
|
// ExpiresIn is the lifetime in seconds of the access token
|
|
ExpiresIn int `json:"expires_in"`
|
|
|
|
// TokenType is the type of token, typically "Bearer"
|
|
TokenType string `json:"token_type"`
|
|
}
|
|
|
|
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
|
|
// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens)
|
|
// and the "refresh_token" grant type (using a refresh token to obtain new tokens).
|
|
// It includes necessary parameters like client credentials and handles PKCE verification if applicable.
|
|
// The function follows redirects and handles potential errors during the exchange.
|
|
//
|
|
// Parameters:
|
|
// - ctx: The context for the outgoing HTTP request.
|
|
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token").
|
|
// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant).
|
|
// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant).
|
|
// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used).
|
|
//
|
|
// Returns:
|
|
// - A TokenResponse containing the obtained tokens (ID, access, refresh).
|
|
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant).
|
|
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
data := url.Values{
|
|
"grant_type": {grantType},
|
|
"client_id": {t.clientID},
|
|
"client_secret": {t.clientSecret},
|
|
}
|
|
|
|
if grantType == "authorization_code" {
|
|
data.Set("code", codeOrToken)
|
|
data.Set("redirect_uri", redirectURL)
|
|
|
|
// Add code_verifier if PKCE is being used
|
|
if codeVerifier != "" {
|
|
data.Set("code_verifier", codeVerifier)
|
|
}
|
|
} else if grantType == "refresh_token" {
|
|
data.Set("refresh_token", codeOrToken)
|
|
}
|
|
|
|
// Create a cookie jar for this request to handle redirects with cookies
|
|
jar, _ := cookiejar.New(nil)
|
|
client := &http.Client{
|
|
Transport: t.httpClient.Transport,
|
|
Timeout: t.httpClient.Timeout,
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
// Always follow redirects for OIDC endpoints
|
|
if len(via) >= 50 {
|
|
return fmt.Errorf("stopped after 50 redirects")
|
|
}
|
|
return nil
|
|
},
|
|
Jar: jar,
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var tokenResponse TokenResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
|
|
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
|
}
|
|
|
|
return &tokenResponse, nil
|
|
}
|
|
|
|
// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh)
|
|
// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the
|
|
// "refresh_token" grant type.
|
|
//
|
|
// Parameters:
|
|
// - refreshToken: The refresh token previously obtained during authentication or a prior refresh.
|
|
//
|
|
// Returns:
|
|
// - A TokenResponse containing the newly obtained tokens.
|
|
// - An error if the refresh operation fails.
|
|
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
|
ctx := context.Background()
|
|
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
|
}
|
|
|
|
t.logger.Debugf("Token response: %+v", tokenResponse)
|
|
return tokenResponse, nil
|
|
}
|
|
|
|
// extractClaims decodes the payload (claims set) part of a JWT string.
|
|
// It splits the JWT into its three parts, base64 URL decodes the second part (payload),
|
|
// and unmarshals the resulting JSON into a map.
|
|
// Note: This function does *not* validate the token's signature or claims.
|
|
//
|
|
// Parameters:
|
|
// - tokenString: The raw JWT string.
|
|
//
|
|
// Returns:
|
|
// - A map representing the JSON claims extracted from the token payload.
|
|
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
|
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid token format")
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode token payload: %w", err)
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// TokenCache provides a caching mechanism for validated tokens.
|
|
// It stores token claims to avoid repeated validation of the
|
|
// same token, improving performance for frequently used tokens.
|
|
type TokenCache struct {
|
|
// cache is the underlying cache implementation
|
|
cache *Cache
|
|
}
|
|
|
|
// NewTokenCache creates and initializes a new TokenCache.
|
|
// It internally creates a new generic Cache instance for storage.
|
|
func NewTokenCache() *TokenCache {
|
|
return &TokenCache{
|
|
cache: NewCache(),
|
|
}
|
|
}
|
|
|
|
// Set stores the claims associated with a specific token string in the cache.
|
|
// It prefixes the token string to avoid potential collisions with other cache types
|
|
// and sets the provided expiration duration.
|
|
//
|
|
// Parameters:
|
|
// - token: The raw token string (used as the key).
|
|
// - claims: The map of claims associated with the token.
|
|
// - expiration: The duration for which the cache entry should be valid.
|
|
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
|
token = "t-" + token
|
|
tc.cache.Set(token, claims, expiration)
|
|
}
|
|
|
|
// Get retrieves the cached claims for a given token string.
|
|
// It prefixes the token string before querying the underlying cache.
|
|
//
|
|
// Parameters:
|
|
// - token: The raw token string to look up.
|
|
//
|
|
// Returns:
|
|
// - The cached claims map if found and valid.
|
|
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
|
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
|
token = "t-" + token
|
|
value, found := tc.cache.Get(token)
|
|
if !found {
|
|
return nil, false
|
|
}
|
|
claims, ok := value.(map[string]interface{})
|
|
return claims, ok
|
|
}
|
|
|
|
// Delete removes the cached entry for a specific token string.
|
|
// It prefixes the token string before calling the underlying cache's Delete method.
|
|
//
|
|
// Parameters:
|
|
// - token: The raw token string to remove from the cache.
|
|
func (tc *TokenCache) Delete(token string) {
|
|
token = "t-" + token
|
|
tc.cache.Delete(token)
|
|
}
|
|
|
|
// Cleanup triggers the cleanup process for the underlying generic cache,
|
|
// removing expired token entries.
|
|
func (tc *TokenCache) Cleanup() {
|
|
tc.cache.Cleanup()
|
|
}
|
|
|
|
// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically
|
|
// for the "authorization_code" grant type. It handles the conditional inclusion of the
|
|
// PKCE code verifier based on the middleware's configuration (t.enablePKCE).
|
|
//
|
|
// Parameters:
|
|
// - code: The authorization code received from the OIDC provider.
|
|
// - redirectURL: The redirect URI used in the initial authorization request.
|
|
// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled).
|
|
//
|
|
// Returns:
|
|
// - A TokenResponse containing the obtained tokens.
|
|
// - An error if the code exchange fails.
|
|
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
ctx := context.Background()
|
|
|
|
// Only include code verifier if PKCE is enabled
|
|
effectiveCodeVerifier := ""
|
|
if t.enablePKCE && codeVerifier != "" {
|
|
effectiveCodeVerifier = codeVerifier
|
|
}
|
|
|
|
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
|
}
|
|
return tokenResponse, nil
|
|
}
|
|
|
|
// createStringMap converts a slice of strings into a map[string]struct{} (a set).
|
|
// This is useful for creating efficient lookups (O(1) average time complexity)
|
|
// for checking the presence of items like allowed domains, roles, or groups.
|
|
//
|
|
// Parameters:
|
|
// - keys: A slice of strings to be added to the set.
|
|
//
|
|
// Returns:
|
|
// - A map where the keys are the strings from the input slice and the values are empty structs.
|
|
func createStringMap(keys []string) map[string]struct{} {
|
|
result := make(map[string]struct{})
|
|
for _, key := range keys {
|
|
result[key] = struct{}{}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// handleLogout processes requests to the configured logout path.
|
|
// It performs the following steps:
|
|
// 1. Retrieves the current user session.
|
|
// 2. Gets the access token (ID token hint) from the session.
|
|
// 3. Clears all authentication-related data from the session cookies.
|
|
// 4. Determines the final post-logout redirect URI.
|
|
// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available,
|
|
// it builds the OIDC logout URL and redirects the user agent to the provider for logout.
|
|
// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI.
|
|
//
|
|
// It handles potential errors during session retrieval or clearing.
|
|
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
|
session, err := t.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.logger.Errorf("Error getting session: %v", err)
|
|
http.Error(rw, "Session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
accessToken := session.GetAccessToken()
|
|
|
|
if err := session.Clear(req, rw); err != nil {
|
|
t.logger.Errorf("Error clearing session: %v", err)
|
|
http.Error(rw, "Session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
host := t.determineHost(req)
|
|
scheme := t.determineScheme(req)
|
|
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
|
|
|
postLogoutRedirectURI := t.postLogoutRedirectURI
|
|
if postLogoutRedirectURI == "" {
|
|
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
|
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
|
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
|
}
|
|
|
|
if t.endSessionURL != "" && accessToken != "" {
|
|
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to build logout URL: %v", err)
|
|
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
|
return
|
|
}
|
|
|
|
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
|
}
|
|
|
|
// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's
|
|
// end_session_endpoint, including the required id_token_hint and optional
|
|
// post_logout_redirect_uri parameters as query arguments.
|
|
//
|
|
// Parameters:
|
|
// - endSessionURL: The URL of the OIDC provider's end session endpoint.
|
|
// - idToken: The ID token previously issued to the user (used as id_token_hint).
|
|
// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout.
|
|
//
|
|
// Returns:
|
|
// - The fully constructed logout URL string.
|
|
// - An error if the provided endSessionURL is invalid.
|
|
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
|
u, err := url.Parse(endSessionURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to parse end session URL: %w", err)
|
|
}
|
|
|
|
q := u.Query()
|
|
q.Set("id_token_hint", idToken)
|
|
if postLogoutRedirectURI != "" {
|
|
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
|
}
|
|
u.RawQuery = q.Encode()
|
|
|
|
return u.String(), nil
|
|
}
|