// Package auth provides authentication-related functionality for the OIDC middleware. package auth import ( "fmt" "net" "net/http" "net/url" "strings" "github.com/google/uuid" ) // AuthHandler provides core authentication functionality for OIDC flows type AuthHandler struct { logger Logger enablePKCE bool isGoogleProv func() bool isAzureProv func() bool clientID string authURL string issuerURL string scopes []string overrideScopes bool } // Logger interface for dependency injection type Logger interface { Debugf(format string, args ...interface{}) Errorf(format string, args ...interface{}) } // NewAuthHandler creates a new AuthHandler instance func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool, clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler { return &AuthHandler{ logger: logger, enablePKCE: enablePKCE, isGoogleProv: isGoogleProv, isAzureProv: isAzureProv, clientID: clientID, authURL: authURL, issuerURL: issuerURL, scopes: scopes, overrideScopes: overrideScopes, } } // InitiateAuthentication initiates the OIDC authentication flow. // It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session, // stores authentication state, and redirects the user to the OIDC provider. func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string, generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) { h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI()) const maxRedirects = 5 redirectCount := session.GetRedirectCount() if redirectCount >= maxRedirects { h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects) session.ResetRedirectCount() http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected) return } session.IncrementRedirectCount() csrfToken := uuid.NewString() nonce, err := generateNonce() if err != nil { h.logger.Errorf("Failed to generate nonce: %v", err) http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError) return } // Generate PKCE code verifier and challenge if PKCE is enabled var codeVerifier, codeChallenge string if h.enablePKCE { codeVerifier, err = generateCodeVerifier() if err != nil { h.logger.Errorf("Failed to generate code verifier: %v", err) http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError) return } codeChallenge, err = deriveCodeChallenge() if err != nil { h.logger.Errorf("Failed to generate code challenge: %v", err) http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError) return } h.logger.Debugf("PKCE enabled, generated code challenge") } session.SetAuthenticated(false) session.SetEmail("") session.SetAccessToken("") session.SetRefreshToken("") session.SetIDToken("") session.SetNonce("") session.SetCodeVerifier("") session.SetCSRF(csrfToken) session.SetNonce(nonce) if h.enablePKCE { session.SetCodeVerifier(codeVerifier) } session.SetIncomingPath(req.URL.RequestURI()) h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI()) session.MarkDirty() if err := session.Save(req, rw); err != nil { h.logger.Errorf("Failed to save session before redirecting to provider: %v", err) http.Error(rw, "Failed to save session", http.StatusInternalServerError) return } h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s", csrfToken, nonce) authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge) h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL) http.Redirect(rw, req, authURL, http.StatusFound) } // BuildAuthURL constructs the OIDC provider authorization URL. // It builds the URL with all necessary parameters including client_id, scopes, // PKCE parameters, and provider-specific parameters for Google and Azure. func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string { params := url.Values{} params.Set("client_id", h.clientID) params.Set("response_type", "code") params.Set("redirect_uri", redirectURL) params.Set("state", state) params.Set("nonce", nonce) if h.enablePKCE && codeChallenge != "" { params.Set("code_challenge", codeChallenge) params.Set("code_challenge_method", "S256") } scopes := make([]string, len(h.scopes)) copy(scopes, h.scopes) if h.isGoogleProv() { params.Set("access_type", "offline") h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens") params.Set("prompt", "consent") h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens") } else if h.isAzureProv() { params.Set("response_mode", "query") h.logger.Debugf("Azure AD provider detected, added response_mode=query") hasOfflineAccess := false for _, scope := range scopes { if scope == "offline_access" { hasOfflineAccess = true break } } if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) { if !hasOfflineAccess { scopes = append(scopes, "offline_access") h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes)) } } else { h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes)) } } else { if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) { hasOfflineAccess := false for _, scope := range scopes { if scope == "offline_access" { hasOfflineAccess = true break } } if !hasOfflineAccess { scopes = append(scopes, "offline_access") h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes)) } } else { h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes)) } } if len(scopes) > 0 { finalScopeString := strings.Join(scopes, " ") params.Set("scope", finalScopeString) h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString) } return h.buildURLWithParams(h.authURL, params) } // buildURLWithParams constructs a URL by combining a base URL with query parameters. // It handles both relative and absolute URLs, validates URL security, // and properly encodes query parameters. func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string { if baseURL != "" { if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") { if err := h.validateURL(baseURL); err != nil { h.logger.Errorf("URL validation failed for %s: %v", baseURL, err) return "" } } } if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { issuerURLParsed, err := url.Parse(h.issuerURL) if err != nil { h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err) return "" } baseURLParsed, err := url.Parse(baseURL) if err != nil { h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err) return "" } resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed) if err := h.validateURL(resolvedURL.String()); err != nil { h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err) return "" } resolvedURL.RawQuery = params.Encode() return resolvedURL.String() } u, err := url.Parse(baseURL) if err != nil { h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err) return "" } if err := h.validateParsedURL(u); err != nil { h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err) return "" } u.RawQuery = params.Encode() return u.String() } // validateURL performs security validation on URLs to prevent SSRF attacks. // It checks for allowed schemes, validates hosts, and prevents access to private networks. func (h *AuthHandler) validateURL(urlStr string) error { if urlStr == "" { return fmt.Errorf("empty URL") } u, err := url.Parse(urlStr) if err != nil { return fmt.Errorf("invalid URL format: %w", err) } return h.validateParsedURL(u) } // validateParsedURL validates a parsed URL structure for security. // It checks schemes, hosts, and paths to prevent malicious URLs. func (h *AuthHandler) validateParsedURL(u *url.URL) error { allowedSchemes := map[string]bool{ "https": true, "http": true, } if !allowedSchemes[u.Scheme] { return fmt.Errorf("disallowed URL scheme: %s", u.Scheme) } if u.Scheme == "http" { h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String()) } if u.Host == "" { return fmt.Errorf("missing host in URL") } if err := h.validateHost(u.Host); err != nil { return fmt.Errorf("invalid host: %w", err) } if strings.Contains(u.Path, "..") { return fmt.Errorf("path traversal detected in URL path") } return nil } // validateHost validates a hostname for security and reachability. // It prevents access to private networks and localhost addresses. func (h *AuthHandler) validateHost(host string) error { if host == "" { return fmt.Errorf("empty host") } // Strip port if present if strings.Contains(host, ":") { var err error host, _, err = net.SplitHostPort(host) if err != nil { return fmt.Errorf("invalid host:port format: %w", err) } } // Check for localhost variations localhostVariations := []string{ "localhost", "127.0.0.1", "::1", "0.0.0.0", } for _, localhost := range localhostVariations { if strings.EqualFold(host, localhost) { return fmt.Errorf("localhost access not allowed: %s", host) } } // Try to parse as IP address if ip := net.ParseIP(host); ip != nil { if ip.IsLoopback() { return fmt.Errorf("loopback IP not allowed: %s", host) } if ip.IsPrivate() { return fmt.Errorf("private IP not allowed: %s", host) } if ip.IsLinkLocalUnicast() { return fmt.Errorf("link-local IP not allowed: %s", host) } if ip.IsMulticast() { return fmt.Errorf("multicast IP not allowed: %s", host) } } return nil } // SessionData interface for dependency injection type SessionData interface { GetRedirectCount() int ResetRedirectCount() IncrementRedirectCount() SetAuthenticated(bool) SetEmail(string) SetAccessToken(string) SetRefreshToken(string) SetIDToken(string) SetNonce(string) SetCodeVerifier(string) SetCSRF(string) SetIncomingPath(string) MarkDirty() Save(req *http.Request, rw http.ResponseWriter) error }