mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Re-introduce user roles separation with additional tests.
This commit is contained in:
@@ -430,6 +430,34 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed roles and groups
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
|
||||
+185
@@ -1342,6 +1342,191 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
claims map[string]interface{}
|
||||
setupSession func(*SessionData)
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "User with allowed role",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"admin", "user"},
|
||||
"groups": []interface{}{"group1"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "admin,user",
|
||||
"X-User-Groups": "group1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User with allowed group",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"allowed-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "allowed-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User without allowed roles or groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "No role/group restrictions",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "regular-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Claims without roles and groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with claims
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
// Create test handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Configure OIDC middleware
|
||||
tOidc := ts.tOidc
|
||||
tOidc.next = nextHandler
|
||||
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Set up session
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
tc.setupSession(session)
|
||||
session.SetAccessToken(token)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset response recorder
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Serve request
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
|
||||
// Check headers if status is OK
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
for header, expectedValue := range tc.expectedHeaders {
|
||||
if value := req.Header.Get(header); value != expectedValue {
|
||||
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
|
||||
Reference in New Issue
Block a user