Add allowed domains list.

This commit is contained in:
2024-09-17 09:12:05 +01:00
parent e97d8e15ff
commit 2fbca0a88c
5 changed files with 302 additions and 67 deletions
+226 -31
View File
@@ -120,36 +120,71 @@ func (suite *TraefikOidcTestSuite) SetupTest() {
}
func (suite *TraefikOidcTestSuite) TestServeHTTP_AuthenticatedUser() {
req := httptest.NewRequest("GET", "http://example.com", nil)
rw := httptest.NewRecorder()
session := sessions.NewSession(suite.mockStore, cookieName)
session.Values["authenticated"] = true
claims := map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
testCases := []struct {
name string
setupClaims func() map[string]interface{}
expectedStatus int
expectedBody string
expectedHeader string
}{
{
name: "Valid authenticated user",
setupClaims: func() map[string]interface{} {
return map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
"email": "user@example.com",
}
},
expectedStatus: http.StatusOK,
expectedBody: "OK",
expectedHeader: "user@example.com",
},
{
name: "Authenticated user without email",
setupClaims: func() map[string]interface{} {
return map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
},
expectedStatus: http.StatusFound,
expectedBody: "",
expectedHeader: "",
},
}
claimsJSON, _ := json.Marshal(claims)
encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON)
mockToken := fmt.Sprintf("header.%s.signature", encodedClaims)
session.Values["id_token"] = mockToken
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
for _, tc := range testCases {
suite.Run(tc.name, func() {
req := httptest.NewRequest("GET", "http://example.com", nil)
rw := httptest.NewRecorder()
suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil)
session := sessions.NewSession(suite.mockStore, cookieName)
session.Values["authenticated"] = true
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
claims := tc.setupClaims()
claimsJSON, _ := json.Marshal(claims)
encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON)
mockToken := fmt.Sprintf("header.%s.signature", encodedClaims)
session.Values["id_token"] = mockToken
suite.oidc.next = nextHandler
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.oidc.ServeHTTP(rw, req)
suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil)
suite.Equal(http.StatusOK, rw.Code)
suite.Equal("OK", rw.Body.String())
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
suite.oidc.next = nextHandler
suite.oidc.ServeHTTP(rw, req)
suite.Equal(tc.expectedStatus, rw.Code)
suite.Contains(rw.Body.String(), tc.expectedBody)
suite.Equal(tc.expectedHeader, req.Header.Get("X-Forwarded-User"))
})
}
}
func (suite *TraefikOidcTestSuite) TestServeHTTP_CallbackPath() {
@@ -533,18 +568,48 @@ func (suite *TraefikOidcTestSuite) TestServeHTTP_ExpiredToken() {
}
func (suite *TraefikOidcTestSuite) TestHandleCallback_InvalidState() {
req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=invalid_state", nil)
rw := httptest.NewRecorder()
testCases := []struct {
name string
setupSession func() *sessions.Session
expectedStatus int
expectedBody string
}{
{
name: "Invalid state",
setupSession: func() *sessions.Session {
session := sessions.NewSession(suite.mockStore, cookieName)
session.Values["csrf"] = "valid_state"
return session
},
expectedStatus: http.StatusFound,
expectedBody: "Authentication failed",
},
{
name: "No CSRF in session",
setupSession: func() *sessions.Session {
return sessions.NewSession(suite.mockStore, cookieName)
},
expectedStatus: http.StatusFound,
expectedBody: "Authentication failed",
},
}
session := sessions.NewSession(suite.mockStore, cookieName)
session.Values["csrf"] = "valid_state"
for _, tc := range testCases {
suite.Run(tc.name, func() {
req := httptest.NewRequest("GET", "http://example.com"+suite.oidc.redirURLPath+"?code=test_code&state=invalid_state", nil)
rw := httptest.NewRecorder()
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
session := tc.setupSession()
suite.oidc.ServeHTTP(rw, req)
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
suite.Equal(http.StatusBadRequest, rw.Code)
suite.Contains(rw.Body.String(), "Invalid state parameter")
suite.oidc.ServeHTTP(rw, req)
suite.Equal(tc.expectedStatus, rw.Code)
suite.Contains(rw.Body.String(), tc.expectedBody)
})
}
}
func (suite *TraefikOidcTestSuite) TestHandleCallback_TokenExchangeError() {
@@ -815,3 +880,133 @@ func (suite *TraefikOidcTestSuite) TestServeHTTP_ExcludedURLs() {
})
}
}
func (suite *TraefikOidcTestSuite) TestServeHTTP_DomainRestriction() {
testCases := []struct {
name string
allowedDomains map[string]struct{}
email string
expectedStatus int
expectedBody string
expectedHeader string
}{
{
name: "Allowed domain",
allowedDomains: map[string]struct{}{"example.com": {}},
email: "user@example.com",
expectedStatus: http.StatusOK,
expectedBody: "OK",
expectedHeader: "user@example.com",
},
{
name: "Not allowed domain",
allowedDomains: map[string]struct{}{"example.com": {}},
email: "user@notallowed.com",
expectedStatus: http.StatusForbidden,
expectedBody: "Access denied: Your email domain is not allowed",
expectedHeader: "",
},
{
name: "No domain restriction",
allowedDomains: map[string]struct{}{},
email: "user@anydomain.com",
expectedStatus: http.StatusOK,
expectedBody: "OK",
expectedHeader: "user@anydomain.com",
},
{
name: "No email claim",
allowedDomains: map[string]struct{}{"example.com": {}},
email: "",
expectedStatus: http.StatusFound,
expectedBody: "",
expectedHeader: "",
},
}
for _, tc := range testCases {
suite.Run(tc.name, func() {
suite.oidc.allowedUserDomains = tc.allowedDomains
req := httptest.NewRequest("GET", "http://example.com", nil)
rw := httptest.NewRecorder()
session := sessions.NewSession(suite.mockStore, cookieName)
session.Values["authenticated"] = true
claims := map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
if tc.email != "" {
claims["email"] = tc.email
}
claimsJSON, _ := json.Marshal(claims)
encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON)
mockToken := fmt.Sprintf("header.%s.signature", encodedClaims)
session.Values["id_token"] = mockToken
suite.mockStore.On("Get", req, cookieName).Return(session, nil).Once()
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil).Once()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
suite.oidc.next = nextHandler
suite.oidc.ServeHTTP(rw, req)
suite.Equal(tc.expectedStatus, rw.Code)
suite.Contains(rw.Body.String(), tc.expectedBody)
suite.Equal(tc.expectedHeader, req.Header.Get("X-Forwarded-User"))
suite.mockStore.AssertExpectations(suite.T())
suite.mockTokenVerifier.AssertExpectations(suite.T())
})
}
}
func (suite *TraefikOidcTestSuite) TestIsAllowedDomain() {
testCases := []struct {
name string
email string
allowedDomains map[string]struct{}
expectedAllowed bool
}{
{
name: "Allowed domain",
email: "user@example.com",
allowedDomains: map[string]struct{}{"example.com": {}, "test.com": {}},
expectedAllowed: true,
},
{
name: "Not allowed domain",
email: "user@notallowed.com",
allowedDomains: map[string]struct{}{"example.com": {}, "test.com": {}},
expectedAllowed: false,
},
{
name: "Empty allowed domains",
email: "user@anydomainallowed.com",
allowedDomains: map[string]struct{}{},
expectedAllowed: true,
},
{
name: "Invalid email format",
email: "invalidemail",
allowedDomains: map[string]struct{}{"example.com": {}},
expectedAllowed: false,
},
}
for _, tc := range testCases {
suite.Run(tc.name, func() {
suite.oidc.allowedUserDomains = tc.allowedDomains
allowed := suite.oidc.isAllowedDomain(tc.email)
suite.Equal(tc.expectedAllowed, allowed)
})
}
}