mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Complete rebuild of the plugin
* Fix bug affecting Azure OIDC authentication ( and most likely others ) * Fixes issue #51 * Ensure that appended roles are unique. Update the documentation. * Improvements targetting possible memory usage spikes. * Additional fixes and cleanup * Refactoring code to fix the issues identified by the users. * Modernize run * Fieldalignment * Multiple changes to improve performance and reduce complexity. - Optimise the errors and recovery. - Deduplicate code in metadata cache. - Remove unused performance monitoring code. - Simplify session management and settings handling. * Fix claims issue. * Add ability to overwrite the default scopes in the settings file * Well.. that escalated quickly. Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ). * Bugfix #51: Ensures that user provided scopes overrides work. * fixup! Bugfix #51: Ensures that user provided scopes overrides work. * fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work. * Abstract the provider logic into a separate package. * Additional micro fixes and cleanups. * Simplify all the things. * fixup! Simplify all the things. * fixup! fixup! Simplify all the things. * fixup! fixup! fixup! Simplify all the things. * fixup! fixup! fixup! fixup! Simplify all the things. * ... * Cleanup tests. * fixup! Cleanup tests. * fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! fixup! Cleanup tests. * Issue #53: Fix CSRF token handling in reverse proxy 1. ✅ HTTPS Detection Fixed (session.go:723) - Now uses X-Forwarded-Proto header instead of r.URL.Scheme - Properly detects HTTPS in reverse proxy environments 2. ✅ SameSite Cookie Attribute Fixed - Removed automatic SameSiteStrictMode for HTTPS (would break OAuth) - Keeps SameSiteLaxMode to allow OAuth callbacks from external domains - Only uses Strict for AJAX requests which don't involve OAuth redirects 3. ✅ Cookie Domain Handling Fixed - Now respects X-Forwarded-Host header for cookie domain - Ensures cookies are set for the public domain, not internal proxy domain 4. ✅ EnhanceSessionSecurity Properly Integrated - Function is now actually called during session save - Applies security enhancements without breaking OAuth flow Why Issue #53 Failed Before: 1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back) 2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail 3. Cookie domain might have been wrong (internal vs public domain) Why It Works Now: 1. Cookies are properly marked Secure for HTTPS 2. Uses SameSite=Lax to allow OAuth provider callbacks 3. Cookie domain uses public domain from X-Forwarded-Host 4. CSRF token persists through the entire OAuth flow * Next set of enhancements together with memory usage improvements. * Memory leak fixes and optimisations. * CSRF and Cookie Domain fixes * fixup! CSRF and Cookie Domain fixes * Metadata cache leak fix + profiling * fixup! Metadata cache leak fix + profiling * Memory leaks hunting, part 1337. * Further pursue of perfection. * fixup! Further pursue of perfection. * fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * Clear race conditions * fixup! Clear race conditions * Weekend fun with memory leaks * Splitting code into multiple files with reasonable testing coverage. ``` ok github.com/lukaszraczylo/traefikoidc 117.017s coverage: 72.6% of statements ok github.com/lukaszraczylo/traefikoidc/auth 0.505s coverage: 87.1% of statements ok github.com/lukaszraczylo/traefikoidc/circuit_breaker 0.283s coverage: 99.0% of statements github.com/lukaszraczylo/traefikoidc/config coverage: 0.0% of statements ok github.com/lukaszraczylo/traefikoidc/handlers 0.349s coverage: 98.2% of statements ok github.com/lukaszraczylo/traefikoidc/internal/providers (cached) coverage: 94.3% of statements ok github.com/lukaszraczylo/traefikoidc/middleware 0.808s coverage: 78.0% of statements ok github.com/lukaszraczylo/traefikoidc/recovery 0.653s coverage: 100.0% of statements ok github.com/lukaszraczylo/traefikoidc/session/chunking (cached) coverage: 87.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/core (cached) coverage: 85.6% of statements ok github.com/lukaszraczylo/traefikoidc/session/crypto (cached) coverage: 81.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/storage (cached) coverage: 93.5% of statements ok github.com/lukaszraczylo/traefikoidc/session/validators (cached) coverage: 98.8% of statements ```` * fixup! Splitting code into multiple files with reasonable testing coverage. * fixup! fixup! Splitting code into multiple files with reasonable testing coverage. * Weekend fun with further optimisations. * fixup! Weekend fun with further optimisations. * fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * Pre-release cleanup. * Enhance test coverage. * fixup! Enhance test coverage. * fixup! fixup! Enhance test coverage. * fixup! fixup! fixup! Enhance test coverage.
This commit is contained in:
@@ -0,0 +1,453 @@
|
||||
// Package chunking provides session chunking functionality for large tokens
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
maxCookieSize = 1200
|
||||
)
|
||||
|
||||
// TokenConfig defines validation and storage parameters for different token types.
|
||||
// It specifies size limits, format requirements, and security constraints to ensure
|
||||
// tokens can be safely stored in browser cookies while maintaining security.
|
||||
type TokenConfig struct {
|
||||
Type string
|
||||
MinLength int
|
||||
MaxLength int
|
||||
MaxChunks int
|
||||
MaxChunkSize int
|
||||
AllowOpaqueTokens bool
|
||||
RequireJWTFormat bool
|
||||
}
|
||||
|
||||
// Global session tracking to prevent memory leaks across all instances
|
||||
var (
|
||||
globalSessionCount int64 = 0
|
||||
globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions
|
||||
)
|
||||
|
||||
// Predefined configurations for each token type
|
||||
var (
|
||||
AccessTokenConfig = TokenConfig{
|
||||
Type: "access",
|
||||
MinLength: 5,
|
||||
MaxLength: 100 * 1024,
|
||||
MaxChunks: 25,
|
||||
MaxChunkSize: maxCookieSize,
|
||||
AllowOpaqueTokens: true,
|
||||
RequireJWTFormat: false,
|
||||
}
|
||||
|
||||
RefreshTokenConfig = TokenConfig{
|
||||
Type: "refresh",
|
||||
MinLength: 5,
|
||||
MaxLength: 50 * 1024,
|
||||
MaxChunks: 15,
|
||||
MaxChunkSize: maxCookieSize,
|
||||
AllowOpaqueTokens: true,
|
||||
RequireJWTFormat: false,
|
||||
}
|
||||
|
||||
IDTokenConfig = TokenConfig{
|
||||
Type: "id",
|
||||
MinLength: 5,
|
||||
MaxLength: 75 * 1024,
|
||||
MaxChunks: 20,
|
||||
MaxChunkSize: maxCookieSize,
|
||||
AllowOpaqueTokens: false,
|
||||
RequireJWTFormat: true,
|
||||
}
|
||||
)
|
||||
|
||||
// TokenRetrievalResult represents the outcome of a token retrieval operation.
|
||||
// It contains either the successfully retrieved token or an error describing
|
||||
// what went wrong during retrieval.
|
||||
type TokenRetrievalResult struct {
|
||||
Error error
|
||||
Token string
|
||||
}
|
||||
|
||||
// SessionEntry represents a session with expiration tracking
|
||||
type SessionEntry struct {
|
||||
Session *sessions.Session
|
||||
ExpiresAt time.Time
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// ChunkManager handles the complex logic of storing and retrieving large tokens
|
||||
// across multiple HTTP cookies. It provides comprehensive validation, security checks,
|
||||
// and error handling to ensure data integrity and prevent security vulnerabilities
|
||||
// throughout the process.
|
||||
type ChunkManager struct {
|
||||
logger Logger
|
||||
mutex *sync.RWMutex
|
||||
// sessionMap provides bounded session storage to prevent memory leaks
|
||||
sessionMap map[string]*SessionEntry
|
||||
maxSessions int
|
||||
sessionTTL time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
// NewChunkManager creates a new ChunkManager instance with proper initialization.
|
||||
// It sets up logging and synchronization primitives for safe concurrent access.
|
||||
func NewChunkManager(logger Logger) *ChunkManager {
|
||||
if logger == nil {
|
||||
logger = NewNoOpLogger()
|
||||
}
|
||||
|
||||
return &ChunkManager{
|
||||
logger: logger,
|
||||
mutex: &sync.RWMutex{},
|
||||
sessionMap: make(map[string]*SessionEntry),
|
||||
maxSessions: 200, // CRITICAL FIX: Reduced from 1000 to 200 per instance
|
||||
sessionTTL: 15 * time.Minute, // CRITICAL FIX: Reduced from 24h to 15 minutes
|
||||
lastCleanup: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetToken retrieves a token from either a single cookie or multiple chunk cookies.
|
||||
// It handles both compressed and uncompressed tokens and performs comprehensive
|
||||
// validation throughout the retrieval process.
|
||||
func (cm *ChunkManager) GetToken(
|
||||
mainSession *sessions.Session,
|
||||
chunks map[int]*sessions.Session,
|
||||
config TokenConfig,
|
||||
compressor TokenCompressor,
|
||||
) TokenRetrievalResult {
|
||||
|
||||
// Try to get token from main session first
|
||||
if mainSession != nil {
|
||||
if tokenValue, ok := mainSession.Values[config.Type+"_token"].(string); ok && tokenValue != "" {
|
||||
cm.logger.Debugf("Found %s token in main session", config.Type)
|
||||
|
||||
// Check if token is compressed
|
||||
decompressed := compressor.DecompressToken(tokenValue)
|
||||
if decompressed != tokenValue {
|
||||
cm.logger.Debugf("Decompressed %s token", config.Type)
|
||||
return cm.processSingleToken(decompressed, true, config)
|
||||
}
|
||||
|
||||
return cm.processSingleToken(tokenValue, false, config)
|
||||
}
|
||||
}
|
||||
|
||||
// If not in main session, try chunks
|
||||
if len(chunks) == 0 {
|
||||
return TokenRetrievalResult{
|
||||
Error: nil,
|
||||
Token: "",
|
||||
}
|
||||
}
|
||||
|
||||
cm.logger.Debugf("Found %d chunks for %s token, processing", len(chunks), config.Type)
|
||||
return cm.processChunkedToken(chunks, config, compressor)
|
||||
}
|
||||
|
||||
// processSingleToken validates and processes a single token
|
||||
func (cm *ChunkManager) processSingleToken(token string, compressed bool, config TokenConfig) TokenRetrievalResult {
|
||||
if compressed {
|
||||
cm.logger.Debugf("Processing compressed %s token (length: %d)", config.Type, len(token))
|
||||
} else {
|
||||
cm.logger.Debugf("Processing single %s token (length: %d)", config.Type, len(token))
|
||||
}
|
||||
|
||||
return cm.validateToken(token, config)
|
||||
}
|
||||
|
||||
// validateToken performs comprehensive validation on a token
|
||||
func (cm *ChunkManager) validateToken(token string, config TokenConfig) TokenRetrievalResult {
|
||||
if token == "" {
|
||||
return TokenRetrievalResult{Error: nil, Token: ""}
|
||||
}
|
||||
|
||||
validator := NewTokenValidator()
|
||||
|
||||
// Basic validation
|
||||
if err := validator.ValidateTokenSize(token, config); err != nil {
|
||||
cm.logger.Errorf("Token size validation failed for %s: %v", config.Type, err)
|
||||
return TokenRetrievalResult{Error: err, Token: ""}
|
||||
}
|
||||
|
||||
// Format validation
|
||||
if config.RequireJWTFormat {
|
||||
if err := validator.ValidateJWTFormat(token, config.Type); err != nil {
|
||||
cm.logger.Errorf("JWT format validation failed for %s: %v", config.Type, err)
|
||||
return TokenRetrievalResult{Error: err, Token: ""}
|
||||
}
|
||||
} else if !config.AllowOpaqueTokens {
|
||||
if err := validator.ValidateJWTFormat(token, config.Type); err != nil {
|
||||
cm.logger.Errorf("Token format validation failed for %s: %v", config.Type, err)
|
||||
return TokenRetrievalResult{Error: err, Token: ""}
|
||||
}
|
||||
}
|
||||
|
||||
// Content validation
|
||||
if err := validator.ValidateTokenContent(token, config); err != nil {
|
||||
cm.logger.Errorf("Token content validation failed for %s: %v", config.Type, err)
|
||||
return TokenRetrievalResult{Error: err, Token: ""}
|
||||
}
|
||||
|
||||
cm.logger.Debugf("Successfully validated %s token", config.Type)
|
||||
return TokenRetrievalResult{Error: nil, Token: token}
|
||||
}
|
||||
|
||||
// processChunkedToken reconstructs a token from multiple chunks
|
||||
func (cm *ChunkManager) processChunkedToken(chunks map[int]*sessions.Session, config TokenConfig, compressor TokenCompressor) TokenRetrievalResult {
|
||||
if len(chunks) > config.MaxChunks {
|
||||
return TokenRetrievalResult{
|
||||
Error: &ChunkError{
|
||||
Type: config.Type,
|
||||
Reason: "too many chunks",
|
||||
Details: "chunk count exceeds maximum allowed",
|
||||
},
|
||||
Token: "",
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct token from chunks
|
||||
reconstructedToken, err := cm.reconstructTokenFromChunks(chunks, config)
|
||||
if err != nil {
|
||||
cm.logger.Errorf("Failed to reconstruct %s token from chunks: %v", config.Type, err)
|
||||
return TokenRetrievalResult{Error: err, Token: ""}
|
||||
}
|
||||
|
||||
// Try decompression
|
||||
decompressedToken := compressor.DecompressToken(reconstructedToken)
|
||||
if decompressedToken != reconstructedToken {
|
||||
cm.logger.Debugf("Decompressed reconstructed %s token", config.Type)
|
||||
return cm.validateToken(decompressedToken, config)
|
||||
}
|
||||
|
||||
return cm.validateToken(reconstructedToken, config)
|
||||
}
|
||||
|
||||
// reconstructTokenFromChunks reconstructs a token from ordered chunks
|
||||
func (cm *ChunkManager) reconstructTokenFromChunks(chunks map[int]*sessions.Session, config TokenConfig) (string, error) {
|
||||
if len(chunks) == 0 {
|
||||
return "", &ChunkError{
|
||||
Type: config.Type,
|
||||
Reason: "no chunks found",
|
||||
Details: "no chunk sessions available for reconstruction",
|
||||
}
|
||||
}
|
||||
|
||||
// Find the maximum chunk index to determine total chunks
|
||||
maxIndex := -1
|
||||
for index := range chunks {
|
||||
if index > maxIndex {
|
||||
maxIndex = index
|
||||
}
|
||||
}
|
||||
|
||||
if maxIndex < 0 {
|
||||
return "", &ChunkError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid chunk indices",
|
||||
Details: "no valid chunk indices found",
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct token by concatenating chunks in order
|
||||
var tokenBuilder strings.Builder
|
||||
for i := 0; i <= maxIndex; i++ {
|
||||
chunk, exists := chunks[i]
|
||||
if !exists || chunk == nil {
|
||||
return "", &ChunkError{
|
||||
Type: config.Type,
|
||||
Reason: "missing chunk",
|
||||
Details: fmt.Sprintf("chunk %d is missing", i),
|
||||
}
|
||||
}
|
||||
|
||||
chunkValue, ok := chunk.Values["value"].(string)
|
||||
if !ok || chunkValue == "" {
|
||||
return "", &ChunkError{
|
||||
Type: config.Type,
|
||||
Reason: "empty chunk",
|
||||
Details: fmt.Sprintf("chunk %d has no value", i),
|
||||
}
|
||||
}
|
||||
|
||||
tokenBuilder.WriteString(chunkValue)
|
||||
}
|
||||
|
||||
reconstructed := tokenBuilder.String()
|
||||
if reconstructed == "" {
|
||||
return "", &ChunkError{
|
||||
Type: config.Type,
|
||||
Reason: "empty reconstructed token",
|
||||
Details: "all chunks were present but resulted in empty token",
|
||||
}
|
||||
}
|
||||
|
||||
cm.logger.Debugf("Successfully reconstructed %s token from %d chunks (length: %d)",
|
||||
config.Type, len(chunks), len(reconstructed))
|
||||
|
||||
return reconstructed, nil
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions removes expired sessions from the session map
|
||||
func (cm *ChunkManager) CleanupExpiredSessions() {
|
||||
cm.mutex.Lock()
|
||||
defer cm.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Only cleanup if enough time has passed
|
||||
if now.Sub(cm.lastCleanup) < time.Hour {
|
||||
return
|
||||
}
|
||||
|
||||
cm.lastCleanup = now
|
||||
cleaned := 0
|
||||
|
||||
for key, entry := range cm.sessionMap {
|
||||
if now.After(entry.ExpiresAt) || now.Sub(entry.LastUsed) > cm.sessionTTL {
|
||||
delete(cm.sessionMap, key)
|
||||
cleaned++
|
||||
}
|
||||
}
|
||||
|
||||
if cleaned > 0 {
|
||||
cm.logger.Debugf("Cleaned up %d expired sessions", cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
// StoreSession stores a session in the session map with expiration tracking
|
||||
func (cm *ChunkManager) StoreSession(key string, session *sessions.Session) {
|
||||
cm.mutex.Lock()
|
||||
defer cm.mutex.Unlock()
|
||||
|
||||
// CRITICAL FIX: Aggressive session limit enforcement
|
||||
currentLocal := len(cm.sessionMap)
|
||||
currentGlobal := atomic.LoadInt64(&globalSessionCount)
|
||||
|
||||
shouldEvict := false
|
||||
targetCapacity := cm.maxSessions
|
||||
|
||||
// Check global limit first (more critical)
|
||||
if currentGlobal >= globalMaxSessions {
|
||||
shouldEvict = true
|
||||
targetCapacity = cm.maxSessions / 4 // Aggressive reduction to 25%
|
||||
} else if currentGlobal >= globalMaxSessions*8/10 { // 80% of global
|
||||
shouldEvict = true
|
||||
targetCapacity = cm.maxSessions / 2 // Reduce to 50%
|
||||
} else if currentLocal >= cm.maxSessions {
|
||||
shouldEvict = true
|
||||
targetCapacity = cm.maxSessions * 3 / 4 // Reduce to 75%
|
||||
}
|
||||
|
||||
if shouldEvict {
|
||||
// Find oldest sessions to remove
|
||||
type sessionAge struct {
|
||||
key string
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
sessions := make([]sessionAge, 0, currentLocal)
|
||||
for k, entry := range cm.sessionMap {
|
||||
sessions = append(sessions, sessionAge{key: k, lastUsed: entry.LastUsed})
|
||||
}
|
||||
|
||||
// Sort by last used time (oldest first)
|
||||
for i := 0; i < len(sessions)-1; i++ {
|
||||
for j := i + 1; j < len(sessions); j++ {
|
||||
if sessions[i].lastUsed.After(sessions[j].lastUsed) {
|
||||
sessions[i], sessions[j] = sessions[j], sessions[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove excess sessions
|
||||
excessCount := currentLocal - targetCapacity
|
||||
if excessCount < 0 {
|
||||
excessCount = 0
|
||||
}
|
||||
|
||||
removedCount := int64(0)
|
||||
for i := 0; i < excessCount && i < len(sessions); i++ {
|
||||
delete(cm.sessionMap, sessions[i].key)
|
||||
removedCount++
|
||||
}
|
||||
|
||||
if removedCount > 0 {
|
||||
atomic.AddInt64(&globalSessionCount, -removedCount)
|
||||
}
|
||||
}
|
||||
|
||||
cm.sessionMap[key] = &SessionEntry{
|
||||
Session: session,
|
||||
ExpiresAt: time.Now().Add(cm.sessionTTL),
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
atomic.AddInt64(&globalSessionCount, 1) // CRITICAL FIX: Track addition
|
||||
}
|
||||
|
||||
// GetSession retrieves a session from the session map
|
||||
func (cm *ChunkManager) GetSession(key string) *sessions.Session {
|
||||
cm.mutex.Lock()
|
||||
defer cm.mutex.Unlock()
|
||||
|
||||
entry, exists := cm.sessionMap[key]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update last used time
|
||||
entry.LastUsed = time.Now()
|
||||
return entry.Session
|
||||
}
|
||||
|
||||
// TokenCompressor interface for token compression operations
|
||||
type TokenCompressor interface {
|
||||
CompressToken(token string) string
|
||||
DecompressToken(compressed string) string
|
||||
}
|
||||
|
||||
// ChunkError represents errors that occur during chunk operations
|
||||
type ChunkError struct {
|
||||
Type string
|
||||
Reason string
|
||||
Details string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (ce *ChunkError) Error() string {
|
||||
return fmt.Sprintf("%s chunk error: %s - %s", ce.Type, ce.Reason, ce.Details)
|
||||
}
|
||||
|
||||
// NoOpLogger provides a no-op logger implementation
|
||||
type NoOpLogger struct{}
|
||||
|
||||
// NewNoOpLogger creates a new no-op logger
|
||||
func NewNoOpLogger() *NoOpLogger {
|
||||
return &NoOpLogger{}
|
||||
}
|
||||
|
||||
// Debug does nothing
|
||||
func (l *NoOpLogger) Debug(msg string) {}
|
||||
|
||||
// Debugf does nothing
|
||||
func (l *NoOpLogger) Debugf(format string, args ...interface{}) {}
|
||||
|
||||
// Error does nothing
|
||||
func (l *NoOpLogger) Error(msg string) {}
|
||||
|
||||
// Errorf does nothing
|
||||
func (l *NoOpLogger) Errorf(format string, args ...interface{}) {}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,279 @@
|
||||
// Package chunking provides chunk serialization functionality
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChunkSerializer handles serialization and deserialization of token chunks
|
||||
type ChunkSerializer struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// NewChunkSerializer creates a new chunk serializer
|
||||
func NewChunkSerializer(logger Logger) *ChunkSerializer {
|
||||
return &ChunkSerializer{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SerializeTokenToChunks splits a token into chunks suitable for cookie storage
|
||||
func (cs *ChunkSerializer) SerializeTokenToChunks(token string, config TokenConfig) ([]ChunkData, error) {
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("cannot serialize empty token")
|
||||
}
|
||||
|
||||
if len(token) < config.MinLength {
|
||||
return nil, fmt.Errorf("token too short: %d < %d", len(token), config.MinLength)
|
||||
}
|
||||
|
||||
if len(token) > config.MaxLength {
|
||||
return nil, fmt.Errorf("token too long: %d > %d", len(token), config.MaxLength)
|
||||
}
|
||||
|
||||
// Calculate optimal chunk size
|
||||
chunkSize := config.MaxChunkSize
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = maxCookieSize
|
||||
}
|
||||
|
||||
// Estimate number of chunks needed
|
||||
estimatedChunks := (len(token) + chunkSize - 1) / chunkSize
|
||||
if estimatedChunks > config.MaxChunks {
|
||||
return nil, fmt.Errorf("token requires too many chunks: %d > %d", estimatedChunks, config.MaxChunks)
|
||||
}
|
||||
|
||||
// Split token into chunks
|
||||
chunks := make([]ChunkData, 0, estimatedChunks)
|
||||
remaining := token
|
||||
|
||||
chunkIndex := 0
|
||||
for len(remaining) > 0 {
|
||||
if chunkIndex >= config.MaxChunks {
|
||||
return nil, fmt.Errorf("exceeded maximum chunk count during serialization")
|
||||
}
|
||||
|
||||
// Determine chunk size for this iteration
|
||||
currentChunkSize := chunkSize
|
||||
if len(remaining) < currentChunkSize {
|
||||
currentChunkSize = len(remaining)
|
||||
}
|
||||
|
||||
// Extract chunk
|
||||
chunkContent := remaining[:currentChunkSize]
|
||||
remaining = remaining[currentChunkSize:]
|
||||
|
||||
// Create chunk data
|
||||
chunkData := ChunkData{
|
||||
Index: chunkIndex,
|
||||
Content: chunkContent,
|
||||
Total: estimatedChunks, // Will be updated after all chunks are created
|
||||
Checksum: cs.calculateChecksum(chunkContent),
|
||||
}
|
||||
|
||||
chunks = append(chunks, chunkData)
|
||||
chunkIndex++
|
||||
}
|
||||
|
||||
// Update total count in all chunks
|
||||
actualChunks := len(chunks)
|
||||
for i := range chunks {
|
||||
chunks[i].Total = actualChunks
|
||||
}
|
||||
|
||||
cs.logger.Debugf("Serialized %s token into %d chunks", config.Type, len(chunks))
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// DeserializeTokenFromChunks reconstructs a token from chunk data
|
||||
func (cs *ChunkSerializer) DeserializeTokenFromChunks(chunks []ChunkData, config TokenConfig) (string, error) {
|
||||
if len(chunks) == 0 {
|
||||
return "", fmt.Errorf("no chunks provided for deserialization")
|
||||
}
|
||||
|
||||
if len(chunks) > config.MaxChunks {
|
||||
return "", fmt.Errorf("too many chunks: %d > %d", len(chunks), config.MaxChunks)
|
||||
}
|
||||
|
||||
// Validate chunk consistency
|
||||
expectedTotal := chunks[0].Total
|
||||
for i, chunk := range chunks {
|
||||
if chunk.Total != expectedTotal {
|
||||
return "", fmt.Errorf("chunk %d has inconsistent total count: %d != %d", i, chunk.Total, expectedTotal)
|
||||
}
|
||||
}
|
||||
|
||||
if len(chunks) != expectedTotal {
|
||||
return "", fmt.Errorf("chunk count mismatch: got %d, expected %d", len(chunks), expectedTotal)
|
||||
}
|
||||
|
||||
// Sort chunks by index
|
||||
orderedChunks := make([]ChunkData, expectedTotal)
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Index < 0 || chunk.Index >= expectedTotal {
|
||||
return "", fmt.Errorf("invalid chunk index: %d (total: %d)", chunk.Index, expectedTotal)
|
||||
}
|
||||
|
||||
if orderedChunks[chunk.Index].Content != "" {
|
||||
return "", fmt.Errorf("duplicate chunk index: %d", chunk.Index)
|
||||
}
|
||||
|
||||
orderedChunks[chunk.Index] = chunk
|
||||
}
|
||||
|
||||
// Verify all chunks are present
|
||||
for i, chunk := range orderedChunks {
|
||||
if chunk.Content == "" {
|
||||
return "", fmt.Errorf("missing chunk at index: %d", i)
|
||||
}
|
||||
|
||||
// Verify checksum
|
||||
expectedChecksum := cs.calculateChecksum(chunk.Content)
|
||||
if chunk.Checksum != expectedChecksum {
|
||||
return "", fmt.Errorf("chunk %d checksum mismatch", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct token
|
||||
var tokenBuilder strings.Builder
|
||||
tokenBuilder.Grow(len(chunks) * config.MaxChunkSize) // Pre-allocate capacity
|
||||
|
||||
for _, chunk := range orderedChunks {
|
||||
tokenBuilder.WriteString(chunk.Content)
|
||||
}
|
||||
|
||||
reconstructedToken := tokenBuilder.String()
|
||||
|
||||
// Final validation
|
||||
if len(reconstructedToken) < config.MinLength {
|
||||
return "", fmt.Errorf("reconstructed token too short: %d < %d", len(reconstructedToken), config.MinLength)
|
||||
}
|
||||
|
||||
if len(reconstructedToken) > config.MaxLength {
|
||||
return "", fmt.Errorf("reconstructed token too long: %d > %d", len(reconstructedToken), config.MaxLength)
|
||||
}
|
||||
|
||||
cs.logger.Debugf("Deserialized %s token from %d chunks (length: %d)", config.Type, len(chunks), len(reconstructedToken))
|
||||
return reconstructedToken, nil
|
||||
}
|
||||
|
||||
// EncodeChunk encodes chunk data for cookie storage
|
||||
func (cs *ChunkSerializer) EncodeChunk(chunk ChunkData) (string, error) {
|
||||
// Create a simple format: index:total:checksum:content
|
||||
encoded := fmt.Sprintf("%d:%d:%s:%s", chunk.Index, chunk.Total, chunk.Checksum, chunk.Content)
|
||||
|
||||
// Base64 encode the entire chunk for safe cookie storage
|
||||
return base64.StdEncoding.EncodeToString([]byte(encoded)), nil
|
||||
}
|
||||
|
||||
// DecodeChunk decodes chunk data from cookie storage
|
||||
func (cs *ChunkSerializer) DecodeChunk(encoded string) (ChunkData, error) {
|
||||
// Base64 decode
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return ChunkData{}, fmt.Errorf("failed to base64 decode chunk: %w", err)
|
||||
}
|
||||
|
||||
// Parse the format: index:total:checksum:content
|
||||
parts := strings.SplitN(string(decoded), ":", 4)
|
||||
if len(parts) != 4 {
|
||||
return ChunkData{}, fmt.Errorf("invalid chunk format: expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
var index, total int
|
||||
if _, err := fmt.Sscanf(parts[0], "%d", &index); err != nil {
|
||||
return ChunkData{}, fmt.Errorf("invalid chunk index: %w", err)
|
||||
}
|
||||
|
||||
if _, err := fmt.Sscanf(parts[1], "%d", &total); err != nil {
|
||||
return ChunkData{}, fmt.Errorf("invalid chunk total: %w", err)
|
||||
}
|
||||
|
||||
checksum := parts[2]
|
||||
content := parts[3]
|
||||
|
||||
return ChunkData{
|
||||
Index: index,
|
||||
Total: total,
|
||||
Content: content,
|
||||
Checksum: checksum,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateChunkIntegrity validates the integrity of chunk data
|
||||
func (cs *ChunkSerializer) ValidateChunkIntegrity(chunk ChunkData) error {
|
||||
if chunk.Index < 0 {
|
||||
return fmt.Errorf("negative chunk index: %d", chunk.Index)
|
||||
}
|
||||
|
||||
if chunk.Total <= 0 {
|
||||
return fmt.Errorf("invalid total chunks: %d", chunk.Total)
|
||||
}
|
||||
|
||||
if chunk.Index >= chunk.Total {
|
||||
return fmt.Errorf("chunk index %d exceeds total %d", chunk.Index, chunk.Total)
|
||||
}
|
||||
|
||||
if chunk.Content == "" {
|
||||
return fmt.Errorf("empty chunk content at index %d", chunk.Index)
|
||||
}
|
||||
|
||||
if chunk.Checksum == "" {
|
||||
return fmt.Errorf("empty chunk checksum at index %d", chunk.Index)
|
||||
}
|
||||
|
||||
// Verify checksum
|
||||
expectedChecksum := cs.calculateChecksum(chunk.Content)
|
||||
if chunk.Checksum != expectedChecksum {
|
||||
return fmt.Errorf("chunk %d checksum mismatch: expected %s, got %s",
|
||||
chunk.Index, expectedChecksum, chunk.Checksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateChecksum calculates a simple checksum for chunk content
|
||||
func (cs *ChunkSerializer) calculateChecksum(content string) string {
|
||||
// Simple checksum using length and first/last characters
|
||||
if len(content) == 0 {
|
||||
return "empty"
|
||||
}
|
||||
|
||||
checksum := fmt.Sprintf("len%d", len(content))
|
||||
if len(content) >= 1 {
|
||||
checksum += fmt.Sprintf("_first%d", int(content[0]))
|
||||
}
|
||||
if len(content) >= 2 {
|
||||
checksum += fmt.Sprintf("_last%d", int(content[len(content)-1]))
|
||||
}
|
||||
|
||||
return checksum
|
||||
}
|
||||
|
||||
// ChunkData represents a single chunk of token data
|
||||
type ChunkData struct {
|
||||
Index int // Position of this chunk in the sequence
|
||||
Total int // Total number of chunks for this token
|
||||
Content string // The actual chunk content
|
||||
Checksum string // Simple checksum for integrity verification
|
||||
}
|
||||
|
||||
// EstimateChunkCount estimates how many chunks a token will need
|
||||
func (cs *ChunkSerializer) EstimateChunkCount(tokenLength int, chunkSize int) int {
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = maxCookieSize
|
||||
}
|
||||
|
||||
return (tokenLength + chunkSize - 1) / chunkSize
|
||||
}
|
||||
|
||||
// MaxTokenSizeForChunks calculates the maximum token size that can fit in the given number of chunks
|
||||
func (cs *ChunkSerializer) MaxTokenSizeForChunks(maxChunks int, chunkSize int) int {
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = maxCookieSize
|
||||
}
|
||||
|
||||
return maxChunks * chunkSize
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
// Package chunking provides chunk validation functionality
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// TokenValidator provides comprehensive validation for tokens and chunks
|
||||
type TokenValidator struct{}
|
||||
|
||||
// NewTokenValidator creates a new token validator
|
||||
func NewTokenValidator() *TokenValidator {
|
||||
return &TokenValidator{}
|
||||
}
|
||||
|
||||
// ValidateTokenSize validates that a token is within size limits
|
||||
func (tv *TokenValidator) ValidateTokenSize(token string, config TokenConfig) error {
|
||||
if len(token) == 0 {
|
||||
return nil // Empty token is allowed
|
||||
}
|
||||
|
||||
if len(token) < config.MinLength {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "token too short",
|
||||
Details: fmt.Sprintf("length %d < minimum %d", len(token), config.MinLength),
|
||||
}
|
||||
}
|
||||
|
||||
if len(token) > config.MaxLength {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "token too long",
|
||||
Details: fmt.Sprintf("length %d > maximum %d", len(token), config.MaxLength),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateJWTFormat validates that a token has proper JWT format
|
||||
func (tv *TokenValidator) ValidateJWTFormat(token string, tokenType string) error {
|
||||
if token == "" {
|
||||
return nil // Empty token is not an error
|
||||
}
|
||||
|
||||
// JWT tokens must have exactly 3 parts separated by dots
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return &ValidationError{
|
||||
Type: tokenType,
|
||||
Reason: "invalid JWT format",
|
||||
Details: fmt.Sprintf("expected 3 parts, got %d", len(parts)),
|
||||
}
|
||||
}
|
||||
|
||||
// Each part must be non-empty
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
return &ValidationError{
|
||||
Type: tokenType,
|
||||
Reason: "empty JWT part",
|
||||
Details: fmt.Sprintf("part %d is empty", i+1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each part is valid base64
|
||||
for i, part := range parts {
|
||||
if err := tv.validateBase64JWT(part); err != nil {
|
||||
return &ValidationError{
|
||||
Type: tokenType,
|
||||
Reason: "invalid base64 in JWT part",
|
||||
Details: fmt.Sprintf("part %d: %v", i+1, err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTokenContent performs comprehensive content validation
|
||||
func (tv *TokenValidator) ValidateTokenContent(token string, config TokenConfig) error {
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate character set
|
||||
if err := tv.validateCharacterSet(token, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate token structure based on type
|
||||
if config.RequireJWTFormat {
|
||||
return tv.validateJWTContent(token, config)
|
||||
} else if config.AllowOpaqueTokens {
|
||||
return tv.validateOpaqueTokenContent(token, config)
|
||||
} else {
|
||||
// Try JWT first, then fall back to opaque validation
|
||||
if err := tv.validateJWTContent(token, config); err != nil {
|
||||
return tv.validateOpaqueTokenContent(token, config)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// validateCharacterSet validates the character set of a token
|
||||
func (tv *TokenValidator) validateCharacterSet(token string, config TokenConfig) error {
|
||||
for i, r := range token {
|
||||
if !tv.isValidTokenCharacter(r) {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid character",
|
||||
Details: fmt.Sprintf("invalid character at position %d: %c (0x%X)", i, r, r),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidTokenCharacter checks if a character is valid in a token
|
||||
func (tv *TokenValidator) isValidTokenCharacter(r rune) bool {
|
||||
// Allow alphanumeric characters
|
||||
if unicode.IsLetter(r) || unicode.IsNumber(r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Allow common token characters
|
||||
validChars := ".-_~:/?#[]@!$&'()*+,;="
|
||||
return strings.ContainsRune(validChars, r)
|
||||
}
|
||||
|
||||
// validateJWTContent validates the content of a JWT token
|
||||
func (tv *TokenValidator) validateJWTContent(token string, config TokenConfig) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid JWT structure",
|
||||
Details: "JWT must have exactly 3 parts",
|
||||
}
|
||||
}
|
||||
|
||||
// Validate header
|
||||
if err := tv.validateJWTHeader(parts[0], config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate payload
|
||||
if err := tv.validateJWTPayload(parts[1], config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate signature
|
||||
if err := tv.validateJWTSignature(parts[2], config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTHeader validates a JWT header
|
||||
func (tv *TokenValidator) validateJWTHeader(header string, config TokenConfig) error {
|
||||
decoded, err := tv.base64URLDecode(header)
|
||||
if err != nil {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid header encoding",
|
||||
Details: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
var headerData map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &headerData); err != nil {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid header JSON",
|
||||
Details: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
if _, ok := headerData["alg"]; !ok {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "missing algorithm",
|
||||
Details: "JWT header must contain 'alg' field",
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := headerData["typ"]; !ok {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "missing type",
|
||||
Details: "JWT header must contain 'typ' field",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTPayload validates a JWT payload
|
||||
func (tv *TokenValidator) validateJWTPayload(payload string, config TokenConfig) error {
|
||||
decoded, err := tv.base64URLDecode(payload)
|
||||
if err != nil {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid payload encoding",
|
||||
Details: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
var payloadData map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &payloadData); err != nil {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid payload JSON",
|
||||
Details: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// For ID tokens, check required claims
|
||||
if config.Type == "id" {
|
||||
requiredClaims := []string{"iss", "sub", "aud", "exp", "iat"}
|
||||
for _, claim := range requiredClaims {
|
||||
if _, ok := payloadData[claim]; !ok {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "missing required claim",
|
||||
Details: fmt.Sprintf("ID token must contain '%s' claim", claim),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTSignature validates a JWT signature part
|
||||
func (tv *TokenValidator) validateJWTSignature(signature string, config TokenConfig) error {
|
||||
if signature == "" {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "empty signature",
|
||||
Details: "JWT signature cannot be empty",
|
||||
}
|
||||
}
|
||||
|
||||
// Just validate it's valid base64URL
|
||||
_, err := tv.base64URLDecode(signature)
|
||||
if err != nil {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid signature encoding",
|
||||
Details: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpaqueTokenContent validates opaque token content
|
||||
func (tv *TokenValidator) validateOpaqueTokenContent(token string, config TokenConfig) error {
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Basic sanity checks for opaque tokens
|
||||
if len(token) < 8 {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "token too short for opaque token",
|
||||
Details: "opaque tokens should be at least 8 characters",
|
||||
}
|
||||
}
|
||||
|
||||
// Check for reasonable entropy
|
||||
if tv.hasLowEntropy(token) {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "low entropy",
|
||||
Details: "token appears to have low entropy",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasLowEntropy checks if a token has suspiciously low entropy
|
||||
func (tv *TokenValidator) hasLowEntropy(token string) bool {
|
||||
if len(token) < 8 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Count unique characters
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
uniqueChars[r] = true
|
||||
}
|
||||
|
||||
// If less than 50% of characters are unique, consider it low entropy
|
||||
entropyRatio := float64(len(uniqueChars)) / float64(len(token))
|
||||
return entropyRatio < 0.5
|
||||
}
|
||||
|
||||
// validateBase64JWT validates base64URL encoding
|
||||
func (tv *TokenValidator) validateBase64JWT(data string) error {
|
||||
_, err := tv.base64URLDecode(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// base64URLDecode decodes base64URL encoded data
|
||||
func (tv *TokenValidator) base64URLDecode(data string) ([]byte, error) {
|
||||
// Add padding if needed
|
||||
switch len(data) % 4 {
|
||||
case 2:
|
||||
data += "=="
|
||||
case 3:
|
||||
data += "="
|
||||
}
|
||||
|
||||
// Replace URL-safe characters
|
||||
data = strings.ReplaceAll(data, "-", "+")
|
||||
data = strings.ReplaceAll(data, "_", "/")
|
||||
|
||||
return base64.StdEncoding.DecodeString(data)
|
||||
}
|
||||
|
||||
// ValidateChunkStructure validates the structure of chunk data
|
||||
func (tv *TokenValidator) ValidateChunkStructure(chunks []ChunkData, config TokenConfig) error {
|
||||
if len(chunks) == 0 {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "no chunks provided",
|
||||
Details: "chunk list is empty",
|
||||
}
|
||||
}
|
||||
|
||||
if len(chunks) > config.MaxChunks {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "too many chunks",
|
||||
Details: fmt.Sprintf("got %d chunks, maximum is %d", len(chunks), config.MaxChunks),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each chunk
|
||||
expectedTotal := chunks[0].Total
|
||||
seenIndices := make(map[int]bool)
|
||||
|
||||
for i, chunk := range chunks {
|
||||
// Check for duplicate indices
|
||||
if seenIndices[chunk.Index] {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "duplicate chunk index",
|
||||
Details: fmt.Sprintf("chunk index %d appears multiple times", chunk.Index),
|
||||
}
|
||||
}
|
||||
seenIndices[chunk.Index] = true
|
||||
|
||||
// Validate individual chunk
|
||||
if err := tv.validateChunkData(chunk, expectedTotal, config); err != nil {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "invalid chunk data",
|
||||
Details: fmt.Sprintf("chunk %d: %v", i, err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for missing indices
|
||||
for i := 0; i < expectedTotal; i++ {
|
||||
if !seenIndices[i] {
|
||||
return &ValidationError{
|
||||
Type: config.Type,
|
||||
Reason: "missing chunk index",
|
||||
Details: fmt.Sprintf("chunk with index %d is missing", i),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateChunkData validates individual chunk data
|
||||
func (tv *TokenValidator) validateChunkData(chunk ChunkData, expectedTotal int, config TokenConfig) error {
|
||||
if chunk.Index < 0 {
|
||||
return fmt.Errorf("negative index: %d", chunk.Index)
|
||||
}
|
||||
|
||||
if chunk.Total != expectedTotal {
|
||||
return fmt.Errorf("inconsistent total: got %d, expected %d", chunk.Total, expectedTotal)
|
||||
}
|
||||
|
||||
if chunk.Index >= chunk.Total {
|
||||
return fmt.Errorf("index %d exceeds total %d", chunk.Index, chunk.Total)
|
||||
}
|
||||
|
||||
if chunk.Content == "" {
|
||||
return fmt.Errorf("empty content")
|
||||
}
|
||||
|
||||
if len(chunk.Content) > config.MaxChunkSize {
|
||||
return fmt.Errorf("chunk too large: %d > %d", len(chunk.Content), config.MaxChunkSize)
|
||||
}
|
||||
|
||||
if chunk.Checksum == "" {
|
||||
return fmt.Errorf("empty checksum")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Type string
|
||||
Reason string
|
||||
Details string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (ve *ValidationError) Error() string {
|
||||
return fmt.Sprintf("%s validation error: %s - %s", ve.Type, ve.Reason, ve.Details)
|
||||
}
|
||||
Reference in New Issue
Block a user