diff --git a/jwt.go b/jwt.go index 67e1c6d..338e0be 100644 --- a/jwt.go +++ b/jwt.go @@ -83,11 +83,28 @@ func parseJWT(tokenString string) (*JWT, error) { // It checks: // - issuer (iss) matches the expected issuer URL // - audience (aud) includes the client ID -// - expiration time (exp) is in the future -// - issued at time (iat) is in the past +// - expiration time (exp) is in the future (with clock skew tolerance) +// - issued at time (iat) is in the past (with clock skew tolerance) +// - not before time (nbf) is in the past (with clock skew tolerance) // - subject (sub) is present and not empty +// - algorithm matches expected value to prevent algorithm switching attacks // Returns an error if any validation fails. func (j *JWT) Verify(issuerURL, clientID string) error { + // Validate algorithm to prevent algorithm switching attacks + alg, ok := j.Header["alg"].(string) + if !ok { + return fmt.Errorf("missing 'alg' header") + } + // List of supported algorithms - should match those in verifySignature + supportedAlgs := map[string]bool{ + "RS256": true, "RS384": true, "RS512": true, + "PS256": true, "PS384": true, "PS512": true, + "ES256": true, "ES384": true, "ES512": true, + } + if !supportedAlgs[alg] { + return fmt.Errorf("unsupported algorithm") + } + claims := j.Claims iss, ok := claims["iss"].(string) @@ -122,6 +139,18 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return err } + // Validate nbf (not before) claim if present + if nbf, ok := claims["nbf"].(float64); ok { + if err := verifyNotBefore(nbf); err != nil { + return err + } + } + + // Validate jti (JWT ID) claim if present to prevent replay attacks + if _, ok := claims["jti"].(string); !ok { + return fmt.Errorf("missing 'jti' claim") + } + sub, ok := claims["sub"].(string) if !ok || sub == "" { return fmt.Errorf("missing or empty 'sub' claim") @@ -173,33 +202,48 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error { return nil } +// Clock skew tolerance for time-based validations +const clockSkewTolerance = 2 * time.Minute + // verifyExpiration checks if the token's expiration time has passed. -// The expiration time is compared against the current time. +// The expiration time is compared against the current time with clock skew tolerance. // Parameters: // - expiration: The expiration timestamp from the token // Returns an error if the token has expired. func verifyExpiration(expiration float64) error { expirationTime := time.Unix(int64(expiration), 0) - if time.Now().After(expirationTime) { + if time.Now().Add(clockSkewTolerance).After(expirationTime) { return fmt.Errorf("token has expired") } return nil } // verifyIssuedAt validates the token's issued-at time. -// Ensures the token wasn't issued in the future, which could -// indicate clock skew or a malicious token. +// Ensures the token wasn't issued in the future, accounting for clock skew. // Parameters: // - issuedAt: The issued-at timestamp from the token // Returns an error if the token was issued in the future. func verifyIssuedAt(issuedAt float64) error { issuedAtTime := time.Unix(int64(issuedAt), 0) - if time.Now().Before(issuedAtTime) { + if time.Now().Add(-clockSkewTolerance).Before(issuedAtTime) { return fmt.Errorf("token used before issued") } return nil } +// verifyNotBefore validates the token's not-before time if present. +// Ensures the token is not used before its valid time period, accounting for clock skew. +// Parameters: +// - notBefore: The not-before timestamp from the token +// Returns an error if the token is not yet valid. +func verifyNotBefore(notBefore float64) error { + notBeforeTime := time.Unix(int64(notBefore), 0) + if time.Now().Add(-clockSkewTolerance).Before(notBeforeTime) { + return fmt.Errorf("token not yet valid") + } + return nil +} + // verifySignature validates the token's cryptographic signature. // Supports multiple signature algorithms: // - RSA: RS256, RS384, RS512 (PKCS#1 v1.5) diff --git a/main_test.go b/main_test.go index ea445f2..17b80c4 100644 --- a/main_test.go +++ b/main_test.go @@ -67,21 +67,24 @@ func (ts *TestSuite) Setup() { } // Create a test JWT token signed with the RSA private key + now := time.Now() ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", - "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), + "exp": now.Add(1 * time.Hour).Unix(), + "iat": now.Add(-5 * time.Minute).Unix(), // Set issued time in the past to handle clock skew + "nbf": now.Add(-5 * time.Minute).Unix(), // Set not before time in the past "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", + "jti": generateRandomString(16), // Add JWT ID for replay protection }) if err != nil { ts.t.Fatalf("Failed to create test JWT: %v", err) } logger := NewLogger("info") - ts.sessionManager = NewSessionManager("test-secret-key", false, logger) + ts.sessionManager = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ @@ -611,7 +614,7 @@ func TestHandleCallback(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") - sessionManager := NewSessionManager("test-secret-key", false, logger) + sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) // Create a new instance for each test to avoid state carryover tOidc := &TraefikOidc{ @@ -916,7 +919,7 @@ func TestHandleLogout(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") - sessionManager := NewSessionManager("test-secret-key", false, logger) + sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) tOidc := &TraefikOidc{ revocationURL: mockRevocationServer.URL, endSessionURL: tc.endSessionURL, @@ -1205,7 +1208,7 @@ func TestHandleExpiredToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") - sessionManager := NewSessionManager("test-secret-key", false, logger) + sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) tOidc := &TraefikOidc{ sessionManager: sessionManager, @@ -1457,10 +1460,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), + "iat": time.Now().Add(-5 * time.Minute).Unix(), + "nbf": time.Now().Add(-5 * time.Minute).Unix(), "sub": "test-subject", "roles": []interface{}{"admin", "user"}, "groups": []interface{}{"group1"}, + "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) @@ -1481,10 +1486,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), + "iat": time.Now().Add(-5 * time.Minute).Unix(), + "nbf": time.Now().Add(-5 * time.Minute).Unix(), "sub": "test-subject", "roles": []interface{}{"user"}, "groups": []interface{}{"allowed-group"}, + "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) @@ -1506,10 +1513,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), + "iat": time.Now().Add(-5 * time.Minute).Unix(), + "nbf": time.Now().Add(-5 * time.Minute).Unix(), "sub": "test-subject", "roles": []interface{}{"user"}, "groups": []interface{}{"regular-group"}, + "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) @@ -1524,10 +1533,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), + "iat": time.Now().Add(-5 * time.Minute).Unix(), + "nbf": time.Now().Add(-5 * time.Minute).Unix(), "sub": "test-subject", "roles": []interface{}{"user"}, "groups": []interface{}{"regular-group"}, + "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) @@ -1546,8 +1557,10 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), + "iat": time.Now().Add(-5 * time.Minute).Unix(), + "nbf": time.Now().Add(-5 * time.Minute).Unix(), "sub": "test-subject", + "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) diff --git a/session.go b/session.go index 99524cc..669b555 100644 --- a/session.go +++ b/session.go @@ -3,30 +3,37 @@ package traefikoidc import ( "bytes" "compress/gzip" + "crypto/rand" "encoding/base64" + "encoding/hex" "fmt" "io" "net/http" "strings" "sync" + "time" "github.com/gorilla/sessions" ) +// generateSecureRandomString creates a cryptographically secure random string of specified length +func generateSecureRandomString(length int) string { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + panic("failed to generate random string") + } + return hex.EncodeToString(bytes) +} + // Cookie names and configuration constants used for session management +var ( + // Using random prefixes to make cookie names less predictable + mainCookieName = "_oidc_m_" + generateSecureRandomString(8) + accessTokenCookie = "_oidc_a_" + generateSecureRandomString(8) + refreshTokenCookie = "_oidc_r_" + generateSecureRandomString(8) +) + const ( - // mainCookieName is the name of the main session cookie that stores authentication state - // and basic user information like email and CSRF tokens - mainCookieName = "_raczylo_oidc" - - // accessTokenCookie is the name of the cookie that stores the OIDC access token - // This may be split into multiple cookies if the token is large - accessTokenCookie = "_raczylo_oidc_access" - - // refreshTokenCookie is the name of the cookie that stores the OIDC refresh token - // This may be split into multiple cookies if the token is large - refreshTokenCookie = "_raczylo_oidc_refresh" - // maxCookieSize is the maximum size for each cookie chunk. // This value is calculated to ensure the final cookie size stays within browser limits: // 1. Browser cookie size limit is typically 4096 bytes @@ -39,6 +46,13 @@ const ( // - Solving for x: x ≤ 3044 // 4. We use 2000 as a conservative limit to account for cookie metadata maxCookieSize = 2000 + + // absoluteSessionTimeout defines the maximum lifetime of a session + // regardless of activity (24 hours) + absoluteSessionTimeout = 24 * time.Hour + + // minEncryptionKeyLength defines the minimum length for the encryption key + minEncryptionKeyLength = 32 ) // compressToken compresses a token using gzip and base64 encodes it @@ -99,6 +113,11 @@ type SessionManager struct { // - logger: Logger instance for recording session-related events // The manager handles session creation, storage, and cookie security settings. func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { + // Validate encryption key length + if len(encryptionKey) < minEncryptionKeyLength { + panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength)) + } + sm := &SessionManager{ store: sessions.NewCookieStore([]byte(encryptionKey)), forceHTTPS: forceHTTPS, @@ -130,7 +149,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { HttpOnly: true, Secure: isSecure || sm.forceHTTPS, SameSite: http.SameSiteLaxMode, - MaxAge: ConstSessionTimeout, + MaxAge: int(absoluteSessionTimeout.Seconds()), Path: "/", } } @@ -151,6 +170,15 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { return nil, fmt.Errorf("failed to get main session: %w", err) } + // Check for absolute session timeout + if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok { + if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { + sessionData.Clear(r, nil) // Clear expired session + sm.sessionPool.Put(sessionData) + return nil, fmt.Errorf("session expired") + } + } + sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie) if err != nil { sm.sessionPool.Put(sessionData) @@ -294,7 +322,10 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { sd.clearTokenChunks(r, sd.accessTokenChunks) sd.clearTokenChunks(r, sd.refreshTokenChunks) - err := sd.Save(r, w) + var err error + if w != nil { + err = sd.Save(r, w) + } // Return session to pool sd.manager.sessionPool.Put(sd) @@ -315,16 +346,31 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session } // GetAuthenticated returns whether the current session is authenticated. -// Returns true if the user has successfully completed OIDC authentication, -// false otherwise or if the authentication status cannot be determined. +// Returns true if the user has successfully completed OIDC authentication +// and the session hasn't expired, false otherwise. func (sd *SessionData) GetAuthenticated() bool { auth, _ := sd.mainSession.Values["authenticated"].(bool) - return auth + if !auth { + return false + } + + // Check session expiration + createdAt, ok := sd.mainSession.Values["created_at"].(int64) + if !ok { + return false + } + return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout } -// SetAuthenticated updates the session's authentication status. +// SetAuthenticated updates the session's authentication status and rotates session ID. // This should be called after successful OIDC authentication or during logout. +// Session ID rotation helps prevent session fixation attacks. func (sd *SessionData) SetAuthenticated(value bool) { + if value { + // Generate new session ID and set creation time + sd.mainSession.ID = generateSecureRandomString(32) + sd.mainSession.Values["created_at"] = time.Now().Unix() + } sd.mainSession.Values["authenticated"] = value } diff --git a/settings.go b/settings.go index 2ae5209..700b1c7 100644 --- a/settings.go +++ b/settings.go @@ -10,10 +10,6 @@ import ( "strings" ) -const ( - cookieName = "_raczylo_oidc" -) - // Config holds the configuration for the OIDC middleware. // It provides all necessary settings to configure OpenID Connect authentication // with various providers like Auth0, Logto, or any standard OIDC provider. @@ -85,30 +81,34 @@ type Config struct { HTTPClient *http.Client } -// CreateConfig creates a new Config with sensible default values. +const ( + // DefaultRateLimit defines the default rate limit for requests per second + DefaultRateLimit = 100 + + // MinRateLimit defines the minimum allowed rate limit to prevent DOS + MinRateLimit = 10 + + // DefaultLogLevel defines the default logging level + DefaultLogLevel = "info" + + // MinSessionEncryptionKeyLength defines the minimum length for session encryption key + MinSessionEncryptionKeyLength = 32 +) + +// CreateConfig creates a new Config with secure default values. // Default values are set for optional fields: // - Scopes: ["openid", "profile", "email"] // - LogLevel: "info" // - LogoutURL: CallbackURL + "/logout" // - RateLimit: 100 requests per second // - PostLogoutRedirectURI: "/" +// - ForceHTTPS: true (for security) func CreateConfig() *Config { - c := &Config{} - - if c.Scopes == nil { - c.Scopes = []string{"openid", "profile", "email"} - } - - if c.LogLevel == "" { - c.LogLevel = "info" - } - - if c.LogoutURL == "" { - c.LogoutURL = c.CallbackURL + "/logout" - } - - if c.RateLimit == 0 { - c.RateLimit = 100 + c := &Config{ + Scopes: []string{"openid", "profile", "email"}, + LogLevel: DefaultLogLevel, + RateLimit: DefaultRateLimit, + ForceHTTPS: true, // Secure by default } return c @@ -118,43 +118,85 @@ func CreateConfig() *Config { // It ensures all required fields are set and have valid values. // Returns an error if any validation check fails. func (c *Config) Validate() error { + // Validate provider URL if c.ProviderURL == "" { return fmt.Errorf("providerURL is required") } - if !isValidURL(c.ProviderURL) { - return fmt.Errorf("providerURL must be a valid URL") + if !isValidSecureURL(c.ProviderURL) { + return fmt.Errorf("providerURL must be a valid HTTPS URL") } + + // Validate callback URL if c.CallbackURL == "" { return fmt.Errorf("callbackURL is required") } if !strings.HasPrefix(c.CallbackURL, "/") { return fmt.Errorf("callbackURL must start with /") } + + // Validate client credentials if c.ClientID == "" { return fmt.Errorf("clientID is required") } if c.ClientSecret == "" { return fmt.Errorf("clientSecret is required") } + + // Validate session encryption key if c.SessionEncryptionKey == "" { return fmt.Errorf("sessionEncryptionKey is required") } - if len(c.SessionEncryptionKey) < 32 { - return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long") - } - if c.RateLimit < 0 { - return fmt.Errorf("rateLimit must be non-negative") + if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength { + return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength) } + + // Validate log level if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) { return fmt.Errorf("logLevel must be one of: debug, info, error") } + + // Validate excluded URLs + for _, url := range c.ExcludedURLs { + if !strings.HasPrefix(url, "/") { + return fmt.Errorf("excluded URL must start with /: %s", url) + } + if strings.Contains(url, "..") { + return fmt.Errorf("excluded URL must not contain path traversal: %s", url) + } + if strings.Contains(url, "*") { + return fmt.Errorf("excluded URL must not contain wildcards: %s", url) + } + } + + // Validate revocation URL if set + if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) { + return fmt.Errorf("revocationURL must be a valid HTTPS URL") + } + + // Validate end session URL if set + if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) { + return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL") + } + + // Validate post-logout redirect URI if set + if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" { + if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") { + return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /") + } + } + + // Validate rate limit + if c.RateLimit < MinRateLimit { + return fmt.Errorf("rateLimit must be at least %d", MinRateLimit) + } + return nil } -// isValidURL checks if the provided string is a valid URL -func isValidURL(s string) bool { +// isValidSecureURL checks if the provided string is a valid HTTPS URL +func isValidSecureURL(s string) bool { u, err := url.Parse(s) - return err == nil && u.Scheme != "" && u.Host != "" + return err == nil && u.Scheme == "https" && u.Host != "" } // isValidLogLevel checks if the provided log level is valid diff --git a/settings_test.go b/settings_test.go index 0662c97..83ae05a 100644 --- a/settings_test.go +++ b/settings_test.go @@ -23,13 +23,18 @@ func TestCreateConfig(t *testing.T) { } // Check default log level - if config.LogLevel != "info" { - t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel) + if config.LogLevel != DefaultLogLevel { + t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel) } // Check default rate limit - if config.RateLimit != 100 { - t.Errorf("Expected default rate limit 100, got %d", config.RateLimit) + if config.RateLimit != DefaultRateLimit { + t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit) + } + + // Check ForceHTTPS default + if !config.ForceHTTPS { + t.Error("Expected ForceHTTPS to be true by default") } }) @@ -38,6 +43,7 @@ func TestCreateConfig(t *testing.T) { config.Scopes = []string{"custom_scope"} config.LogLevel = "debug" config.RateLimit = 50 + config.ForceHTTPS = false // Verify custom values are not overwritten if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" { @@ -49,6 +55,9 @@ func TestCreateConfig(t *testing.T) { if config.RateLimit != 50 { t.Error("Custom rate limit was overwritten") } + if config.ForceHTTPS { + t.Error("Custom ForceHTTPS value was overwritten") + } }) } @@ -98,15 +107,15 @@ func TestConfigValidate(t *testing.T) { expectedError: "sessionEncryptionKey is required", }, { - name: "Invalid ProviderURL", + name: "Non-HTTPS ProviderURL", config: &Config{ - ProviderURL: "not-a-url", + ProviderURL: "http://provider.com", CallbackURL: "/callback", ClientID: "client-id", ClientSecret: "client-secret", SessionEncryptionKey: "encryption-key", }, - expectedError: "providerURL must be a valid URL", + expectedError: "providerURL must be a valid HTTPS URL", }, { name: "Invalid CallbackURL", @@ -131,16 +140,16 @@ func TestConfigValidate(t *testing.T) { expectedError: "sessionEncryptionKey must be at least 32 characters long", }, { - name: "Negative RateLimit", + name: "Low RateLimit", config: &Config{ ProviderURL: "https://provider.com", CallbackURL: "/callback", ClientID: "client-id", ClientSecret: "client-secret", SessionEncryptionKey: "this-is-a-long-enough-encryption-key", - RateLimit: -1, + RateLimit: 5, }, - expectedError: "rateLimit must be non-negative", + expectedError: "rateLimit must be at least 10", }, { name: "Invalid LogLevel", @@ -154,6 +163,30 @@ func TestConfigValidate(t *testing.T) { }, expectedError: "logLevel must be one of: debug, info, error", }, + { + name: "Non-HTTPS RevocationURL", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "this-is-a-long-enough-encryption-key", + RevocationURL: "http://revoke.com", + }, + expectedError: "revocationURL must be a valid HTTPS URL", + }, + { + name: "Non-HTTPS OIDCEndSessionURL", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "this-is-a-long-enough-encryption-key", + OIDCEndSessionURL: "http://endsession.com", + }, + expectedError: "oidcEndSessionURL must be a valid HTTPS URL", + }, { name: "Valid Config", config: &Config{ @@ -164,6 +197,8 @@ func TestConfigValidate(t *testing.T) { SessionEncryptionKey: "this-is-a-long-enough-encryption-key", LogLevel: "debug", RateLimit: 100, + RevocationURL: "https://revoke.com", + OIDCEndSessionURL: "https://endsession.com", }, expectedError: "", }, @@ -192,9 +227,9 @@ func TestLogger(t *testing.T) { var debugBuf, infoBuf, errorBuf bytes.Buffer tests := []struct { - name string - logLevel string - testFunc func(*Logger) + name string + logLevel string + testFunc func(*Logger) checkFunc func(t *testing.T, debugOut, infoOut, errorOut string) }{ { @@ -289,7 +324,7 @@ func TestLogger(t *testing.T) { // Create logger with test buffers logger := NewLogger(tc.logLevel) logger.logError.SetOutput(&errorBuf) - + if tc.logLevel == "debug" || tc.logLevel == "info" { logger.logInfo.SetOutput(&infoBuf) }