diff --git a/helpers.go b/helpers.go index 740c154..b63b501 100644 --- a/helpers.go +++ b/helpers.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -15,256 +14,207 @@ import ( ) func generateNonce() (string, error) { - nonceBytes := make([]byte, 32) - _, err := rand.Read(nonceBytes) - if err != nil { - return "", fmt.Errorf("could not generate nonce") - } - return base64.URLEncoding.EncodeToString(nonceBytes), nil + nonceBytes := make([]byte, 32) + _, err := rand.Read(nonceBytes) + if err != nil { + return "", fmt.Errorf("could not generate nonce: %w", err) + } + return base64.URLEncoding.EncodeToString(nonceBytes), nil } -func assembleRedirectURL(scheme, host, path string) string { - if scheme == "" { - // infoLogger.Println("Scheme is empty, defaulting to http") - scheme = "http" - } - return scheme + "://" + host + path +func buildFullURL(scheme, host, path string) string { + if scheme == "" { + scheme = "http" + } + return fmt.Sprintf("%s://%s%s", scheme, host, path) } -func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code string, redirectURL string) (map[string]interface{}, error) { - data := url.Values{} - data.Set("grant_type", "authorization_code") - data.Set("code", code) - data.Set("client_id", t.clientID) - data.Set("client_secret", t.clientSecret) - data.Set("redirect_uri", redirectURL) // Use the full redirect URL +func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectURL string) (map[string]interface{}, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "client_id": {t.clientID}, + "client_secret": {t.clientSecret}, + "redirect_uri": {redirectURL}, + } - // infoLogger.Printf("Exchanging code for token with redirect_uri: %s", redirectURL) + req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for token: %w", err) + } + defer resp.Body.Close() - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode token response: %w", err) + } - var result map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, err - } - - // infoLogger.Printf("Token response: %+v", result) - - return result, nil + return result, nil } func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { - ctx := req.Context() - session, err := t.store.Get(req, cookie_name) - if err != nil { - // infoLogger.Printf("Error getting session: %v", err) - http.Error(rw, "Session error", http.StatusInternalServerError) - return false, "" - } + ctx := req.Context() + session, err := t.store.Get(req, cookieName) + if err != nil { + handleError(rw, "Session error", http.StatusInternalServerError) + return false, "" + } - // infoLogger.Printf("Session values: %+v", session.Values) + callbackState := req.URL.Query().Get("state") + sessionState, ok := session.Values["csrf"].(string) + if !ok || callbackState != sessionState { + handleError(rw, "Invalid state parameter", http.StatusBadRequest) + return false, "" + } - callbackState := req.URL.Query().Get("state") - sessionState, ok := session.Values["csrf"].(string) - // infoLogger.Printf("Callback state: %s, Session state: %s, Match: %v", callbackState, sessionState, ok && callbackState == sessionState) + code := req.URL.Query().Get("code") + redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath) + oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL) + if err != nil { + handleError(rw, "Failed to exchange token", http.StatusInternalServerError) + return false, "" + } - if !ok || callbackState != sessionState { - // infoLogger.Printf("Invalid state parameter: callback=%s, session=%s", callbackState, sessionState) - http.Error(rw, "Invalid state parameter", http.StatusBadRequest) - return false, "" - } + rawIDToken, ok := oauth2Token["id_token"].(string) + if !ok { + handleError(rw, "No id_token field in oauth2 token", http.StatusInternalServerError) + return false, "" + } - code := req.URL.Query().Get("code") - redirectURL := assembleRedirectURL(t.scheme, req.Host, t.redirURLPath) - oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL) - if err != nil { - // infoLogger.Printf("Failed to exchange token: %v", err) - http.Error(rw, "Failed to exchange token", http.StatusInternalServerError) - return false, "" - } + if err := t.verifyToken(rawIDToken); err != nil { + handleError(rw, "Failed to verify token", http.StatusUnauthorized) + return false, "" + } - rawIDToken, ok := oauth2Token["id_token"].(string) - if !ok { - // infoLogger.Printf("No id_token field in oauth2 token") - http.Error(rw, "No id_token field in oauth2 token", http.StatusInternalServerError) - return false, "" - } + claims, err := extractClaims(rawIDToken) + if err != nil { + handleError(rw, "Failed to extract claims", http.StatusInternalServerError) + return false, "" + } - if err := t.verifyToken(rawIDToken); err != nil { - // infoLogger.Printf("Token verification failed: %v", err) - http.Error(rw, "Failed to verify token", http.StatusUnauthorized) - return false, "" - } - // infoLogger.Printf("Token verification successful") + email, _ := claims["email"].(string) - claims, err := extractClaims(rawIDToken) - if err != nil { - // infoLogger.Printf("Failed to extract claims: %v", err) - http.Error(rw, "Failed to extract claims", http.StatusInternalServerError) - return false, "" - } + session.Values["authenticated"] = true + session.Values["id_token"] = rawIDToken + session.Values["email"] = email + if err := session.Save(req, rw); err != nil { + handleError(rw, "Failed to save session", http.StatusInternalServerError) + return false, "" + } - email, _ := claims["email"].(string) + originalPath, ok := session.Values["incoming_path"].(string) + if !ok { + originalPath = "/" + } + delete(session.Values, "incoming_path") - session.Values["authenticated"] = true - session.Values["id_token"] = rawIDToken - session.Values["email"] = email - if err := session.Save(req, rw); err != nil { - // infoLogger.Printf("Failed to save session: %v", err) - http.Error(rw, "Failed to save session", http.StatusInternalServerError) - return false, "" - } - - // infoLogger.Printf("User %s authenticated\n", email) - originalPath, ok := session.Values["incoming_path"].(string) - if !ok { - originalPath = "/" - } - delete(session.Values, "incoming_path") - - return true, originalPath + return true, originalPath } func extractClaims(tokenString string) (map[string]interface{}, error) { - parts := strings.Split(tokenString, ".") - if len(parts) != 3 { - return nil, errors.New("invalid token format") - } + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid token format") + } - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, err - } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode token payload: %w", err) + } - var claims map[string]interface{} - if err := json.Unmarshal(payload, &claims); err != nil { - return nil, err - } + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %w", err) + } - return claims, nil -} - -func verifyToken(token string, publicKey []byte) (map[string]interface{}, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, errors.New("invalid token format") - } - - payloadJson, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, err - } - - var claims map[string]interface{} - err = json.Unmarshal(payloadJson, &claims) - if err != nil { - return nil, err - } - - if exp, ok := claims["exp"].(float64); ok { - if time.Now().Unix() > int64(exp) { - return nil, errors.New("token expired") - } - } - - // Placeholder for signature verification - // err = verifySignature(parts[0]+"."+parts[1], parts[2], publicKey) - // if err != nil { - // return nil, err - // } - - return claims, nil + return claims, nil } type UsedTokens struct { - tokens map[string]bool - mutex sync.RWMutex + tokens map[string]bool + mutex sync.RWMutex } type TokenBlacklist struct { - blacklist map[string]time.Time - mutex sync.RWMutex + blacklist map[string]time.Time + mutex sync.RWMutex } func NewTokenBlacklist() *TokenBlacklist { - return &TokenBlacklist{ - blacklist: make(map[string]time.Time), - } + return &TokenBlacklist{ + blacklist: make(map[string]time.Time), + } } func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) { - tb.mutex.Lock() - defer tb.mutex.Unlock() - tb.blacklist[tokenID] = expiration + tb.mutex.Lock() + defer tb.mutex.Unlock() + tb.blacklist[tokenID] = expiration } func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { - tb.mutex.RLock() - defer tb.mutex.RUnlock() - expiration, exists := tb.blacklist[tokenID] - return exists && time.Now().Before(expiration) + tb.mutex.RLock() + defer tb.mutex.RUnlock() + expiration, exists := tb.blacklist[tokenID] + return exists && time.Now().Before(expiration) } func (tb *TokenBlacklist) Cleanup() { - tb.mutex.Lock() - defer tb.mutex.Unlock() - now := time.Now() - for tokenID, expiration := range tb.blacklist { - if now.After(expiration) { - delete(tb.blacklist, tokenID) - } - } + tb.mutex.Lock() + defer tb.mutex.Unlock() + now := time.Now() + for tokenID, expiration := range tb.blacklist { + if now.After(expiration) { + delete(tb.blacklist, tokenID) + } + } } type TokenCache struct { - cache map[string]*TokenInfo - mutex sync.RWMutex + cache map[string]*TokenInfo + mutex sync.RWMutex } type TokenInfo struct { - Token string - ExpiresAt time.Time + Token string + ExpiresAt time.Time } func NewTokenCache() *TokenCache { - return &TokenCache{ - cache: make(map[string]*TokenInfo), - } + return &TokenCache{ + cache: make(map[string]*TokenInfo), + } } func (tc *TokenCache) Set(token string, expiresAt time.Time) { - tc.mutex.Lock() - defer tc.mutex.Unlock() - tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt} + tc.mutex.Lock() + defer tc.mutex.Unlock() + tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt} } func (tc *TokenCache) Get(token string) (*TokenInfo, bool) { - tc.mutex.RLock() - defer tc.mutex.RUnlock() - info, exists := tc.cache[token] - if exists && time.Now().Before(info.ExpiresAt) { - return info, true - } - return nil, false + tc.mutex.RLock() + defer tc.mutex.RUnlock() + info, exists := tc.cache[token] + if exists && time.Now().Before(info.ExpiresAt) { + return info, true + } + return nil, false } func (tc *TokenCache) Cleanup() { - tc.mutex.Lock() - defer tc.mutex.Unlock() - now := time.Now() - for token, info := range tc.cache { - if now.After(info.ExpiresAt) { - delete(tc.cache, token) - } - } + tc.mutex.Lock() + defer tc.mutex.Unlock() + now := time.Now() + for token, info := range tc.cache { + if now.After(info.ExpiresAt) { + delete(tc.cache, token) + } + } } diff --git a/jwk.go b/jwk.go index 63f0509..816fdfe 100644 --- a/jwk.go +++ b/jwk.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "encoding/json" "encoding/pem" - "errors" "fmt" "math/big" "net/http" @@ -15,135 +14,135 @@ import ( ) type JWK struct { - Kty string `json:"kty"` - Kid string `json:"kid"` - Use string `json:"use"` - N string `json:"n"` - E string `json:"e"` - Alg string `json:"alg"` + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` + Alg string `json:"alg"` } type JWKSet struct { - Keys []JWK `json:"keys"` + Keys []JWK `json:"keys"` } type JWKCache struct { - jwks *JWKSet - expiresAt time.Time - mutex sync.RWMutex + jwks *JWKSet + expiresAt time.Time + mutex sync.RWMutex } -func (c *JWKCache) GetJWKS(jwksURL string) (*JWKSet, error) { - c.mutex.RLock() - if c.jwks != nil && time.Now().Before(c.expiresAt) { - defer c.mutex.RUnlock() - return c.jwks, nil - } - c.mutex.RUnlock() +func (c *JWKCache) GetJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) { + c.mutex.RLock() + if c.jwks != nil && time.Now().Before(c.expiresAt) { + defer c.mutex.RUnlock() + return c.jwks, nil + } + c.mutex.RUnlock() - c.mutex.Lock() - defer c.mutex.Unlock() + c.mutex.Lock() + defer c.mutex.Unlock() - if c.jwks != nil && time.Now().Before(c.expiresAt) { - return c.jwks, nil - } + if c.jwks != nil && time.Now().Before(c.expiresAt) { + return c.jwks, nil + } - jwks, err := fetchJWKS(jwksURL) - if err != nil { - return nil, err - } + jwks, err := fetchJWKS(jwksURL, httpClient) + if err != nil { + return nil, err + } - c.jwks = jwks - c.expiresAt = time.Now().Add(1 * time.Hour) + c.jwks = jwks + c.expiresAt = time.Now().Add(1 * time.Hour) - return jwks, nil + return jwks, nil } -func fetchJWKS(jwksURL string) (*JWKSet, error) { - resp, err := http.Get(jwksURL) - if err != nil { - return nil, err - } - defer resp.Body.Close() +func fetchJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) { + resp, err := httpClient.Get(jwksURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, errors.New("failed to fetch JWKS") - } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode) + } - var jwks JWKSet - if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { - return nil, err - } + var jwks JWKSet + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + return nil, fmt.Errorf("failed to decode JWKS: %w", err) + } - return &jwks, nil + return &jwks, nil } func verifyNonce(tokenNonce, expectedNonce string) error { - if tokenNonce != expectedNonce { - return errors.New("invalid nonce") - } - return nil + if tokenNonce != expectedNonce { + return fmt.Errorf("invalid nonce") + } + return nil } func verifyAudience(tokenAudience, expectedAudience string) error { - if tokenAudience != expectedAudience { - return errors.New("invalid audience") - } - return nil + if tokenAudience != expectedAudience { + return fmt.Errorf("invalid audience") + } + return nil } func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error { - now := time.Now().Unix() - if now < issuedAt-int64(allowedClockSkew.Seconds()) { - return errors.New("token used before issued") - } - if now > expiration+int64(allowedClockSkew.Seconds()) { - return errors.New("token is expired") - } - return nil + now := time.Now().Unix() + if now < issuedAt-int64(allowedClockSkew.Seconds()) { + return fmt.Errorf("token used before issued") + } + if now > expiration+int64(allowedClockSkew.Seconds()) { + return fmt.Errorf("token is expired") + } + return nil } func verifyIssuer(tokenIssuer, expectedIssuer string) error { - if tokenIssuer != expectedIssuer { - return errors.New("invalid issuer") - } - return nil + if tokenIssuer != expectedIssuer { + return fmt.Errorf("invalid issuer") + } + return nil } func validateClaims(claims map[string]interface{}) error { - requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} - for _, claim := range requiredClaims { - if _, ok := claims[claim]; !ok { - return fmt.Errorf("missing required claim: %s", claim) - } - } - return nil + requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} + for _, claim := range requiredClaims { + if _, ok := claims[claim]; !ok { + return fmt.Errorf("missing required claim: %s", claim) + } + } + return nil } func jwkToPEM(jwk *JWK) ([]byte, error) { - n, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return nil, err - } - e, err := base64.RawURLEncoding.DecodeString(jwk.E) - if err != nil { - return nil, err - } + n, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) + } + e, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err) + } - publicKey := &rsa.PublicKey{ - N: new(big.Int).SetBytes(n), - E: int(new(big.Int).SetBytes(e).Int64()), - } + publicKey := &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(new(big.Int).SetBytes(e).Int64()), + } - publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - return nil, err - } + publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal public key: %w", err) + } - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: publicKeyBytes, - }) + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }) - return publicKeyPEM, nil + return publicKeyPEM, nil } diff --git a/jwt.go b/jwt.go index b26cf4b..2b1d4eb 100644 --- a/jwt.go +++ b/jwt.go @@ -8,122 +8,125 @@ import ( "encoding/base64" "encoding/json" "encoding/pem" - "errors" + "fmt" "strings" "time" ) type JWT struct { - Header map[string]interface{} - Claims map[string]interface{} - Signature string + Header map[string]interface{} + Claims map[string]interface{} + Signature string } func parseJWT(token string) (*JWT, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, errors.New("invalid token format") - } + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid token format") + } - header, err := decodeSegment(parts[0]) - if err != nil { - return nil, err - } + header, err := decodeSegment(parts[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode header: %w", err) + } - claims, err := decodeSegment(parts[1]) - if err != nil { - return nil, err - } + claims, err := decodeSegment(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode claims: %w", err) + } - return &JWT{ - Header: header, - Claims: claims, - Signature: parts[2], - }, nil + return &JWT{ + Header: header, + Claims: claims, + Signature: parts[2], + }, nil } func (j *JWT) Verify(issuerURL, clientID string) error { - claims := j.Claims + claims := j.Claims - if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil { - return err - } + if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil { + return err + } - if err := verifyAudience(claims["aud"].(string), clientID); err != nil { - return err - } + if err := verifyAudience(claims["aud"].(string), clientID); err != nil { + return err + } - if err := verifyExpiration(claims["exp"].(float64)); err != nil { - return err - } + if err := verifyExpiration(claims["exp"].(float64)); err != nil { + return err + } - if err := verifyIssuedAt(claims["iat"].(float64)); err != nil { - return err - } + if err := verifyIssuedAt(claims["iat"].(float64)); err != nil { + return err + } - return nil + return nil } func verifyExpiration(expiration float64) error { - expirationTime := time.Unix(int64(expiration), 0) - if time.Now().After(expirationTime) { - return errors.New("token has expired") - } - return nil + expirationTime := time.Unix(int64(expiration), 0) + if time.Now().After(expirationTime) { + return fmt.Errorf("token has expired") + } + return nil } func verifySignature(token string, publicKeyPEM []byte) error { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return errors.New("invalid token format") - } + parts := strings.Split(token, ".") + if len(parts) != 3 { + return fmt.Errorf("invalid token format") + } - block, _ := pem.Decode(publicKeyPEM) - if block == nil { - return errors.New("failed to parse PEM block containing the public key") - } + block, _ := pem.Decode(publicKeyPEM) + if block == nil { + return fmt.Errorf("failed to parse PEM block containing the public key") + } - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return err - } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse public key: %w", err) + } - rsaPublicKey, ok := pub.(*rsa.PublicKey) - if !ok { - return errors.New("not an RSA public key") - } + rsaPublicKey, ok := pub.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("not an RSA public key") + } - signedContent := parts[0] + "." + parts[1] - signature, _ := base64.RawURLEncoding.DecodeString(parts[2]) + signedContent := parts[0] + "." + parts[1] + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } - hash := sha256.Sum256([]byte(signedContent)) - err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) - if err != nil { - return errors.New("invalid token signature") - } + hash := sha256.Sum256([]byte(signedContent)) + err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) + if err != nil { + return fmt.Errorf("invalid token signature: %w", err) + } - return nil + return nil } func verifyIssuedAt(issuedAt float64) error { - issuedAtTime := time.Unix(int64(issuedAt), 0) - if time.Now().Before(issuedAtTime) { - return errors.New("token used before issued") - } - return nil + issuedAtTime := time.Unix(int64(issuedAt), 0) + if time.Now().Before(issuedAtTime) { + return fmt.Errorf("token used before issued") + } + return nil } func decodeSegment(seg string) (map[string]interface{}, error) { - data, err := base64.RawURLEncoding.DecodeString(seg) - if err != nil { - return nil, err - } + data, err := base64.RawURLEncoding.DecodeString(seg) + if err != nil { + return nil, fmt.Errorf("failed to decode segment: %w", err) + } - var result map[string]interface{} - err = json.Unmarshal(data, &result) - if err != nil { - return nil, err - } + var result map[string]interface{} + err = json.Unmarshal(data, &result) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal segment: %w", err) + } - return result, nil + return result, nil } diff --git a/main.go b/main.go index 8327fb3..27d571e 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,7 @@ package traefikoidc import ( "context" "encoding/json" - "errors" "fmt" - "io" - "log" "net/http" "net/url" "strings" @@ -17,14 +14,10 @@ import ( "golang.org/x/time/rate" ) -var ( - infoLogger = log.New(io.Discard, "INFO: traefikoidc: ", log.Ldate|log.Ltime) -) - type TraefikOidc struct { next http.Handler name string - store *sessions.CookieStore + store sessions.Store redirURLPath string issuerURL string jwkCache *JWKCache @@ -39,6 +32,8 @@ type TraefikOidc struct { forceHTTPS bool scheme string tokenCache *TokenCache + httpClient HTTPClient + logger Logger } type ProviderMetadata struct { @@ -58,10 +53,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h SameSite: http.SameSiteLaxMode, } - metadata, err := discoverProviderMetadata(config.ProviderURL) + metadata, err := discoverProviderMetadata(config.ProviderURL, &http.Client{}) if err != nil { - return nil, fmt.Errorf("failed to discover provider metadata: %v", err) + return nil, fmt.Errorf("failed to discover provider metadata: %w", err) } + logger := NewLogger(config.LogLevel) t := &TraefikOidc{ next: next, @@ -80,17 +76,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h scopes: config.Scopes, limiter: rate.NewLimiter(rate.Every(time.Second), 100), tokenCache: NewTokenCache(), + httpClient: &http.Client{}, + logger: logger, } t.startTokenCleanup() return t, nil } -func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) { +func discoverProviderMetadata(providerURL string, httpClient HTTPClient) (*ProviderMetadata, error) { wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" - resp, err := http.Get(wellKnownURL) + resp, err := httpClient.Get(wellKnownURL) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to fetch provider metadata: %w", err) } defer resp.Body.Close() @@ -100,13 +98,47 @@ func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) { var metadata ProviderMetadata if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { - return nil, err + return nil, fmt.Errorf("failed to decode provider metadata: %w", err) } return &metadata, nil } func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + t.scheme = t.determineScheme(req) + host := t.determineHost(req) + + redirectURL := buildFullURL(t.scheme, host, t.redirURLPath) + t.logger.Infof("Final redirect URL: %s", redirectURL) + + session, err := t.store.Get(req, cookieName) + if err != nil { + t.logger.Errorf("Error getting session: %v", err) + http.Error(rw, "Session error", http.StatusInternalServerError) + return + } + + if req.URL.Path == t.redirURLPath { + t.logger.Infof("Handling callback, URL: %s", req.URL.String()) + authSuccess, originalPath := t.handleCallback(rw, req) + if authSuccess { + http.Redirect(rw, req, originalPath, http.StatusFound) + return + } + http.Error(rw, "Authentication failed", http.StatusUnauthorized) + return + } + + if t.isUserAuthenticated(session) { + t.next.ServeHTTP(rw, req) + return + } + + // User is not authenticated, start the auth process + t.initiateAuthentication(rw, req, session, redirectURL) +} + +func (t *TraefikOidc) determineScheme(req *http.Request) string { scheme := req.URL.Scheme if scheme == "" { scheme = req.Header.Get("X-Forwarded-Proto") @@ -121,8 +153,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if t.forceHTTPS { scheme = "https" } - t.scheme = scheme + return scheme +} +func (t *TraefikOidc) determineHost(req *http.Request) string { host := req.URL.Host if host == "" { host = req.Header.Get("X-Forwarded-Host") @@ -130,74 +164,36 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if host == "" { host = req.Host } + return host +} - // infoLogger.Printf("Scheme: %s, Host: %s, Path: %s", scheme, host, t.redirURLPath) - // infoLogger.Printf("X-Forwarded-Proto: %s", req.Header.Get("X-Forwarded-Proto")) - // infoLogger.Printf("X-Forwarded-Host: %s", req.Header.Get("X-Forwarded-Host")) - redirectURL := assembleRedirectURL(t.scheme, host, t.redirURLPath) - // infoLogger.Printf("Final redirect URL: %s", redirectURL) - - session, err := t.store.Get(req, cookie_name) - if err != nil { - // infoLogger.Printf("Error getting session: %v", err) - http.Error(rw, "Session error: "+err.Error(), http.StatusInternalServerError) - return - } - - if req.URL.Path == t.redirURLPath { - // infoLogger.Printf("Handling callback, URL: %s", req.URL.String()) - authSuccess, originalPath := t.handleCallback(rw, req) - if authSuccess { - http.Redirect(rw, req, originalPath, http.StatusFound) - return - } - // If auth was not successful, return an error instead of re-authenticating - http.Error(rw, "Authentication failed", http.StatusUnauthorized) - return - } - +func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool { authenticated, _ := session.Values["authenticated"].(bool) if authenticated { idToken, ok := session.Values["id_token"].(string) if !ok || idToken == "" { - http.Error(rw, "Invalid session", http.StatusUnauthorized) - return + return false } - - if err := t.verifyToken(idToken); err != nil { - http.Error(rw, "Invalid token", http.StatusUnauthorized) - return - } - - // Proceed with the request - t.next.ServeHTTP(rw, req) - return + return t.verifyToken(idToken) == nil } + return false +} - // User is not authenticated, start the auth process +func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { csrfToken := uuid.New().String() session.Values["csrf"] = csrfToken session.Values["incoming_path"] = req.URL.Path - // infoLogger.Printf("Setting CSRF token: %s", csrfToken) - err = session.Save(req, rw) - if err != nil { - // infoLogger.Printf("Failed to save session: %v", err) - http.Error(rw, "Failed to save session: "+err.Error(), http.StatusInternalServerError) - return - } + t.logger.Infof("Setting CSRF token: %s", csrfToken) - // Verify the session was saved correctly - verifySession, _ := t.store.Get(req, cookie_name) - savedCSRF, ok := verifySession.Values["csrf"].(string) - if !ok || savedCSRF != csrfToken { - // infoLogger.Printf("Failed to save CSRF token. Saved: %s, Expected: %s", savedCSRF, csrfToken) - http.Error(rw, "Failed to save CSRF token", http.StatusInternalServerError) + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save session: %v", err) + http.Error(rw, "Failed to save session", http.StatusInternalServerError) return } nonce, err := generateNonce() if err != nil { - http.Error(rw, "Failed to generate nonce: "+err.Error(), http.StatusInternalServerError) + http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError) return } @@ -205,22 +201,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { http.Redirect(rw, req, authURL, http.StatusFound) } -func (t *TraefikOidc) isUserAuthenticated(req *http.Request) bool { - session, err := t.store.Get(req, cookie_name) - if err != nil { - return false - } - - if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { - return false - } - - return true -} - func (t *TraefikOidc) verifyToken(token string) error { if !t.limiter.Allow() { - return errors.New("rate limit exceeded") + return fmt.Errorf("rate limit exceeded") } if _, exists := t.tokenCache.Get(token); exists { @@ -229,17 +212,17 @@ func (t *TraefikOidc) verifyToken(token string) error { jwt, err := parseJWT(token) if err != nil { - return err + return fmt.Errorf("failed to parse JWT: %w", err) } - jwks, err := t.jwkCache.GetJWKS(t.jwksURL) + jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) if err != nil { - return err + return fmt.Errorf("failed to get JWKS: %w", err) } kid, ok := jwt.Header["kid"].(string) if !ok { - return errors.New("missing key ID in token header") + return fmt.Errorf("missing key ID in token header") } var publicKeyPEM []byte @@ -247,38 +230,22 @@ func (t *TraefikOidc) verifyToken(token string) error { if key.Kid == kid { publicKeyPEM, err = jwkToPEM(&key) if err != nil { - return err + return fmt.Errorf("failed to convert JWK to PEM: %w", err) } break } } if publicKeyPEM == nil { - return errors.New("unable to find matching public key") + return fmt.Errorf("unable to find matching public key") } if err := verifySignature(token, publicKeyPEM); err != nil { - return err - } - - if err := verifyAudience(jwt.Claims["aud"].(string), t.clientID); err != nil { - return err + return fmt.Errorf("signature verification failed: %w", err) } if err := jwt.Verify(t.issuerURL, t.clientID); err != nil { - return err - } - - if err := verifyTokenTimes( - int64(jwt.Claims["iat"].(float64)), - int64(jwt.Claims["exp"].(float64)), - 5*time.Minute, // Allowed clock skew - ); err != nil { - return err - } - - if err := validateClaims(jwt.Claims); err != nil { - return err + return fmt.Errorf("JWT verification failed: %w", err) } expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) @@ -288,17 +255,16 @@ func (t *TraefikOidc) verifyToken(token string) error { } func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { - params := url.Values{} - params.Add("client_id", t.clientID) - params.Add("response_type", "code") - params.Add("redirect_uri", redirectURL) - params.Add("scope", strings.Join(t.scopes, " ")) - params.Add("state", state) - params.Add("nonce", nonce) + params := url.Values{ + "client_id": {t.clientID}, + "response_type": {"code"}, + "redirect_uri": {redirectURL}, + "scope": {strings.Join(t.scopes, " ")}, + "state": {state}, + "nonce": {nonce}, + } - authURL := t.authURL + "?" + params.Encode() - // infoLogger.Printf("Built auth URL: %s", authURL) - return authURL + return fmt.Sprintf("%s?%s", t.authURL, params.Encode()) } func (t *TraefikOidc) startTokenCleanup() { @@ -306,6 +272,7 @@ func (t *TraefikOidc) startTokenCleanup() { go func() { for range ticker.C { t.tokenCache.Cleanup() + t.tokenBlacklist.Cleanup() } }() } diff --git a/settings.go b/settings.go index 3c88856..bdb0daf 100644 --- a/settings.go +++ b/settings.go @@ -1,10 +1,13 @@ package traefikoidc -import "os" +import ( + "fmt" + "net/http" + "os" +) -// constants const ( - cookie_name = "_raczylo_oidc" + cookieName = "_raczylo_oidc" ) type Config struct { @@ -19,6 +22,59 @@ type Config struct { } func CreateConfig() *Config { - infoLogger.SetOutput(os.Stdout) - return &Config{} + return &Config{ + Scopes: []string{"openid", "profile", "email"}, + LogLevel: "info", + } +} + +func (c *Config) Validate() error { + if c.ProviderURL == "" { + return fmt.Errorf("providerURL is required") + } + if c.CallbackURL == "" { + return fmt.Errorf("callbackURL is required") + } + if c.ClientID == "" { + return fmt.Errorf("clientID is required") + } + if c.ClientSecret == "" { + return fmt.Errorf("clientSecret is required") + } + if c.SessionEncryptionKey == "" { + return fmt.Errorf("sessionEncryptionKey is required") + } + return nil +} + +type defaultLogger struct { + level string +} + +func NewLogger(level string) Logger { + return &defaultLogger{level: level} +} + +func (l *defaultLogger) Infof(format string, args ...interface{}) { + if l.level == "info" || l.level == "debug" { + fmt.Printf("INFO: "+format+"\n", args...) + } +} + +func (l *defaultLogger) Errorf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...) +} + +type HTTPClient interface { + Get(url string) (*http.Response, error) + Do(req *http.Request) (*http.Response, error) +} + +type Logger interface { + Infof(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + +func handleError(w http.ResponseWriter, message string, code int) { + http.Error(w, message, code) }