From 448392e9bdd31ac582868e20f6c15f0a0e94ba74 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 2 Aug 2024 00:37:51 +0100 Subject: [PATCH] Update: Don't refresh token / issue cookie on every request. --- helpers.go | 51 ++++++++++++++++--- main.go | 62 +++++++++++++++++------ main_test.go | 140 +++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 210 insertions(+), 43 deletions(-) diff --git a/helpers.go b/helpers.go index aabd568..d487a0c 100644 --- a/helpers.go +++ b/helpers.go @@ -31,13 +31,18 @@ func buildFullURL(scheme, host, path string) string { return fmt.Sprintf("%s://%s%s", scheme, host, path) } -func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectURL string) (map[string]interface{}, error) { +func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) { data := url.Values{ - "grant_type": {"authorization_code"}, - "code": {code}, + "grant_type": {grantType}, "client_id": {t.clientID}, "client_secret": {t.clientSecret}, - "redirect_uri": {redirectURL}, + } + + if grantType == "authorization_code" { + data.Set("code", codeOrToken) + data.Set("redirect_uri", redirectURL) + } else if grantType == "refresh_token" { + data.Set("refresh_token", codeOrToken) } req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) @@ -48,7 +53,7 @@ func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectUR resp, err := t.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to exchange code for token: %w", err) + return nil, fmt.Errorf("failed to exchange tokens: %w", err) } defer resp.Body.Close() @@ -60,6 +65,38 @@ func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectUR return result, nil } +type TokenResponse struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` +} + +func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { + ctx := context.Background() + result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "") + if err != nil { + return nil, fmt.Errorf("failed to refresh token: %w", err) + } + + response := &TokenResponse{ + IDToken: result["id_token"].(string), + AccessToken: result["access_token"].(string), + ExpiresIn: int(result["expires_in"].(float64)), + TokenType: result["token_type"].(string), + } + + // The refresh token might not be returned if it hasn't changed + if newRefreshToken, ok := result["refresh_token"].(string); ok { + response.RefreshToken = newRefreshToken + } else { + response.RefreshToken = refreshToken + } + + return response, nil +} + func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { session, err := t.store.Get(req, cookieName) t.logger.Debugf("Logging out user") @@ -96,7 +133,6 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque } func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { - ctx := req.Context() session, err := t.store.Get(req, cookieName) if err != nil { handleError(rw, "Session error", http.StatusInternalServerError, t.logger) @@ -113,7 +149,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) code := req.URL.Query().Get("code") redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath) - oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL) + oauth2Token, err := t.exchangeTokens(req.Context(), "authorization_code", code, redirectURL) if err != nil { handleError(rw, "Failed to exchange token", http.StatusUnauthorized, t.logger) return false, "" @@ -140,6 +176,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) session.Values["authenticated"] = true session.Values["id_token"] = rawIDToken + session.Values["refresh_token"] = oauth2Token["refresh_token"] session.Values["email"] = email if err := session.Save(req, rw); err != nil { handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger) diff --git a/main.go b/main.go index 9b8df99..6de6dfa 100644 --- a/main.go +++ b/main.go @@ -242,17 +242,23 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - authenticated, tokenExpired := t.isUserAuthenticated(session) + authenticated, tokenExpired, needsRefresh := t.isUserAuthenticated(session) if authenticated { - t.refreshSession(rw, req) - t.logger.Debugf("User is authenticated, serving content") + if needsRefresh { + // Attempt to refresh the token silently + if refreshed := t.refreshToken(rw, req, session); !refreshed { + // If refresh failed, re-authenticate + t.initiateAuthentication(rw, req, session, t.redirectURL) + return + } + } t.next.ServeHTTP(rw, req) return } if tokenExpired { t.logger.Debugf("Token has expired, initiating reauthentication") - t.handleExpiredToken(rw, req, session) + t.initiateAuthentication(rw, req, session, t.redirectURL) return } @@ -280,35 +286,41 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { return req.Host } -func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool) { +func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) { authenticated, _ := session.Values["authenticated"].(bool) if authenticated { idToken, ok := session.Values["id_token"].(string) if !ok || idToken == "" { - return false, false + return false, false, false } claims, err := extractClaims(idToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) - return false, false + return false, false, false } exp, ok := claims["exp"].(float64) if !ok { t.logger.Errorf("Failed to get expiration time from claims") - return false, false + return false, false, false } - gracePeriod := time.Minute * 1 - if time.Now().Add(gracePeriod).Unix() > int64(exp) { - t.logger.Debugf("Session has expired or will expire soon") - return false, true // Token expired or will expire soon + now := time.Now().Unix() + expTime := int64(exp) + + if now > expTime { + return false, true, false // Token has expired } - return t.verifyToken(idToken) == nil, false + gracePeriod := time.Minute * 5 + if time.Now().Add(gracePeriod).Unix() > expTime { + return true, false, true // Token will expire soon, needs refresh + } + + return t.verifyToken(idToken) == nil, false, false } - return false, false + return false, false, false } func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { @@ -401,3 +413,25 @@ func (t *TraefikOidc) refreshSession(w http.ResponseWriter, r *http.Request) { } } } + +func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool { + refreshToken, ok := session.Values["refresh_token"].(string) + if !ok || refreshToken == "" { + return false + } + + newToken, err := t.getNewTokenWithRefreshToken(refreshToken) + if err != nil { + t.logger.Errorf("Failed to refresh token: %v", err) + return false + } + + session.Values["id_token"] = newToken.IDToken + session.Values["refresh_token"] = newToken.RefreshToken + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save refreshed session: %v", err) + return false + } + + return true +} diff --git a/main_test.go b/main_test.go index 6e35156..27b732b 100644 --- a/main_test.go +++ b/main_test.go @@ -274,25 +274,61 @@ func (suite *TraefikOidcTestSuite) TestBuildFullURL() { suite.Equal("http://example.com/path", url) } -func (suite *TraefikOidcTestSuite) TestExchangeCodeForToken() { +func (suite *TraefikOidcTestSuite) TestExchangeTokens() { ctx := context.Background() - code := "test_code" - redirectURL := "http://example.com/callback" - expectedToken := map[string]interface{}{ - "access_token": "test_access_token", - "id_token": "test_id_token", + testCases := []struct { + name string + grantType string + codeOrToken string + redirectURL string + expectedToken map[string]interface{} + }{ + { + name: "Authorization Code Exchange", + grantType: "authorization_code", + codeOrToken: "test_code", + redirectURL: "http://example.com/callback", + expectedToken: map[string]interface{}{ + "access_token": "test_access_token", + "id_token": "test_id_token", + "refresh_token": "test_refresh_token", + "expires_in": float64(3600), + "token_type": "Bearer", + }, + }, + { + name: "Refresh Token Exchange", + grantType: "refresh_token", + codeOrToken: "test_refresh_token", + redirectURL: "", + expectedToken: map[string]interface{}{ + "access_token": "new_access_token", + "id_token": "new_id_token", + "refresh_token": "new_refresh_token", + "expires_in": float64(3600), + "token_type": "Bearer", + }, + }, } - tokenJSON, _ := json.Marshal(expectedToken) - suite.mockHTTPClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(tokenJSON)), - }, nil).Once() + for _, tc := range testCases { + suite.Run(tc.name, func() { + tokenJSON, _ := json.Marshal(tc.expectedToken) - token, err := suite.oidc.exchangeCodeForToken(ctx, code, redirectURL) - suite.NoError(err) - suite.Equal(expectedToken, token) + // Set up the mock HTTP client + suite.mockHTTPClient.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(tokenJSON)), + }, nil).Once() + + token, err := suite.oidc.exchangeTokens(ctx, tc.grantType, tc.codeOrToken, tc.redirectURL) + suite.NoError(err) + suite.Equal(tc.expectedToken, token) + + suite.mockHTTPClient.AssertExpectations(suite.T()) + }) + } } func (suite *TraefikOidcTestSuite) TestHandleLogout() { @@ -369,15 +405,75 @@ func (suite *TraefikOidcTestSuite) TestDetermineHost() { } func (suite *TraefikOidcTestSuite) TestIsUserAuthenticated() { - session := sessions.NewSession(suite.mockStore, cookieName) - session.Values["authenticated"] = true - session.Values["id_token"] = "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature" + testCases := []struct { + name string + setupSession func() *sessions.Session + expectedAuth bool + expectedExpired bool + expectedRefresh bool + }{ + { + name: "Valid Token", + setupSession: func() *sessions.Session { + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + session.Values["id_token"] = "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature" + return session + }, + expectedAuth: true, + expectedExpired: false, + expectedRefresh: false, + }, + { + name: "Expired Token", + setupSession: func() *sessions.Session { + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + session.Values["id_token"] = "expired.eyJleHAiOjE1OTM1NjE2MDB9.signature" + return session + }, + expectedAuth: false, + expectedExpired: true, + expectedRefresh: false, + }, + { + name: "Token Needs Refresh", + setupSession: func() *sessions.Session { + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = true + // Set expiration to 4 minutes from now + exp := time.Now().Add(4 * time.Minute).Unix() + token := fmt.Sprintf("needsrefresh.%s.signature", base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"exp":%d}`, exp)))) + session.Values["id_token"] = token + return session + }, + expectedAuth: true, + expectedExpired: false, + expectedRefresh: true, + }, + { + name: "Not Authenticated", + setupSession: func() *sessions.Session { + session := sessions.NewSession(suite.mockStore, cookieName) + session.Values["authenticated"] = false + return session + }, + expectedAuth: false, + expectedExpired: false, + expectedRefresh: false, + }, + } - suite.mockTokenVerifier.On("VerifyToken", "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature").Return(nil) - - authenticated, tokenExpired := suite.oidc.isUserAuthenticated(session) - suite.True(authenticated) - suite.False(tokenExpired) + for _, tc := range testCases { + suite.Run(tc.name, func() { + session := tc.setupSession() + suite.mockTokenVerifier.On("VerifyToken", mock.AnythingOfType("string")).Return(nil).Maybe() + authenticated, tokenExpired, needsRefresh := suite.oidc.isUserAuthenticated(session) + suite.Equal(tc.expectedAuth, authenticated) + suite.Equal(tc.expectedExpired, tokenExpired) + suite.Equal(tc.expectedRefresh, needsRefresh) + }) + } } func (suite *TraefikOidcTestSuite) TestInitiateAuthentication() {