Compare commits

..

22 Commits

Author SHA1 Message Date
lukaszraczylo 0be45f06a5 Merge pull request #10 from lukaszraczylo/code-improvements
Code improvements
2024-10-09 09:34:45 +01:00
lukaszraczylo 2cc89f0f31 fixup! Update dependencies. 2024-10-09 09:33:20 +01:00
lukaszraczylo b9928f2f0c Update dependencies. 2024-10-09 09:32:25 +01:00
lukaszraczylo 2e1a3a9320 Add ability to verify default ECDSA keys provided by logto as well. 2024-10-09 09:30:42 +01:00
lukaszraczylo 9dabd0e5cf Revert "Update go mod dependencies."
This reverts commit dedbdf63c3.
2024-10-09 09:11:07 +01:00
lukaszraczylo dedbdf63c3 Update go mod dependencies. 2024-10-09 09:07:25 +01:00
lukaszraczylo af032c6cd3 Add simple benchmark to track the allocations and speed for future improvements. 2024-10-08 14:41:43 +01:00
lukaszraczylo 9938cff053 fixup! Cleanup and optimise the code. 2024-10-08 14:26:26 +01:00
lukaszraczylo 7a404ef76f Cleanup and optimise the code. 2024-10-08 14:14:47 +01:00
lukaszraczylo 63922f362f fixup! Add support for more algorithms. 2024-10-07 16:07:07 +01:00
lukaszraczylo 2de9297ab6 Add support for more algorithms. 2024-10-07 16:01:07 +01:00
lukaszraczylo 971c84f762 Abstract filling up maps. 2024-10-07 15:56:24 +01:00
lukaszraczylo d2a0d2167e Fix the bug with user not being redirected to originally requested URL post authentication. 2024-10-05 09:33:56 +01:00
lukaszraczylo c46d958397 Update documentation - setting secrets in kubernetes. 2024-10-04 17:15:43 +01:00
lukaszraczylo 95cf0034d6 Fix the tests hanging on the open channel. 2024-10-04 14:39:37 +01:00
lukaszraczylo 380ef96571 Improvement - startup time.
Previous implementations blocked the traefik startup until OIDC plugin was loaded.
This caused chicken-or-egg issue when called OIDC endpoint was hosted by the same traefik as well,
generating rather ridiculous situation when traefik couldn't come up because plugin tried to call the
discovery endpoint which was hosted by the same traefik.

This version resolves the issue allowing for quickstart and lazy loading of the provider metadata.
Disadvantage is - until discovery is done, the plugin will not provide any access to the client.
2024-10-04 14:22:16 +01:00
lukaszraczylo 1886396dc1 First step in improvement of caching mechanism. 2024-10-04 14:05:12 +01:00
lukaszraczylo 24ecf00053 Add support for roles and groups. 2024-10-03 19:48:43 +01:00
lukaszraczylo e338992f84 Update tests and additional fixups. 2024-10-03 14:00:47 +01:00
lukaszraczylo a9a596031b Update the tests to handle nonce 2024-10-03 14:00:44 +01:00
lukaszraczylo 23afcad2ba Support additional verification of the token to ensure OIDC compliance 2024-10-03 14:00:44 +01:00
lukaszraczylo d06f9fcf90 Update dependencies. 2024-10-03 14:00:44 +01:00
4 changed files with 70 additions and 616 deletions
+28 -55
View File
@@ -97,74 +97,47 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
return tokenResponse, nil
}
// handleLogout handles the logout process
func (t *TraefikOidc) handleLogout(w http.ResponseWriter, r *http.Request) {
session, err := t.store.Get(r, cookieName)
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil {
handleError(w, fmt.Sprintf("Error getting session: %v", err), http.StatusInternalServerError, t.logger)
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
// Get tokens from session
idToken, _ := session.Values["id_token"].(string)
refreshToken, _ := session.Values["refresh_token"].(string)
accessToken, _ := session.Values["access_token"].(string)
// Revoke tokens if they exist
if refreshToken != "" {
t.RevokeTokenWithProvider(refreshToken, "refresh_token")
// Revoke tokens if available
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil {
t.logger.Errorf("Failed to revoke refresh token: %v", err)
}
t.RevokeToken(refreshToken)
}
if accessToken != "" {
t.RevokeTokenWithProvider(accessToken, "access_token")
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil {
t.logger.Errorf("Failed to revoke access token: %v", err)
}
t.RevokeToken(accessToken)
}
// Clear session
// Remove tokens from session
delete(session.Values, "id_token")
delete(session.Values, "refresh_token")
delete(session.Values, "access_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = defaultSessionOptions
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if err := session.Save(r, w); err != nil {
handleError(w, fmt.Sprintf("Error saving session: %v", err), http.StatusInternalServerError, t.logger)
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
// Determine redirect URL
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
scheme := "http"
if r.Header.Get("X-Forwarded-Proto") == "https" || t.forceHTTPS {
scheme = "https"
}
baseURL := fmt.Sprintf("%s://%s/", scheme, host)
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, baseURL)
if err != nil {
handleError(w, fmt.Sprintf("Invalid end session URL: %v", err), http.StatusInternalServerError, t.logger)
return
}
http.Redirect(w, r, logoutURL, http.StatusFound)
return
}
http.Redirect(w, r, baseURL, http.StatusFound)
}
// BuildLogoutURL constructs the logout URL with proper encoding
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("invalid end session URL: %v", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
u.RawQuery = q.Encode()
return u.String(), nil
// Redirect or display logout message
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
+42 -57
View File
@@ -62,19 +62,17 @@ type TraefikOidc struct {
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
endSessionURL string
baseURL string
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
@@ -84,14 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -264,26 +254,20 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
t.initOnce.Do(func() {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
}
close(t.initComplete)
})
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -636,9 +620,14 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
}
// RevokeTokenWithProvider revokes the token with the provider
@@ -737,30 +726,26 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
var groups []string
var roles []string
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
}
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
-503
View File
@@ -820,506 +820,3 @@ func TestOIDCHandler(t *testing.T) {
})
}
}
// TestHandleLogout tests the logout functionality
func TestHandleLogout(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create mock revocation endpoint server
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify the required parameters are present
if r.Form.Get("token") == "" {
t.Error("Missing token parameter")
}
if r.Form.Get("token_type_hint") == "" {
t.Error("Missing token_type_hint parameter")
}
w.WriteHeader(http.StatusOK)
}))
defer mockRevocationServer.Close()
tests := []struct {
name string
setupSession func(*sessions.Session)
endSessionURL string
expectedStatus int
expectedURL string
host string
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
// Fix: The entire URL should be URL-encoded
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *sessions.Session) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
host: "test-host",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: NewLogger("info"),
tokenBlacklist: NewTokenBlacklist(),
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
}
// Create request with proper headers
req := httptest.NewRequest("GET", "/logout", nil)
req.Header.Set("Host", tc.host)
// Create a response recorder
rr := httptest.NewRecorder()
// Get a session
session, err := tOidc.store.Get(req, cookieName)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session
tc.setupSession(session)
session.Save(req, rr)
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Handle logout
tOidc.handleLogout(rr, req)
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check redirect URL if expected
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
}
}
// Verify session is cleared
newSession, _ := tOidc.store.Get(req, cookieName)
if len(newSession.Values) > 0 {
t.Error("Session was not cleared")
}
if newSession.Options.MaxAge != -1 {
t.Error("Session MaxAge was not set to -1")
}
// Check token blacklist
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
t.Error("Refresh token was not blacklisted")
}
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
t.Error("Access token was not blacklisted")
}
}
})
}
}
// TestRevokeTokenWithProvider tests the token revocation with provider
func TestRevokeTokenWithProvider(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
token string
tokenType string
statusCode int
expectError bool
}{
{
name: "Successful token revocation",
token: "valid-token",
tokenType: "refresh_token",
statusCode: http.StatusOK,
expectError: false,
},
{
name: "Failed token revocation",
token: "invalid-token",
tokenType: "refresh_token",
statusCode: http.StatusBadRequest,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method and content type
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
}
// Verify form values
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
if got := r.Form.Get("token"); got != tc.token {
t.Errorf("Expected token %s, got %s", tc.token, got)
}
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
}
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
}
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
}
w.WriteHeader(tc.statusCode)
}))
defer server.Close()
// Set revocation URL to test server
ts.tOidc.revocationURL = server.URL
// Test token revocation
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
if tc.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !tc.expectError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
}
// TestRevokeToken tests the token revocation functionality
func TestRevokeToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
token := "test.token.with.claims"
claims := map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
// Test token revocation
t.Run("Token revocation", func(t *testing.T) {
// Create a new instance for this specific test
tOidc := &TraefikOidc{
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
}
// Cache the token
tOidc.tokenCache.Set(token, claims, time.Hour)
// Revoke the token
tOidc.RevokeToken(token)
// Verify token was removed from cache
if _, exists := tOidc.tokenCache.Get(token); exists {
t.Error("Token was not removed from cache")
}
// Verify token was added to blacklist
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Token was not added to blacklist")
}
})
}
// Add this new test function
func TestBuildLogoutURL(t *testing.T) {
tests := []struct {
name string
endSessionURL string
idToken string
postLogoutRedirect string
expectedURL string
expectError bool
}{
{
name: "Valid URL",
endSessionURL: "https://provider/end-session",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
{
name: "Invalid URL",
endSessionURL: "://invalid-url",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectError: true,
},
{
name: "URL with existing query parameters",
endSessionURL: "https://provider/end-session?existing=param",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if url != tc.expectedURL {
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
}
}
})
}
}
// Add this new test function
func TestHandleExpiredToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupSession func(*sessions.Session)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["email"] = "test@example.com"
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: NewLogger("info"),
redirectURL: "http://example.com/callback",
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
// Add this initialization of initiateAuthenticationFunc
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Mock implementation for test
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
close(tOidc.initComplete)
// Create request
req := httptest.NewRequest("GET", tc.expectedPath, nil)
rr := httptest.NewRecorder()
// Get session
session, _ := tOidc.store.New(req, cookieName)
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session)
// Verify session is cleaned
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
t.Errorf("Expected 3 session values, got %d", len(session.Values))
}
// Verify required values are set
if _, ok := session.Values["csrf"].(string); !ok {
t.Error("CSRF token not set")
}
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if _, ok := session.Values["nonce"].(string); !ok {
t.Error("Nonce not set")
}
// Verify session options
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
t.Error("Session MaxAge not set correctly")
}
// Verify redirect status
if rr.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
}
})
}
}
// Add this new test function
func TestExtractGroupsAndRoles(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
claims map[string]interface{}
expectGroups []string
expectRoles []string
expectError bool
}{
{
name: "Valid groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{"group1", "group2"},
"roles": []interface{}{"role1", "role2"},
},
expectGroups: []string{"group1", "group2"},
expectRoles: []string{"role1", "role2"},
expectError: false,
},
{
name: "Empty groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{},
"roles": []interface{}{},
},
expectGroups: []string{},
expectRoles: []string{},
expectError: false,
},
{
name: "Invalid groups format",
claims: map[string]interface{}{
"groups": "not-an-array",
"roles": []interface{}{"role1"},
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a test token with the claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Compare groups
if !stringSliceEqual(groups, tc.expectGroups) {
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
}
// Compare roles
if !stringSliceEqual(roles, tc.expectRoles) {
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
-1
View File
@@ -30,7 +30,6 @@ type Config struct {
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
HTTPClient *http.Client
}