mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
ae59a5e88a
* Add ability to disable replay protection. - This is useful for runs with multiple traefik replicas to avoid false positives and tokens re-creation. * Enhance the CI/CD pipelines * Increase test coverage. * Update vendored dependencies. * Update behaviour on forceHTTPS as per issue #82
849 lines
20 KiB
Go
849 lines
20 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// Test Circuit Breaker State Transitions
|
|
|
|
func TestCircuitBreakerStateTransitions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
failures int
|
|
maxFailures int
|
|
expectedStateBefore string
|
|
expectedStateAfter string
|
|
}{
|
|
{
|
|
name: "stays closed below threshold",
|
|
failures: 1,
|
|
maxFailures: 3,
|
|
expectedStateBefore: "closed",
|
|
expectedStateAfter: "closed",
|
|
},
|
|
{
|
|
name: "opens at threshold",
|
|
failures: 3,
|
|
maxFailures: 3,
|
|
expectedStateBefore: "closed",
|
|
expectedStateAfter: "open",
|
|
},
|
|
{
|
|
name: "opens above threshold",
|
|
failures: 5,
|
|
maxFailures: 3,
|
|
expectedStateBefore: "closed",
|
|
expectedStateAfter: "open",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: tt.maxFailures,
|
|
Timeout: time.Second,
|
|
ResetTimeout: time.Second,
|
|
}, nil)
|
|
|
|
// Verify initial state
|
|
if state := circuitBreakerStateToString(cb.GetState()); state != tt.expectedStateBefore {
|
|
t.Errorf("Expected initial state %s, got %s", tt.expectedStateBefore, state)
|
|
}
|
|
|
|
// Trigger failures
|
|
for i := 0; i < tt.failures; i++ {
|
|
_ = cb.Execute(func() error {
|
|
return errors.New("test failure")
|
|
})
|
|
}
|
|
|
|
// Verify final state
|
|
if state := circuitBreakerStateToString(cb.GetState()); state != tt.expectedStateAfter {
|
|
t.Errorf("Expected final state %s, got %s", tt.expectedStateAfter, state)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerHalfOpenTransition(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: 2,
|
|
Timeout: 100 * time.Millisecond,
|
|
ResetTimeout: 50 * time.Millisecond,
|
|
}, nil)
|
|
|
|
// Open the circuit
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
|
|
if cb.GetState() != CircuitBreakerOpen {
|
|
t.Error("Circuit should be open after failures")
|
|
}
|
|
|
|
// Wait for timeout to trigger half-open
|
|
time.Sleep(150 * time.Millisecond)
|
|
|
|
// Next request should be allowed (half-open)
|
|
allowed := false
|
|
_ = cb.Execute(func() error {
|
|
allowed = true
|
|
return nil
|
|
})
|
|
|
|
if !allowed {
|
|
t.Error("Request should be allowed in half-open state")
|
|
}
|
|
|
|
// Successful request should close the circuit
|
|
if cb.GetState() != CircuitBreakerClosed {
|
|
t.Errorf("Circuit should be closed after successful half-open request, got %v", cb.GetState())
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerHalfOpenFailure(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: 2,
|
|
Timeout: 100 * time.Millisecond,
|
|
ResetTimeout: 50 * time.Millisecond,
|
|
}, nil)
|
|
|
|
// Open the circuit
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
|
|
// Wait for half-open
|
|
time.Sleep(150 * time.Millisecond)
|
|
|
|
// Fail in half-open state
|
|
_ = cb.Execute(func() error {
|
|
return errors.New("fail again")
|
|
})
|
|
|
|
// Should return to open state
|
|
if cb.GetState() != CircuitBreakerOpen {
|
|
t.Errorf("Circuit should be open after half-open failure, got %v", cb.GetState())
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerConcurrency(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: 10,
|
|
Timeout: time.Second,
|
|
ResetTimeout: time.Second,
|
|
}, nil)
|
|
|
|
var wg sync.WaitGroup
|
|
successCount := int64(0)
|
|
failureCount := int64(0)
|
|
|
|
// Concurrent successful requests
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
err := cb.Execute(func() error {
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
atomic.AddInt64(&successCount, 1)
|
|
} else {
|
|
atomic.AddInt64(&failureCount, 1)
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if successCount != 100 {
|
|
t.Errorf("Expected 100 successful requests, got %d", successCount)
|
|
}
|
|
|
|
metrics := cb.GetMetrics()
|
|
if metrics["total_requests"].(int64) != 100 {
|
|
t.Errorf("Expected 100 total requests, got %d", metrics["total_requests"])
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerReset(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: 2,
|
|
Timeout: time.Second,
|
|
ResetTimeout: time.Second,
|
|
}, nil)
|
|
|
|
// Open the circuit
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
|
|
if cb.GetState() != CircuitBreakerOpen {
|
|
t.Error("Circuit should be open")
|
|
}
|
|
|
|
// Reset
|
|
cb.Reset()
|
|
|
|
if cb.GetState() != CircuitBreakerClosed {
|
|
t.Error("Circuit should be closed after reset")
|
|
}
|
|
|
|
// Should allow requests after reset
|
|
err := cb.Execute(func() error {
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Should allow requests after reset, got error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerMetrics(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: 3,
|
|
Timeout: time.Second,
|
|
ResetTimeout: time.Second,
|
|
}, nil)
|
|
|
|
// Execute some requests
|
|
_ = cb.Execute(func() error { return nil })
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
_ = cb.Execute(func() error { return nil })
|
|
|
|
metrics := cb.GetMetrics()
|
|
|
|
if metrics["total_requests"].(int64) != 3 {
|
|
t.Errorf("Expected 3 requests, got %d", metrics["total_requests"])
|
|
}
|
|
|
|
if metrics["total_successes"].(int64) != 2 {
|
|
t.Errorf("Expected 2 successes, got %d", metrics["total_successes"])
|
|
}
|
|
|
|
if metrics["total_failures"].(int64) != 1 {
|
|
t.Errorf("Expected 1 failure, got %d", metrics["total_failures"])
|
|
}
|
|
|
|
if metrics["state"] != "closed" {
|
|
t.Errorf("Expected state 'closed', got %v", metrics["state"])
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerIsAvailable(t *testing.T) {
|
|
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
|
MaxFailures: 2,
|
|
Timeout: 100 * time.Millisecond,
|
|
ResetTimeout: 50 * time.Millisecond,
|
|
}, nil)
|
|
|
|
// Should be available initially
|
|
if !cb.IsAvailable() {
|
|
t.Error("Circuit should be available initially")
|
|
}
|
|
|
|
// Open the circuit
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
_ = cb.Execute(func() error { return errors.New("fail") })
|
|
|
|
// Should not be available when open
|
|
if cb.IsAvailable() {
|
|
t.Error("Circuit should not be available when open")
|
|
}
|
|
|
|
// Wait for timeout
|
|
time.Sleep(150 * time.Millisecond)
|
|
|
|
// Should be available in half-open
|
|
if !cb.IsAvailable() {
|
|
t.Error("Circuit should be available in half-open state")
|
|
}
|
|
}
|
|
|
|
// Test Retry Executor
|
|
|
|
func TestRetryExecutorSuccess(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
}, nil)
|
|
|
|
attempts := 0
|
|
err := re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
|
|
if attempts != 1 {
|
|
t.Errorf("Expected 1 attempt for immediate success, got %d", attempts)
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorEventualSuccess(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
RetryableErrors: []string{"temporary failure"},
|
|
}, nil)
|
|
|
|
attempts := 0
|
|
err := re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
if attempts < 3 {
|
|
return errors.New("temporary failure")
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected success after retries, got %v", err)
|
|
}
|
|
|
|
if attempts != 3 {
|
|
t.Errorf("Expected 3 attempts, got %d", attempts)
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorMaxAttemptsExceeded(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
RetryableErrors: []string{"temporary failure"},
|
|
}, nil)
|
|
|
|
attempts := 0
|
|
err := re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
return errors.New("temporary failure")
|
|
})
|
|
|
|
if err == nil {
|
|
t.Error("Expected error after max attempts")
|
|
}
|
|
|
|
if attempts != 3 {
|
|
t.Errorf("Expected 3 attempts, got %d", attempts)
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorNonRetryableError(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
RetryableErrors: []string{"temporary failure"},
|
|
}, nil)
|
|
|
|
attempts := 0
|
|
err := re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
return errors.New("permanent failure")
|
|
})
|
|
|
|
if err == nil {
|
|
t.Error("Expected error for non-retryable failure")
|
|
}
|
|
|
|
if attempts != 1 {
|
|
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorContextCancellation(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 5,
|
|
InitialDelay: 100 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
RetryableErrors: []string{"temporary failure"},
|
|
}, nil)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
attempts := 0
|
|
done := make(chan error, 1)
|
|
|
|
go func() {
|
|
done <- re.ExecuteWithContext(ctx, func() error {
|
|
attempts++
|
|
return errors.New("temporary failure")
|
|
})
|
|
}()
|
|
|
|
// Cancel after short delay
|
|
time.Sleep(150 * time.Millisecond)
|
|
cancel()
|
|
|
|
err := <-done
|
|
|
|
if err != context.Canceled {
|
|
t.Errorf("Expected context.Canceled error, got %v", err)
|
|
}
|
|
|
|
if attempts == 0 {
|
|
t.Error("Should have attempted at least once")
|
|
}
|
|
|
|
if attempts >= 5 {
|
|
t.Error("Should not have completed all attempts after cancellation")
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorExponentialBackoff(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 4,
|
|
InitialDelay: 100 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
RetryableErrors: []string{"temporary failure"},
|
|
}, nil)
|
|
|
|
attempts := 0
|
|
startTime := time.Now()
|
|
|
|
_ = re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
return errors.New("temporary failure")
|
|
})
|
|
|
|
elapsed := time.Since(startTime)
|
|
|
|
// Should have delays: 100ms, 200ms, 400ms = 700ms total (approx)
|
|
if elapsed < 650*time.Millisecond || elapsed > 850*time.Millisecond {
|
|
t.Errorf("Expected ~700ms elapsed with exponential backoff, got %v", elapsed)
|
|
}
|
|
|
|
if attempts != 4 {
|
|
t.Errorf("Expected 4 attempts, got %d", attempts)
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorWithJitter(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 100 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: true,
|
|
RetryableErrors: []string{"temporary failure"},
|
|
}, nil)
|
|
|
|
// Run multiple times to verify jitter adds variability
|
|
durations := make([]time.Duration, 5)
|
|
for i := 0; i < 5; i++ {
|
|
startTime := time.Now()
|
|
_ = re.ExecuteWithContext(context.Background(), func() error {
|
|
return errors.New("temporary failure")
|
|
})
|
|
durations[i] = time.Since(startTime)
|
|
}
|
|
|
|
// Check that not all durations are identical (jitter should add variance)
|
|
allSame := true
|
|
for i := 1; i < len(durations); i++ {
|
|
if durations[i] != durations[0] {
|
|
allSame = false
|
|
break
|
|
}
|
|
}
|
|
|
|
if allSame {
|
|
t.Error("Expected jitter to add variability to retry delays")
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorNetworkErrors(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
}, nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
shouldRetry bool
|
|
}{
|
|
{
|
|
name: "timeout error",
|
|
err: &mockNetError{timeout: true, temporary: true},
|
|
shouldRetry: true,
|
|
},
|
|
{
|
|
name: "temporary network error",
|
|
err: &mockNetError{timeout: false, temporary: true, msg: "temporary failure"},
|
|
shouldRetry: true,
|
|
},
|
|
{
|
|
name: "connection refused",
|
|
err: &mockNetError{timeout: false, temporary: false, msg: "connection refused"},
|
|
shouldRetry: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
attempts := 0
|
|
_ = re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
return tt.err
|
|
})
|
|
|
|
expectedAttempts := 1
|
|
if tt.shouldRetry {
|
|
expectedAttempts = 3
|
|
}
|
|
|
|
if attempts != expectedAttempts {
|
|
t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorHTTPErrors(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: false,
|
|
}, nil)
|
|
|
|
tests := []struct {
|
|
name string
|
|
statusCode int
|
|
shouldRetry bool
|
|
}{
|
|
{"500 Internal Server Error", 500, true},
|
|
{"502 Bad Gateway", 502, true},
|
|
{"503 Service Unavailable", 503, true},
|
|
{"429 Too Many Requests", 429, true},
|
|
{"400 Bad Request", 400, false},
|
|
{"404 Not Found", 404, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
attempts := 0
|
|
_ = re.ExecuteWithContext(context.Background(), func() error {
|
|
attempts++
|
|
return &HTTPError{StatusCode: tt.statusCode, Message: "test"}
|
|
})
|
|
|
|
expectedAttempts := 1
|
|
if tt.shouldRetry {
|
|
expectedAttempts = 3
|
|
}
|
|
|
|
if attempts != expectedAttempts {
|
|
t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRetryExecutorMetrics(t *testing.T) {
|
|
re := NewRetryExecutor(RetryConfig{
|
|
MaxAttempts: 3,
|
|
InitialDelay: 10 * time.Millisecond,
|
|
MaxDelay: time.Second,
|
|
BackoffFactor: 2.0,
|
|
EnableJitter: true,
|
|
}, nil)
|
|
|
|
_ = re.ExecuteWithContext(context.Background(), func() error {
|
|
return nil
|
|
})
|
|
|
|
metrics := re.GetMetrics()
|
|
|
|
if metrics["max_attempts"] != 3 {
|
|
t.Errorf("Expected max_attempts 3, got %v", metrics["max_attempts"])
|
|
}
|
|
|
|
if metrics["backoff_factor"] != 2.0 {
|
|
t.Errorf("Expected backoff_factor 2.0, got %v", metrics["backoff_factor"])
|
|
}
|
|
|
|
if metrics["enable_jitter"] != true {
|
|
t.Errorf("Expected enable_jitter true, got %v", metrics["enable_jitter"])
|
|
}
|
|
}
|
|
|
|
// Test Error Types
|
|
|
|
func TestOIDCErrorCreation(t *testing.T) {
|
|
err := NewOIDCError("invalid_token", "Token is expired", nil)
|
|
|
|
if err.Code != "invalid_token" {
|
|
t.Errorf("Expected code 'invalid_token', got %s", err.Code)
|
|
}
|
|
|
|
if err.Message != "Token is expired" {
|
|
t.Errorf("Expected message 'Token is expired', got %s", err.Message)
|
|
}
|
|
|
|
expectedMsg := "OIDC error [invalid_token]: Token is expired"
|
|
if err.Error() != expectedMsg {
|
|
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
|
}
|
|
}
|
|
|
|
func TestOIDCErrorWithCause(t *testing.T) {
|
|
cause := errors.New("underlying error")
|
|
err := NewOIDCError("token_error", "Failed to validate", cause)
|
|
|
|
if err.Unwrap() != cause {
|
|
t.Error("Expected unwrap to return underlying cause")
|
|
}
|
|
|
|
if err.Error() == "" {
|
|
t.Error("Error string should include cause")
|
|
}
|
|
}
|
|
|
|
func TestOIDCErrorWithContext(t *testing.T) {
|
|
err := NewOIDCError("auth_failed", "Authentication failed", nil).
|
|
WithContext("provider", "google").
|
|
WithContext("user_id", "12345")
|
|
|
|
if err.Context["provider"] != "google" {
|
|
t.Errorf("Expected provider 'google', got %v", err.Context["provider"])
|
|
}
|
|
|
|
if err.Context["user_id"] != "12345" {
|
|
t.Errorf("Expected user_id '12345', got %v", err.Context["user_id"])
|
|
}
|
|
}
|
|
|
|
func TestSessionErrorCreation(t *testing.T) {
|
|
err := NewSessionError("save", "Failed to save session", nil)
|
|
|
|
if err.Operation != "save" {
|
|
t.Errorf("Expected operation 'save', got %s", err.Operation)
|
|
}
|
|
|
|
expectedMsg := "Session error in save: Failed to save session"
|
|
if err.Error() != expectedMsg {
|
|
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
|
}
|
|
}
|
|
|
|
func TestSessionErrorWithSessionID(t *testing.T) {
|
|
err := NewSessionError("load", "Session not found", nil).
|
|
WithSessionID("sess_12345")
|
|
|
|
if err.SessionID != "sess_12345" {
|
|
t.Errorf("Expected session ID 'sess_12345', got %s", err.SessionID)
|
|
}
|
|
}
|
|
|
|
func TestTokenErrorCreation(t *testing.T) {
|
|
err := NewTokenError("id_token", "expired", "Token has expired", nil)
|
|
|
|
if err.TokenType != "id_token" {
|
|
t.Errorf("Expected token type 'id_token', got %s", err.TokenType)
|
|
}
|
|
|
|
if err.Reason != "expired" {
|
|
t.Errorf("Expected reason 'expired', got %s", err.Reason)
|
|
}
|
|
|
|
expectedMsg := "Token error (id_token) - expired: Token has expired"
|
|
if err.Error() != expectedMsg {
|
|
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
|
}
|
|
}
|
|
|
|
// Test Base Recovery Mechanism
|
|
|
|
func TestBaseRecoveryMechanismMetrics(t *testing.T) {
|
|
base := NewBaseRecoveryMechanism("test-mechanism", nil)
|
|
|
|
base.RecordRequest()
|
|
base.RecordSuccess()
|
|
base.RecordRequest()
|
|
base.RecordFailure()
|
|
|
|
metrics := base.GetBaseMetrics()
|
|
|
|
if metrics["total_requests"].(int64) != 2 {
|
|
t.Errorf("Expected 2 requests, got %d", metrics["total_requests"])
|
|
}
|
|
|
|
if metrics["total_successes"].(int64) != 1 {
|
|
t.Errorf("Expected 1 success, got %d", metrics["total_successes"])
|
|
}
|
|
|
|
if metrics["total_failures"].(int64) != 1 {
|
|
t.Errorf("Expected 1 failure, got %d", metrics["total_failures"])
|
|
}
|
|
|
|
if metrics["success_rate"].(float64) != 0.5 {
|
|
t.Errorf("Expected success rate 0.5, got %v", metrics["success_rate"])
|
|
}
|
|
}
|
|
|
|
func TestBaseRecoveryMechanismConcurrentUpdates(t *testing.T) {
|
|
base := NewBaseRecoveryMechanism("concurrent-test", nil)
|
|
|
|
var wg sync.WaitGroup
|
|
iterations := 1000
|
|
|
|
// Concurrent requests
|
|
for i := 0; i < iterations; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
base.RecordRequest()
|
|
if i%2 == 0 {
|
|
base.RecordSuccess()
|
|
} else {
|
|
base.RecordFailure()
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
metrics := base.GetBaseMetrics()
|
|
|
|
if metrics["total_requests"].(int64) != int64(iterations) {
|
|
t.Errorf("Expected %d requests, got %d", iterations, metrics["total_requests"])
|
|
}
|
|
|
|
totalSuccessesAndFailures := metrics["total_successes"].(int64) + metrics["total_failures"].(int64)
|
|
if totalSuccessesAndFailures != int64(iterations) {
|
|
t.Errorf("Expected %d total successes+failures, got %d", iterations, totalSuccessesAndFailures)
|
|
}
|
|
}
|
|
|
|
// Test Error Recovery Manager
|
|
|
|
func TestErrorRecoveryManagerCreation(t *testing.T) {
|
|
erm := NewErrorRecoveryManager(nil)
|
|
|
|
if erm == nil {
|
|
t.Fatal("Expected non-nil error recovery manager")
|
|
}
|
|
|
|
if erm.retryExecutor == nil {
|
|
t.Error("Expected retry executor to be initialized")
|
|
}
|
|
|
|
if erm.gracefulDegradation == nil {
|
|
t.Error("Expected graceful degradation to be initialized")
|
|
}
|
|
}
|
|
|
|
func TestErrorRecoveryManagerGetCircuitBreaker(t *testing.T) {
|
|
erm := NewErrorRecoveryManager(nil)
|
|
|
|
cb1 := erm.GetCircuitBreaker("service1")
|
|
cb2 := erm.GetCircuitBreaker("service1")
|
|
cb3 := erm.GetCircuitBreaker("service2")
|
|
|
|
if cb1 == nil || cb2 == nil || cb3 == nil {
|
|
t.Fatal("Expected non-nil circuit breakers")
|
|
}
|
|
|
|
// Should return same instance for same service
|
|
if cb1 != cb2 {
|
|
t.Error("Expected same circuit breaker instance for same service")
|
|
}
|
|
|
|
// Should return different instances for different services
|
|
if cb1 == cb3 {
|
|
t.Error("Expected different circuit breaker instances for different services")
|
|
}
|
|
}
|
|
|
|
func TestErrorRecoveryManagerExecuteWithRecovery(t *testing.T) {
|
|
erm := NewErrorRecoveryManager(nil)
|
|
|
|
success := false
|
|
err := erm.ExecuteWithRecovery(context.Background(), "test-service", func() error {
|
|
success = true
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
|
|
if !success {
|
|
t.Error("Expected function to execute")
|
|
}
|
|
}
|
|
|
|
func TestErrorRecoveryManagerMetrics(t *testing.T) {
|
|
erm := NewErrorRecoveryManager(nil)
|
|
|
|
// Create some circuit breakers
|
|
_ = erm.GetCircuitBreaker("service1")
|
|
_ = erm.GetCircuitBreaker("service2")
|
|
|
|
metrics := erm.GetRecoveryMetrics()
|
|
|
|
cbMetrics, ok := metrics["circuit_breakers"].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("Expected circuit_breakers in metrics")
|
|
}
|
|
|
|
if len(cbMetrics) != 2 {
|
|
t.Errorf("Expected 2 circuit breakers in metrics, got %d", len(cbMetrics))
|
|
}
|
|
}
|
|
|
|
// Helper functions and types
|
|
|
|
func circuitBreakerStateToString(state CircuitBreakerState) string {
|
|
switch state {
|
|
case CircuitBreakerClosed:
|
|
return "closed"
|
|
case CircuitBreakerOpen:
|
|
return "open"
|
|
case CircuitBreakerHalfOpen:
|
|
return "half-open"
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
// Mock network error for testing
|
|
type mockNetError struct {
|
|
timeout bool
|
|
temporary bool
|
|
msg string
|
|
}
|
|
|
|
func (e *mockNetError) Error() string { return e.msg }
|
|
func (e *mockNetError) Timeout() bool { return e.timeout }
|
|
func (e *mockNetError) Temporary() bool { return e.temporary }
|
|
|
|
// Ensure mockNetError implements net.Error
|
|
var _ net.Error = (*mockNetError)(nil)
|