mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Refactor codebase for clarity and consistency.
This commit is contained in:
+141
-191
@@ -5,7 +5,6 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -15,256 +14,207 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func generateNonce() (string, error) {
|
func generateNonce() (string, error) {
|
||||||
nonceBytes := make([]byte, 32)
|
nonceBytes := make([]byte, 32)
|
||||||
_, err := rand.Read(nonceBytes)
|
_, err := rand.Read(nonceBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not generate nonce")
|
return "", fmt.Errorf("could not generate nonce: %w", err)
|
||||||
}
|
}
|
||||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func assembleRedirectURL(scheme, host, path string) string {
|
func buildFullURL(scheme, host, path string) string {
|
||||||
if scheme == "" {
|
if scheme == "" {
|
||||||
// infoLogger.Println("Scheme is empty, defaulting to http")
|
scheme = "http"
|
||||||
scheme = "http"
|
}
|
||||||
}
|
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||||
return scheme + "://" + host + path
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code string, redirectURL string) (map[string]interface{}, error) {
|
func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectURL string) (map[string]interface{}, error) {
|
||||||
data := url.Values{}
|
data := url.Values{
|
||||||
data.Set("grant_type", "authorization_code")
|
"grant_type": {"authorization_code"},
|
||||||
data.Set("code", code)
|
"code": {code},
|
||||||
data.Set("client_id", t.clientID)
|
"client_id": {t.clientID},
|
||||||
data.Set("client_secret", t.clientSecret)
|
"client_secret": {t.clientSecret},
|
||||||
data.Set("redirect_uri", redirectURL) // Use the full redirect URL
|
"redirect_uri": {redirectURL},
|
||||||
|
}
|
||||||
|
|
||||||
// infoLogger.Printf("Exchanging code for token with redirect_uri: %s", redirectURL)
|
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
resp, err := t.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
defer resp.Body.Close()
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
var result map[string]interface{}
|
||||||
if err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result map[string]interface{}
|
return result, nil
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// infoLogger.Printf("Token response: %+v", result)
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
|
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
session, err := t.store.Get(req, cookie_name)
|
session, err := t.store.Get(req, cookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// infoLogger.Printf("Error getting session: %v", err)
|
handleError(rw, "Session error", http.StatusInternalServerError)
|
||||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
return false, ""
|
||||||
return false, ""
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// infoLogger.Printf("Session values: %+v", session.Values)
|
callbackState := req.URL.Query().Get("state")
|
||||||
|
sessionState, ok := session.Values["csrf"].(string)
|
||||||
|
if !ok || callbackState != sessionState {
|
||||||
|
handleError(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
callbackState := req.URL.Query().Get("state")
|
code := req.URL.Query().Get("code")
|
||||||
sessionState, ok := session.Values["csrf"].(string)
|
redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath)
|
||||||
// infoLogger.Printf("Callback state: %s, Session state: %s, Match: %v", callbackState, sessionState, ok && callbackState == sessionState)
|
oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL)
|
||||||
|
if err != nil {
|
||||||
|
handleError(rw, "Failed to exchange token", http.StatusInternalServerError)
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
if !ok || callbackState != sessionState {
|
rawIDToken, ok := oauth2Token["id_token"].(string)
|
||||||
// infoLogger.Printf("Invalid state parameter: callback=%s, session=%s", callbackState, sessionState)
|
if !ok {
|
||||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
handleError(rw, "No id_token field in oauth2 token", http.StatusInternalServerError)
|
||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
code := req.URL.Query().Get("code")
|
if err := t.verifyToken(rawIDToken); err != nil {
|
||||||
redirectURL := assembleRedirectURL(t.scheme, req.Host, t.redirURLPath)
|
handleError(rw, "Failed to verify token", http.StatusUnauthorized)
|
||||||
oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL)
|
return false, ""
|
||||||
if err != nil {
|
}
|
||||||
// infoLogger.Printf("Failed to exchange token: %v", err)
|
|
||||||
http.Error(rw, "Failed to exchange token", http.StatusInternalServerError)
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
rawIDToken, ok := oauth2Token["id_token"].(string)
|
claims, err := extractClaims(rawIDToken)
|
||||||
if !ok {
|
if err != nil {
|
||||||
// infoLogger.Printf("No id_token field in oauth2 token")
|
handleError(rw, "Failed to extract claims", http.StatusInternalServerError)
|
||||||
http.Error(rw, "No id_token field in oauth2 token", http.StatusInternalServerError)
|
return false, ""
|
||||||
return false, ""
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.verifyToken(rawIDToken); err != nil {
|
email, _ := claims["email"].(string)
|
||||||
// infoLogger.Printf("Token verification failed: %v", err)
|
|
||||||
http.Error(rw, "Failed to verify token", http.StatusUnauthorized)
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
// infoLogger.Printf("Token verification successful")
|
|
||||||
|
|
||||||
claims, err := extractClaims(rawIDToken)
|
session.Values["authenticated"] = true
|
||||||
if err != nil {
|
session.Values["id_token"] = rawIDToken
|
||||||
// infoLogger.Printf("Failed to extract claims: %v", err)
|
session.Values["email"] = email
|
||||||
http.Error(rw, "Failed to extract claims", http.StatusInternalServerError)
|
if err := session.Save(req, rw); err != nil {
|
||||||
return false, ""
|
handleError(rw, "Failed to save session", http.StatusInternalServerError)
|
||||||
}
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
email, _ := claims["email"].(string)
|
originalPath, ok := session.Values["incoming_path"].(string)
|
||||||
|
if !ok {
|
||||||
|
originalPath = "/"
|
||||||
|
}
|
||||||
|
delete(session.Values, "incoming_path")
|
||||||
|
|
||||||
session.Values["authenticated"] = true
|
return true, originalPath
|
||||||
session.Values["id_token"] = rawIDToken
|
|
||||||
session.Values["email"] = email
|
|
||||||
if err := session.Save(req, rw); err != nil {
|
|
||||||
// infoLogger.Printf("Failed to save session: %v", err)
|
|
||||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// infoLogger.Printf("User %s authenticated\n", email)
|
|
||||||
originalPath, ok := session.Values["incoming_path"].(string)
|
|
||||||
if !ok {
|
|
||||||
originalPath = "/"
|
|
||||||
}
|
|
||||||
delete(session.Values, "incoming_path")
|
|
||||||
|
|
||||||
return true, originalPath
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||||
parts := strings.Split(tokenString, ".")
|
parts := strings.Split(tokenString, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, errors.New("invalid token format")
|
return nil, fmt.Errorf("invalid token format")
|
||||||
}
|
}
|
||||||
|
|
||||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode token payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims map[string]interface{}
|
var claims map[string]interface{}
|
||||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
|
||||||
|
|
||||||
func verifyToken(token string, publicKey []byte) (map[string]interface{}, error) {
|
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return nil, errors.New("invalid token format")
|
|
||||||
}
|
|
||||||
|
|
||||||
payloadJson, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var claims map[string]interface{}
|
|
||||||
err = json.Unmarshal(payloadJson, &claims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if exp, ok := claims["exp"].(float64); ok {
|
|
||||||
if time.Now().Unix() > int64(exp) {
|
|
||||||
return nil, errors.New("token expired")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Placeholder for signature verification
|
|
||||||
// err = verifySignature(parts[0]+"."+parts[1], parts[2], publicKey)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
return claims, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsedTokens struct {
|
type UsedTokens struct {
|
||||||
tokens map[string]bool
|
tokens map[string]bool
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenBlacklist struct {
|
type TokenBlacklist struct {
|
||||||
blacklist map[string]time.Time
|
blacklist map[string]time.Time
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTokenBlacklist() *TokenBlacklist {
|
func NewTokenBlacklist() *TokenBlacklist {
|
||||||
return &TokenBlacklist{
|
return &TokenBlacklist{
|
||||||
blacklist: make(map[string]time.Time),
|
blacklist: make(map[string]time.Time),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||||
tb.mutex.Lock()
|
tb.mutex.Lock()
|
||||||
defer tb.mutex.Unlock()
|
defer tb.mutex.Unlock()
|
||||||
tb.blacklist[tokenID] = expiration
|
tb.blacklist[tokenID] = expiration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||||
tb.mutex.RLock()
|
tb.mutex.RLock()
|
||||||
defer tb.mutex.RUnlock()
|
defer tb.mutex.RUnlock()
|
||||||
expiration, exists := tb.blacklist[tokenID]
|
expiration, exists := tb.blacklist[tokenID]
|
||||||
return exists && time.Now().Before(expiration)
|
return exists && time.Now().Before(expiration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tb *TokenBlacklist) Cleanup() {
|
func (tb *TokenBlacklist) Cleanup() {
|
||||||
tb.mutex.Lock()
|
tb.mutex.Lock()
|
||||||
defer tb.mutex.Unlock()
|
defer tb.mutex.Unlock()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for tokenID, expiration := range tb.blacklist {
|
for tokenID, expiration := range tb.blacklist {
|
||||||
if now.After(expiration) {
|
if now.After(expiration) {
|
||||||
delete(tb.blacklist, tokenID)
|
delete(tb.blacklist, tokenID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenCache struct {
|
type TokenCache struct {
|
||||||
cache map[string]*TokenInfo
|
cache map[string]*TokenInfo
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenInfo struct {
|
type TokenInfo struct {
|
||||||
Token string
|
Token string
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTokenCache() *TokenCache {
|
func NewTokenCache() *TokenCache {
|
||||||
return &TokenCache{
|
return &TokenCache{
|
||||||
cache: make(map[string]*TokenInfo),
|
cache: make(map[string]*TokenInfo),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
|
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
|
||||||
tc.mutex.Lock()
|
tc.mutex.Lock()
|
||||||
defer tc.mutex.Unlock()
|
defer tc.mutex.Unlock()
|
||||||
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
|
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
|
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
|
||||||
tc.mutex.RLock()
|
tc.mutex.RLock()
|
||||||
defer tc.mutex.RUnlock()
|
defer tc.mutex.RUnlock()
|
||||||
info, exists := tc.cache[token]
|
info, exists := tc.cache[token]
|
||||||
if exists && time.Now().Before(info.ExpiresAt) {
|
if exists && time.Now().Before(info.ExpiresAt) {
|
||||||
return info, true
|
return info, true
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *TokenCache) Cleanup() {
|
func (tc *TokenCache) Cleanup() {
|
||||||
tc.mutex.Lock()
|
tc.mutex.Lock()
|
||||||
defer tc.mutex.Unlock()
|
defer tc.mutex.Unlock()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for token, info := range tc.cache {
|
for token, info := range tc.cache {
|
||||||
if now.After(info.ExpiresAt) {
|
if now.After(info.ExpiresAt) {
|
||||||
delete(tc.cache, token)
|
delete(tc.cache, token)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,135 +14,135 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type JWK struct {
|
type JWK struct {
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
Kid string `json:"kid"`
|
Kid string `json:"kid"`
|
||||||
Use string `json:"use"`
|
Use string `json:"use"`
|
||||||
N string `json:"n"`
|
N string `json:"n"`
|
||||||
E string `json:"e"`
|
E string `json:"e"`
|
||||||
Alg string `json:"alg"`
|
Alg string `json:"alg"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWKSet struct {
|
type JWKSet struct {
|
||||||
Keys []JWK `json:"keys"`
|
Keys []JWK `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type JWKCache struct {
|
type JWKCache struct {
|
||||||
jwks *JWKSet
|
jwks *JWKSet
|
||||||
expiresAt time.Time
|
expiresAt time.Time
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *JWKCache) GetJWKS(jwksURL string) (*JWKSet, error) {
|
func (c *JWKCache) GetJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) {
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||||
defer c.mutex.RUnlock()
|
defer c.mutex.RUnlock()
|
||||||
return c.jwks, nil
|
return c.jwks, nil
|
||||||
}
|
}
|
||||||
c.mutex.RUnlock()
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||||
return c.jwks, nil
|
return c.jwks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
jwks, err := fetchJWKS(jwksURL)
|
jwks, err := fetchJWKS(jwksURL, httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.jwks = jwks
|
c.jwks = jwks
|
||||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
|
||||||
return jwks, nil
|
return jwks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchJWKS(jwksURL string) (*JWKSet, error) {
|
func fetchJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) {
|
||||||
resp, err := http.Get(jwksURL)
|
resp, err := httpClient.Get(jwksURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, errors.New("failed to fetch JWKS")
|
return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var jwks JWKSet
|
var jwks JWKSet
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &jwks, nil
|
return &jwks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyNonce(tokenNonce, expectedNonce string) error {
|
func verifyNonce(tokenNonce, expectedNonce string) error {
|
||||||
if tokenNonce != expectedNonce {
|
if tokenNonce != expectedNonce {
|
||||||
return errors.New("invalid nonce")
|
return fmt.Errorf("invalid nonce")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyAudience(tokenAudience, expectedAudience string) error {
|
func verifyAudience(tokenAudience, expectedAudience string) error {
|
||||||
if tokenAudience != expectedAudience {
|
if tokenAudience != expectedAudience {
|
||||||
return errors.New("invalid audience")
|
return fmt.Errorf("invalid audience")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
|
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
if now < issuedAt-int64(allowedClockSkew.Seconds()) {
|
if now < issuedAt-int64(allowedClockSkew.Seconds()) {
|
||||||
return errors.New("token used before issued")
|
return fmt.Errorf("token used before issued")
|
||||||
}
|
}
|
||||||
if now > expiration+int64(allowedClockSkew.Seconds()) {
|
if now > expiration+int64(allowedClockSkew.Seconds()) {
|
||||||
return errors.New("token is expired")
|
return fmt.Errorf("token is expired")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||||
if tokenIssuer != expectedIssuer {
|
if tokenIssuer != expectedIssuer {
|
||||||
return errors.New("invalid issuer")
|
return fmt.Errorf("invalid issuer")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateClaims(claims map[string]interface{}) error {
|
func validateClaims(claims map[string]interface{}) error {
|
||||||
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
|
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
|
||||||
for _, claim := range requiredClaims {
|
for _, claim := range requiredClaims {
|
||||||
if _, ok := claims[claim]; !ok {
|
if _, ok := claims[claim]; !ok {
|
||||||
return fmt.Errorf("missing required claim: %s", claim)
|
return fmt.Errorf("missing required claim: %s", claim)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||||
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
|
||||||
}
|
}
|
||||||
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := &rsa.PublicKey{
|
publicKey := &rsa.PublicKey{
|
||||||
N: new(big.Int).SetBytes(n),
|
N: new(big.Int).SetBytes(n),
|
||||||
E: int(new(big.Int).SetBytes(e).Int64()),
|
E: int(new(big.Int).SetBytes(e).Int64()),
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: publicKeyBytes,
|
Bytes: publicKeyBytes,
|
||||||
})
|
})
|
||||||
|
|
||||||
return publicKeyPEM, nil
|
return publicKeyPEM, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,122 +8,125 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWT struct {
|
type JWT struct {
|
||||||
Header map[string]interface{}
|
Header map[string]interface{}
|
||||||
Claims map[string]interface{}
|
Claims map[string]interface{}
|
||||||
Signature string
|
Signature string
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseJWT(token string) (*JWT, error) {
|
func parseJWT(token string) (*JWT, error) {
|
||||||
parts := strings.Split(token, ".")
|
parts := strings.Split(token, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, errors.New("invalid token format")
|
return nil, fmt.Errorf("invalid token format")
|
||||||
}
|
}
|
||||||
|
|
||||||
header, err := decodeSegment(parts[0])
|
header, err := decodeSegment(parts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode header: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := decodeSegment(parts[1])
|
claims, err := decodeSegment(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode claims: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &JWT{
|
return &JWT{
|
||||||
Header: header,
|
Header: header,
|
||||||
Claims: claims,
|
Claims: claims,
|
||||||
Signature: parts[2],
|
Signature: parts[2],
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||||
claims := j.Claims
|
claims := j.Claims
|
||||||
|
|
||||||
if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil {
|
if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := verifyAudience(claims["aud"].(string), clientID); err != nil {
|
if err := verifyAudience(claims["aud"].(string), clientID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := verifyExpiration(claims["exp"].(float64)); err != nil {
|
if err := verifyExpiration(claims["exp"].(float64)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := verifyIssuedAt(claims["iat"].(float64)); err != nil {
|
if err := verifyIssuedAt(claims["iat"].(float64)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyExpiration(expiration float64) error {
|
func verifyExpiration(expiration float64) error {
|
||||||
expirationTime := time.Unix(int64(expiration), 0)
|
expirationTime := time.Unix(int64(expiration), 0)
|
||||||
if time.Now().After(expirationTime) {
|
if time.Now().After(expirationTime) {
|
||||||
return errors.New("token has expired")
|
return fmt.Errorf("token has expired")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifySignature(token string, publicKeyPEM []byte) error {
|
func verifySignature(token string, publicKeyPEM []byte) error {
|
||||||
parts := strings.Split(token, ".")
|
parts := strings.Split(token, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return errors.New("invalid token format")
|
return fmt.Errorf("invalid token format")
|
||||||
}
|
}
|
||||||
|
|
||||||
block, _ := pem.Decode(publicKeyPEM)
|
block, _ := pem.Decode(publicKeyPEM)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return errors.New("failed to parse PEM block containing the public key")
|
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to parse public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rsaPublicKey, ok := pub.(*rsa.PublicKey)
|
rsaPublicKey, ok := pub.(*rsa.PublicKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("not an RSA public key")
|
return fmt.Errorf("not an RSA public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
signedContent := parts[0] + "." + parts[1]
|
signedContent := parts[0] + "." + parts[1]
|
||||||
signature, _ := base64.RawURLEncoding.DecodeString(parts[2])
|
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to decode signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(signedContent))
|
hash := sha256.Sum256([]byte(signedContent))
|
||||||
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature)
|
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("invalid token signature")
|
return fmt.Errorf("invalid token signature: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyIssuedAt(issuedAt float64) error {
|
func verifyIssuedAt(issuedAt float64) error {
|
||||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||||
if time.Now().Before(issuedAtTime) {
|
if time.Now().Before(issuedAtTime) {
|
||||||
return errors.New("token used before issued")
|
return fmt.Errorf("token used before issued")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeSegment(seg string) (map[string]interface{}, error) {
|
func decodeSegment(seg string) (map[string]interface{}, error) {
|
||||||
data, err := base64.RawURLEncoding.DecodeString(seg)
|
data, err := base64.RawURLEncoding.DecodeString(seg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode segment: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
err = json.Unmarshal(data, &result)
|
err = json.Unmarshal(data, &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to unmarshal segment: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,10 +3,7 @@ package traefikoidc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,14 +14,10 @@ import (
|
|||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
infoLogger = log.New(io.Discard, "INFO: traefikoidc: ", log.Ldate|log.Ltime)
|
|
||||||
)
|
|
||||||
|
|
||||||
type TraefikOidc struct {
|
type TraefikOidc struct {
|
||||||
next http.Handler
|
next http.Handler
|
||||||
name string
|
name string
|
||||||
store *sessions.CookieStore
|
store sessions.Store
|
||||||
redirURLPath string
|
redirURLPath string
|
||||||
issuerURL string
|
issuerURL string
|
||||||
jwkCache *JWKCache
|
jwkCache *JWKCache
|
||||||
@@ -39,6 +32,8 @@ type TraefikOidc struct {
|
|||||||
forceHTTPS bool
|
forceHTTPS bool
|
||||||
scheme string
|
scheme string
|
||||||
tokenCache *TokenCache
|
tokenCache *TokenCache
|
||||||
|
httpClient HTTPClient
|
||||||
|
logger Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProviderMetadata struct {
|
type ProviderMetadata struct {
|
||||||
@@ -58,10 +53,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
|||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, err := discoverProviderMetadata(config.ProviderURL)
|
metadata, err := discoverProviderMetadata(config.ProviderURL, &http.Client{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to discover provider metadata: %v", err)
|
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
logger := NewLogger(config.LogLevel)
|
||||||
|
|
||||||
t := &TraefikOidc{
|
t := &TraefikOidc{
|
||||||
next: next,
|
next: next,
|
||||||
@@ -80,17 +76,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
|||||||
scopes: config.Scopes,
|
scopes: config.Scopes,
|
||||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
||||||
tokenCache: NewTokenCache(),
|
tokenCache: NewTokenCache(),
|
||||||
|
httpClient: &http.Client{},
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
t.startTokenCleanup()
|
t.startTokenCleanup()
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) {
|
func discoverProviderMetadata(providerURL string, httpClient HTTPClient) (*ProviderMetadata, error) {
|
||||||
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
||||||
resp, err := http.Get(wellKnownURL)
|
resp, err := httpClient.Get(wellKnownURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@@ -100,13 +98,47 @@ func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) {
|
|||||||
|
|
||||||
var metadata ProviderMetadata
|
var metadata ProviderMetadata
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to decode provider metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &metadata, nil
|
return &metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
t.scheme = t.determineScheme(req)
|
||||||
|
host := t.determineHost(req)
|
||||||
|
|
||||||
|
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
|
||||||
|
t.logger.Infof("Final redirect URL: %s", redirectURL)
|
||||||
|
|
||||||
|
session, err := t.store.Get(req, cookieName)
|
||||||
|
if err != nil {
|
||||||
|
t.logger.Errorf("Error getting session: %v", err)
|
||||||
|
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.URL.Path == t.redirURLPath {
|
||||||
|
t.logger.Infof("Handling callback, URL: %s", req.URL.String())
|
||||||
|
authSuccess, originalPath := t.handleCallback(rw, req)
|
||||||
|
if authSuccess {
|
||||||
|
http.Redirect(rw, req, originalPath, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(rw, "Authentication failed", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.isUserAuthenticated(session) {
|
||||||
|
t.next.ServeHTTP(rw, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// User is not authenticated, start the auth process
|
||||||
|
t.initiateAuthentication(rw, req, session, redirectURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||||
scheme := req.URL.Scheme
|
scheme := req.URL.Scheme
|
||||||
if scheme == "" {
|
if scheme == "" {
|
||||||
scheme = req.Header.Get("X-Forwarded-Proto")
|
scheme = req.Header.Get("X-Forwarded-Proto")
|
||||||
@@ -121,8 +153,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
if t.forceHTTPS {
|
if t.forceHTTPS {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
}
|
}
|
||||||
t.scheme = scheme
|
return scheme
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||||
host := req.URL.Host
|
host := req.URL.Host
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = req.Header.Get("X-Forwarded-Host")
|
host = req.Header.Get("X-Forwarded-Host")
|
||||||
@@ -130,74 +164,36 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
if host == "" {
|
if host == "" {
|
||||||
host = req.Host
|
host = req.Host
|
||||||
}
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
// infoLogger.Printf("Scheme: %s, Host: %s, Path: %s", scheme, host, t.redirURLPath)
|
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool {
|
||||||
// infoLogger.Printf("X-Forwarded-Proto: %s", req.Header.Get("X-Forwarded-Proto"))
|
|
||||||
// infoLogger.Printf("X-Forwarded-Host: %s", req.Header.Get("X-Forwarded-Host"))
|
|
||||||
redirectURL := assembleRedirectURL(t.scheme, host, t.redirURLPath)
|
|
||||||
// infoLogger.Printf("Final redirect URL: %s", redirectURL)
|
|
||||||
|
|
||||||
session, err := t.store.Get(req, cookie_name)
|
|
||||||
if err != nil {
|
|
||||||
// infoLogger.Printf("Error getting session: %v", err)
|
|
||||||
http.Error(rw, "Session error: "+err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.URL.Path == t.redirURLPath {
|
|
||||||
// infoLogger.Printf("Handling callback, URL: %s", req.URL.String())
|
|
||||||
authSuccess, originalPath := t.handleCallback(rw, req)
|
|
||||||
if authSuccess {
|
|
||||||
http.Redirect(rw, req, originalPath, http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// If auth was not successful, return an error instead of re-authenticating
|
|
||||||
http.Error(rw, "Authentication failed", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticated, _ := session.Values["authenticated"].(bool)
|
authenticated, _ := session.Values["authenticated"].(bool)
|
||||||
if authenticated {
|
if authenticated {
|
||||||
idToken, ok := session.Values["id_token"].(string)
|
idToken, ok := session.Values["id_token"].(string)
|
||||||
if !ok || idToken == "" {
|
if !ok || idToken == "" {
|
||||||
http.Error(rw, "Invalid session", http.StatusUnauthorized)
|
return false
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return t.verifyToken(idToken) == nil
|
||||||
if err := t.verifyToken(idToken); err != nil {
|
|
||||||
http.Error(rw, "Invalid token", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proceed with the request
|
|
||||||
t.next.ServeHTTP(rw, req)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// User is not authenticated, start the auth process
|
func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||||
csrfToken := uuid.New().String()
|
csrfToken := uuid.New().String()
|
||||||
session.Values["csrf"] = csrfToken
|
session.Values["csrf"] = csrfToken
|
||||||
session.Values["incoming_path"] = req.URL.Path
|
session.Values["incoming_path"] = req.URL.Path
|
||||||
// infoLogger.Printf("Setting CSRF token: %s", csrfToken)
|
t.logger.Infof("Setting CSRF token: %s", csrfToken)
|
||||||
err = session.Save(req, rw)
|
|
||||||
if err != nil {
|
|
||||||
// infoLogger.Printf("Failed to save session: %v", err)
|
|
||||||
http.Error(rw, "Failed to save session: "+err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the session was saved correctly
|
if err := session.Save(req, rw); err != nil {
|
||||||
verifySession, _ := t.store.Get(req, cookie_name)
|
t.logger.Errorf("Failed to save session: %v", err)
|
||||||
savedCSRF, ok := verifySession.Values["csrf"].(string)
|
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||||
if !ok || savedCSRF != csrfToken {
|
|
||||||
// infoLogger.Printf("Failed to save CSRF token. Saved: %s, Expected: %s", savedCSRF, csrfToken)
|
|
||||||
http.Error(rw, "Failed to save CSRF token", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce, err := generateNonce()
|
nonce, err := generateNonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "Failed to generate nonce: "+err.Error(), http.StatusInternalServerError)
|
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,22 +201,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) isUserAuthenticated(req *http.Request) bool {
|
|
||||||
session, err := t.store.Get(req, cookie_name)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TraefikOidc) verifyToken(token string) error {
|
func (t *TraefikOidc) verifyToken(token string) error {
|
||||||
if !t.limiter.Allow() {
|
if !t.limiter.Allow() {
|
||||||
return errors.New("rate limit exceeded")
|
return fmt.Errorf("rate limit exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, exists := t.tokenCache.Get(token); exists {
|
if _, exists := t.tokenCache.Get(token); exists {
|
||||||
@@ -229,17 +212,17 @@ func (t *TraefikOidc) verifyToken(token string) error {
|
|||||||
|
|
||||||
jwt, err := parseJWT(token)
|
jwt, err := parseJWT(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL)
|
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
kid, ok := jwt.Header["kid"].(string)
|
kid, ok := jwt.Header["kid"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("missing key ID in token header")
|
return fmt.Errorf("missing key ID in token header")
|
||||||
}
|
}
|
||||||
|
|
||||||
var publicKeyPEM []byte
|
var publicKeyPEM []byte
|
||||||
@@ -247,38 +230,22 @@ func (t *TraefikOidc) verifyToken(token string) error {
|
|||||||
if key.Kid == kid {
|
if key.Kid == kid {
|
||||||
publicKeyPEM, err = jwkToPEM(&key)
|
publicKeyPEM, err = jwkToPEM(&key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if publicKeyPEM == nil {
|
if publicKeyPEM == nil {
|
||||||
return errors.New("unable to find matching public key")
|
return fmt.Errorf("unable to find matching public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := verifySignature(token, publicKeyPEM); err != nil {
|
if err := verifySignature(token, publicKeyPEM); err != nil {
|
||||||
return err
|
return fmt.Errorf("signature verification failed: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := verifyAudience(jwt.Claims["aud"].(string), t.clientID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
||||||
return err
|
return fmt.Errorf("JWT verification failed: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := verifyTokenTimes(
|
|
||||||
int64(jwt.Claims["iat"].(float64)),
|
|
||||||
int64(jwt.Claims["exp"].(float64)),
|
|
||||||
5*time.Minute, // Allowed clock skew
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateClaims(jwt.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||||
@@ -288,17 +255,16 @@ func (t *TraefikOidc) verifyToken(token string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||||
params := url.Values{}
|
params := url.Values{
|
||||||
params.Add("client_id", t.clientID)
|
"client_id": {t.clientID},
|
||||||
params.Add("response_type", "code")
|
"response_type": {"code"},
|
||||||
params.Add("redirect_uri", redirectURL)
|
"redirect_uri": {redirectURL},
|
||||||
params.Add("scope", strings.Join(t.scopes, " "))
|
"scope": {strings.Join(t.scopes, " ")},
|
||||||
params.Add("state", state)
|
"state": {state},
|
||||||
params.Add("nonce", nonce)
|
"nonce": {nonce},
|
||||||
|
}
|
||||||
|
|
||||||
authURL := t.authURL + "?" + params.Encode()
|
return fmt.Sprintf("%s?%s", t.authURL, params.Encode())
|
||||||
// infoLogger.Printf("Built auth URL: %s", authURL)
|
|
||||||
return authURL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TraefikOidc) startTokenCleanup() {
|
func (t *TraefikOidc) startTokenCleanup() {
|
||||||
@@ -306,6 +272,7 @@ func (t *TraefikOidc) startTokenCleanup() {
|
|||||||
go func() {
|
go func() {
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
t.tokenCache.Cleanup()
|
t.tokenCache.Cleanup()
|
||||||
|
t.tokenBlacklist.Cleanup()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|||||||
+61
-5
@@ -1,10 +1,13 @@
|
|||||||
package traefikoidc
|
package traefikoidc
|
||||||
|
|
||||||
import "os"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
// constants
|
|
||||||
const (
|
const (
|
||||||
cookie_name = "_raczylo_oidc"
|
cookieName = "_raczylo_oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -19,6 +22,59 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateConfig() *Config {
|
func CreateConfig() *Config {
|
||||||
infoLogger.SetOutput(os.Stdout)
|
return &Config{
|
||||||
return &Config{}
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
LogLevel: "info",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
if c.ProviderURL == "" {
|
||||||
|
return fmt.Errorf("providerURL is required")
|
||||||
|
}
|
||||||
|
if c.CallbackURL == "" {
|
||||||
|
return fmt.Errorf("callbackURL is required")
|
||||||
|
}
|
||||||
|
if c.ClientID == "" {
|
||||||
|
return fmt.Errorf("clientID is required")
|
||||||
|
}
|
||||||
|
if c.ClientSecret == "" {
|
||||||
|
return fmt.Errorf("clientSecret is required")
|
||||||
|
}
|
||||||
|
if c.SessionEncryptionKey == "" {
|
||||||
|
return fmt.Errorf("sessionEncryptionKey is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultLogger struct {
|
||||||
|
level string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogger(level string) Logger {
|
||||||
|
return &defaultLogger{level: level}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||||
|
if l.level == "info" || l.level == "debug" {
|
||||||
|
fmt.Printf("INFO: "+format+"\n", args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPClient interface {
|
||||||
|
Get(url string) (*http.Response, error)
|
||||||
|
Do(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Infof(format string, args ...interface{})
|
||||||
|
Errorf(format string, args ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleError(w http.ResponseWriter, message string, code int) {
|
||||||
|
http.Error(w, message, code)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user