diff --git a/main.go b/main.go index f67c71f..ac1df78 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/google/uuid" @@ -59,6 +60,8 @@ type TraefikOidc struct { initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) + initOnce sync.Once + initComplete chan struct{} } type ProviderMetadata struct { @@ -203,11 +206,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } } - metadata, err := discoverProviderMetadata(config.ProviderURL, httpClient, NewLogger(config.LogLevel)) - if err != nil { - return nil, fmt.Errorf("failed to discover provider metadata: %w", err) - } - t := &TraefikOidc{ next: next, name: name, @@ -219,21 +217,17 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } return config.LogoutURL }(), - issuerURL: metadata.Issuer, - revocationURL: metadata.RevokeURL, tokenBlacklist: NewTokenBlacklist(), jwkCache: &JWKCache{}, - jwksURL: metadata.JWKSURL, - clientID: config.ClientID, - clientSecret: config.ClientSecret, - forceHTTPS: config.ForceHTTPS, - authURL: metadata.AuthURL, - tokenURL: metadata.TokenURL, - scopes: config.Scopes, - limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), - tokenCache: NewTokenCache(), - httpClient: httpClient, - logger: NewLogger(config.LogLevel), + + clientID: config.ClientID, + clientSecret: config.ClientSecret, + forceHTTPS: config.ForceHTTPS, + scopes: config.Scopes, + limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), + tokenCache: NewTokenCache(), + httpClient: httpClient, + logger: NewLogger(config.LogLevel), excludedURLs: func() map[string]struct{} { m := make(map[string]struct{}) for _, url := range config.ExcludedURLs { @@ -256,6 +250,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } return m }(), + initComplete: make(chan struct{}), } t.initiateAuthenticationFunc = t.defaultInitiateAuthentication @@ -270,9 +265,28 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.tokenVerifier = t t.jwtVerifier = t t.startTokenCleanup() + go t.initializeMetadata(config.ProviderURL) + return t, nil } +func (t *TraefikOidc) initializeMetadata(providerURL string) { + t.initOnce.Do(func() { + metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) + if err != nil { + t.logger.Error("Failed to discover provider metadata: %v", err) + } else { + t.logger.Debug("Provider metadata discovered successfully") + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + } + close(t.initComplete) + }) +} + func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) { wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" @@ -333,6 +347,20 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad } func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + select { + case <-t.initComplete: + if t.issuerURL == "" { + t.logger.Debug("OIDC middleware not yet initialized") + http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError) + return + } + // Process the request as normal + case <-req.Context().Done(): + t.logger.Debug("Request cancelled") + http.Error(rw, "Request cancelled", http.StatusServiceUnavailable) + return + } + if t.determineExcludedURL(req.URL.Path) { t.next.ServeHTTP(rw, req) return