mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Refactor codebase for clarity and consistency.
This commit is contained in:
+141
-191
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -15,256 +14,207 @@ import (
|
||||
)
|
||||
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate nonce")
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate nonce: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
func assembleRedirectURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
// infoLogger.Println("Scheme is empty, defaulting to http")
|
||||
scheme = "http"
|
||||
}
|
||||
return scheme + "://" + host + path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code string, redirectURL string) (map[string]interface{}, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("client_id", t.clientID)
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
data.Set("redirect_uri", redirectURL) // Use the full redirect URL
|
||||
func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectURL string) (map[string]interface{}, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
"redirect_uri": {redirectURL},
|
||||
}
|
||||
|
||||
// infoLogger.Printf("Exchanging code for token with redirect_uri: %s", redirectURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
resp, err := t.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// infoLogger.Printf("Token response: %+v", result)
|
||||
|
||||
return result, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
|
||||
ctx := req.Context()
|
||||
session, err := t.store.Get(req, cookie_name)
|
||||
if err != nil {
|
||||
// infoLogger.Printf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
ctx := req.Context()
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
handleError(rw, "Session error", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// infoLogger.Printf("Session values: %+v", session.Values)
|
||||
callbackState := req.URL.Query().Get("state")
|
||||
sessionState, ok := session.Values["csrf"].(string)
|
||||
if !ok || callbackState != sessionState {
|
||||
handleError(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
callbackState := req.URL.Query().Get("state")
|
||||
sessionState, ok := session.Values["csrf"].(string)
|
||||
// infoLogger.Printf("Callback state: %s, Session state: %s, Match: %v", callbackState, sessionState, ok && callbackState == sessionState)
|
||||
code := req.URL.Query().Get("code")
|
||||
redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath)
|
||||
oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to exchange token", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if !ok || callbackState != sessionState {
|
||||
// infoLogger.Printf("Invalid state parameter: callback=%s, session=%s", callbackState, sessionState)
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return false, ""
|
||||
}
|
||||
rawIDToken, ok := oauth2Token["id_token"].(string)
|
||||
if !ok {
|
||||
handleError(rw, "No id_token field in oauth2 token", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
code := req.URL.Query().Get("code")
|
||||
redirectURL := assembleRedirectURL(t.scheme, req.Host, t.redirURLPath)
|
||||
oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL)
|
||||
if err != nil {
|
||||
// infoLogger.Printf("Failed to exchange token: %v", err)
|
||||
http.Error(rw, "Failed to exchange token", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
if err := t.verifyToken(rawIDToken); err != nil {
|
||||
handleError(rw, "Failed to verify token", http.StatusUnauthorized)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token["id_token"].(string)
|
||||
if !ok {
|
||||
// infoLogger.Printf("No id_token field in oauth2 token")
|
||||
http.Error(rw, "No id_token field in oauth2 token", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
claims, err := extractClaims(rawIDToken)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to extract claims", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if err := t.verifyToken(rawIDToken); err != nil {
|
||||
// infoLogger.Printf("Token verification failed: %v", err)
|
||||
http.Error(rw, "Failed to verify token", http.StatusUnauthorized)
|
||||
return false, ""
|
||||
}
|
||||
// infoLogger.Printf("Token verification successful")
|
||||
email, _ := claims["email"].(string)
|
||||
|
||||
claims, err := extractClaims(rawIDToken)
|
||||
if err != nil {
|
||||
// infoLogger.Printf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Failed to extract claims", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
session.Values["email"] = email
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
handleError(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
originalPath, ok := session.Values["incoming_path"].(string)
|
||||
if !ok {
|
||||
originalPath = "/"
|
||||
}
|
||||
delete(session.Values, "incoming_path")
|
||||
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
session.Values["email"] = email
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
// infoLogger.Printf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// infoLogger.Printf("User %s authenticated\n", email)
|
||||
originalPath, ok := session.Values["incoming_path"].(string)
|
||||
if !ok {
|
||||
originalPath = "/"
|
||||
}
|
||||
delete(session.Values, "incoming_path")
|
||||
|
||||
return true, originalPath
|
||||
return true, originalPath
|
||||
}
|
||||
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func verifyToken(token string, publicKey []byte) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
|
||||
payloadJson, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
err = json.Unmarshal(payloadJson, &claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
// Placeholder for signature verification
|
||||
// err = verifySignature(parts[0]+"."+parts[1], parts[2], publicKey)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
return claims, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
type UsedTokens struct {
|
||||
tokens map[string]bool
|
||||
mutex sync.RWMutex
|
||||
tokens map[string]bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type TokenBlacklist struct {
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
expiration, exists := tb.blacklist[tokenID]
|
||||
return exists && time.Now().Before(expiration)
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
expiration, exists := tb.blacklist[tokenID]
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for tokenID, expiration := range tb.blacklist {
|
||||
if now.After(expiration) {
|
||||
delete(tb.blacklist, tokenID)
|
||||
}
|
||||
}
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for tokenID, expiration := range tb.blacklist {
|
||||
if now.After(expiration) {
|
||||
delete(tb.blacklist, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TokenCache struct {
|
||||
cache map[string]*TokenInfo
|
||||
mutex sync.RWMutex
|
||||
cache map[string]*TokenInfo
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: make(map[string]*TokenInfo),
|
||||
}
|
||||
return &TokenCache{
|
||||
cache: make(map[string]*TokenInfo),
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
|
||||
}
|
||||
|
||||
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
info, exists := tc.cache[token]
|
||||
if exists && time.Now().Before(info.ExpiresAt) {
|
||||
return info, true
|
||||
}
|
||||
return nil, false
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
info, exists := tc.cache[token]
|
||||
if exists && time.Now().Before(info.ExpiresAt) {
|
||||
return info, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for token, info := range tc.cache {
|
||||
if now.After(info.ExpiresAt) {
|
||||
delete(tc.cache, token)
|
||||
}
|
||||
}
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for token, info := range tc.cache {
|
||||
if now.After(info.ExpiresAt) {
|
||||
delete(tc.cache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user