mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
546ceb949c
* fix(security): encrypt session cookies + fail closed on invalid config
Batch 1 of security audit remediation (ranks 1, 2, 6).
- session.go: derive independent HMAC + AES-256 keys via stdlib HKDF-SHA256
and build the gorilla cookie store with both, so session cookies are now
encrypted, not merely signed. The single-key store previously left OIDC
access/refresh/ID tokens recoverable from raw cookie bytes. Cookie format
changes, so existing sessions are invalidated on deploy (one-time re-login).
- main.go: call config.Validate() at construction and error out on failure,
instead of silently substituting a public hardcoded encryption key for
empty/short keys (which allowed session forgery). The yaegi analyzer
passes via .traefik.yml testData.
- settings.go: isValidSecureURL permits plaintext HTTP for loopback hosts
only (RFC 8252); remote providers must still use HTTPS.
- tests: complete configs that did not satisfy Validate(); add regression
tests in security_audit_fixes_test.go.
Configs below documented minimums (rateLimit < 10, key < 32 chars) are now
rejected at startup (fail closed).
* fix(security): validate discovered OIDC endpoints + pin introspection host
Batch 2 of security audit remediation (ranks 3, 4).
- url_helpers.go: add validateDiscoveredEndpoint, an SSRF screen for endpoints
taken from the provider discovery document (jwks_uri, token, authorization,
revocation, end_session, introspection, registration). Blocks link-local
(cloud metadata 169.254.169.254), multicast, unspecified and private
addresses (unless allowPrivateIPAddresses); blocks loopback unless the
configured providerURL is itself loopback (dev/test). Cross-domain JWKS
hosts (e.g. Google) stay allowed. Add sameHost helper.
- main.go: updateMetadataEndpoints screens every discovered endpoint and
blanks any that fail (fail closed downstream). The introspection endpoint
carries the client secret via HTTP Basic, so it is additionally pinned to
the providerURL host to stop a poisoned discovery document exfiltrating the
secret to an attacker-controlled host.
- tests: regression tests for the SSRF guard and the host pin.
* fix(security): close open redirects + anchor excluded-URL matching
Batch 3 of security audit remediation (ranks 5, 14, 15).
- auth_flow.go: run the stored incoming path through normalizeLogoutPath
before using it as the post-login redirect, so //evil.com and /\evil.com
payloads become host-relative (open-redirect, rank 5).
- url_helpers.go: excluded-URL matching is anchored at a natural boundary
(exact, sub-path "/", or file extension "."), so excluding "/public" no
longer also bypasses auth on "/publicsecret"; "/favicon" still matches
"/favicon.ico" (rank 14).
- internal/utils: X-Forwarded-Host is sanitized (first value only; reject
CRLF/whitespace/multi-value) before building redirect URLs (rank 15).
- helpers.go: the logout redirect used when there is no provider end-session
endpoint is host-relative, never an absolute URL derived from the
client-controllable request host (logout open-redirect, rank 15).
- tests: update two logout cases that asserted the old absolute redirect;
add regression tests.
* fix(security): reject unverified Azure tokens; fix transport TLS reuse
Batch 4 of security audit remediation (ranks 7, 11).
- token_validation_rs.go: an Azure nonce-bearing access token that cannot be
cryptographically verified no longer returns "authenticated" when there is
no ID token to corroborate it; it refreshes (if possible) or forces
re-authentication instead of failing open (rank 7).
- http_client_pool.go: the at-limit transport-reuse path now takes the write
lock before mutating refCount (fixes a data race) and only reuses a
transport whose TLS settings (CA pool + InsecureSkipVerify) match the
caller's, never one with a different trust store; if none matches it returns
nil so the caller falls back to a verifying default transport (rank 11).
- tests: add a transport-pool TLS-isolation regression test.
* fix(security): stop logging templated header values (token leak)
Batch 5 of security audit remediation (rank 16).
middleware.go: templated downstream headers commonly carry the access token
(e.g. "Authorization: Bearer {{.AccessToken}}"). The debug log line printed
the full header value, leaking credentials into logs. Log the header name and
byte length instead.
* fix(security): cache-key collision, cache-config divergence, fleet cleanup
Batch 6 of security audit remediation (ranks 9, 10, 12).
- token_manager.go: detectTokenType keys its cache on a SHA-256 hash of the
full token instead of the first 32 chars (which are only the base64url JWT
header). Distinct tokens sharing alg+kid no longer collide and get
mis-classified (rank 10).
- cache_manager.go: the process-global cache manager is initialized once and
shared across plugin instances; it now logs a loud warning when a later
instance requests a different explicit Redis backend that is silently
ignored, surfacing the cross-instance state-isolation hazard (rank 9).
- singleton_resources.go / main.go / utilities.go: track a process-global live
instance count; the shared singleton-token-cleanup task is stopped only when
the LAST instance shuts down, so one instance's Close() (e.g. a config reload)
no longer kills cleanup for surviving instances (rank 12).
- tests: update TestDetectTokenTypeCaching for the new key; add regression tests.
* fix(security): bound introspection cache + cookie lifetime to config
Batch 7 of security audit remediation (ranks 8, 13).
- token_introspection.go: when requireTokenIntrospection is enabled, cap the
positive introspection-result cache at 30s (instead of 5m) so a token
revoked at the provider stops passing within ~30s, matching the operator's
near-real-time revocation expectation (rank 8).
- session.go: bind the cookie store's MaxAge to the configured sessionMaxAge,
so the cookie codec's cryptographic timestamp validity is no longer fixed at
gorilla's 30-day default; a stolen cookie is valid only for the configured
session lifetime (rank 13).
- tests: add a cookie-lifetime regression test.
* fix(security): low-severity hardening (cache, DoS caps, PKCE, throttle)
Batch 8 of security audit remediation — low severity
(ranks 24, 25, 27, 29, 31, 36, 37, 41, 45, 46, 49).
- universal_cache.go: updateLocalCache updates an existing key in place instead
of orphaning its LRU element and double-counting currentSize/currentMemory
(rank 36 — the only production-reachable bug in this batch).
- jwk.go / metadata_cache.go / token_introspection.go: bound response bodies
with io.LimitReader (1 MiB) to prevent memory exhaustion from a hostile or
buggy provider (ranks 24, 25).
- jwk.go: skip JWKs not usable for signature verification (use != sig, or
key_ops without "verify") when building the key set (rank 49).
- auth_flow.go: fail closed at the callback when PKCE is enabled but the code
verifier is missing, instead of silently dropping it (rank 27).
- utilities.go / main.go: match allowedUserDomains case-insensitively (rank 31).
- bearer_auth.go: a single success no longer wipes an active per-IP penalty;
the counter resets only when no penalty is in effect (rank 29).
- main.go: handle (not discard) the NewSessionManager error (rank 37).
- error_recovery.go: take a write lock in isServiceDegraded (it deletes from a
map); compare retryable-error substrings case-insensitively (ranks 45, 46).
- singleton_resources.go: bind the generic-cache cleanup goroutine to the
resource-manager shutdown channel so it cannot outlive its owner (rank 41).
- tests: update the bearer throttle test to the corrected penalty semantics.
* fix(security): header sanitization, issuer pinning, fail-closed paths
Batch 9 of security audit remediation (ranks 18, 19, 20, 21, 22, 30, 33, 34).
- middleware.go / bearer_auth.go: sanitize claim-derived values on the cookie
auth path before injecting them into downstream headers. Drop group/role and
identifier values containing control chars, bidi-override runes, or the
, ; = delimiters (a comma would inject phantom entries into X-User-Groups);
reject control/bidi/over-length in rendered templated header output (but
permit , ; = in free-form values such as a bearer token). The bearer path
already sanitized; the cookie path did not (ranks 33, 34).
- main.go / metadata_cache.go: pin the discovered issuer to the configured
provider host (sameHost) and refuse/never-cache a mismatch, so a poisoned
discovery document cannot redefine the JWT trust anchor (ranks 21, 22).
- token_introspection.go: when a distinct API audience is configured, fail
closed on a missing or mismatched introspection audience; aud parsed as
string-or-array per RFC 7662 (rank 19).
- logout.go: front-channel logout requires a matching issuer; an empty iss is
rejected (blocks unauthenticated forced-logout via a known sid) (rank 30).
- token_validation_rs.go: an opaque access token with no ID token and no
successful introspection fails closed (re-auth) instead of authenticating
(ranks 18, 20).
- tests: realistic same-host provider mocks; regression tests for the header
sanitization distinction and the fail-closed paths.
* chore(security): remove unwired dead code with latent footguns
Batch 10 of security audit remediation — delete confirmed-dead, unwired
subsystems (ranks 26, 35, 50). None had a production caller (grep-verified);
removal eliminates the latent footguns and ~2.1k lines of dead code.
- token_validator.go (deleted): an unused *TokenValidator whose validateJWT set
Valid=true with NO signature verification — a severe footgun if ever wired
(rank 50). The wired RS-aware validators are unaffected.
- security_monitoring.go (deleted): an unused *SecurityMonitor / ExtractClientIP
that trusted spoofable X-Forwarded-For / X-Real-IP. The live bearer throttle
uses clientIPForBearer (RemoteAddr-only), unchanged (rank 35).
- dynamic_client_registration.go: removed the RFC 7592 management methods
(Update/Read/DeleteClientRegistration) that dereferenced an attacker-
influenced RegistrationClientURI with the registration token attached and no
HTTPS/SSRF gate, and had no callers. The wired RFC 7591 RegisterClient and
credential-store helpers are kept (rank 26).
- tests: removed the tests covering the deleted code.
* chore: add Makefile with yaegi load validation
No Makefile existed. The new `yaegi-validate` target interprets the plugin
under the yaegi interpreter the same way Traefik loads it, catching yaegi-only
incompatibilities (unsupported stdlib symbols, reflection edge cases) that the
native `go build` / `go test` toolchain does not. Importing the plugin forces
yaegi to interpret every file plus its vendored deps; CreateConfig + New
exercise the instantiation path.
- cmd/yaegicheck/main.go: the load driver, marked //go:build ignore so it is
excluded from `go build ./...` (avoids VCS-stamping a main binary, which
fails in git-worktree layouts) yet is run explicitly by yaegi.
- Makefile: build / fmt / vet / lint / test / vendor / yaegi-validate / check
targets; `make check` runs vet + tests + yaegi-validate.
Verified: `make yaegi-validate` passes on this branch — the HKDF cookie
encryption, net-based endpoint validation, and claim sanitizers all interpret
and instantiate cleanly under yaegi.
* ci: bump workflow Go toolchain to 1.25; pin yaegi-validate to v0.16.1
Traefik v3.7.1 (the deployed version) is built with `go 1.25.0`, so the PR and
release workflows now use Go 1.25.x to match the toolchain Traefik uses.
Important distinction: the CI Go version is the build TOOLCHAIN. The plugin's
actual interpreter-compatibility ceiling is the yaegi version Traefik bundles
(v0.16.1, which declares go 1.21 and ships a ~Go 1.22 stdlib symbol surface),
NOT the CI Go version. That ceiling is enforced by `make yaegi-validate` plus
the go.mod language directive — e.g. it is why HKDF is hand-rolled with
hmac+sha256 rather than Go 1.24's crypto/hkdf, which yaegi v0.16.1 lacks.
Also pin Makefile YAEGI_VERSION to v0.16.1 (what Traefik v3.7.1 vendors) so
yaegi-validate exercises the real deployed interpreter instead of @latest,
which could pass on a newer yaegi that supports symbols the deployed one does
not.
* docs: align README/CONFIGURATION with branch behavior changes
- excludedURLs: documented as segment/extension-boundary matching (was
"prefix-matched") — "/public" no longer also matches "/publicsecret" (rank 14).
- Front-channel logout now requires a matching `iss`; requests without one are
rejected with 400 (rank 30).
- Add an "Upgrading from an earlier release" note: session cookies are now
AES-256 encrypted with lifetime tracking sessionMaxAge (one-time re-login on
upgrade), and invalid configuration (rateLimit < 10, key < 32 bytes, missing
callbackURL, non-HTTPS remote providerURL) now fails closed at startup.
* fix: remove staticcheck-flagged unused functions; wire staticcheck into make check
CI Static Analysis (standalone staticcheck) failed with U1000 "unused":
- dynamic_client_registration.go: deleteCredentialsFromStore — its only caller
was the RFC 7592 DeleteClientRegistration removed in the dead-code batch.
- token_test.go: createTestJWTSimple — its only callers were the TokenValidator
tests removed in the same batch.
Both confirmed to have zero remaining callers and removed. build / vet /
go test ./... / staticcheck ./... all green.
The pre-commit hook runs golangci-lint, but CI runs standalone staticcheck
(which flags U1000). Add a `staticcheck` Makefile target and include it in
`make check` so this class of finding is caught locally before push.
* fix(test): stabilize flaky TestWorkerPool_TaskPanic
tasksFailed is incremented in the worker's deferred recover(), which runs after the panicking task's own defer wg.Done(). wg.Wait() could therefore return before the failure was recorded, so reading the counter immediately raced and flaked on slow CI runners. Poll until the failure lands (2s budget) instead. Verified 200x plain + 50x under -race/GOMAXPROCS=1.
2873 lines
91 KiB
Go
2873 lines
91 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"github.com/lukaszraczylo/traefikoidc/internal/pool"
|
|
)
|
|
|
|
// constantTimeStringCompare performs a constant-time comparison of two strings
|
|
// to prevent timing attacks. Returns true if the strings are equal.
|
|
func constantTimeStringCompare(a, b string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
|
}
|
|
|
|
// deriveCookieKeys derives an independent 64-byte HMAC authentication key and a
|
|
// 32-byte AES-256 encryption key from the operator-provided session encryption
|
|
// key using HKDF-SHA256 (RFC 5869).
|
|
//
|
|
// gorilla/securecookie only ENCRYPTS the cookie payload when a block
|
|
// (encryption) key is supplied; constructing the store with a single key leaves
|
|
// sessions signed-but-plaintext, so the OIDC access/refresh/ID tokens stored in
|
|
// the cookie are recoverable by anyone who can read the raw cookie bytes. Two
|
|
// independent keys are derived here so the cookie is both encrypted and
|
|
// authenticated. HKDF is implemented with stdlib hmac+sha256 so it runs under
|
|
// Traefik's yaegi interpreter, which may not export crypto/hkdf.
|
|
func deriveCookieKeys(secret string) (authKey, encKey []byte) {
|
|
okm := hkdfSHA256([]byte(secret), nil, []byte("traefikoidc session cookie keys v1"), 96)
|
|
return okm[:64], okm[64:96]
|
|
}
|
|
|
|
// hkdfSHA256 performs HKDF-Extract followed by HKDF-Expand (RFC 5869) using
|
|
// HMAC-SHA256 and returns length bytes of output keying material.
|
|
func hkdfSHA256(ikm, salt, info []byte, length int) []byte {
|
|
if len(salt) == 0 {
|
|
salt = make([]byte, sha256.Size)
|
|
}
|
|
// Extract: PRK = HMAC-SHA256(salt, IKM)
|
|
ext := hmac.New(sha256.New, salt)
|
|
ext.Write(ikm)
|
|
prk := ext.Sum(nil)
|
|
// Expand: T(i) = HMAC-SHA256(PRK, T(i-1) | info | i)
|
|
var out, t []byte
|
|
for i := byte(1); len(out) < length; i++ {
|
|
exp := hmac.New(sha256.New, prk)
|
|
exp.Write(t)
|
|
exp.Write(info)
|
|
exp.Write([]byte{i})
|
|
t = exp.Sum(nil)
|
|
out = append(out, t...)
|
|
}
|
|
return out[:length]
|
|
}
|
|
|
|
// min returns the minimum of two integers.
|
|
// This is a utility function used throughout the session management code.
|
|
// Parameters:
|
|
// - a: The first integer to compare.
|
|
// - b: The second integer to compare.
|
|
//
|
|
// Returns:
|
|
// - The smaller of the two integers.
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// generateSecureRandomString creates a cryptographically secure random string.
|
|
// It generates random bytes using crypto/rand and encodes them as hexadecimal.
|
|
// This is used for session IDs and other security-sensitive random values.
|
|
// Parameters:
|
|
// - length: The number of random bytes to generate (output will be 2x this length in hex).
|
|
//
|
|
// Returns:
|
|
// - The hex-encoded random string.
|
|
// - An error if reading random bytes fails.
|
|
func generateSecureRandomString(length int) (string, error) {
|
|
bytes := make([]byte, length)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|
|
|
|
// Cookie name suffixes used for session management
|
|
// These are appended to the cookiePrefix to create full cookie names
|
|
// #nosec G101 -- These are cookie name suffixes, not hardcoded credentials
|
|
const (
|
|
mainCookieSuffix = "m"
|
|
accessTokenSuffix = "a"
|
|
refreshTokenSuffix = "r"
|
|
idTokenSuffix = "id"
|
|
combinedCookieSuffix = "s" // Combined session cookie suffix
|
|
defaultCookiePrefix = "_oidc_raczylo_"
|
|
)
|
|
|
|
const (
|
|
// maxCookieSize is the maximum raw data size per chunk before securecookie encoding.
|
|
//
|
|
// Browser cookie limit: 4096 bytes
|
|
// Securecookie overhead: ~2x (gob + AES encryption + MAC + double base64)
|
|
// Safety headroom: 1000 bytes for cookie metadata, headers, edge cases
|
|
//
|
|
// Math: 1400 raw → ~2800 encoded → safely under (4096 - 1000 = 3096) limit
|
|
maxCookieSize = 1400
|
|
|
|
// maxCombinedChunks is the maximum number of chunks allowed for combined session
|
|
maxCombinedChunks = 10
|
|
|
|
absoluteSessionTimeout = 24 * time.Hour
|
|
|
|
minEncryptionKeyLength = 32
|
|
)
|
|
|
|
// combinedSessionPayload is the JSON structure for combined cookie storage.
|
|
// Uses short field names to minimize size.
|
|
type combinedSessionPayload struct {
|
|
X map[string]interface{} `json:"x,omitempty"`
|
|
A string `json:"a,omitempty"`
|
|
R string `json:"r,omitempty"`
|
|
I string `json:"i,omitempty"`
|
|
Ui string `json:"ui,omitempty"`
|
|
Cs string `json:"cs,omitempty"`
|
|
N string `json:"n,omitempty"`
|
|
Cv string `json:"cv,omitempty"`
|
|
Ip string `json:"ip,omitempty"`
|
|
Ca int64 `json:"ca,omitempty"`
|
|
Rc int `json:"rc,omitempty"`
|
|
Au bool `json:"au,omitempty"`
|
|
}
|
|
|
|
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
|
// All other mainSession.Values keys are stored in the X (extra) field.
|
|
var knownSessionKeys = map[string]bool{
|
|
"access_token": true,
|
|
"refresh_token": true,
|
|
"id_token": true,
|
|
"user_identifier": true,
|
|
"authenticated": true,
|
|
"csrf": true,
|
|
"nonce": true,
|
|
"code_verifier": true,
|
|
"incoming_path": true,
|
|
"created_at": true,
|
|
"redirect_count": true,
|
|
}
|
|
|
|
// compressCombinedPayload compresses the combined session payload using gzip.
|
|
// It serializes the payload to JSON, compresses it, and returns base64-encoded data.
|
|
// Returns the compressed string and any error encountered.
|
|
func compressCombinedPayload(payload *combinedSessionPayload) (string, error) {
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal combined payload: %w", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
gz := gzip.NewWriter(&buf)
|
|
if _, err := gz.Write(jsonData); err != nil {
|
|
return "", fmt.Errorf("failed to compress combined payload: %w", err)
|
|
}
|
|
if err := gz.Close(); err != nil {
|
|
return "", fmt.Errorf("failed to close gzip writer: %w", err)
|
|
}
|
|
|
|
compressed := base64.StdEncoding.EncodeToString(buf.Bytes())
|
|
return compressed, nil
|
|
}
|
|
|
|
// decompressCombinedPayload decompresses a base64+gzip encoded combined session payload.
|
|
// Returns the deserialized payload and any error encountered.
|
|
func decompressCombinedPayload(compressed string) (*combinedSessionPayload, error) {
|
|
if compressed == "" {
|
|
return nil, fmt.Errorf("empty compressed data")
|
|
}
|
|
|
|
data, err := base64.StdEncoding.DecodeString(compressed)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode base64: %w", err)
|
|
}
|
|
|
|
gr, err := gzip.NewReader(bytes.NewReader(data))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
|
}
|
|
defer func() { _ = gr.Close() }()
|
|
|
|
// Limit decompressed size to prevent zip bombs
|
|
limitedReader := io.LimitReader(gr, 512*1024) // 512KB max
|
|
decompressed, err := io.ReadAll(limitedReader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decompress: %w", err)
|
|
}
|
|
|
|
var payload combinedSessionPayload
|
|
if err := json.Unmarshal(decompressed, &payload); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal combined payload: %w", err)
|
|
}
|
|
|
|
return &payload, nil
|
|
}
|
|
|
|
// splitCombinedIntoChunks splits compressed data into chunks of maxCookieSize.
|
|
// Returns the chunks and the total number of chunks.
|
|
func splitCombinedIntoChunks(data string, chunkSize int) []string {
|
|
if len(data) <= chunkSize {
|
|
return []string{data}
|
|
}
|
|
|
|
var chunks []string
|
|
for i := 0; i < len(data); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(data) {
|
|
end = len(data)
|
|
}
|
|
chunks = append(chunks, data[i:end])
|
|
}
|
|
return chunks
|
|
}
|
|
|
|
// assembleCombinedChunks reassembles chunks back into the original compressed data.
|
|
// It reads chunks from session values in order (chunk_0, chunk_1, etc.).
|
|
func assembleCombinedChunks(sessions []*sessions.Session) string {
|
|
if len(sessions) == 0 {
|
|
return ""
|
|
}
|
|
|
|
var parts []string
|
|
for i := 0; i < len(sessions); i++ {
|
|
session := sessions[i]
|
|
if session == nil {
|
|
break
|
|
}
|
|
chunk, ok := session.Values["d"].(string) // "d" for data
|
|
if !ok || chunk == "" {
|
|
break
|
|
}
|
|
parts = append(parts, chunk)
|
|
}
|
|
return strings.Join(parts, "")
|
|
}
|
|
|
|
// compressToken compresses a JWT token using gzip compression if beneficial.
|
|
// It validates the token format, attempts compression, and verifies the compressed
|
|
// data can be decompressed correctly. Only compresses if it reduces size.
|
|
// Parameters:
|
|
// - token: The JWT token string to potentially compress.
|
|
//
|
|
// Returns:
|
|
// - The base64 encoded, gzipped string, or the original string if compression fails.
|
|
func compressToken(token string) string {
|
|
if token == "" {
|
|
return token
|
|
}
|
|
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
return token
|
|
}
|
|
|
|
if len(token) > 50*1024 {
|
|
return token
|
|
}
|
|
|
|
pm := pool.Get()
|
|
b := pm.GetBuffer(4096)
|
|
defer pm.PutBuffer(b)
|
|
|
|
gz := gzip.NewWriter(b)
|
|
|
|
written, err := gz.Write([]byte(token))
|
|
if err != nil || written != len(token) {
|
|
return token
|
|
}
|
|
|
|
if err := gz.Close(); err != nil {
|
|
return token
|
|
}
|
|
|
|
compressedBytes := b.Bytes()
|
|
if len(compressedBytes) == 0 {
|
|
return token
|
|
}
|
|
|
|
compressed := base64.StdEncoding.EncodeToString(compressedBytes)
|
|
|
|
if len(compressed) >= len(token) {
|
|
return token
|
|
}
|
|
|
|
decompressed := decompressTokenInternal(compressed)
|
|
if decompressed != token {
|
|
return token
|
|
}
|
|
|
|
if strings.Count(decompressed, ".") != 2 {
|
|
return token
|
|
}
|
|
|
|
return compressed
|
|
}
|
|
|
|
// decompressToken decompresses a previously compressed token string.
|
|
// It decodes the base64 data, validates gzip headers, and decompresses safely
|
|
// with size limits to prevent compression bombs.
|
|
// Parameters:
|
|
// - compressed: The base64-encoded compressed token string.
|
|
//
|
|
// Returns:
|
|
// - The decompressed original string, or the input string if decompression fails.
|
|
func decompressToken(compressed string) string {
|
|
return decompressTokenInternal(compressed)
|
|
}
|
|
|
|
// decompressTokenInternal is the internal decompression function.
|
|
// Separated internal function for integrity verification during compression.
|
|
// It performs the actual decompression logic with proper resource management.
|
|
// Parameters:
|
|
// - compressed: The compressed token string to decompress.
|
|
//
|
|
// Returns:
|
|
// - The decompressed token or the original string if decompression fails.
|
|
func decompressTokenInternal(compressed string) string {
|
|
if compressed == "" {
|
|
return compressed
|
|
}
|
|
|
|
if len(compressed) > 100*1024 {
|
|
return compressed
|
|
}
|
|
|
|
data, err := base64.StdEncoding.DecodeString(compressed)
|
|
if err != nil {
|
|
return compressed
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return compressed
|
|
}
|
|
|
|
if len(data) < 2 || data[0] != 0x1f || data[1] != 0x8b {
|
|
return compressed
|
|
}
|
|
|
|
pm := pool.Get()
|
|
readerBuf := pm.GetHTTPResponseBuffer()
|
|
defer pm.PutHTTPResponseBuffer(readerBuf)
|
|
|
|
gz, err := gzip.NewReader(bytes.NewReader(data))
|
|
if err != nil {
|
|
return compressed
|
|
}
|
|
|
|
defer func() {
|
|
if closeErr := gz.Close(); closeErr != nil {
|
|
_ = closeErr
|
|
}
|
|
}()
|
|
|
|
limitedReader := io.LimitReader(gz, 500*1024)
|
|
|
|
if cap(readerBuf) >= 512*1024 {
|
|
readerBuf = readerBuf[:cap(readerBuf)]
|
|
n, err := limitedReader.Read(readerBuf)
|
|
if err != nil && err != io.EOF {
|
|
return compressed
|
|
}
|
|
decompressed := readerBuf[:n]
|
|
return string(decompressed)
|
|
}
|
|
|
|
decompressed, err := io.ReadAll(limitedReader)
|
|
if err != nil {
|
|
return compressed
|
|
}
|
|
|
|
if len(decompressed) == 0 {
|
|
return compressed
|
|
}
|
|
|
|
decompressedStr := string(decompressed)
|
|
|
|
if decompressedStr != "" && strings.Count(decompressedStr, ".") != 2 {
|
|
return compressed
|
|
}
|
|
|
|
return decompressedStr
|
|
}
|
|
|
|
// SessionManager manages OIDC session state and cookie-based storage.
|
|
// It provides secure session management with support for token compression,
|
|
// chunked storage for large tokens, session pooling for performance,
|
|
// session object reuse and supports both HTTP and HTTPS schemes.
|
|
type SessionManager struct {
|
|
sessionPool sync.Pool
|
|
ctx context.Context
|
|
store sessions.Store
|
|
logger *Logger
|
|
chunkManager *ChunkManager
|
|
memoryMonitor *TaskMemoryMonitor
|
|
cancel context.CancelFunc
|
|
cookieDomain string
|
|
cookiePrefix string
|
|
sessionMaxAge time.Duration
|
|
activeSessions int64
|
|
poolHits int64
|
|
poolMisses int64
|
|
cleanupMutex sync.RWMutex
|
|
shutdownOnce sync.Once
|
|
forceHTTPS bool
|
|
cleanupDone bool
|
|
}
|
|
|
|
// NewSessionManager creates a new SessionManager instance with secure defaults.
|
|
// It initializes the cookie store with encryption, sets up session pooling,
|
|
// and configures chunk management for large tokens.
|
|
// Parameters:
|
|
// - encryptionKey: The key for encrypting session cookies (minimum 32 bytes).
|
|
// - forceHTTPS: Whether to force HTTPS-only cookies regardless of request scheme.
|
|
// - cookieDomain: The domain for session cookies (empty for auto-detection).
|
|
// - cookiePrefix: Prefix for session cookie names (empty for default "_oidc_raczylo_").
|
|
// - sessionMaxAge: Maximum session age duration (0 for default 24 hours).
|
|
// - logger: Logger instance for debug and error logging.
|
|
//
|
|
// Returns:
|
|
// - The configured SessionManager instance.
|
|
// - An error if the encryption key does not meet minimum length requirements.
|
|
func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain string, cookiePrefix string, sessionMaxAge time.Duration, logger *Logger) (*SessionManager, error) {
|
|
if len(encryptionKey) < minEncryptionKeyLength {
|
|
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
|
}
|
|
|
|
// Set default cookie prefix if not provided
|
|
if cookiePrefix == "" {
|
|
cookiePrefix = defaultCookiePrefix
|
|
}
|
|
|
|
// Set default session max age if not provided (24 hours for backward compatibility)
|
|
if sessionMaxAge == 0 {
|
|
sessionMaxAge = absoluteSessionTimeout
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Derive independent authentication + encryption keys so the session cookie
|
|
// is AES-256 encrypted and HMAC authenticated, not merely signed. See
|
|
// deriveCookieKeys: a single key would leave the stored tokens in plaintext.
|
|
authKey, encKey := deriveCookieKeys(encryptionKey)
|
|
|
|
sm := &SessionManager{
|
|
store: sessions.NewCookieStore(authKey, encKey),
|
|
forceHTTPS: forceHTTPS,
|
|
cookieDomain: cookieDomain,
|
|
cookiePrefix: cookiePrefix,
|
|
sessionMaxAge: sessionMaxAge,
|
|
logger: logger,
|
|
chunkManager: NewChunkManager(logger),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
// Bind the cookie codec's timestamp validity (and the cookie Max-Age) to the
|
|
// configured session lifetime instead of gorilla's 30-day default, so a
|
|
// stolen cookie is not cryptographically valid for up to 30 days regardless
|
|
// of the (possibly much shorter) configured sessionMaxAge (rank 13).
|
|
if cs, ok := sm.store.(*sessions.CookieStore); ok {
|
|
cs.MaxAge(int(sessionMaxAge.Seconds()))
|
|
}
|
|
|
|
// Initialize global memory monitoring (singleton)
|
|
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
|
|
|
|
// Start memory monitoring every 30 seconds (will skip if already started)
|
|
if err := sm.memoryMonitor.Start(30 * time.Second); err != nil {
|
|
logger.Debugf("Failed to start memory monitoring: %v", err)
|
|
}
|
|
|
|
sm.sessionPool.New = func() interface{} {
|
|
atomic.AddInt64(&sm.poolMisses, 1)
|
|
sd := &SessionData{
|
|
manager: sm,
|
|
accessTokenChunks: make(map[int]*sessions.Session),
|
|
refreshTokenChunks: make(map[int]*sessions.Session),
|
|
idTokenChunks: make(map[int]*sessions.Session),
|
|
combinedChunks: make(map[int]*sessions.Session),
|
|
useCombinedStorage: true, // Use combined storage by default for new sessions
|
|
refreshMutex: sync.Mutex{},
|
|
sessionMutex: sync.RWMutex{},
|
|
dirty: false,
|
|
inUse: false,
|
|
}
|
|
sd.Reset()
|
|
return sd
|
|
}
|
|
|
|
// Start background cleanup routine
|
|
go sm.backgroundCleanup()
|
|
|
|
return sm, nil
|
|
}
|
|
|
|
// Cookie name helper methods - build cookie names using the configured prefix
|
|
|
|
// mainCookieName returns the main session cookie name with the configured prefix
|
|
func (sm *SessionManager) mainCookieName() string {
|
|
return sm.cookiePrefix + mainCookieSuffix
|
|
}
|
|
|
|
// accessTokenCookieName returns the access token cookie name with the configured prefix
|
|
func (sm *SessionManager) accessTokenCookieName() string {
|
|
return sm.cookiePrefix + accessTokenSuffix
|
|
}
|
|
|
|
// refreshTokenCookieName returns the refresh token cookie name with the configured prefix
|
|
func (sm *SessionManager) refreshTokenCookieName() string {
|
|
return sm.cookiePrefix + refreshTokenSuffix
|
|
}
|
|
|
|
// idTokenCookieName returns the ID token cookie name with the configured prefix
|
|
func (sm *SessionManager) idTokenCookieName() string {
|
|
return sm.cookiePrefix + idTokenSuffix
|
|
}
|
|
|
|
// combinedCookieName returns the combined session cookie base name with the configured prefix
|
|
// Chunk cookies are named: prefix + "s" + "_" + chunkIndex (e.g., "_oidc_raczylo_s_0")
|
|
func (sm *SessionManager) combinedCookieName() string {
|
|
return sm.cookiePrefix + combinedCookieSuffix
|
|
}
|
|
|
|
// combinedChunkCookieName returns the name for a specific combined session chunk
|
|
func (sm *SessionManager) combinedChunkCookieName(chunkIndex int) string {
|
|
return fmt.Sprintf("%s_%d", sm.combinedCookieName(), chunkIndex)
|
|
}
|
|
|
|
// GetCookiePrefix returns the cookie prefix used for all OIDC session cookies.
|
|
func (sm *SessionManager) GetCookiePrefix() string {
|
|
return sm.cookiePrefix
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the SessionManager and all its background tasks
|
|
func (sm *SessionManager) Shutdown() error {
|
|
var shutdownErr error
|
|
sm.shutdownOnce.Do(func() {
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("SessionManager shutdown initiated")
|
|
}
|
|
|
|
// Cancel context to stop all background operations
|
|
if sm.cancel != nil {
|
|
sm.cancel()
|
|
}
|
|
|
|
// Stop memory monitor
|
|
if sm.memoryMonitor != nil {
|
|
sm.memoryMonitor.Stop()
|
|
}
|
|
|
|
// Stop chunk manager
|
|
if sm.chunkManager != nil {
|
|
sm.chunkManager.Shutdown()
|
|
}
|
|
|
|
// Force garbage collection to help cleanup
|
|
runtime.GC()
|
|
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("SessionManager shutdown completed")
|
|
}
|
|
})
|
|
return shutdownErr
|
|
}
|
|
|
|
// backgroundCleanup runs periodic cleanup tasks for session management
|
|
func (sm *SessionManager) backgroundCleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-sm.ctx.Done():
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("Background cleanup routine terminated")
|
|
}
|
|
return
|
|
case <-ticker.C:
|
|
sm.performCleanupCycle()
|
|
}
|
|
}
|
|
}
|
|
|
|
// performCleanupCycle executes a complete cleanup cycle
|
|
func (sm *SessionManager) performCleanupCycle() {
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("Starting background cleanup cycle")
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Run periodic chunk cleanup
|
|
sm.PeriodicChunkCleanup()
|
|
|
|
// Clean up session pool by forcing GC on old sessions
|
|
sm.cleanupSessionPool()
|
|
|
|
// Force garbage collection if memory usage is high
|
|
var m runtime.MemStats
|
|
runtime.ReadMemStats(&m)
|
|
if m.HeapAlloc > 50*1024*1024 { // 50MB threshold
|
|
runtime.GC()
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("Forced garbage collection due to high memory usage")
|
|
}
|
|
}
|
|
|
|
duration := time.Since(startTime)
|
|
if sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
sm.logger.Debugf("Cleanup cycle completed in %v", duration)
|
|
}
|
|
}
|
|
|
|
// cleanupSessionPool performs cleanup on the session pool
|
|
func (sm *SessionManager) cleanupSessionPool() {
|
|
cleaned := 0
|
|
const maxCleanup = 20 // Limit cleanup per cycle to avoid performance impact
|
|
|
|
for i := 0; i < maxCleanup; i++ {
|
|
select {
|
|
case <-sm.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if poolSession := sm.sessionPool.Get(); poolSession != nil {
|
|
sessionData, ok := poolSession.(*SessionData)
|
|
if ok && sessionData != nil && !sessionData.inUse {
|
|
sessionData.Reset()
|
|
cleaned++
|
|
}
|
|
sm.sessionPool.Put(poolSession)
|
|
} else {
|
|
break // Pool is empty
|
|
}
|
|
}
|
|
|
|
if cleaned > 0 && sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
sm.logger.Debugf("Cleaned %d session pool objects", cleaned)
|
|
}
|
|
}
|
|
|
|
// GetSessionStats returns statistics about session management
|
|
func (sm *SessionManager) GetSessionStats() map[string]interface{} {
|
|
stats := make(map[string]interface{})
|
|
stats["active_sessions"] = atomic.LoadInt64(&sm.activeSessions)
|
|
stats["pool_hits"] = atomic.LoadInt64(&sm.poolHits)
|
|
stats["pool_misses"] = atomic.LoadInt64(&sm.poolMisses)
|
|
|
|
if sm.memoryMonitor != nil {
|
|
if currentStats, err := sm.memoryMonitor.GetCurrentStats(); err == nil {
|
|
stats["goroutines"] = currentStats.Goroutines
|
|
stats["heap_alloc"] = currentStats.HeapAlloc
|
|
stats["num_gc"] = currentStats.NumGC
|
|
}
|
|
}
|
|
|
|
return stats
|
|
}
|
|
|
|
// PeriodicChunkCleanup performs comprehensive session maintenance and cleanup.
|
|
// It cleans up orphaned token chunks, expired sessions, and unused pool objects.
|
|
// This helps maintain performance and prevent cookie accumulation in client browsers.
|
|
func (sm *SessionManager) PeriodicChunkCleanup() {
|
|
if sm == nil || sm.logger == nil {
|
|
return
|
|
}
|
|
|
|
// Check if context is canceled or we're in test mode to prevent logging after test completion
|
|
if sm.ctx == nil || sm.ctx.Err() != nil || isTestMode() {
|
|
return // Skip logging if context is canceled or in test mode
|
|
}
|
|
|
|
sm.logger.Debug("Starting comprehensive session cleanup cycle")
|
|
|
|
cleanupStart := time.Now()
|
|
var orphanedChunks, expiredSessions, cleanupErrors int
|
|
|
|
if sm.store != nil {
|
|
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok {
|
|
// Check context again before logging
|
|
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
sm.logger.Debug("Running session store cleanup")
|
|
}
|
|
_ = cookieStore
|
|
}
|
|
}
|
|
|
|
// Cleanup expired sessions in chunk manager to prevent memory leaks
|
|
if sm.chunkManager != nil {
|
|
sm.chunkManager.CleanupExpiredSessions()
|
|
}
|
|
|
|
poolCleaned := 0
|
|
for i := 0; i < 10; i++ {
|
|
if poolSession := sm.sessionPool.Get(); poolSession != nil {
|
|
sessionData, ok := poolSession.(*SessionData)
|
|
if ok && sessionData != nil && !sessionData.inUse {
|
|
sessionData.Reset()
|
|
poolCleaned++
|
|
}
|
|
sm.sessionPool.Put(poolSession)
|
|
}
|
|
}
|
|
|
|
// Check context before final logging
|
|
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
cleanupDuration := time.Since(cleanupStart)
|
|
sm.logger.Debugf("Session cleanup completed in %v: pool_cleaned=%d, orphaned_chunks=%d, expired_sessions=%d, errors=%d",
|
|
cleanupDuration, poolCleaned, orphanedChunks, expiredSessions, cleanupErrors)
|
|
}
|
|
}
|
|
|
|
// ValidateSessionHealth performs comprehensive validation of session integrity.
|
|
// It checks authentication state, validates token formats, and detects
|
|
// potential tampering or corruption in session data.
|
|
// Parameters:
|
|
// - sessionData: The session data to validate.
|
|
//
|
|
// Returns:
|
|
// - An error describing any validation failures, nil if session is healthy.
|
|
func (sm *SessionManager) ValidateSessionHealth(sessionData *SessionData) error {
|
|
if sessionData == nil {
|
|
return fmt.Errorf("session data is nil")
|
|
}
|
|
|
|
if !sessionData.GetAuthenticated() {
|
|
return fmt.Errorf("session is not authenticated or has expired")
|
|
}
|
|
|
|
accessToken := sessionData.GetAccessToken()
|
|
refreshToken := sessionData.GetRefreshToken()
|
|
idToken := sessionData.GetIDToken()
|
|
|
|
if accessToken != "" {
|
|
if err := sm.validateTokenFormat(accessToken, "access_token"); err != nil {
|
|
return fmt.Errorf("access token validation failed: %w", err)
|
|
}
|
|
}
|
|
|
|
if refreshToken != "" {
|
|
if err := sm.validateTokenFormat(refreshToken, "refresh_token"); err != nil {
|
|
return fmt.Errorf("refresh token validation failed: %w", err)
|
|
}
|
|
}
|
|
|
|
if idToken != "" {
|
|
if err := sm.validateTokenFormat(idToken, "id_token"); err != nil {
|
|
return fmt.Errorf("ID token validation failed: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := sm.detectSessionTampering(sessionData); err != nil {
|
|
return fmt.Errorf("session tampering detected: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateTokenFormat validates the structure and format of authentication tokens.
|
|
// It checks for corruption markers, validates JWT structure if applicable,
|
|
// and ensures tokens meet format requirements.
|
|
// Parameters:
|
|
// - token: The token string to validate.
|
|
// - tokenType: The type of token being validated (for error messages).
|
|
//
|
|
// Returns:
|
|
// - An error if the token has invalid structure or exceeds size limits.
|
|
func (sm *SessionManager) validateTokenFormat(token, tokenType string) error {
|
|
if token == "" {
|
|
return nil
|
|
}
|
|
|
|
if isCorruptionMarker(token) {
|
|
return fmt.Errorf("%s contains corruption marker", tokenType)
|
|
}
|
|
|
|
if strings.Count(token, ".") == 2 {
|
|
parts := strings.Split(token, ".")
|
|
for i, part := range parts {
|
|
if part == "" {
|
|
return fmt.Errorf("%s has empty part %d in JWT format", tokenType, i)
|
|
}
|
|
if strings.ContainsAny(part, "+/=") && !strings.ContainsAny(part, "-_") {
|
|
sm.logger.Debugf("Token %s part %d uses base64 instead of base64url encoding", tokenType, i)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// detectSessionTampering checks for indicators of session tampering.
|
|
// It examines session values for path traversal attempts, XSS payloads,
|
|
// and suspicious data patterns that might indicate malicious modification.
|
|
// Parameters:
|
|
// - sessionData: The session data to examine for tampering.
|
|
//
|
|
// Returns:
|
|
// - An error if tampering is detected, nil if session appears safe.
|
|
func (sm *SessionManager) detectSessionTampering(sessionData *SessionData) error {
|
|
if sessionData.mainSession == nil {
|
|
return fmt.Errorf("main session is missing")
|
|
}
|
|
|
|
for key, value := range sessionData.mainSession.Values {
|
|
if str, ok := value.(string); ok {
|
|
if strings.Contains(str, "../") || strings.Contains(str, "..\\") {
|
|
return fmt.Errorf("potential path traversal attempt in session key %v", key)
|
|
}
|
|
if strings.Contains(str, "<script") || strings.Contains(str, "javascript:") {
|
|
return fmt.Errorf("potential XSS attempt in session key %v", key)
|
|
}
|
|
if len(str) > 10000 {
|
|
return fmt.Errorf("suspiciously long session value for key %v (length: %d)", key, len(str))
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSessionMetrics returns metrics about session management for monitoring purposes.
|
|
// It provides information about session configuration, security settings,
|
|
// and internal state for debugging and monitoring.
|
|
// Returns:
|
|
// - A map containing session metrics and configuration information.
|
|
func (sm *SessionManager) GetSessionMetrics() map[string]interface{} {
|
|
metrics := make(map[string]interface{})
|
|
metrics["session_manager_type"] = "CookieStore"
|
|
metrics["force_https"] = sm.forceHTTPS
|
|
metrics["absolute_timeout_hours"] = sm.sessionMaxAge.Hours()
|
|
metrics["max_cookie_size"] = maxCookieSize
|
|
metrics["max_encoded_cookie_size"] = maxCookieSize * 2 // ~2x encoding overhead
|
|
|
|
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok && len(cookieStore.Codecs) > 0 {
|
|
metrics["has_encryption"] = true
|
|
metrics["codec_count"] = len(cookieStore.Codecs)
|
|
} else {
|
|
metrics["has_encryption"] = false
|
|
}
|
|
|
|
metrics["pool_implementation"] = "sync.Pool"
|
|
|
|
return metrics
|
|
}
|
|
|
|
// EnhanceSessionSecurity applies additional security measures to session options.
|
|
// It configures secure cookies, domain detection, SameSite policies, and
|
|
// adapts security settings based on request context and client characteristics.
|
|
// Parameters:
|
|
// - options: The base session options to enhance (can be nil).
|
|
// - r: The HTTP request context for security decisions.
|
|
//
|
|
// Returns:
|
|
// - Enhanced sessions.Options with additional security measures.
|
|
func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *http.Request) *sessions.Options {
|
|
if options == nil {
|
|
options = &sessions.Options{}
|
|
}
|
|
|
|
if r != nil {
|
|
userAgent := r.Header.Get("User-Agent")
|
|
if userAgent == "" {
|
|
sm.logger.Debugf("Request from %s missing User-Agent header", r.RemoteAddr)
|
|
options.MaxAge = int((sm.sessionMaxAge / 2).Seconds())
|
|
}
|
|
|
|
if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS {
|
|
options.Secure = true
|
|
}
|
|
|
|
// Keep SameSite=Lax consistently for OAuth flows
|
|
// Removed dynamic switching based on XMLHttpRequest header to prevent redirect loop
|
|
options.SameSite = http.SameSiteLaxMode
|
|
}
|
|
|
|
options.HttpOnly = true
|
|
options.Path = "/" // Ensure cookies are available on all paths for OAuth flow
|
|
|
|
if sm.cookieDomain != "" {
|
|
options.Domain = sm.cookieDomain
|
|
sm.logger.Debugf("Using configured cookie domain: %s", sm.cookieDomain)
|
|
} else if options.Domain == "" && r != nil {
|
|
host := r.Host
|
|
|
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
|
host = forwardedHost
|
|
}
|
|
|
|
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
|
|
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
|
|
host = host[:colonIndex]
|
|
}
|
|
options.Domain = host
|
|
sm.logger.Debugf("Auto-detected cookie domain: %s", host)
|
|
}
|
|
}
|
|
|
|
return options
|
|
}
|
|
|
|
// getSessionOptions creates base session options with security settings.
|
|
// It configures cookie security, lifetime, path, and domain settings
|
|
// based on the HTTPS status and manager configuration.
|
|
// Parameters:
|
|
// - isSecure: Whether the request is over HTTPS or should be treated as secure.
|
|
//
|
|
// Returns:
|
|
// - A pointer to a configured sessions.Options struct.
|
|
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
|
baseOptions := &sessions.Options{
|
|
HttpOnly: true,
|
|
Secure: isSecure || sm.forceHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: int(sm.sessionMaxAge.Seconds()),
|
|
Path: "/",
|
|
Domain: sm.cookieDomain,
|
|
}
|
|
return baseOptions
|
|
}
|
|
|
|
// CleanupOldCookies removes stale session cookies from the client browser.
|
|
// This method handles cleanup of cookies that may exist with different domain
|
|
// configurations, ensuring clean state when domain settings change.
|
|
// It removes cookies with various domain variations to ensure cleanup after configuration changes.
|
|
// Parameters:
|
|
// - w: The HTTP response writer for setting cookie deletion headers.
|
|
// - r: The HTTP request containing cookies to examine and clean up.
|
|
func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Request) {
|
|
cookies := r.Cookies()
|
|
|
|
currentDomain := sm.cookieDomain
|
|
host := r.Host
|
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
|
host = forwardedHost
|
|
}
|
|
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
|
|
host = host[:colonIndex]
|
|
}
|
|
|
|
// This ensures we clean up cookies from various possible domains
|
|
var domainsToClean []string
|
|
|
|
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
|
|
domainsToClean = append(domainsToClean,
|
|
host,
|
|
"."+host,
|
|
)
|
|
|
|
parts := strings.Split(host, ".")
|
|
if len(parts) > 2 {
|
|
parentDomain := strings.Join(parts[len(parts)-2:], ".")
|
|
domainsToClean = append(domainsToClean,
|
|
parentDomain,
|
|
"."+parentDomain,
|
|
)
|
|
}
|
|
}
|
|
|
|
processedCookies := make(map[string]bool)
|
|
|
|
for _, cookie := range cookies {
|
|
// Check if cookie belongs to this middleware instance using the configured prefix
|
|
if strings.HasPrefix(cookie.Name, sm.cookiePrefix) ||
|
|
strings.HasPrefix(cookie.Name, "access_token_chunk_") ||
|
|
strings.HasPrefix(cookie.Name, "refresh_token_chunk_") {
|
|
|
|
processedCookies[cookie.Name] = true
|
|
|
|
sm.cleanupMutex.RLock()
|
|
shouldCleanup := currentDomain != "" && !sm.cleanupDone
|
|
sm.cleanupMutex.RUnlock()
|
|
|
|
if shouldCleanup {
|
|
for _, domain := range domainsToClean {
|
|
if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain {
|
|
continue
|
|
}
|
|
|
|
deleteCookie := &http.Cookie{
|
|
Name: cookie.Name,
|
|
Value: "",
|
|
Path: "/",
|
|
Domain: domain,
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
Secure: r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
http.SetCookie(w, deleteCookie)
|
|
sm.logger.Debugf("Attempting to clean up cookie %s with domain %s", cookie.Name, domain)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(processedCookies) > 0 {
|
|
sm.cleanupMutex.Lock()
|
|
if !sm.cleanupDone {
|
|
sm.cleanupDone = true
|
|
}
|
|
sm.cleanupMutex.Unlock()
|
|
}
|
|
}
|
|
|
|
// GetSession retrieves or creates session data from the HTTP request.
|
|
// It first tries to load from combined cookies (new format), falling back to legacy
|
|
// cookies if combined cookies don't exist. Performs validation and timeout checks.
|
|
// The returned session must be explicitly returned to the pool by calling
|
|
// returnToPoolSafely() to prevent memory leaks.
|
|
// MEMORY LEAK FIX: Session is NOT returned to pool here - caller must call ReturnToPool() when done.
|
|
// Parameters:
|
|
// - r: The HTTP request containing session cookies.
|
|
//
|
|
// Returns:
|
|
// - The loaded SessionData instance.
|
|
// - An error if session loading or validation fails.
|
|
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
|
sessionData, _ := sm.sessionPool.Get().(*SessionData) // Safe to ignore: pool return is best-effort
|
|
atomic.AddInt64(&sm.poolHits, 1)
|
|
atomic.AddInt64(&sm.activeSessions, 1)
|
|
|
|
sessionData.inUse = true
|
|
sessionData.request = r
|
|
sessionData.dirty = false
|
|
|
|
handleError := func(err error, message string) (*SessionData, error) {
|
|
if sessionData != nil {
|
|
sessionData.inUse = false
|
|
sessionData.Reset()
|
|
sm.sessionPool.Put(sessionData)
|
|
atomic.AddInt64(&sm.activeSessions, -1)
|
|
}
|
|
return nil, fmt.Errorf("%s: %w", message, err)
|
|
}
|
|
|
|
// Try to load from combined cookies first (new format)
|
|
if sm.loadFromCombinedCookies(r, sessionData) {
|
|
sessionData.useCombinedStorage = true
|
|
sm.logger.Debug("Loaded session from combined cookies")
|
|
|
|
// Check session timeout
|
|
if sessionData.getCreatedAtUnsafe() > 0 {
|
|
if time.Since(time.Unix(sessionData.getCreatedAtUnsafe(), 0)) > sm.sessionMaxAge {
|
|
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
|
|
return handleError(fmt.Errorf("session timeout"), "session expired")
|
|
}
|
|
}
|
|
|
|
return sessionData, nil
|
|
}
|
|
|
|
// Fall back to legacy cookies
|
|
sessionData.useCombinedStorage = false
|
|
sm.logger.Debug("Loading session from legacy cookies")
|
|
|
|
var err error
|
|
sessionData.mainSession, err = sm.store.Get(r, sm.mainCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get main session")
|
|
}
|
|
|
|
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
|
if time.Since(time.Unix(createdAt, 0)) > sm.sessionMaxAge {
|
|
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
|
|
return handleError(fmt.Errorf("session timeout"), "session expired")
|
|
}
|
|
}
|
|
|
|
sessionData.accessSession, err = sm.store.Get(r, sm.accessTokenCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get access token session")
|
|
}
|
|
|
|
sessionData.refreshSession, err = sm.store.Get(r, sm.refreshTokenCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get refresh token session")
|
|
}
|
|
|
|
sessionData.idTokenSession, err = sm.store.Get(r, sm.idTokenCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get ID token session")
|
|
}
|
|
|
|
for k := range sessionData.accessTokenChunks {
|
|
delete(sessionData.accessTokenChunks, k)
|
|
}
|
|
for k := range sessionData.refreshTokenChunks {
|
|
delete(sessionData.refreshTokenChunks, k)
|
|
}
|
|
for k := range sessionData.idTokenChunks {
|
|
delete(sessionData.idTokenChunks, k)
|
|
}
|
|
|
|
sm.getTokenChunkSessions(r, sm.accessTokenCookieName(), sessionData.accessTokenChunks)
|
|
sm.getTokenChunkSessions(r, sm.refreshTokenCookieName(), sessionData.refreshTokenChunks)
|
|
sm.getTokenChunkSessions(r, sm.idTokenCookieName(), sessionData.idTokenChunks)
|
|
|
|
// If legacy session has data, migrate to combined storage on next save
|
|
if !sessionData.mainSession.IsNew {
|
|
sessionData.useCombinedStorage = true
|
|
sessionData.dirty = true // Mark dirty to trigger migration on save
|
|
sm.logger.Debug("Legacy session found, will migrate to combined storage on save")
|
|
}
|
|
|
|
return sessionData, nil
|
|
}
|
|
|
|
// loadFromCombinedCookies attempts to load session data from combined cookies.
|
|
// Returns true if combined cookies were found and successfully loaded.
|
|
func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *SessionData) bool {
|
|
// Check if first combined chunk exists
|
|
firstChunk, err := sm.store.Get(r, sm.combinedChunkCookieName(0))
|
|
if err != nil || firstChunk.IsNew {
|
|
return false
|
|
}
|
|
|
|
// Get total chunk count from first chunk
|
|
totalChunks, ok := firstChunk.Values["n"].(int)
|
|
if !ok || totalChunks < 1 || totalChunks > maxCombinedChunks {
|
|
sm.logger.Debugf("Invalid combined cookie chunk count: %v", firstChunk.Values["n"])
|
|
return false
|
|
}
|
|
|
|
// Load all chunks
|
|
chunkSessions := make([]*sessions.Session, totalChunks)
|
|
chunkSessions[0] = firstChunk
|
|
sessionData.combinedChunks[0] = firstChunk
|
|
|
|
for i := 1; i < totalChunks; i++ {
|
|
chunk, err := sm.store.Get(r, sm.combinedChunkCookieName(i))
|
|
if err != nil || chunk.IsNew {
|
|
sm.logger.Debugf("Missing combined cookie chunk %d", i)
|
|
return false
|
|
}
|
|
chunkSessions[i] = chunk
|
|
sessionData.combinedChunks[i] = chunk
|
|
}
|
|
|
|
// Assemble and decompress
|
|
compressed := assembleCombinedChunks(chunkSessions)
|
|
if compressed == "" {
|
|
sm.logger.Debug("Failed to assemble combined chunks")
|
|
return false
|
|
}
|
|
|
|
payload, err := decompressCombinedPayload(compressed)
|
|
if err != nil {
|
|
sm.logger.Debugf("Failed to decompress combined payload: %v", err)
|
|
return false
|
|
}
|
|
|
|
// Hydrate the legacy session objects for compatibility with existing getter methods
|
|
// We need to initialize them even though we're using combined storage
|
|
sessionData.mainSession, _ = sm.store.Get(r, sm.mainCookieName())
|
|
sessionData.accessSession, _ = sm.store.Get(r, sm.accessTokenCookieName())
|
|
sessionData.refreshSession, _ = sm.store.Get(r, sm.refreshTokenCookieName())
|
|
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
|
|
|
// Populate legacy session values from combined payload
|
|
sessionData.mainSession.Values["user_identifier"] = payload.Ui
|
|
sessionData.mainSession.Values["authenticated"] = payload.Au
|
|
sessionData.mainSession.Values["csrf"] = payload.Cs
|
|
sessionData.mainSession.Values["nonce"] = payload.N
|
|
sessionData.mainSession.Values["code_verifier"] = payload.Cv
|
|
sessionData.mainSession.Values["incoming_path"] = payload.Ip
|
|
sessionData.mainSession.Values["created_at"] = payload.Ca
|
|
sessionData.mainSession.Values["redirect_count"] = payload.Rc
|
|
|
|
// Restore extra custom session values
|
|
for key, val := range payload.X {
|
|
sessionData.mainSession.Values[key] = val
|
|
}
|
|
|
|
sessionData.accessSession.Values["token"] = payload.A
|
|
sessionData.accessSession.Values["compressed"] = false
|
|
|
|
sessionData.refreshSession.Values["token"] = payload.R
|
|
sessionData.refreshSession.Values["compressed"] = false
|
|
|
|
sessionData.idTokenSession.Values["token"] = payload.I
|
|
sessionData.idTokenSession.Values["compressed"] = false
|
|
|
|
return true
|
|
}
|
|
|
|
// getTokenChunkSessions loads all available token chunk sessions for a given token type.
|
|
// It iterates through numbered chunk sessions until no more are found,
|
|
// populating the provided chunks map with the loaded sessions.
|
|
// Parameters:
|
|
// - r: The HTTP request containing chunk cookies.
|
|
// - baseName: The base cookie name for the token type (e.g., "_oidc_raczylo_a").
|
|
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks)
|
|
// to populate with the found session chunks.
|
|
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
|
for i := 0; ; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
|
session, err := sm.store.Get(r, sessionName)
|
|
if err != nil || session.IsNew {
|
|
break
|
|
}
|
|
chunks[i] = session
|
|
}
|
|
}
|
|
|
|
// SessionData represents a user's authentication session with comprehensive token management.
|
|
// It handles main session data and supports large tokens that need to be
|
|
// split across multiple cookies due to browser size limitations.
|
|
// Supports both legacy (separate cookies) and combined (single compressed cookie) storage.
|
|
type SessionData struct {
|
|
manager *SessionManager
|
|
|
|
request *http.Request
|
|
|
|
// Legacy storage (kept for backward compatibility during migration)
|
|
mainSession *sessions.Session
|
|
|
|
accessSession *sessions.Session
|
|
|
|
refreshSession *sessions.Session
|
|
|
|
idTokenSession *sessions.Session
|
|
|
|
accessTokenChunks map[int]*sessions.Session
|
|
|
|
refreshTokenChunks map[int]*sessions.Session
|
|
|
|
idTokenChunks map[int]*sessions.Session
|
|
|
|
// Combined storage (new approach - single compressed cookie)
|
|
combinedChunks map[int]*sessions.Session
|
|
|
|
// useCombinedStorage indicates whether to use the new combined storage format
|
|
useCombinedStorage bool
|
|
|
|
refreshMutex sync.Mutex
|
|
|
|
sessionMutex sync.RWMutex
|
|
|
|
dirty bool
|
|
|
|
inUse bool
|
|
|
|
// cachedClaimsToken is the ID token string whose claims were last parsed and
|
|
// cached. A lazy, per-request cache to avoid re-parsing the JWT on every
|
|
// authenticated request (e.g. for headerTemplates). Protected by sessionMutex.
|
|
cachedClaimsToken string
|
|
|
|
// cachedClaims holds the parsed claims for cachedClaimsToken.
|
|
cachedClaims map[string]interface{}
|
|
|
|
// cachedClaimsErr holds the parse error (if any) for cachedClaimsToken so
|
|
// failures are not retried within the same request.
|
|
cachedClaimsErr error
|
|
}
|
|
|
|
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
|
|
// This is used to optimize session saves by only writing when necessary.
|
|
// Returns:
|
|
// - true if the session has pending changes, false otherwise.
|
|
func (sd *SessionData) IsDirty() bool {
|
|
return sd.dirty
|
|
}
|
|
|
|
// MarkDirty marks the session as having pending changes that need to be saved.
|
|
// This is used when session data hasn't changed in content but should still
|
|
// trigger a session save (e.g., to ensure the cookie is re-issued).
|
|
func (sd *SessionData) MarkDirty() {
|
|
sd.dirty = true
|
|
}
|
|
|
|
// Save persists all session data including main session and token chunks.
|
|
// It applies security options, saves all session components, and handles
|
|
// errors gracefully by continuing to save other components even if one fails.
|
|
// Uses combined cookie storage for efficiency when useCombinedStorage is true.
|
|
// Parameters:
|
|
// - r: The HTTP request context for security option configuration.
|
|
// - w: The HTTP response writer for setting session cookies.
|
|
//
|
|
// Returns:
|
|
// - An error if saving any of the session components fails.
|
|
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
|
isSecure := r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sd.manager.forceHTTPS
|
|
|
|
options := sd.manager.getSessionOptions(isSecure)
|
|
options = sd.manager.EnhanceSessionSecurity(options, r)
|
|
|
|
// Use combined storage for new sessions
|
|
if sd.useCombinedStorage {
|
|
return sd.saveCombined(r, w, options)
|
|
}
|
|
|
|
// Legacy storage path (for backward compatibility)
|
|
return sd.saveLegacy(r, w, options)
|
|
}
|
|
|
|
// saveCombined saves all session data in a single compressed, chunked cookie.
|
|
// This reduces cookie count and total size through combined compression.
|
|
func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, options *sessions.Options) error {
|
|
// Build the combined payload
|
|
payload := &combinedSessionPayload{
|
|
A: sd.getAccessTokenUnsafe(),
|
|
R: sd.getRefreshTokenUnsafe(),
|
|
I: sd.getIDTokenUnsafe(),
|
|
Ui: sd.getUserIdentifierUnsafe(),
|
|
Au: sd.getAuthenticatedUnsafe(),
|
|
Cs: sd.getCSRFUnsafe(),
|
|
N: sd.getNonceUnsafe(),
|
|
Cv: sd.getCodeVerifierUnsafe(),
|
|
Ip: sd.getIncomingPathUnsafe(),
|
|
Ca: sd.getCreatedAtUnsafe(),
|
|
Rc: sd.getRedirectCountUnsafe(),
|
|
}
|
|
|
|
// Collect extra session values not handled by the standard fields
|
|
sd.sessionMutex.RLock()
|
|
if sd.mainSession != nil && len(sd.mainSession.Values) > 0 {
|
|
extra := make(map[string]interface{})
|
|
for key, val := range sd.mainSession.Values {
|
|
keyStr, ok := key.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
// Skip known session keys that are already in the payload
|
|
if knownSessionKeys[keyStr] {
|
|
continue
|
|
}
|
|
// Store the extra value (must be JSON-serializable)
|
|
extra[keyStr] = val
|
|
}
|
|
if len(extra) > 0 {
|
|
payload.X = extra
|
|
}
|
|
}
|
|
sd.sessionMutex.RUnlock()
|
|
|
|
// Compress the payload
|
|
compressed, err := compressCombinedPayload(payload)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("Failed to compress combined payload: %v", err)
|
|
// Fall back to legacy storage on compression failure
|
|
return sd.saveLegacy(r, w, options)
|
|
}
|
|
|
|
sd.manager.logger.Debugf("Combined session: raw payload compressed to %d bytes", len(compressed))
|
|
|
|
// Split into chunks
|
|
chunks := splitCombinedIntoChunks(compressed, maxCookieSize)
|
|
if len(chunks) > maxCombinedChunks {
|
|
sd.manager.logger.Errorf("Combined session requires %d chunks, exceeds max %d", len(chunks), maxCombinedChunks)
|
|
return fmt.Errorf("session data too large: requires %d chunks, max is %d", len(chunks), maxCombinedChunks)
|
|
}
|
|
|
|
sd.manager.logger.Debugf("Combined session split into %d chunks", len(chunks))
|
|
|
|
var firstErr error
|
|
|
|
// Save each chunk
|
|
for i, chunkData := range chunks {
|
|
cookieName := sd.manager.combinedChunkCookieName(i)
|
|
session, err := sd.manager.store.Get(r, cookieName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("Failed to get combined chunk session %s: %v", cookieName, err)
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
continue
|
|
}
|
|
|
|
session.Values["d"] = chunkData // "d" for data
|
|
session.Values["n"] = len(chunks) // "n" for total number of chunks
|
|
session.Values["i"] = i // "i" for index
|
|
session.Options = options
|
|
|
|
if err := session.Save(r, w); err != nil {
|
|
sd.manager.logger.Errorf("Failed to save combined chunk %d: %v", i, err)
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
sd.combinedChunks[i] = session
|
|
}
|
|
|
|
// Expire old combined chunks that are no longer needed
|
|
sd.expireOldCombinedChunks(r, w, options, len(chunks))
|
|
|
|
// Expire legacy cookies if they exist (migration)
|
|
sd.expireLegacyCookies(r, w, options)
|
|
|
|
if firstErr == nil {
|
|
sd.dirty = false
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
// saveLegacy saves session data using the old separate cookie approach.
|
|
// Kept for backward compatibility during migration.
|
|
func (sd *SessionData) saveLegacy(r *http.Request, w http.ResponseWriter, options *sessions.Options) error {
|
|
sd.mainSession.Options = options
|
|
sd.accessSession.Options = options
|
|
sd.refreshSession.Options = options
|
|
sd.idTokenSession.Options = options
|
|
|
|
var firstErr error
|
|
saveOrLogError := func(s *sessions.Session, name string) {
|
|
if s == nil {
|
|
sd.manager.logger.Errorf("Attempted to save nil session: %s", name)
|
|
if firstErr == nil {
|
|
firstErr = fmt.Errorf("attempted to save nil session: %s", name)
|
|
}
|
|
return
|
|
}
|
|
if err := s.Save(r, w); err != nil {
|
|
errMsg := fmt.Errorf("failed to save %s session: %w", name, err)
|
|
sd.manager.logger.Error("%s", errMsg.Error())
|
|
if firstErr == nil {
|
|
firstErr = errMsg
|
|
}
|
|
}
|
|
}
|
|
|
|
saveOrLogError(sd.mainSession, "main")
|
|
saveOrLogError(sd.accessSession, "access token")
|
|
saveOrLogError(sd.refreshSession, "refresh token")
|
|
saveOrLogError(sd.idTokenSession, "ID token")
|
|
|
|
for i, sessionChunk := range sd.accessTokenChunks {
|
|
sessionChunk.Options = options
|
|
saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i))
|
|
}
|
|
|
|
for i, sessionChunk := range sd.refreshTokenChunks {
|
|
sessionChunk.Options = options
|
|
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
|
|
}
|
|
|
|
for i, sessionChunk := range sd.idTokenChunks {
|
|
sessionChunk.Options = options
|
|
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
|
|
}
|
|
|
|
if firstErr == nil {
|
|
sd.dirty = false
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
// expireOldCombinedChunks expires combined cookie chunks that are no longer needed.
|
|
func (sd *SessionData) expireOldCombinedChunks(r *http.Request, w http.ResponseWriter, options *sessions.Options, currentChunks int) {
|
|
// Expire chunks beyond the current count
|
|
for i := currentChunks; i < maxCombinedChunks; i++ {
|
|
cookieName := sd.manager.combinedChunkCookieName(i)
|
|
session, err := sd.manager.store.Get(r, cookieName)
|
|
if err != nil || session.IsNew {
|
|
// No more old chunks
|
|
break
|
|
}
|
|
// Expire this chunk
|
|
expireOptions := *options
|
|
expireOptions.MaxAge = -1
|
|
session.Options = &expireOptions
|
|
for k := range session.Values {
|
|
delete(session.Values, k)
|
|
}
|
|
if err := session.Save(r, w); err != nil {
|
|
sd.manager.logger.Debugf("Failed to expire old combined chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// expireLegacyCookies expires old legacy format cookies during migration.
|
|
func (sd *SessionData) expireLegacyCookies(r *http.Request, w http.ResponseWriter, options *sessions.Options) {
|
|
expireOptions := *options
|
|
expireOptions.MaxAge = -1
|
|
|
|
// Helper to expire a legacy session cookie without clearing in-memory values
|
|
// IMPORTANT: We must NOT clear values from sessions that sd is holding,
|
|
// as store.Get() returns the same cached session object
|
|
expireLegacyChunk := func(cookieName string) {
|
|
session, err := sd.manager.store.Get(r, cookieName)
|
|
if err != nil || session.IsNew {
|
|
return // Cookie doesn't exist
|
|
}
|
|
session.Options = &expireOptions
|
|
// Clear values from chunk cookies (not the main session objects)
|
|
for k := range session.Values {
|
|
delete(session.Values, k)
|
|
}
|
|
_ = session.Save(r, w) // Best effort
|
|
}
|
|
|
|
// For main session cookies, only set expiration WITHOUT clearing values
|
|
// because sd.mainSession, sd.accessSession, etc. point to these same objects
|
|
expireLegacyMain := func(cookieName string) {
|
|
session, err := sd.manager.store.Get(r, cookieName)
|
|
if err != nil || session.IsNew {
|
|
return // Cookie doesn't exist
|
|
}
|
|
// Just expire the cookie, don't clear values (they're still needed in memory)
|
|
session.Options = &expireOptions
|
|
_ = session.Save(r, w) // Best effort
|
|
}
|
|
|
|
// Expire main legacy cookies (don't clear in-memory values)
|
|
expireLegacyMain(sd.manager.mainCookieName())
|
|
expireLegacyMain(sd.manager.accessTokenCookieName())
|
|
expireLegacyMain(sd.manager.refreshTokenCookieName())
|
|
expireLegacyMain(sd.manager.idTokenCookieName())
|
|
|
|
// Expire legacy chunk cookies (safe to clear values, they're separate from main sessions)
|
|
for i := 0; i < 50; i++ { // Max legacy chunks was 50
|
|
accessChunk := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
|
|
refreshChunk := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
|
|
idChunk := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
|
|
|
|
session, err := sd.manager.store.Get(r, accessChunk)
|
|
if err != nil || session.IsNew {
|
|
break // No more chunks
|
|
}
|
|
expireLegacyChunk(accessChunk)
|
|
expireLegacyChunk(refreshChunk)
|
|
expireLegacyChunk(idChunk)
|
|
}
|
|
}
|
|
|
|
// clearSessionValues removes all values from a session and optionally expires it.
|
|
// This is used during session cleanup and logout operations.
|
|
// Parameters:
|
|
// - session: The session to clear values from.
|
|
// - expire: If true, sets MaxAge to -1 to expire the cookie.
|
|
func clearSessionValues(session *sessions.Session, expire bool) {
|
|
if session == nil {
|
|
return
|
|
}
|
|
|
|
for k := range session.Values {
|
|
delete(session.Values, k)
|
|
}
|
|
|
|
if expire {
|
|
session.Options.MaxAge = -1
|
|
}
|
|
}
|
|
|
|
// clearAllSessionData clears all session data including main session and token chunks.
|
|
// It removes all session values and optionally expires all associated cookies.
|
|
// Parameters:
|
|
// - r: The HTTP request context (used for chunk cleanup).
|
|
// - expire: Whether to expire the cookies (set MaxAge to -1).
|
|
func (sd *SessionData) clearAllSessionData(r *http.Request, expire bool) {
|
|
clearSessionValues(sd.mainSession, expire)
|
|
clearSessionValues(sd.accessSession, expire)
|
|
clearSessionValues(sd.refreshSession, expire)
|
|
clearSessionValues(sd.idTokenSession, expire)
|
|
|
|
if expire && r != nil {
|
|
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
|
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
|
sd.clearTokenChunks(r, sd.idTokenChunks)
|
|
} else {
|
|
for k := range sd.accessTokenChunks {
|
|
delete(sd.accessTokenChunks, k)
|
|
}
|
|
for k := range sd.refreshTokenChunks {
|
|
delete(sd.refreshTokenChunks, k)
|
|
}
|
|
for k := range sd.idTokenChunks {
|
|
delete(sd.idTokenChunks, k)
|
|
}
|
|
}
|
|
|
|
if expire {
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// Clear completely clears all session data and safely returns the session to the pool.
|
|
// It removes all authentication data, expires cookies, and handles panic recovery.
|
|
// This method ensures the SessionData object is always returned to the pool.
|
|
// Parameters:
|
|
// - r: The HTTP request context.
|
|
// - w: The HTTP response writer for cookie expiration (can be nil).
|
|
//
|
|
// Returns:
|
|
// - An error if session saving fails during cleanup.
|
|
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
|
defer func() {
|
|
sd.returnToPoolSafely()
|
|
}()
|
|
|
|
sd.sessionMutex.Lock()
|
|
sd.clearAllSessionData(r, true)
|
|
|
|
// Release the lock before calling Save to prevent deadlock
|
|
sd.sessionMutex.Unlock()
|
|
|
|
// This is primarily for testing - in production w will often be nil
|
|
var err error
|
|
if w != nil {
|
|
if r != nil && r.Header.Get("X-Test-Error") == "true" {
|
|
// Return a test error without trying to save problematic data
|
|
err = fmt.Errorf("test error triggered by X-Test-Error header")
|
|
} else {
|
|
err = sd.Save(r, w)
|
|
}
|
|
}
|
|
|
|
sd.request = nil
|
|
|
|
return err
|
|
}
|
|
|
|
// returnToPoolSafely safely returns the session to the object pool.
|
|
// Add thread-safe helper method to return session to pool.
|
|
// It ensures the session is marked as not in use and properly reset before pooling.
|
|
func (sd *SessionData) returnToPoolSafely() {
|
|
if sd != nil && sd.manager != nil {
|
|
if sd.inUse {
|
|
sd.inUse = false
|
|
sd.Reset()
|
|
sd.manager.sessionPool.Put(sd)
|
|
atomic.AddInt64(&sd.manager.activeSessions, -1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// clearTokenChunks clears and expires all token chunk sessions.
|
|
// This is used during logout and session cleanup to ensure
|
|
// all token data is properly removed from the client.
|
|
// Parameters:
|
|
// - r: The HTTP request context.
|
|
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
|
|
func (sd *SessionData) clearTokenChunks(_ *http.Request, chunks map[int]*sessions.Session) {
|
|
for _, session := range chunks {
|
|
clearSessionValues(session, true)
|
|
}
|
|
}
|
|
|
|
// GetAuthenticated returns whether the user is currently authenticated.
|
|
// It checks both the authentication flag and session timeout.
|
|
// Returns:
|
|
// - true if the user is authenticated and the session is not expired.
|
|
// - false otherwise.
|
|
func (sd *SessionData) GetAuthenticated() bool {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
return sd.getAuthenticatedUnsafe()
|
|
}
|
|
|
|
// getAuthenticatedUnsafe checks authentication status without acquiring locks.
|
|
// Used when the mutex is already held to avoid deadlocks.
|
|
// It validates both the authentication flag and session creation time.
|
|
// Returns:
|
|
// - true if authenticated and not expired, false otherwise.
|
|
func (sd *SessionData) getAuthenticatedUnsafe() bool {
|
|
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
|
if !auth {
|
|
return false
|
|
}
|
|
|
|
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return time.Since(time.Unix(createdAt, 0)) <= sd.manager.sessionMaxAge
|
|
}
|
|
|
|
// SetAuthenticated sets the authentication status and manages session security.
|
|
// When setting to true, it generates a new secure session ID and updates timestamps.
|
|
// This prevents session fixation attacks by regenerating the session identifier.
|
|
// Parameters:
|
|
// - value: The authentication status to set.
|
|
//
|
|
// Returns:
|
|
// - An error if generating a new session ID fails when setting value to true.
|
|
func (sd *SessionData) SetAuthenticated(value bool) error {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
currentAuth := sd.getAuthenticatedUnsafe()
|
|
changed := false
|
|
|
|
if currentAuth != value {
|
|
changed = true
|
|
}
|
|
|
|
if value {
|
|
id, err := generateSecureRandomString(64)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate secure session id: %w", err)
|
|
}
|
|
|
|
maxRetries := 5
|
|
for retry := 0; retry < maxRetries; retry++ {
|
|
if sd.mainSession.ID != id {
|
|
break
|
|
}
|
|
id, err = generateSecureRandomString(64)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate secure session id on retry %d: %w", retry, err)
|
|
}
|
|
}
|
|
|
|
if sd.mainSession.ID != id {
|
|
changed = true
|
|
}
|
|
sd.mainSession.ID = id
|
|
newCreationTime := time.Now().Unix()
|
|
if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime {
|
|
changed = true
|
|
}
|
|
sd.mainSession.Values["created_at"] = newCreationTime
|
|
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
|
|
changed = true
|
|
}
|
|
} else {
|
|
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
|
|
changed = true
|
|
}
|
|
}
|
|
|
|
sd.mainSession.Values["authenticated"] = value
|
|
if changed {
|
|
sd.dirty = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// resetSession prepares a session for reuse by clearing its state.
|
|
// This is specifically for pool reuse preparation to ensure
|
|
// no data leaks between different user sessions.
|
|
// Parameters:
|
|
// - session: The session to reset for reuse.
|
|
func resetSession(session *sessions.Session) {
|
|
if session == nil {
|
|
return
|
|
}
|
|
|
|
clearSessionValues(session, false)
|
|
|
|
session.ID = ""
|
|
session.IsNew = true
|
|
}
|
|
|
|
// Reset clears all session data and prepares the SessionData for reuse.
|
|
// It ensures no authentication data persists when the object is reused
|
|
// between different users/sessions.
|
|
func (sd *SessionData) Reset() {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
sd.clearAllSessionData(nil, false)
|
|
|
|
resetSession(sd.mainSession)
|
|
resetSession(sd.accessSession)
|
|
resetSession(sd.refreshSession)
|
|
resetSession(sd.idTokenSession)
|
|
|
|
// Clear combined chunks
|
|
for k, session := range sd.combinedChunks {
|
|
resetSession(session)
|
|
delete(sd.combinedChunks, k)
|
|
}
|
|
|
|
// Clear redirect count to prevent leaking between sessions
|
|
if sd.mainSession != nil && sd.mainSession.Values != nil {
|
|
delete(sd.mainSession.Values, "redirect_count")
|
|
}
|
|
|
|
sd.dirty = false
|
|
sd.inUse = false
|
|
sd.request = nil
|
|
sd.useCombinedStorage = true // Reset to use combined storage by default
|
|
|
|
// Drop any cached claims so pooled SessionData does not leak claim data
|
|
// between requests/users.
|
|
sd.cachedClaimsToken = ""
|
|
sd.cachedClaims = nil
|
|
sd.cachedClaimsErr = nil
|
|
|
|
// Reset the refresh mutex to ensure clean state
|
|
// Note: We don't need to lock it since sessionMutex is already held
|
|
// and this session is not in use by any request
|
|
}
|
|
|
|
// ReturnToPool manually returns the session to the object pool.
|
|
// This is used in cleanup paths where Clear() is not called, to prevent memory leaks.
|
|
// It only returns the session if it's not currently in use.
|
|
func (sd *SessionData) ReturnToPool() {
|
|
if sd != nil && sd.manager != nil {
|
|
if !sd.inUse {
|
|
sd.Reset()
|
|
sd.manager.sessionPool.Put(sd)
|
|
atomic.AddInt64(&sd.manager.activeSessions, -1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetAccessToken retrieves the user's access token from session storage.
|
|
// It handles both single-cookie storage and chunked storage for large tokens,
|
|
// with automatic decompression if the token was compressed.
|
|
// Returns:
|
|
// - The complete, decompressed access token string, or an empty string if not found.
|
|
func (sd *SessionData) GetAccessToken() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
return sd.getAccessTokenUnsafe()
|
|
}
|
|
|
|
// getAccessTokenUnsafe retrieves the access token without acquiring locks.
|
|
// Enhanced token retrieval with comprehensive integrity checks and recovery mechanisms.
|
|
// Used when the session mutex is already held to prevent deadlocks.
|
|
// Returns:
|
|
// - The complete access token string or empty string on error.
|
|
func (sd *SessionData) getAccessTokenUnsafe() string {
|
|
token, _ := sd.accessSession.Values["token"].(string)
|
|
compressed, _ := sd.accessSession.Values["compressed"].(bool)
|
|
|
|
// Debug: Check if manager/chunkManager is nil
|
|
if sd.manager == nil || sd.manager.chunkManager == nil {
|
|
// Direct return if no chunk manager (test scenario)
|
|
return token
|
|
}
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.accessTokenChunks,
|
|
AccessTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
// Check if we have a raw token available
|
|
// This handles cases where the token exists but doesn't validate as JWT
|
|
if token != "" && !compressed && len(sd.accessTokenChunks) == 0 {
|
|
// We have a non-chunked, non-compressed token that failed validation
|
|
// Check if it's an opaque token (doesn't have JWT structure)
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
// This is likely an opaque token that failed JWT validation
|
|
// Return the raw token as-is since opaque tokens are valid
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Returning opaque access token (dots: %d) despite validation error: %v", dotCount, result.Error)
|
|
}
|
|
return token
|
|
}
|
|
}
|
|
|
|
// For JWT validation errors or other issues, log and return empty
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("ChunkManager.GetToken error: %v", result.Error)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// SetAccessToken stores an access token with automatic compression and chunking.
|
|
// It validates token format, compresses if beneficial, and splits into chunks
|
|
// if the token exceeds cookie size limits. Includes integrity verification.
|
|
// Parameters:
|
|
// - token: The access token string to store.
|
|
func (sd *SessionData) SetAccessToken(token string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
if token != "" {
|
|
if len(token) < 20 {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
currentAccessToken := sd.getAccessTokenUnsafe()
|
|
if constantTimeStringCompare(currentAccessToken, token) {
|
|
return
|
|
}
|
|
sd.dirty = true
|
|
|
|
// Debug: Check if accessSession is properly initialized
|
|
if sd.accessSession == nil {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: accessSession is nil when trying to store token")
|
|
}
|
|
return
|
|
}
|
|
|
|
if sd.request != nil {
|
|
sd.expireAccessTokenChunksEnhanced(nil)
|
|
}
|
|
|
|
for k := range sd.accessTokenChunks {
|
|
delete(sd.accessTokenChunks, k)
|
|
}
|
|
|
|
if token == "" {
|
|
if sd.accessSession != nil {
|
|
sd.accessSession.Values["token"] = ""
|
|
sd.accessSession.Values["compressed"] = false
|
|
}
|
|
return
|
|
}
|
|
|
|
compressed := compressToken(token)
|
|
|
|
// Debug for test
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Token compression: original %d bytes, compressed %d bytes", len(token), len(compressed))
|
|
}
|
|
|
|
if len(compressed) > 100*1024 {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Info("Access token too large after compression (%d bytes) - storing uncompressed", len(compressed))
|
|
}
|
|
return
|
|
}
|
|
|
|
if compressed != token {
|
|
testDecompressed := decompressToken(compressed)
|
|
if testDecompressed != token {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debug("Access token compression verification failed - storing uncompressed")
|
|
}
|
|
compressed = token
|
|
}
|
|
}
|
|
|
|
if len(compressed) <= maxCookieSize {
|
|
if sd.accessSession != nil {
|
|
sd.accessSession.Values["token"] = compressed
|
|
sd.accessSession.Values["compressed"] = (compressed != token)
|
|
// Debug for test
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Stored token in session: compressed=%v, token_len=%d",
|
|
compressed != token, len(compressed))
|
|
}
|
|
}
|
|
} else {
|
|
if sd.accessSession != nil {
|
|
sd.accessSession.Values["token"] = ""
|
|
sd.accessSession.Values["compressed"] = (compressed != token)
|
|
}
|
|
|
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
|
|
|
if len(chunks) == 0 {
|
|
sd.manager.logger.Error("Failed to create chunks for access token")
|
|
return
|
|
}
|
|
|
|
if len(chunks) > 50 {
|
|
sd.manager.logger.Info("Too many chunks (%d) for access token", len(chunks))
|
|
return
|
|
}
|
|
|
|
testReassembled := strings.Join(chunks, "")
|
|
if testReassembled != compressed {
|
|
sd.manager.logger.Debug("Access token chunk reassembly test failed")
|
|
return
|
|
}
|
|
|
|
for i, chunkData := range chunks {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
|
|
|
|
if sd.request == nil {
|
|
sd.manager.logger.Error("SetAccessToken: sd.request is nil, cannot create chunk session %s", sessionName)
|
|
return
|
|
}
|
|
|
|
if chunkData == "" {
|
|
sd.manager.logger.Debug("Empty chunk data at index %d", i)
|
|
return
|
|
}
|
|
|
|
if len(chunkData) > maxCookieSize {
|
|
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
|
|
return
|
|
}
|
|
|
|
if !validateChunkSize(chunkData) {
|
|
sd.manager.logger.Errorf("CRITICAL: Chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
|
|
return
|
|
}
|
|
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
|
|
return
|
|
}
|
|
|
|
session.Values["token_chunk"] = chunkData
|
|
session.Values["compressed"] = (compressed != token)
|
|
session.Values["chunk_created_at"] = time.Now().Unix()
|
|
sd.accessTokenChunks[i] = session
|
|
}
|
|
|
|
sd.manager.logger.Debugf("SUCCESS: Stored access token in %d chunks", len(chunks))
|
|
}
|
|
}
|
|
|
|
// GetRefreshToken retrieves the user's refresh token from session storage.
|
|
// It handles both single-cookie storage and chunked storage for large tokens,
|
|
// with automatic decompression if the token was compressed.
|
|
// Returns:
|
|
// - The complete, decompressed refresh token string, or an empty string if not found.
|
|
func (sd *SessionData) GetRefreshToken() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
token, _ := sd.refreshSession.Values["token"].(string)
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.refreshTokenChunks,
|
|
RefreshTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
// Check if we have a raw token available
|
|
// This handles cases where the token exists but doesn't validate as JWT
|
|
if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 {
|
|
// We have a non-chunked, non-compressed token that failed validation
|
|
// Check if it's an opaque token (doesn't have JWT structure)
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
// This is likely an opaque token that failed JWT validation
|
|
// Return the raw token as-is since opaque tokens are valid
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Returning opaque refresh token (dots: %d) despite validation error: %v", dotCount, result.Error)
|
|
}
|
|
return token
|
|
}
|
|
}
|
|
// For JWT validation errors or other issues, log and return empty
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("ChunkManager.GetToken error for refresh token: %v", result.Error)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// SetRefreshToken stores a refresh token with automatic compression and chunking.
|
|
// It validates token size, compresses if beneficial, and splits into chunks
|
|
// if needed. Includes comprehensive error checking and integrity verification.
|
|
// Parameters:
|
|
// - token: The refresh token string to store.
|
|
func (sd *SessionData) SetRefreshToken(token string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
if len(token) > 50*1024 {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token too large (%d bytes) - possible corruption, rejecting", len(token))
|
|
return
|
|
}
|
|
|
|
// Get current refresh token without mutex to avoid deadlock since we already hold the lock
|
|
var currentRefreshToken string
|
|
sessionToken, _ := sd.refreshSession.Values["token"].(string)
|
|
if sessionToken != "" {
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
if compressed {
|
|
decompressed := decompressToken(sessionToken)
|
|
currentRefreshToken = decompressed
|
|
} else {
|
|
currentRefreshToken = sessionToken
|
|
}
|
|
} else if len(sd.refreshTokenChunks) > 0 {
|
|
// Simplified chunked token retrieval for deadlock prevention
|
|
var chunks []string
|
|
for i := 0; i < len(sd.refreshTokenChunks); i++ {
|
|
if session, ok := sd.refreshTokenChunks[i]; ok {
|
|
if chunk, chunkOk := session.Values["token_chunk"].(string); chunkOk && chunk != "" {
|
|
chunks = append(chunks, chunk)
|
|
}
|
|
}
|
|
}
|
|
if len(chunks) == len(sd.refreshTokenChunks) {
|
|
reassembled := strings.Join(chunks, "")
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
if compressed {
|
|
currentRefreshToken = decompressToken(reassembled)
|
|
} else {
|
|
currentRefreshToken = reassembled
|
|
}
|
|
}
|
|
}
|
|
if constantTimeStringCompare(currentRefreshToken, token) {
|
|
return
|
|
}
|
|
sd.dirty = true
|
|
|
|
if sd.request != nil {
|
|
sd.expireRefreshTokenChunksEnhanced(nil)
|
|
}
|
|
|
|
for k := range sd.refreshTokenChunks {
|
|
delete(sd.refreshTokenChunks, k)
|
|
}
|
|
|
|
if token == "" {
|
|
sd.refreshSession.Values["token"] = ""
|
|
sd.refreshSession.Values["compressed"] = false
|
|
return
|
|
}
|
|
|
|
compressed := compressToken(token)
|
|
|
|
if compressed != token {
|
|
testDecompressed := decompressToken(compressed)
|
|
if testDecompressed != token {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token compression verification failed - storing uncompressed")
|
|
compressed = token
|
|
}
|
|
}
|
|
|
|
if len(compressed) <= maxCookieSize {
|
|
sd.refreshSession.Values["token"] = compressed
|
|
sd.refreshSession.Values["compressed"] = (compressed != token)
|
|
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
|
|
} else {
|
|
sd.refreshSession.Values["token"] = ""
|
|
sd.refreshSession.Values["compressed"] = (compressed != token)
|
|
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
|
|
|
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
|
|
|
if len(chunks) == 0 {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for refresh token")
|
|
return
|
|
}
|
|
|
|
if len(chunks) > 50 {
|
|
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for refresh token - possible corruption", len(chunks))
|
|
return
|
|
}
|
|
|
|
testReassembled := strings.Join(chunks, "")
|
|
if testReassembled != compressed {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk reassembly test failed")
|
|
return
|
|
}
|
|
|
|
for i, chunkData := range chunks {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
|
|
|
|
if sd.request == nil {
|
|
sd.manager.logger.Errorf("CRITICAL: SetRefreshToken: sd.request is nil, cannot create chunk session %s", sessionName)
|
|
return
|
|
}
|
|
|
|
if chunkData == "" {
|
|
sd.manager.logger.Errorf("CRITICAL: Empty refresh token chunk data at index %d", i)
|
|
return
|
|
}
|
|
|
|
if len(chunkData) > maxCookieSize {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
|
|
return
|
|
}
|
|
|
|
if !validateChunkSize(chunkData) {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
|
|
return
|
|
}
|
|
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to get refresh token chunk session %s: %v", sessionName, err)
|
|
return
|
|
}
|
|
|
|
session.Values["token_chunk"] = chunkData
|
|
session.Values["compressed"] = (compressed != token)
|
|
session.Values["chunk_created_at"] = time.Now().Unix()
|
|
sd.refreshTokenChunks[i] = session
|
|
}
|
|
|
|
sd.manager.logger.Debugf("SUCCESS: Stored refresh token in %d chunks", len(chunks))
|
|
}
|
|
}
|
|
|
|
// GetRefreshTokenIssuedAt retrieves the timestamp when the refresh token was issued/stored.
|
|
// Returns the time when the current refresh token was obtained, or zero time if not available.
|
|
func (sd *SessionData) GetRefreshTokenIssuedAt() time.Time {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
if issuedAtUnix, ok := sd.refreshSession.Values["issued_at"].(int64); ok {
|
|
return time.Unix(issuedAtUnix, 0)
|
|
}
|
|
|
|
// For chunked tokens, check the first chunk for timestamp
|
|
if len(sd.refreshTokenChunks) > 0 {
|
|
if session, exists := sd.refreshTokenChunks[0]; exists {
|
|
if chunkCreatedAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
return time.Unix(chunkCreatedAt, 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
return time.Time{}
|
|
}
|
|
|
|
// expireAccessTokenChunksEnhanced expires all access token chunks and detects orphaned chunks.
|
|
// It searches for all existing chunks, identifies orphaned or expired chunks,
|
|
// and properly expires them to prevent cookie accumulation.
|
|
// Parameters:
|
|
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
|
func (sd *SessionData) expireAccessTokenChunksEnhanced(w http.ResponseWriter) {
|
|
const maxChunkSearchLimit = 50
|
|
orphanedChunks := 0
|
|
|
|
for i := 0; i < maxChunkSearchLimit; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if session.IsNew {
|
|
break
|
|
}
|
|
|
|
if chunk, exists := session.Values["token_chunk"]; exists {
|
|
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
chunkAge := time.Since(time.Unix(createdAt, 0))
|
|
if chunkAge > 24*time.Hour {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found orphaned access token chunk %d (age: %v)", i, chunkAge)
|
|
}
|
|
} else if chunk != nil {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found access token chunk %d without timestamp, treating as orphaned", i)
|
|
}
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
session.Values = make(map[interface{}]interface{})
|
|
if w != nil {
|
|
if err := session.Save(sd.request, w); err != nil {
|
|
sd.manager.logger.Errorf("failed to save expired access token chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if orphanedChunks > 0 {
|
|
sd.manager.logger.Infof("Cleaned up %d orphaned access token chunks", orphanedChunks)
|
|
}
|
|
}
|
|
|
|
// expireRefreshTokenChunksEnhanced expires all refresh token chunks and detects orphaned chunks.
|
|
// It searches for all existing chunks, identifies orphaned or expired chunks,
|
|
// and properly expires them to prevent cookie accumulation.
|
|
// Parameters:
|
|
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
|
func (sd *SessionData) expireRefreshTokenChunksEnhanced(w http.ResponseWriter) {
|
|
const maxChunkSearchLimit = 50
|
|
orphanedChunks := 0
|
|
|
|
for i := 0; i < maxChunkSearchLimit; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if session.IsNew {
|
|
break
|
|
}
|
|
|
|
if chunk, exists := session.Values["token_chunk"]; exists {
|
|
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
chunkAge := time.Since(time.Unix(createdAt, 0))
|
|
if chunkAge > 24*time.Hour {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found orphaned refresh token chunk %d (age: %v)", i, chunkAge)
|
|
}
|
|
} else if chunk != nil {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found refresh token chunk %d without timestamp, treating as orphaned", i)
|
|
}
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
session.Values = make(map[interface{}]interface{})
|
|
if w != nil {
|
|
if err := session.Save(sd.request, w); err != nil {
|
|
sd.manager.logger.Errorf("failed to save expired refresh token chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if orphanedChunks > 0 {
|
|
sd.manager.logger.Infof("Cleaned up %d orphaned refresh token chunks", orphanedChunks)
|
|
}
|
|
}
|
|
|
|
// expireIDTokenChunksEnhanced expires all ID token chunks and detects orphaned chunks.
|
|
// It searches for all existing chunks, identifies orphaned or expired chunks,
|
|
// and properly expires them to prevent cookie accumulation.
|
|
// Parameters:
|
|
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
|
func (sd *SessionData) expireIDTokenChunksEnhanced(w http.ResponseWriter) {
|
|
const maxChunkSearchLimit = 50
|
|
orphanedChunks := 0
|
|
|
|
for i := 0; i < maxChunkSearchLimit; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if session.IsNew {
|
|
break
|
|
}
|
|
|
|
if chunk, exists := session.Values["token_chunk"]; exists {
|
|
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
chunkAge := time.Since(time.Unix(createdAt, 0))
|
|
if chunkAge > 24*time.Hour {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found orphaned ID token chunk %d (age: %v)", i, chunkAge)
|
|
}
|
|
} else if chunk != nil {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found ID token chunk %d without timestamp, treating as orphaned", i)
|
|
}
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
session.Values = make(map[interface{}]interface{})
|
|
if w != nil {
|
|
if err := session.Save(sd.request, w); err != nil {
|
|
sd.manager.logger.Errorf("failed to save expired ID token chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if orphanedChunks > 0 {
|
|
sd.manager.logger.Infof("Cleaned up %d orphaned ID token chunks", orphanedChunks)
|
|
}
|
|
}
|
|
|
|
// splitIntoChunks divides a string into chunks of specified maximum size.
|
|
// It ensures chunks don't exceed browser cookie limits and handles
|
|
// the string splitting logic for large token storage.
|
|
// Parameters:
|
|
// - s: The string to split into chunks.
|
|
// - chunkSize: The maximum size for each chunk.
|
|
//
|
|
// Returns:
|
|
// - A slice of strings representing the chunks.
|
|
func splitIntoChunks(s string, chunkSize int) []string {
|
|
effectiveChunkSize := min(chunkSize, maxCookieSize)
|
|
|
|
var chunks []string
|
|
for len(s) > 0 {
|
|
if len(s) > effectiveChunkSize {
|
|
chunks = append(chunks, s[:effectiveChunkSize])
|
|
s = s[effectiveChunkSize:]
|
|
} else {
|
|
chunks = append(chunks, s)
|
|
break
|
|
}
|
|
}
|
|
return chunks
|
|
}
|
|
|
|
// validateChunkSize checks if a chunk will fit within browser cookie limits.
|
|
// It estimates the encoded size including cookie overhead and headers
|
|
// to ensure the chunk won't exceed browser-imposed cookie size limits.
|
|
// Parameters:
|
|
// - chunkData: The chunk data to validate.
|
|
//
|
|
// Returns:
|
|
// - true if the chunk is safe to store, false if it may exceed browser limits.
|
|
func validateChunkSize(chunkData string) bool {
|
|
// Estimate ~50% overhead for encoding, compare against ~2x maxCookieSize limit
|
|
estimatedEncodedSize := len(chunkData) + (len(chunkData) * 50 / 100)
|
|
return estimatedEncodedSize <= maxCookieSize*2
|
|
}
|
|
|
|
// isCorruptionMarker detects if data contains known corruption indicators.
|
|
// It checks for specific corruption markers and invalid characters
|
|
// that indicate the data has been tampered with or corrupted.
|
|
// Parameters:
|
|
// - data: The data string to check for corruption markers.
|
|
//
|
|
// Returns:
|
|
// - true if the data contains corruption markers, false otherwise.
|
|
func isCorruptionMarker(data string) bool {
|
|
if data == "" {
|
|
return false
|
|
}
|
|
|
|
corruptionMarkers := []string{
|
|
"__CORRUPTION_MARKER_TEST__",
|
|
"__INVALID_BASE64_DATA__",
|
|
"__CORRUPTED_CHUNK_DATA__",
|
|
"!@#$%^&*()",
|
|
"<<<CORRUPTED>>>",
|
|
}
|
|
|
|
for _, marker := range corruptionMarkers {
|
|
if data == marker {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if len(data) > 10 {
|
|
invalidChars := "!@#$%^&*(){}[]|\\:;\"'<>?,`~"
|
|
for _, char := range invalidChars {
|
|
if strings.ContainsRune(data, char) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// GetCSRF retrieves the CSRF token for state validation.
|
|
// This token is used to prevent cross-site request forgery attacks
|
|
// during the OIDC authentication flow.
|
|
// Returns:
|
|
// - The CSRF token string, or an empty string if not set.
|
|
func (sd *SessionData) GetCSRF() string {
|
|
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
|
return csrf
|
|
}
|
|
|
|
// SetCSRF stores the CSRF token for state validation.
|
|
// The token is used to validate the state parameter in OAuth callbacks.
|
|
// Parameters:
|
|
// - token: The CSRF token to store.
|
|
func (sd *SessionData) SetCSRF(token string) {
|
|
currentVal, _ := sd.mainSession.Values["csrf"].(string)
|
|
if currentVal != token {
|
|
sd.mainSession.Values["csrf"] = token
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetNonce retrieves the nonce for ID token validation.
|
|
// The nonce prevents replay attacks by ensuring ID tokens
|
|
// were issued in response to the specific authentication request.
|
|
// Returns:
|
|
// - The nonce string, or an empty string if not set.
|
|
func (sd *SessionData) GetNonce() string {
|
|
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
|
return nonce
|
|
}
|
|
|
|
// SetNonce stores the nonce for ID token validation.
|
|
// The nonce will be validated against the nonce claim in received ID tokens.
|
|
// Parameters:
|
|
// - nonce: The nonce string to store.
|
|
func (sd *SessionData) SetNonce(nonce string) {
|
|
currentVal, _ := sd.mainSession.Values["nonce"].(string)
|
|
if currentVal != nonce {
|
|
sd.mainSession.Values["nonce"] = nonce
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetCodeVerifier retrieves the PKCE code verifier.
|
|
// This is used in the PKCE (Proof Key for Code Exchange) flow
|
|
// to enhance security for public clients.
|
|
// Returns:
|
|
// - The code verifier string, or an empty string if not set or PKCE is disabled.
|
|
func (sd *SessionData) GetCodeVerifier() string {
|
|
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
|
return codeVerifier
|
|
}
|
|
|
|
// SetCodeVerifier stores the PKCE code verifier.
|
|
// The code verifier is used to generate the code challenge sent to the
|
|
// authorization server and validated during token exchange.
|
|
// Parameters:
|
|
// - codeVerifier: The PKCE code verifier string to store.
|
|
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
|
currentVal, _ := sd.mainSession.Values["code_verifier"].(string)
|
|
if currentVal != codeVerifier {
|
|
sd.mainSession.Values["code_verifier"] = codeVerifier
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
|
|
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
|
|
// upn, preferred_username, etc.). The value is used for authorization
|
|
// decisions and header injection.
|
|
// Returns:
|
|
// - The user identifier string, or an empty string if not set.
|
|
func (sd *SessionData) GetUserIdentifier() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
|
return userIdentifier
|
|
}
|
|
|
|
// SetUserIdentifier stores the authenticated user's identifier value.
|
|
// Parameters:
|
|
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
|
|
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
|
|
if currentVal != userIdentifier {
|
|
sd.mainSession.Values["user_identifier"] = userIdentifier
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetIncomingPath retrieves the original request URI that triggered authentication.
|
|
// This path is used to redirect the user back to their intended destination
|
|
// after successful authentication.
|
|
// Returns:
|
|
// - The original request URI string, or an empty string if not set.
|
|
func (sd *SessionData) GetIncomingPath() string {
|
|
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
|
return path
|
|
}
|
|
|
|
// SetIncomingPath stores the original request URI for post-authentication redirect.
|
|
// This allows the user to be redirected to their originally requested resource
|
|
// after completing the authentication flow.
|
|
// Parameters:
|
|
// - path: The original request URI string (e.g., "/protected/resource?id=123").
|
|
func (sd *SessionData) SetIncomingPath(path string) {
|
|
currentVal, _ := sd.mainSession.Values["incoming_path"].(string)
|
|
if currentVal != path {
|
|
sd.mainSession.Values["incoming_path"] = path
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetIDToken retrieves the user's ID token from session storage.
|
|
// The ID token contains user claims and is used for user identification
|
|
// and authorization decisions. Handles compression and chunking automatically.
|
|
// Returns:
|
|
// - The complete, decompressed ID token string, or an empty string if not found.
|
|
func (sd *SessionData) GetIDToken() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
return sd.getIDTokenUnsafe()
|
|
}
|
|
|
|
// GetIDTokenClaims returns claims parsed from the current ID token, caching
|
|
// the result on the SessionData so repeated callers within the same request
|
|
// do not re-parse the JWT. The cache is keyed on the ID token string and is
|
|
// cleared when the SessionData is reset (see Reset) or when the ID token
|
|
// changes (e.g. after a refresh).
|
|
//
|
|
// The parser parameter is typically the TraefikOidc.extractClaimsFunc, which
|
|
// lets tests inject mocks just like the direct call it replaces.
|
|
//
|
|
// Returns an empty claims map and a nil error when the session has no ID
|
|
// token, matching the existing "no-op" behavior of the caller sites.
|
|
func (sd *SessionData) GetIDTokenClaims(parser func(string) (map[string]interface{}, error)) (map[string]interface{}, error) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
token := sd.getIDTokenUnsafe()
|
|
if token == "" {
|
|
// Invalidate any stale cache without running the parser.
|
|
sd.cachedClaimsToken = ""
|
|
sd.cachedClaims = nil
|
|
sd.cachedClaimsErr = nil
|
|
return nil, nil
|
|
}
|
|
|
|
if sd.cachedClaimsToken == token && (sd.cachedClaims != nil || sd.cachedClaimsErr != nil) {
|
|
return sd.cachedClaims, sd.cachedClaimsErr
|
|
}
|
|
|
|
claims, err := parser(token)
|
|
sd.cachedClaimsToken = token
|
|
sd.cachedClaims = claims
|
|
sd.cachedClaimsErr = err
|
|
return claims, err
|
|
}
|
|
|
|
// getIDTokenUnsafe retrieves the ID token without acquiring locks.
|
|
// Enhanced ID token retrieval with comprehensive integrity checks and chunking support.
|
|
// Used when the session mutex is already held to prevent deadlocks.
|
|
// Returns:
|
|
// - The complete ID token string or empty string on error.
|
|
func (sd *SessionData) getIDTokenUnsafe() string {
|
|
token, _ := sd.idTokenSession.Values["token"].(string)
|
|
compressed, _ := sd.idTokenSession.Values["compressed"].(bool)
|
|
|
|
// Debug: Check if manager/chunkManager is nil
|
|
if sd.manager == nil || sd.manager.chunkManager == nil {
|
|
// Direct return if no chunk manager (test scenario)
|
|
return token
|
|
}
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.idTokenChunks,
|
|
IDTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// getRefreshTokenUnsafe retrieves the refresh token without acquiring locks.
|
|
// Used when the session mutex is already held to prevent deadlocks.
|
|
func (sd *SessionData) getRefreshTokenUnsafe() string {
|
|
token, _ := sd.refreshSession.Values["token"].(string)
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
|
|
if sd.manager == nil || sd.manager.chunkManager == nil {
|
|
return token
|
|
}
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.refreshTokenChunks,
|
|
RefreshTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
// Handle opaque tokens
|
|
if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 {
|
|
if strings.Count(token, ".") != 2 {
|
|
return token
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
|
|
func (sd *SessionData) getUserIdentifierUnsafe() string {
|
|
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
|
return userIdentifier
|
|
}
|
|
|
|
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
|
func (sd *SessionData) getCSRFUnsafe() string {
|
|
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
|
return csrf
|
|
}
|
|
|
|
// getNonceUnsafe retrieves the nonce without acquiring locks.
|
|
func (sd *SessionData) getNonceUnsafe() string {
|
|
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
|
return nonce
|
|
}
|
|
|
|
// getCodeVerifierUnsafe retrieves the code verifier without acquiring locks.
|
|
func (sd *SessionData) getCodeVerifierUnsafe() string {
|
|
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
|
return codeVerifier
|
|
}
|
|
|
|
// getIncomingPathUnsafe retrieves the incoming path without acquiring locks.
|
|
func (sd *SessionData) getIncomingPathUnsafe() string {
|
|
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
|
return path
|
|
}
|
|
|
|
// getCreatedAtUnsafe retrieves the created_at timestamp without acquiring locks.
|
|
func (sd *SessionData) getCreatedAtUnsafe() int64 {
|
|
createdAt, _ := sd.mainSession.Values["created_at"].(int64)
|
|
return createdAt
|
|
}
|
|
|
|
// getRedirectCountUnsafe retrieves the redirect count without acquiring locks.
|
|
func (sd *SessionData) getRedirectCountUnsafe() int {
|
|
count, _ := sd.mainSession.Values["redirect_count"].(int)
|
|
return count
|
|
}
|
|
|
|
// SetIDToken stores an ID token with automatic compression and chunking.
|
|
// It validates the JWT format, compresses if beneficial, and splits into chunks
|
|
// if the token exceeds cookie size limits. Includes comprehensive validation.
|
|
// Parameters:
|
|
// - token: The ID token string to store.
|
|
func (sd *SessionData) SetIDToken(token string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
if token != "" {
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
sd.manager.logger.Errorf("CRITICAL: Attempt to store invalid JWT ID token format (dots: %d) - rejecting", dotCount)
|
|
return
|
|
}
|
|
}
|
|
|
|
if len(token) > 50*1024 {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token too large (%d bytes) - possible corruption, rejecting", len(token))
|
|
return
|
|
}
|
|
currentIDToken := sd.getIDTokenUnsafe()
|
|
if constantTimeStringCompare(currentIDToken, token) {
|
|
return
|
|
}
|
|
sd.dirty = true
|
|
|
|
if sd.request != nil {
|
|
sd.expireIDTokenChunksEnhanced(nil)
|
|
}
|
|
|
|
for k := range sd.idTokenChunks {
|
|
delete(sd.idTokenChunks, k)
|
|
}
|
|
|
|
if token == "" {
|
|
if sd.idTokenSession != nil {
|
|
sd.idTokenSession.Values["token"] = ""
|
|
sd.idTokenSession.Values["compressed"] = false
|
|
}
|
|
return
|
|
}
|
|
|
|
compressed := compressToken(token)
|
|
|
|
if compressed != token {
|
|
testDecompressed := decompressToken(compressed)
|
|
if testDecompressed != token {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token compression verification failed - storing uncompressed")
|
|
compressed = token
|
|
}
|
|
}
|
|
|
|
if len(compressed) <= maxCookieSize {
|
|
if sd.idTokenSession != nil {
|
|
sd.idTokenSession.Values["token"] = compressed
|
|
sd.idTokenSession.Values["compressed"] = (compressed != token)
|
|
}
|
|
} else {
|
|
if sd.idTokenSession != nil {
|
|
sd.idTokenSession.Values["token"] = ""
|
|
sd.idTokenSession.Values["compressed"] = (compressed != token)
|
|
}
|
|
|
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
|
|
|
if len(chunks) == 0 {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for ID token")
|
|
return
|
|
}
|
|
|
|
if len(chunks) > 50 {
|
|
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for ID token - possible corruption", len(chunks))
|
|
return
|
|
}
|
|
|
|
testReassembled := strings.Join(chunks, "")
|
|
if testReassembled != compressed {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token chunk reassembly test failed")
|
|
return
|
|
}
|
|
|
|
for i, chunkData := range chunks {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
|
|
|
|
if sd.request == nil {
|
|
sd.manager.logger.Errorf("CRITICAL: SetIDToken: sd.request is nil, cannot create chunk session %s", sessionName)
|
|
return
|
|
}
|
|
|
|
if chunkData == "" {
|
|
sd.manager.logger.Debug("Empty chunk data at index %d", i)
|
|
return
|
|
}
|
|
|
|
if len(chunkData) > maxCookieSize {
|
|
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
|
|
return
|
|
}
|
|
|
|
if !validateChunkSize(chunkData) {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
|
|
return
|
|
}
|
|
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
|
|
return
|
|
}
|
|
|
|
session.Values["token_chunk"] = chunkData
|
|
session.Values["compressed"] = (compressed != token)
|
|
session.Values["chunk_created_at"] = time.Now().Unix()
|
|
sd.idTokenChunks[i] = session
|
|
}
|
|
|
|
sd.manager.logger.Debugf("SUCCESS: Stored ID token in %d chunks", len(chunks))
|
|
}
|
|
}
|
|
|
|
// GetRedirectCount returns the number of redirects in the current authentication flow.
|
|
// STABILITY FIX: Prevents infinite redirect loops by tracking redirect attempts.
|
|
// Returns:
|
|
// - The current redirect count, 0 if not set.
|
|
func (sd *SessionData) GetRedirectCount() int {
|
|
if count, ok := sd.mainSession.Values["redirect_count"].(int); ok {
|
|
return count
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// IncrementRedirectCount increases the redirect counter by one.
|
|
// STABILITY FIX: Prevents infinite redirect loops by tracking successive redirects.
|
|
// Used to detect potential redirect loops and abort authentication if too many occur.
|
|
func (sd *SessionData) IncrementRedirectCount() {
|
|
currentCount := sd.GetRedirectCount()
|
|
sd.mainSession.Values["redirect_count"] = currentCount + 1
|
|
sd.dirty = true
|
|
}
|
|
|
|
// ResetRedirectCount resets the redirect counter to zero.
|
|
// STABILITY FIX: Prevents infinite redirect loops by clearing the counter
|
|
// when authentication completes successfully or when starting a new flow.
|
|
func (sd *SessionData) ResetRedirectCount() {
|
|
sd.mainSession.Values["redirect_count"] = 0
|
|
sd.dirty = true
|
|
}
|