mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bef4212c57 | |||
| 1fee2f9e9a | |||
| 11bc6f3e31 |
@@ -430,6 +430,34 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed roles and groups
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
|
||||
+185
@@ -1342,6 +1342,191 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
claims map[string]interface{}
|
||||
setupSession func(*SessionData)
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "User with allowed role",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"admin", "user"},
|
||||
"groups": []interface{}{"group1"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "admin,user",
|
||||
"X-User-Groups": "group1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User with allowed group",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"allowed-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "allowed-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User without allowed roles or groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "No role/group restrictions",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "regular-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Claims without roles and groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with claims
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
// Create test handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Configure OIDC middleware
|
||||
tOidc := ts.tOidc
|
||||
tOidc.next = nextHandler
|
||||
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Set up session
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
tc.setupSession(session)
|
||||
session.SetAccessToken(token)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset response recorder
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Serve request
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
|
||||
// Check headers if status is OK
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
for header, expectedValue := range tc.expectedHeaders {
|
||||
if value := req.Header.Get(header); value != expectedValue {
|
||||
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
|
||||
+170
-11
@@ -12,6 +12,20 @@ const (
|
||||
mainCookieName = "_raczylo_oidc" // Main session cookie
|
||||
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
|
||||
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
|
||||
maxCookieSize = 2000 // Max size for each chunk to stay within 4096-byte cookie limit
|
||||
|
||||
// REASON:
|
||||
// Let x be the maximum size of the chunk (maxCookieSize).
|
||||
// Encrypted size = x + 28 bytes
|
||||
// Base64-encoded size = ((x + 28) * 4) / 3 bytes
|
||||
// ((x + 28) * 4) / 3 <= 4096
|
||||
// Multiply both sides by 3:
|
||||
// 4 * (x + 28) <= 4096 * 3
|
||||
// 4 * (x + 28) <= 12288
|
||||
// Divide both sides by 4:
|
||||
// x + 28 <= 3072
|
||||
// Subtract 28 from both sides:
|
||||
// x <= 3044
|
||||
)
|
||||
|
||||
// SessionManager handles multiple session cookies
|
||||
@@ -60,20 +74,44 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
|
||||
sessionData := &SessionData{
|
||||
manager: sm,
|
||||
request: r,
|
||||
mainSession: mainSession,
|
||||
accessSession: accessSession,
|
||||
refreshSession: refreshSession,
|
||||
}
|
||||
|
||||
// Retrieve chunked access token sessions
|
||||
sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie)
|
||||
// Retrieve chunked refresh token sessions
|
||||
sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie)
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves sessions for token chunks
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
session, err := sm.store.Get(r, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
// No more sessions
|
||||
break
|
||||
}
|
||||
chunks[i] = session
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// SessionData holds all session information
|
||||
type SessionData struct {
|
||||
manager *SessionManager
|
||||
mainSession *sessions.Session
|
||||
accessSession *sessions.Session
|
||||
refreshSession *sessions.Session
|
||||
manager *SessionManager
|
||||
request *http.Request
|
||||
mainSession *sessions.Session
|
||||
accessSession *sessions.Session
|
||||
refreshSession *sessions.Session
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
}
|
||||
|
||||
// Save saves all session data
|
||||
@@ -81,20 +119,42 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
// Set options for all sessions
|
||||
sd.mainSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
sd.accessSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
sd.refreshSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
options := sd.manager.getSessionOptions(isSecure)
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
|
||||
// Save main session
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token session
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
|
||||
// Save refresh token session
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token chunks
|
||||
for _, session := range sd.accessTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save refresh token chunks
|
||||
for _, session := range sd.refreshTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -115,9 +175,23 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
|
||||
// Clear chunk sessions
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
|
||||
return sd.Save(r, w)
|
||||
}
|
||||
|
||||
// clearTokenChunks clears chunked token sessions
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns authentication status
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
@@ -132,23 +206,108 @@ func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
// GetAccessToken returns the access token
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
return token
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
if len(sd.accessTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.accessTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return strings.Join(chunks, "")
|
||||
}
|
||||
|
||||
// SetAccessToken sets the access token
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
sd.accessSession.Values["token"] = token
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if len(token) <= maxCookieSize {
|
||||
sd.accessSession.Values["token"] = token
|
||||
} else {
|
||||
// Split token into chunks
|
||||
sd.accessSession.Values["token"] = ""
|
||||
chunks := splitIntoChunks(token, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.accessTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken returns the refresh token
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
return token
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
if len(sd.refreshTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.refreshTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return strings.Join(chunks, "")
|
||||
}
|
||||
|
||||
// SetRefreshToken sets the refresh token
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
sd.refreshSession.Values["token"] = token
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if len(token) <= maxCookieSize {
|
||||
sd.refreshSession.Values["token"] = token
|
||||
} else {
|
||||
// Split token into chunks
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
chunks := splitIntoChunks(token, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.refreshTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
if len(s) > chunkSize {
|
||||
chunks = append(chunks, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
} else {
|
||||
chunks = append(chunks, s)
|
||||
break
|
||||
}
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF returns the CSRF token
|
||||
|
||||
+111
-42
@@ -2,59 +2,128 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSessionManager tests the SessionManager functionality
|
||||
func TestSessionManager(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
manager := NewSessionManager("test-secret-key", false, logger)
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := manager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticated bool
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
expectedCookieCount int
|
||||
}{
|
||||
{
|
||||
name: "Short tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: "shortaccesstoken",
|
||||
refreshToken: "shortrefreshtoken",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
},
|
||||
{
|
||||
name: "Long tokens exceeding 4096 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 5000),
|
||||
refreshToken: strings.Repeat("y", 6000),
|
||||
// Recalculate expected cookies based on new maxCookieSize
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
|
||||
},
|
||||
{
|
||||
name: "REALLY long tokens, exceeding 25000 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 25000),
|
||||
refreshToken: strings.Repeat("y", 25000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated session",
|
||||
authenticated: false,
|
||||
email: "",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
},
|
||||
}
|
||||
|
||||
// Test setting and getting values
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetAccessToken("test.access.token")
|
||||
session.SetRefreshToken("test.refresh.token")
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != 3 {
|
||||
t.Errorf("Expected 3 cookies, got %d", len(cookies))
|
||||
}
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := manager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
// Verify cookies are set
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != tc.expectedCookieCount {
|
||||
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
|
||||
}
|
||||
|
||||
if !newSession.GetAuthenticated() {
|
||||
t.Error("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != "test@example.com" {
|
||||
t.Errorf("Expected email test@example.com, got %s", email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != "test.access.token" {
|
||||
t.Errorf("Expected access token test.access.token, got %s", token)
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != "test.refresh.token" {
|
||||
t.Errorf("Expected refresh token test.refresh.token, got %s", token)
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := ts.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
if newSession.GetAuthenticated() != tc.authenticated {
|
||||
t.Errorf("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != tc.email {
|
||||
t.Errorf("Expected email %s, got %s", tc.email, email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != tc.accessToken {
|
||||
t.Errorf("Access token not preserved")
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
|
||||
t.Errorf("Refresh token not preserved")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
|
||||
count := 3 // main, access, refresh
|
||||
|
||||
// Calculate number of chunks for access token
|
||||
accessChunks := len(splitIntoChunks(accessToken, maxCookieSize))
|
||||
if accessChunks > 1 {
|
||||
count += accessChunks
|
||||
}
|
||||
|
||||
// Calculate number of chunks for refresh token
|
||||
refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize))
|
||||
if refreshChunks > 1 {
|
||||
count += refreshChunks
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
Reference in New Issue
Block a user