mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
17dea67229
JWT Token Security: Protected against algorithm switching attacks by validating and whitelisting algorithms (RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512) Added 2-minute clock skew tolerance for time-based validations Added "not before" (nbf) claim validation with clock skew tolerance Required JWT ID (jti) claim to prevent replay attacks Added strict algorithm validation to prevent downgrade attacks Session Management Security: Implemented cryptographically secure random cookie names to prevent targeting Added automatic session ID rotation after successful login to prevent session fixation Enforced 24-hour absolute session timeout Added strict encryption key length validation (minimum 32 bytes) Added comprehensive session validation including timeout checks Implemented session pooling for secure resource management Added secure session cleanup on expiration Configuration and URL Security: Enforced HTTPS for all provider URLs and external endpoints Added minimum rate limit (10 req/sec) to prevent DOS attacks Added strict validation for excluded URLs: Must start with "/" No path traversal (..) No wildcards (*) Made ForceHTTPS true by default for secure cookies Added validation for secure redirect URIs Added validation for all OIDC endpoints (must be HTTPS) Added secure defaults in configuration Test Coverage: Added comprehensive test cases verifying all security validations Added test cases for HTTPS enforcement on all endpoints Added test cases for minimum rate limits Added test cases for secure session management Added test cases for token validation with clock skew Added test cases for secure configuration defaults All security improvements have been verified through passing test cases, protecting against: Session fixation attacks Token replay attacks Algorithm switching attacks Path traversal attacks Session hijacking Timing attacks DOS attacks Man-in-the-middle attacks through enforced HTTPS
398 lines
10 KiB
Go
398 lines
10 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"bytes"
|
|
"log"
|
|
"net/http"
|
|
"testing"
|
|
)
|
|
|
|
func TestCreateConfig(t *testing.T) {
|
|
t.Run("Default Values", func(t *testing.T) {
|
|
config := CreateConfig()
|
|
|
|
// Check default scopes
|
|
expectedScopes := []string{"openid", "profile", "email"}
|
|
if len(config.Scopes) != len(expectedScopes) {
|
|
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
|
|
}
|
|
for i, scope := range expectedScopes {
|
|
if config.Scopes[i] != scope {
|
|
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
|
|
}
|
|
}
|
|
|
|
// Check default log level
|
|
if config.LogLevel != DefaultLogLevel {
|
|
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
|
|
}
|
|
|
|
// Check default rate limit
|
|
if config.RateLimit != DefaultRateLimit {
|
|
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
|
|
}
|
|
|
|
// Check ForceHTTPS default
|
|
if !config.ForceHTTPS {
|
|
t.Error("Expected ForceHTTPS to be true by default")
|
|
}
|
|
})
|
|
|
|
t.Run("Custom Values Preserved", func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.Scopes = []string{"custom_scope"}
|
|
config.LogLevel = "debug"
|
|
config.RateLimit = 50
|
|
config.ForceHTTPS = false
|
|
|
|
// Verify custom values are not overwritten
|
|
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
|
|
t.Error("Custom scopes were overwritten")
|
|
}
|
|
if config.LogLevel != "debug" {
|
|
t.Error("Custom log level was overwritten")
|
|
}
|
|
if config.RateLimit != 50 {
|
|
t.Error("Custom rate limit was overwritten")
|
|
}
|
|
if config.ForceHTTPS {
|
|
t.Error("Custom ForceHTTPS value was overwritten")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConfigValidate(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config *Config
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "Empty Config",
|
|
config: &Config{},
|
|
expectedError: "providerURL is required",
|
|
},
|
|
{
|
|
name: "Missing CallbackURL",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
},
|
|
expectedError: "callbackURL is required",
|
|
},
|
|
{
|
|
name: "Missing ClientID",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
},
|
|
expectedError: "clientID is required",
|
|
},
|
|
{
|
|
name: "Missing ClientSecret",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
},
|
|
expectedError: "clientSecret is required",
|
|
},
|
|
{
|
|
name: "Missing SessionEncryptionKey",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
},
|
|
expectedError: "sessionEncryptionKey is required",
|
|
},
|
|
{
|
|
name: "Non-HTTPS ProviderURL",
|
|
config: &Config{
|
|
ProviderURL: "http://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "encryption-key",
|
|
},
|
|
expectedError: "providerURL must be a valid HTTPS URL",
|
|
},
|
|
{
|
|
name: "Invalid CallbackURL",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "callback", // Missing leading slash
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "encryption-key",
|
|
},
|
|
expectedError: "callbackURL must start with /",
|
|
},
|
|
{
|
|
name: "Short SessionEncryptionKey",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "short",
|
|
},
|
|
expectedError: "sessionEncryptionKey must be at least 32 characters long",
|
|
},
|
|
{
|
|
name: "Low RateLimit",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
|
RateLimit: 5,
|
|
},
|
|
expectedError: "rateLimit must be at least 10",
|
|
},
|
|
{
|
|
name: "Invalid LogLevel",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
|
LogLevel: "invalid",
|
|
},
|
|
expectedError: "logLevel must be one of: debug, info, error",
|
|
},
|
|
{
|
|
name: "Non-HTTPS RevocationURL",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
|
RevocationURL: "http://revoke.com",
|
|
},
|
|
expectedError: "revocationURL must be a valid HTTPS URL",
|
|
},
|
|
{
|
|
name: "Non-HTTPS OIDCEndSessionURL",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
|
OIDCEndSessionURL: "http://endsession.com",
|
|
},
|
|
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
|
|
},
|
|
{
|
|
name: "Valid Config",
|
|
config: &Config{
|
|
ProviderURL: "https://provider.com",
|
|
CallbackURL: "/callback",
|
|
ClientID: "client-id",
|
|
ClientSecret: "client-secret",
|
|
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
|
LogLevel: "debug",
|
|
RateLimit: 100,
|
|
RevocationURL: "https://revoke.com",
|
|
OIDCEndSessionURL: "https://endsession.com",
|
|
},
|
|
expectedError: "",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := tc.config.Validate()
|
|
if tc.expectedError == "" {
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got: %v", err)
|
|
}
|
|
} else {
|
|
if err == nil {
|
|
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
|
|
} else if err.Error() != tc.expectedError {
|
|
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLogger(t *testing.T) {
|
|
// Capture log output
|
|
var debugBuf, infoBuf, errorBuf bytes.Buffer
|
|
|
|
tests := []struct {
|
|
name string
|
|
logLevel string
|
|
testFunc func(*Logger)
|
|
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
|
|
}{
|
|
{
|
|
name: "Debug Level",
|
|
logLevel: "debug",
|
|
testFunc: func(l *Logger) {
|
|
l.Debug("debug message")
|
|
l.Info("info message")
|
|
l.Error("error message")
|
|
},
|
|
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
|
if debugOut == "" {
|
|
t.Error("Expected debug message in output")
|
|
}
|
|
if infoOut == "" {
|
|
t.Error("Expected info message in output")
|
|
}
|
|
if errorOut == "" {
|
|
t.Error("Expected error message in output")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Info Level",
|
|
logLevel: "info",
|
|
testFunc: func(l *Logger) {
|
|
l.Debug("debug message")
|
|
l.Info("info message")
|
|
l.Error("error message")
|
|
},
|
|
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
|
if debugOut != "" {
|
|
t.Error("Did not expect debug message in output")
|
|
}
|
|
if infoOut == "" {
|
|
t.Error("Expected info message in output")
|
|
}
|
|
if errorOut == "" {
|
|
t.Error("Expected error message in output")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Error Level",
|
|
logLevel: "error",
|
|
testFunc: func(l *Logger) {
|
|
l.Debug("debug message")
|
|
l.Info("info message")
|
|
l.Error("error message")
|
|
},
|
|
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
|
if debugOut != "" {
|
|
t.Error("Did not expect debug message in output")
|
|
}
|
|
if infoOut != "" {
|
|
t.Error("Did not expect info message in output")
|
|
}
|
|
if errorOut == "" {
|
|
t.Error("Expected error message in output")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Printf Methods",
|
|
logLevel: "debug",
|
|
testFunc: func(l *Logger) {
|
|
l.Debugf("debug %s", "formatted")
|
|
l.Infof("info %s", "formatted")
|
|
l.Errorf("error %s", "formatted")
|
|
},
|
|
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
|
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
|
|
t.Error("Expected formatted debug message")
|
|
}
|
|
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
|
|
t.Error("Expected formatted info message")
|
|
}
|
|
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
|
|
t.Error("Expected formatted error message")
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset buffers
|
|
debugBuf.Reset()
|
|
infoBuf.Reset()
|
|
errorBuf.Reset()
|
|
|
|
// Create logger with test buffers
|
|
logger := NewLogger(tc.logLevel)
|
|
logger.logError.SetOutput(&errorBuf)
|
|
|
|
if tc.logLevel == "debug" || tc.logLevel == "info" {
|
|
logger.logInfo.SetOutput(&infoBuf)
|
|
}
|
|
if tc.logLevel == "debug" {
|
|
logger.logDebug.SetOutput(&debugBuf)
|
|
}
|
|
|
|
// Run test
|
|
tc.testFunc(logger)
|
|
|
|
// Check results
|
|
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleError(t *testing.T) {
|
|
// Create a test logger with captured output
|
|
var errorBuf bytes.Buffer
|
|
logger := &Logger{
|
|
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
|
|
}
|
|
logger.logError.SetOutput(&errorBuf)
|
|
|
|
// Create a test response recorder
|
|
rr := &testResponseRecorder{
|
|
headers: make(map[string][]string),
|
|
}
|
|
|
|
// Test error handling
|
|
message := "test error message"
|
|
code := 400
|
|
handleError(rr, message, code, logger)
|
|
|
|
// Check response code
|
|
if rr.statusCode != code {
|
|
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
|
|
}
|
|
|
|
// Check response body
|
|
expectedBody := message + "\n"
|
|
if rr.body != expectedBody {
|
|
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
|
|
}
|
|
|
|
// Check error was logged
|
|
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
|
|
t.Error("Error message was not logged")
|
|
}
|
|
}
|
|
|
|
// Test helper types
|
|
type testResponseRecorder struct {
|
|
statusCode int
|
|
body string
|
|
headers map[string][]string
|
|
}
|
|
|
|
func (r *testResponseRecorder) Header() http.Header {
|
|
return r.headers
|
|
}
|
|
|
|
func (r *testResponseRecorder) Write(b []byte) (int, error) {
|
|
r.body = string(b)
|
|
return len(b), nil
|
|
}
|
|
|
|
func (r *testResponseRecorder) WriteHeader(code int) {
|
|
r.statusCode = code
|
|
}
|