diff --git a/.traefik.yml b/.traefik.yml index 0ce8122..11881e8 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -4,7 +4,7 @@ type: middleware import: github.com/lukaszraczylo/traefikoidc summary: | - Middleware adding OIDC authentication to traefik routes. + Middleware adding OIDC authentication to traefik routes. Does what it says on the tin. testData: providerURL: https://accounts.google.com @@ -16,10 +16,12 @@ testData: - openid - email - profile + allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no + - raczylo.com sessionEncryptionKey: potato-secret forceHTTPS: false logLevel: debug # debug, info, warn, error rateLimit: 100 # Simple rate limiter to prevent brute force attacks excludedURLs: # Determines the list of URLs which are NOT a subject to authentication - - /login + - /login # covers /login, /login/me, /login/reminder etc. - /my-public-data diff --git a/README.md b/README.md index c05b93b..1027c18 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,13 @@ This middleware is supposed to replace the need for the forward-auth and oauth2- ### Configuration options +Middleware currently supports following scenarios: + +* Setting custom callback and logout URLs via `callbackURL` and `logoutURL` +* Allowing for access only from the listed domains if `allowedUserDomains` is set, otherwise it relies entirely on the OIDC provider +* Using excluded URLs which do **NOT** require the OIDC authentication +* Rate limiting requests to prevent the bruteforce attacks + #### Docker compose example `docker-compose.yaml` @@ -94,6 +101,8 @@ http: - openid - email - profile + allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no + - raczylo.com sessionEncryptionKey: potato-secret forceHTTPS: false logLevel: debug # debug, info, warn, error diff --git a/main.go b/main.go index a98e549..c0dd21e 100644 --- a/main.go +++ b/main.go @@ -26,30 +26,31 @@ type JWTVerifier interface { } type TraefikOidc struct { - next http.Handler - name string - store sessions.Store - redirURLPath string - logoutURLPath string - issuerURL string - jwkCache *JWKCache - tokenBlacklist *TokenBlacklist - jwksURL string - clientID string - clientSecret string - authURL string - tokenURL string - scopes []string - limiter *rate.Limiter - forceHTTPS bool - scheme string - tokenCache *TokenCache - httpClient *http.Client - logger *Logger - redirectURL string - tokenVerifier TokenVerifier - jwtVerifier JWTVerifier - excludedURLs map[string]struct{} + next http.Handler + name string + store sessions.Store + redirURLPath string + logoutURLPath string + issuerURL string + jwkCache *JWKCache + tokenBlacklist *TokenBlacklist + jwksURL string + clientID string + clientSecret string + authURL string + tokenURL string + scopes []string + limiter *rate.Limiter + forceHTTPS bool + scheme string + tokenCache *TokenCache + httpClient *http.Client + logger *Logger + redirectURL string + tokenVerifier TokenVerifier + jwtVerifier JWTVerifier + excludedURLs map[string]struct{} + allowedUserDomains map[string]struct{} } type ProviderMetadata struct { @@ -191,6 +192,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return m }(), redirectURL: "", + allowedUserDomains: func() map[string]struct{} { + m := make(map[string]struct{}) + for _, domain := range config.AllowedUserDomains { + m[domain] = struct{}{} + } + return m + }(), } // add defaultExcludedURLs to excludedURLs for k, v := range defaultExcludedURLs { @@ -289,28 +297,33 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if authenticated { - if needsRefresh { - // Attempt to refresh the token silently - if refreshed := t.refreshToken(rw, req, session); !refreshed { - // If refresh failed, re-authenticate - t.initiateAuthentication(rw, req, session, t.redirectURL) - return - } - } - idToken, ok := session.Values["id_token"].(string) if !ok || idToken == "" { + t.logger.Errorf("No id_token found in session") + t.initiateAuthentication(rw, req, session, t.redirectURL) return } claims, err := extractClaims(idToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) + t.initiateAuthentication(rw, req, session, t.redirectURL) return } - // Add authenticated user email to the header X-Forwarded-User email, _ := claims["email"].(string) + if email == "" { + t.logger.Errorf("No email found in token claims") + t.initiateAuthentication(rw, req, session, t.redirectURL) + return + } + + if !t.isAllowedDomain(email) { + t.logger.Infof("User with email %s is not from an allowed domain", email) + http.Error(rw, "Access denied: Your email domain is not allowed", http.StatusForbidden) + return + } + req.Header.Set("X-Forwarded-User", email) t.next.ServeHTTP(rw, req) @@ -508,3 +521,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return true } + +func (t *TraefikOidc) isAllowedDomain(email string) bool { + if len(t.allowedUserDomains) == 0 { + return true // If no domains are specified, all are allowed + } + + parts := strings.Split(email, "@") + if len(parts) != 2 { + return false // Invalid email format + } + + domain := parts[1] + _, ok := t.allowedUserDomains[domain] + return ok +} diff --git a/main_test.go b/main_test.go index 85ce27c..72e2c2d 100644 --- a/main_test.go +++ b/main_test.go @@ -120,36 +120,71 @@ func (suite *TraefikOidcTestSuite) SetupTest() { } func (suite *TraefikOidcTestSuite) TestServeHTTP_AuthenticatedUser() { - req := httptest.NewRequest("GET", "http://example.com", nil) - rw := httptest.NewRecorder() - - session := sessions.NewSession(suite.mockStore, cookieName) - session.Values["authenticated"] = true - - claims := map[string]interface{}{ - "exp": float64(time.Now().Add(time.Hour).Unix()), + testCases := []struct { + name string + setupClaims func() map[string]interface{} + expectedStatus int + expectedBody string + expectedHeader string + }{ + { + name: "Valid authenticated user", + setupClaims: func() map[string]interface{} { + return map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + "email": "user@example.com", + } + }, + expectedStatus: http.StatusOK, + expectedBody: "OK", + expectedHeader: "user@example.com", + }, + { + name: "Authenticated user without email", + setupClaims: func() map[string]interface{} { + return map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + } + }, + expectedStatus: http.StatusFound, + expectedBody: "", + expectedHeader: "", + }, } - claimsJSON, _ := json.Marshal(claims) - encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON) - mockToken := fmt.Sprintf("header.%s.signature", encodedClaims) - session.Values["id_token"] = mockToken - suite.mockStore.On("Get", req, cookieName).Return(session, nil) - suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) + for _, tc := range testCases { + suite.Run(tc.name, func() { + req := httptest.NewRequest("GET", "http://example.com", nil) + rw := httptest.NewRecorder() - suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil) + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true - nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) - }) + claims := tc.setupClaims() + claimsJSON, _ := json.Marshal(claims) + encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON) + mockToken := fmt.Sprintf("header.%s.signature", encodedClaims) + session.Values["id_token"] = mockToken - suite.oidc.next = nextHandler + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) - suite.oidc.ServeHTTP(rw, req) + suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil) - suite.Equal(http.StatusOK, rw.Code) - suite.Equal("OK", rw.Body.String()) + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + suite.oidc.next = nextHandler + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(tc.expectedStatus, rw.Code) + suite.Contains(rw.Body.String(), tc.expectedBody) + suite.Equal(tc.expectedHeader, req.Header.Get("X-Forwarded-User")) + }) + } } func (suite *TraefikOidcTestSuite) TestServeHTTP_CallbackPath() { @@ -533,18 +568,48 @@ func (suite *TraefikOidcTestSuite) TestServeHTTP_ExpiredToken() { } func (suite *TraefikOidcTestSuite) TestHandleCallback_InvalidState() { - req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=invalid_state", nil) - rw := httptest.NewRecorder() + testCases := []struct { + name string + setupSession func() *sessions.Session + expectedStatus int + expectedBody string + }{ + { + name: "Invalid state", + setupSession: func() *sessions.Session { + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["csrf"] = "valid_state" + return session + }, + expectedStatus: http.StatusFound, + expectedBody: "Authentication failed", + }, + { + name: "No CSRF in session", + setupSession: func() *sessions.Session { + return sessions.NewSession(suite.mockStore, cookieName) + }, + expectedStatus: http.StatusFound, + expectedBody: "Authentication failed", + }, + } - session := sessions.NewSession(suite.mockStore, cookieName) - session.Values["csrf"] = "valid_state" + for _, tc := range testCases { + suite.Run(tc.name, func() { + req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=invalid_state", nil) + rw := httptest.NewRecorder() - suite.mockStore.On("Get", req, cookieName).Return(session, nil) + session := tc.setupSession() - suite.oidc.ServeHTTP(rw, req) + suite.mockStore.On("Get", req, cookieName).Return(session, nil) + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil) - suite.Equal(http.StatusBadRequest, rw.Code) - suite.Contains(rw.Body.String(), "Invalid state parameter") + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(tc.expectedStatus, rw.Code) + suite.Contains(rw.Body.String(), tc.expectedBody) + }) + } } func (suite *TraefikOidcTestSuite) TestHandleCallback_TokenExchangeError() { @@ -815,3 +880,133 @@ func (suite *TraefikOidcTestSuite) TestServeHTTP_ExcludedURLs() { }) } } + +func (suite *TraefikOidcTestSuite) TestServeHTTP_DomainRestriction() { + testCases := []struct { + name string + allowedDomains map[string]struct{} + email string + expectedStatus int + expectedBody string + expectedHeader string + }{ + { + name: "Allowed domain", + allowedDomains: map[string]struct{}{"example.com": {}}, + email: "user@example.com", + expectedStatus: http.StatusOK, + expectedBody: "OK", + expectedHeader: "user@example.com", + }, + { + name: "Not allowed domain", + allowedDomains: map[string]struct{}{"example.com": {}}, + email: "user@notallowed.com", + expectedStatus: http.StatusForbidden, + expectedBody: "Access denied: Your email domain is not allowed", + expectedHeader: "", + }, + { + name: "No domain restriction", + allowedDomains: map[string]struct{}{}, + email: "user@anydomain.com", + expectedStatus: http.StatusOK, + expectedBody: "OK", + expectedHeader: "user@anydomain.com", + }, + { + name: "No email claim", + allowedDomains: map[string]struct{}{"example.com": {}}, + email: "", + expectedStatus: http.StatusFound, + expectedBody: "", + expectedHeader: "", + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + suite.oidc.allowedUserDomains = tc.allowedDomains + + req := httptest.NewRequest("GET", "http://example.com", nil) + rw := httptest.NewRecorder() + + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + + claims := map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + } + if tc.email != "" { + claims["email"] = tc.email + } + claimsJSON, _ := json.Marshal(claims) + encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON) + mockToken := fmt.Sprintf("header.%s.signature", encodedClaims) + session.Values["id_token"] = mockToken + + suite.mockStore.On("Get", req, cookieName).Return(session, nil).Once() + suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + + suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil).Once() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + suite.oidc.next = nextHandler + + suite.oidc.ServeHTTP(rw, req) + + suite.Equal(tc.expectedStatus, rw.Code) + suite.Contains(rw.Body.String(), tc.expectedBody) + suite.Equal(tc.expectedHeader, req.Header.Get("X-Forwarded-User")) + + suite.mockStore.AssertExpectations(suite.T()) + suite.mockTokenVerifier.AssertExpectations(suite.T()) + }) + } +} + +func (suite *TraefikOidcTestSuite) TestIsAllowedDomain() { + testCases := []struct { + name string + email string + allowedDomains map[string]struct{} + expectedAllowed bool + }{ + { + name: "Allowed domain", + email: "user@example.com", + allowedDomains: map[string]struct{}{"example.com": {}, "test.com": {}}, + expectedAllowed: true, + }, + { + name: "Not allowed domain", + email: "user@notallowed.com", + allowedDomains: map[string]struct{}{"example.com": {}, "test.com": {}}, + expectedAllowed: false, + }, + { + name: "Empty allowed domains", + email: "user@anydomainallowed.com", + allowedDomains: map[string]struct{}{}, + expectedAllowed: true, + }, + { + name: "Invalid email format", + email: "invalidemail", + allowedDomains: map[string]struct{}{"example.com": {}}, + expectedAllowed: false, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + suite.oidc.allowedUserDomains = tc.allowedDomains + allowed := suite.oidc.isAllowedDomain(tc.email) + suite.Equal(tc.expectedAllowed, allowed) + }) + } +} diff --git a/settings.go b/settings.go index 3687d17..9052868 100644 --- a/settings.go +++ b/settings.go @@ -24,6 +24,7 @@ type Config struct { ForceHTTPS bool `json:"forceHTTPS"` RateLimit int `json:"rateLimit"` ExcludedURLs []string `json:"excludedURLs"` + AllowedUserDomains []string `json:"allowedUserDomains"` } func CreateConfig() *Config {