mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Add tests
This commit is contained in:
+6
-2
@@ -78,6 +78,9 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusForbidden)
|
||||
rw.Write([]byte("Logged out"))
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
|
||||
@@ -97,15 +100,16 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
|
||||
|
||||
code := req.URL.Query().Get("code")
|
||||
redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath)
|
||||
|
||||
oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to exchange token", http.StatusInternalServerError, t.logger)
|
||||
handleError(rw, "Failed to exchange token", http.StatusUnauthorized, t.logger)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token["id_token"].(string)
|
||||
if !ok {
|
||||
handleError(rw, "No id_token field in oauth2 token", http.StatusInternalServerError, t.logger)
|
||||
handleError(rw, "No id_token field in oauth2 token", http.StatusUnauthorized, t.logger)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
|
||||
@@ -132,57 +132,11 @@ func decodeSegment(seg string) (map[string]interface{}, error) {
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) verifyAndCacheToken(token string) error {
|
||||
t.logger.Debugf("Verifying token")
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
if t.tokenBlacklist.IsBlacklisted(token) {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
if err := t.verifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
t.tokenCache.Set(token, expirationTime)
|
||||
|
||||
return nil
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
|
||||
publicKeyPEM, err := getPublicKeyPEM(jwks, kid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifySignature(token, publicKeyPEM); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
return jwt.Verify(t.issuerURL, t.clientID)
|
||||
return t.jwtVerifier.VerifyJWTSignatureAndClaims(jwt, token)
|
||||
}
|
||||
|
||||
func getPublicKeyPEM(jwks *JWKSet, kid string) ([]byte, error) {
|
||||
|
||||
@@ -14,6 +14,14 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
type JWTVerifier interface {
|
||||
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
}
|
||||
|
||||
type TraefikOidc struct {
|
||||
next http.Handler
|
||||
name string
|
||||
@@ -36,6 +44,8 @@ type TraefikOidc struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
redirectURL string
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
}
|
||||
|
||||
type ProviderMetadata struct {
|
||||
@@ -45,6 +55,59 @@ type ProviderMetadata struct {
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
t.logger.Debugf("Verifying token")
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
if t.tokenBlacklist.IsBlacklisted(token) {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
t.tokenCache.Set(token, expirationTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
|
||||
publicKeyPEM, err := getPublicKeyPEM(jwks, kid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifySignature(token, publicKeyPEM); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
return jwt.Verify(t.issuerURL, t.clientID)
|
||||
}
|
||||
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
|
||||
store.Options = &sessions.Options{
|
||||
@@ -82,6 +145,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
logger: NewLogger(config.LogLevel),
|
||||
redirectURL: "",
|
||||
}
|
||||
|
||||
t.tokenVerifier = t
|
||||
t.jwtVerifier = t
|
||||
t.startTokenCleanup()
|
||||
return t, nil
|
||||
}
|
||||
@@ -92,6 +158,9 @@ func discoverProviderMetadata(providerURL string, httpClient http.Client) (*Prov
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("received nil response from provider")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
@@ -150,32 +219,23 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
scheme := req.URL.Scheme
|
||||
if scheme == "" {
|
||||
scheme = req.Header.Get("X-Forwarded-Proto")
|
||||
}
|
||||
if scheme == "" {
|
||||
if req.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
if t.forceHTTPS {
|
||||
scheme = "https"
|
||||
return "https"
|
||||
}
|
||||
return scheme
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
host := req.URL.Host
|
||||
if host == "" {
|
||||
host = req.Header.Get("X-Forwarded-Host")
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
}
|
||||
if host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
return req.Host
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool {
|
||||
@@ -232,7 +292,7 @@ func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.R
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.verifyAndCacheToken(token)
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
|
||||
+586
@@ -0,0 +1,586 @@
|
||||
// main_test.go
|
||||
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type MockHTTPClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
args := m.Called(req)
|
||||
return args.Get(0).(*http.Response), args.Error(1)
|
||||
}
|
||||
|
||||
type MockSessionStore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
||||
args := m.Called(r, name)
|
||||
if session, ok := args.Get(0).(*sessions.Session); ok {
|
||||
return session, args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
||||
args := m.Called(r, name)
|
||||
return args.Get(0).(*sessions.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) Save(r *http.Request, w http.ResponseWriter, s *sessions.Session) error {
|
||||
args := m.Called(r, w, s)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockTokenVerifier struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockTokenVerifier) VerifyToken(token string) error {
|
||||
args := m.Called(token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockJWTVerifier struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
args := m.Called(jwt, token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type TraefikOidcTestSuite struct {
|
||||
suite.Suite
|
||||
oidc *TraefikOidc
|
||||
mockHTTPClient *MockHTTPClient
|
||||
mockStore *MockSessionStore
|
||||
mockTokenVerifier *MockTokenVerifier
|
||||
mockJWTVerifier *MockJWTVerifier
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) SetupTest() {
|
||||
suite.mockHTTPClient = new(MockHTTPClient)
|
||||
suite.mockStore = new(MockSessionStore)
|
||||
suite.mockTokenVerifier = new(MockTokenVerifier)
|
||||
suite.mockJWTVerifier = new(MockJWTVerifier)
|
||||
|
||||
config := &Config{
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
CallbackURL: "/callback",
|
||||
LogoutURL: "/logout",
|
||||
SessionEncryptionKey: "test-encryption-key",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
}
|
||||
|
||||
suite.oidc = &TraefikOidc{
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
redirURLPath: config.CallbackURL,
|
||||
logoutURLPath: config.LogoutURL,
|
||||
store: suite.mockStore,
|
||||
httpClient: &http.Client{Transport: suite.mockHTTPClient},
|
||||
jwkCache: &JWKCache{},
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
tokenCache: NewTokenCache(),
|
||||
logger: NewLogger("debug"),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
||||
authURL: "https://example.com/auth",
|
||||
tokenURL: "https://example.com/token",
|
||||
jwksURL: "https://example.com/.well-known/jwks.json",
|
||||
tokenVerifier: suite.mockTokenVerifier,
|
||||
jwtVerifier: suite.mockJWTVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestServeHTTP_AuthenticatedUser() {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["authenticated"] = true
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
}
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
mockToken := fmt.Sprintf("header.%s.signature", encodedClaims)
|
||||
session.Values["id_token"] = mockToken
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil)
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
suite.oidc.next = nextHandler
|
||||
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
|
||||
suite.Equal(http.StatusOK, rw.Code)
|
||||
suite.Equal("OK", rw.Body.String())
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestServeHTTP_CallbackPath() {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=test_state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["csrf"] = "test_state"
|
||||
session.Values["incoming_path"] = "/original_path"
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"email": "test@example.com",
|
||||
}
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
mockToken := fmt.Sprintf("header.%s.signature", encodedClaims)
|
||||
|
||||
tokenResponse := map[string]interface{}{
|
||||
"id_token": mockToken,
|
||||
}
|
||||
tokenResponseJSON, _ := json.Marshal(tokenResponse)
|
||||
|
||||
suite.mockHTTPClient.On("RoundTrip", mock.MatchedBy(func(req *http.Request) bool {
|
||||
return strings.Contains(req.URL.String(), "token")
|
||||
})).Return(&http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(string(tokenResponseJSON))),
|
||||
}, nil)
|
||||
|
||||
suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil)
|
||||
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
|
||||
suite.Equal(http.StatusFound, rw.Code)
|
||||
suite.Equal("/original_path", rw.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestVerifyToken() {
|
||||
token := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Rfa2lkIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE1MTYyMzkxMjJ9.ZmFrZV9zaWduYXR1cmU"
|
||||
|
||||
suite.mockTokenVerifier.On("VerifyToken", token).Return(nil)
|
||||
|
||||
err := suite.oidc.verifyToken(token)
|
||||
suite.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestBuildAuthURL() {
|
||||
authURL := suite.oidc.buildAuthURL("http://example.com/callback", "test_state", "test_nonce")
|
||||
suite.Contains(authURL, suite.oidc.authURL)
|
||||
suite.Contains(authURL, "client_id="+suite.oidc.clientID)
|
||||
suite.Contains(authURL, "redirect_uri=http%3A%2F%2Fexample.com%2Fcallback")
|
||||
suite.Contains(authURL, "state=test_state")
|
||||
suite.Contains(authURL, "nonce=test_nonce")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestJWKToPEM() {
|
||||
jwk := &JWK{
|
||||
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
|
||||
}
|
||||
pem, err := jwkToPEM(jwk)
|
||||
suite.Require().NoError(err)
|
||||
suite.NotEmpty(pem)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestTokenBlacklist() {
|
||||
tb := NewTokenBlacklist()
|
||||
token := "test_token"
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
|
||||
tb.Add(token, expiration)
|
||||
suite.True(tb.IsBlacklisted(token))
|
||||
|
||||
tb.Cleanup()
|
||||
suite.True(tb.IsBlacklisted(token))
|
||||
|
||||
tb.Add("expired_token", time.Now().Add(-time.Hour))
|
||||
tb.Cleanup()
|
||||
suite.False(tb.IsBlacklisted("expired_token"))
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestTokenCache() {
|
||||
tc := NewTokenCache()
|
||||
token := "test_token"
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
|
||||
tc.Set(token, expiration)
|
||||
info, exists := tc.Get(token)
|
||||
suite.True(exists)
|
||||
suite.Equal(token, info.Token)
|
||||
suite.Equal(expiration, info.ExpiresAt)
|
||||
|
||||
tc.Delete(token)
|
||||
_, exists = tc.Get(token)
|
||||
suite.False(exists)
|
||||
|
||||
tc.Set("expired_token", time.Now().Add(-time.Hour))
|
||||
tc.Cleanup()
|
||||
_, exists = tc.Get("expired_token")
|
||||
suite.False(exists)
|
||||
}
|
||||
|
||||
func TestTraefikOidcSuite(t *testing.T) {
|
||||
suite.Run(t, new(TraefikOidcTestSuite))
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestGenerateNonce() {
|
||||
nonce, err := generateNonce()
|
||||
suite.NoError(err)
|
||||
suite.Len(nonce, 44) // Base64 encoded 32 bytes
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestBuildFullURL() {
|
||||
url := buildFullURL("https", "example.com", "/path")
|
||||
suite.Equal("https://example.com/path", url)
|
||||
|
||||
url = buildFullURL("", "example.com", "/path")
|
||||
suite.Equal("http://example.com/path", url)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestExchangeCodeForToken() {
|
||||
ctx := context.Background()
|
||||
code := "test_code"
|
||||
redirectURL := "http://example.com/callback"
|
||||
|
||||
expectedToken := map[string]interface{}{
|
||||
"access_token": "test_access_token",
|
||||
"id_token": "test_id_token",
|
||||
}
|
||||
tokenJSON, _ := json.Marshal(expectedToken)
|
||||
|
||||
suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(tokenJSON)),
|
||||
}, nil).Once()
|
||||
|
||||
token, err := suite.oidc.exchangeCodeForToken(ctx, code, redirectURL)
|
||||
suite.NoError(err)
|
||||
suite.Equal(expectedToken, token)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestHandleLogout() {
|
||||
req := httptest.NewRequest("GET", "http://example.com/logout", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["id_token"] = "test_token"
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
suite.oidc.handleLogout(rw, req)
|
||||
|
||||
suite.Equal(http.StatusForbidden, rw.Code)
|
||||
suite.Equal("Logged out", rw.Body.String())
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestExtractClaims() {
|
||||
tokenString := "header.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.signature"
|
||||
claims, err := extractClaims(tokenString)
|
||||
suite.NoError(err)
|
||||
suite.Equal("1234567890", claims["sub"])
|
||||
suite.Equal("John Doe", claims["name"])
|
||||
suite.Equal(float64(1516239022), claims["iat"])
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestDiscoverProviderMetadata() {
|
||||
providerURL := "https://example.com"
|
||||
expectedMetadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/auth",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/.well-known/jwks.json",
|
||||
}
|
||||
metadataJSON, _ := json.Marshal(expectedMetadata)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: suite.mockHTTPClient,
|
||||
}
|
||||
|
||||
suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(metadataJSON)),
|
||||
}, nil)
|
||||
|
||||
metadata, err := discoverProviderMetadata(providerURL, *httpClient)
|
||||
suite.NoError(err)
|
||||
suite.Equal(expectedMetadata, metadata)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestDetermineScheme() {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
scheme := suite.oidc.determineScheme(req)
|
||||
suite.Equal("http", scheme)
|
||||
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
scheme = suite.oidc.determineScheme(req)
|
||||
suite.Equal("https", scheme)
|
||||
|
||||
suite.oidc.forceHTTPS = true
|
||||
scheme = suite.oidc.determineScheme(req)
|
||||
suite.Equal("https", scheme)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestDetermineHost() {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
host := suite.oidc.determineHost(req)
|
||||
suite.Equal("example.com", host)
|
||||
|
||||
req.Header.Set("X-Forwarded-Host", "forwarded.example.com")
|
||||
host = suite.oidc.determineHost(req)
|
||||
suite.Equal("forwarded.example.com", host)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() {
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature"
|
||||
|
||||
suite.mockTokenVerifier.On("VerifyToken", "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature").Return(nil)
|
||||
|
||||
authenticated := suite.oidc.isUserAuthenticated(session)
|
||||
suite.True(authenticated)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestInitiateAuthentication() {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
suite.oidc.initiateAuthentication(rw, req, session, "http://example.com/callback")
|
||||
|
||||
suite.Equal(http.StatusFound, rw.Code)
|
||||
location := rw.Header().Get("Location")
|
||||
suite.Contains(location, suite.oidc.authURL)
|
||||
suite.Contains(location, "redirect_uri=http%3A%2F%2Fexample.com%2Fcallback")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestRevokeToken() {
|
||||
token := "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature"
|
||||
suite.oidc.RevokeToken(token)
|
||||
|
||||
_, exists := suite.oidc.tokenCache.Get(token)
|
||||
suite.False(exists)
|
||||
suite.True(suite.oidc.tokenBlacklist.IsBlacklisted(token))
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestServeHTTP_InvalidSession() {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return((*sessions.Session)(nil), fmt.Errorf("invalid session"))
|
||||
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
|
||||
suite.Equal(http.StatusInternalServerError, rw.Code)
|
||||
suite.Contains(rw.Body.String(), "Session error")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestServeHTTP_ExpiredToken() {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "expired.eyJleHAiOjF9.signature" // expired token
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
|
||||
suite.Equal(http.StatusFound, rw.Code) // Should redirect to authentication
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestHandleCallback_InvalidState() {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=invalid_state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["csrf"] = "valid_state"
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
|
||||
suite.Equal(http.StatusBadRequest, rw.Code)
|
||||
suite.Contains(rw.Body.String(), "Invalid state parameter")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestHandleCallback_TokenExchangeError() {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=invalid_code&state=test_state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["csrf"] = "test_state"
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error": "invalid_grant"}`)),
|
||||
}, nil)
|
||||
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
|
||||
suite.Equal(http.StatusUnauthorized, rw.Code)
|
||||
suite.Contains(rw.Body.String(), "Authentication failed")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestVerifyToken_RateLimitExceeded() {
|
||||
suite.oidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 1) // Set a very low limit
|
||||
|
||||
// Use up the only allowed request
|
||||
suite.oidc.limiter.Allow()
|
||||
|
||||
err := suite.oidc.VerifyToken("some_token")
|
||||
suite.Error(err)
|
||||
suite.Contains(err.Error(), "rate limit exceeded")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestVerifyToken_BlacklistedToken() {
|
||||
token := "blacklisted_token"
|
||||
suite.oidc.tokenBlacklist.Add(token, time.Now().Add(time.Hour))
|
||||
|
||||
err := suite.oidc.VerifyToken(token)
|
||||
suite.Error(err)
|
||||
suite.Contains(err.Error(), "token is blacklisted")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestExtractClaims_InvalidToken() {
|
||||
invalidToken := "invalid.token.format"
|
||||
claims, err := extractClaims(invalidToken)
|
||||
suite.Error(err)
|
||||
suite.Nil(claims)
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestDiscoverProviderMetadata_HTTPError() {
|
||||
providerURL := "https://example.com"
|
||||
httpClient := &http.Client{
|
||||
Transport: suite.mockHTTPClient,
|
||||
}
|
||||
|
||||
suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader("Internal Server Error")),
|
||||
}, nil)
|
||||
|
||||
metadata, err := discoverProviderMetadata(providerURL, *httpClient)
|
||||
suite.Error(err)
|
||||
suite.Nil(metadata)
|
||||
suite.Contains(err.Error(), "failed to fetch provider metadata: status code 500")
|
||||
}
|
||||
|
||||
func (suite *TraefikOidcTestSuite) TestRevokeToken_InvalidToken() {
|
||||
invalidToken := "invalid.token"
|
||||
suite.oidc.RevokeToken(invalidToken)
|
||||
|
||||
// Check that the invalid token is not added to the blacklist
|
||||
suite.False(suite.oidc.tokenBlacklist.IsBlacklisted(invalidToken))
|
||||
}
|
||||
|
||||
func TestTraefikOidc_ServeHTTP(t *testing.T) {
|
||||
type fields struct {
|
||||
next http.Handler
|
||||
name string
|
||||
store sessions.Store
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
issuerURL string
|
||||
jwkCache *JWKCache
|
||||
tokenBlacklist *TokenBlacklist
|
||||
jwksURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
authURL string
|
||||
tokenURL string
|
||||
scopes []string
|
||||
limiter *rate.Limiter
|
||||
forceHTTPS bool
|
||||
scheme string
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
redirectURL string
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
}
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
req *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
next: tt.fields.next,
|
||||
name: tt.fields.name,
|
||||
store: tt.fields.store,
|
||||
redirURLPath: tt.fields.redirURLPath,
|
||||
logoutURLPath: tt.fields.logoutURLPath,
|
||||
issuerURL: tt.fields.issuerURL,
|
||||
jwkCache: tt.fields.jwkCache,
|
||||
tokenBlacklist: tt.fields.tokenBlacklist,
|
||||
jwksURL: tt.fields.jwksURL,
|
||||
clientID: tt.fields.clientID,
|
||||
clientSecret: tt.fields.clientSecret,
|
||||
authURL: tt.fields.authURL,
|
||||
tokenURL: tt.fields.tokenURL,
|
||||
scopes: tt.fields.scopes,
|
||||
limiter: tt.fields.limiter,
|
||||
forceHTTPS: tt.fields.forceHTTPS,
|
||||
scheme: tt.fields.scheme,
|
||||
tokenCache: tt.fields.tokenCache,
|
||||
httpClient: tt.fields.httpClient,
|
||||
logger: tt.fields.logger,
|
||||
redirectURL: tt.fields.redirectURL,
|
||||
tokenVerifier: tt.fields.tokenVerifier,
|
||||
jwtVerifier: tt.fields.jwtVerifier,
|
||||
}
|
||||
tr.ServeHTTP(tt.args.rw, tt.args.req)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user