From 11bc6f3e3142ed4ae33d502fe900609a2437d9f2 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 11 Dec 2024 09:08:50 +0000 Subject: [PATCH] Re-introduce user roles separation with additional tests. --- main.go | 28 ++++++++ main_test.go | 185 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/main.go b/main.go index cbe0628..739be11 100644 --- a/main.go +++ b/main.go @@ -430,6 +430,34 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken()) + if err != nil { + t.logger.Errorf("Failed to extract groups and roles: %v", err) + } else { + if len(groups) > 0 { + req.Header.Set("X-User-Groups", strings.Join(groups, ",")) + } + if len(roles) > 0 { + req.Header.Set("X-User-Roles", strings.Join(roles, ",")) + } + } + + // Check allowed roles and groups + if len(t.allowedRolesAndGroups) > 0 { + allowed := false + for _, roleOrGroup := range append(groups, roles...) { + if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok { + allowed = true + break + } + } + if !allowed { + t.logger.Infof("User with email %s does not have any allowed roles or groups", email) + http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) + return + } + } + // Set user information in headers req.Header.Set("X-Forwarded-User", email) diff --git a/main_test.go b/main_test.go index b41eea8..a151b9b 100644 --- a/main_test.go +++ b/main_test.go @@ -1342,6 +1342,191 @@ func TestExtractGroupsAndRoles(t *testing.T) { } } +func TestServeHTTPRolesAndGroups(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + allowedRolesAndGroups map[string]struct{} + claims map[string]interface{} + setupSession func(*SessionData) + expectedStatus int + expectedHeaders map[string]string + }{ + { + name: "User with allowed role", + allowedRolesAndGroups: map[string]struct{}{ + "admin": {}, + }, + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "sub": "test-subject", + "roles": []interface{}{"admin", "user"}, + "groups": []interface{}{"group1"}, + }, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + }, + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "X-User-Roles": "admin,user", + "X-User-Groups": "group1", + }, + }, + { + name: "User with allowed group", + allowedRolesAndGroups: map[string]struct{}{ + "allowed-group": {}, + }, + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "sub": "test-subject", + "roles": []interface{}{"user"}, + "groups": []interface{}{"allowed-group"}, + }, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + }, + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "X-User-Roles": "user", + "X-User-Groups": "allowed-group", + }, + }, + { + name: "User without allowed roles or groups", + allowedRolesAndGroups: map[string]struct{}{ + "admin": {}, + "allowed-group": {}, + }, + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "sub": "test-subject", + "roles": []interface{}{"user"}, + "groups": []interface{}{"regular-group"}, + }, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + }, + expectedStatus: http.StatusForbidden, + }, + { + name: "No role/group restrictions", + allowedRolesAndGroups: map[string]struct{}{}, + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "sub": "test-subject", + "roles": []interface{}{"user"}, + "groups": []interface{}{"regular-group"}, + }, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + }, + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{ + "X-User-Roles": "user", + "X-User-Groups": "regular-group", + }, + }, + { + name: "Claims without roles and groups", + allowedRolesAndGroups: map[string]struct{}{}, + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "sub": "test-subject", + }, + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + }, + expectedStatus: http.StatusOK, + expectedHeaders: map[string]string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create token with claims + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims) + if err != nil { + t.Fatalf("Failed to create test token: %v", err) + } + + // Create test handler + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Configure OIDC middleware + tOidc := ts.tOidc + tOidc.next = nextHandler + tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups + + // Create request + req := httptest.NewRequest("GET", "/protected", nil) + rr := httptest.NewRecorder() + + // Set up session + session, err := tOidc.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + tc.setupSession(session) + session.SetAccessToken(token) + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Copy cookies to the new request + for _, cookie := range rr.Result().Cookies() { + req.AddCookie(cookie) + } + + // Reset response recorder + rr = httptest.NewRecorder() + + // Serve request + tOidc.ServeHTTP(rr, req) + + // Check status code + if rr.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) + } + + // Check headers if status is OK + if tc.expectedStatus == http.StatusOK { + for header, expectedValue := range tc.expectedHeaders { + if value := req.Header.Get(header); value != expectedValue { + t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value) + } + } + } + }) + } +} + // Helper function to compare string slices func stringSliceEqual(a, b []string) bool { if len(a) != len(b) {