mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-05 23:03:48 +00:00
Add tests (#2)
* Initial tests draft. * Add tests for parsing jwt token * Further code optimisations
This commit is contained in:
@@ -10,8 +10,8 @@ help: ## display this help
|
||||
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n\nTargets:\n"} /^[a-zA-Z0-9_-]+:.*?##/ { printf " \033[36m%-20s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST)
|
||||
|
||||
.PHONY: run
|
||||
run: ## run application
|
||||
@LOG_LEVEL=warn BLOCK_SCHEMA_INTROSPECTION=false JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/v1/graphql go run *.go
|
||||
run: build ## run application
|
||||
@LOG_LEVEL=warn BLOCK_SCHEMA_INTROSPECTION=false JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/v1/graphql ./graphql-proxy
|
||||
|
||||
.PHONY: build
|
||||
build: ## build the binary
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (suite *Tests) Test_cacheLookup() {
|
||||
type args struct {
|
||||
hash string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
addCache struct {
|
||||
data []byte
|
||||
}
|
||||
}{
|
||||
{
|
||||
name: "test_non_existent",
|
||||
args: args{
|
||||
hash: "00000000000000000000000000000000000000",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "test_existent",
|
||||
args: args{
|
||||
hash: "00000000000000000000000000000000000001",
|
||||
},
|
||||
want: []byte("it's fine."),
|
||||
addCache: struct {
|
||||
data []byte
|
||||
}{
|
||||
data: []byte("it's fine."),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
suite.T().Run(tt.name, func(t *testing.T) {
|
||||
if tt.addCache.data != nil {
|
||||
cfg.Cache.CacheClient.Set(tt.args.hash, tt.addCache.data, time.Duration(1)*time.Second)
|
||||
}
|
||||
got := cacheLookup(tt.args.hash)
|
||||
assert.Equal(tt.want, got, "Unexpected cache lookup result")
|
||||
})
|
||||
}
|
||||
}
|
||||
+23
-24
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/ask"
|
||||
@@ -9,45 +10,43 @@ import (
|
||||
)
|
||||
|
||||
func extractClaimsFromJWTHeader(authorization string) (usr string, role string) {
|
||||
usr, role = "-", "-"
|
||||
|
||||
handleError := func(msg string, details map[string]interface{}) {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
cfg.Logger.Error(msg, details)
|
||||
}
|
||||
|
||||
tokenParts := strings.Split(authorization, ".")
|
||||
if len(tokenParts) != 3 {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
cfg.Logger.Error("Can't split the token", map[string]interface{}{"token": authorization})
|
||||
handleError("Can't split the token", map[string]interface{}{"token": authorization})
|
||||
return
|
||||
}
|
||||
|
||||
claim, err := base64.RawURLEncoding.DecodeString(tokenParts[1])
|
||||
if err != nil {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
cfg.Logger.Error("Can't decode the token", map[string]interface{}{"token": authorization})
|
||||
handleError("Can't decode the token", map[string]interface{}{"token": authorization})
|
||||
return
|
||||
}
|
||||
|
||||
var claimMap map[string]interface{}
|
||||
err = json.Unmarshal(claim, &claimMap)
|
||||
if err != nil {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
cfg.Logger.Error("Can't unmarshal the claim", map[string]interface{}{"token": authorization})
|
||||
if err = json.Unmarshal(claim, &claimMap); err != nil {
|
||||
handleError("Can't unmarshal the claim", map[string]interface{}{"token": authorization})
|
||||
return
|
||||
}
|
||||
|
||||
if len(cfg.Client.JWTUserClaimPath) > 0 {
|
||||
var ok bool
|
||||
usr, ok = ask.For(claimMap, cfg.Client.JWTUserClaimPath).String("-")
|
||||
if !ok {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
cfg.Logger.Error("Can't find the user id", map[string]interface{}{"claim_map": claimMap, "path": cfg.Client.JWTUserClaimPath})
|
||||
return
|
||||
extractClaim := func(claimPath string, target *string, name string) {
|
||||
if len(claimPath) > 0 {
|
||||
var ok bool
|
||||
*target, ok = ask.For(claimMap, claimPath).String("-")
|
||||
if !ok {
|
||||
handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": claimMap, "path": claimPath})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.Client.JWTRoleClaimPath) > 0 {
|
||||
var ok bool
|
||||
role, ok = ask.For(claimMap, cfg.Client.JWTRoleClaimPath).String("-")
|
||||
if !ok {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
cfg.Logger.Error("Can't find the role", map[string]interface{}{"claim_map": claimMap, "path": cfg.Client.JWTRoleClaimPath})
|
||||
return
|
||||
}
|
||||
}
|
||||
extractClaim(cfg.Client.JWTUserClaimPath, &usr, "user id")
|
||||
extractClaim(cfg.Client.JWTRoleClaimPath, &role, "role")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func (suite *Tests) Test_extractClaimsFromJWTHeader() {
|
||||
jwt_token_for_tests := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiSGFzdXJhIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsiZ3Vlc3QiLCJ1c2VyIiwiZ3JvdXBhZG1pbiIsInBheWFkbWluIl0sIngtaGFzdXJhLWRlZmF1bHQtcm9sZSI6Imd1ZXN0IiwieC1oYXN1cmEtdXNlci1pZCI6IjE2NyIsIngtaGFzdXJhLXVzZXItdXVpZCI6ImRkM2U2ZTM1LTA0MDktNDNiMC1iZmYxLWNlZjNjNmVkNWYxMCJ9LCJpc3MiOiJBdXRoU2VydmljZSIsImV4cCI6MTY5NjgwMTcyNiwibmJmIjoxNjk2NTg1NzI2LCJpYXQiOjE2OTY1ODU3MjZ9.dsJ5JKzG5tXOlqeZ_Gfe2XC-vyrcwtYwOGfhvt8q9UY"
|
||||
|
||||
type args struct {
|
||||
authorization string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantUsr string
|
||||
wantRole string
|
||||
jwt_token_path string
|
||||
jwt_role_path string
|
||||
}{
|
||||
{
|
||||
name: "test_empty",
|
||||
wantUsr: "-",
|
||||
wantRole: "-",
|
||||
},
|
||||
{
|
||||
name: "test_invalid_path",
|
||||
args: args{
|
||||
authorization: jwt_token_for_tests,
|
||||
},
|
||||
wantUsr: "-",
|
||||
wantRole: "-",
|
||||
jwt_token_path: "invalid",
|
||||
},
|
||||
{
|
||||
name: "test_invalid_role_path",
|
||||
args: args{
|
||||
authorization: jwt_token_for_tests,
|
||||
},
|
||||
wantUsr: "-",
|
||||
wantRole: "-",
|
||||
jwt_role_path: "invalid",
|
||||
},
|
||||
{
|
||||
name: "test_valid",
|
||||
args: args{
|
||||
authorization: jwt_token_for_tests,
|
||||
},
|
||||
wantUsr: "167",
|
||||
wantRole: "guest",
|
||||
jwt_token_path: "Hasura.x-hasura-user-id",
|
||||
jwt_role_path: "Hasura.x-hasura-default-role",
|
||||
},
|
||||
{
|
||||
name: "test_invalid_token",
|
||||
args: args{
|
||||
authorization: "invalid",
|
||||
},
|
||||
wantUsr: "-",
|
||||
wantRole: "-",
|
||||
},
|
||||
{
|
||||
name: "test_invalid_three_part_token",
|
||||
args: args{
|
||||
authorization: "invalid.threepart.token",
|
||||
},
|
||||
wantUsr: "-",
|
||||
wantRole: "-",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
suite.T().Run(tt.name, func(t *testing.T) {
|
||||
if len(tt.jwt_token_path) > 0 {
|
||||
cfg.Client.JWTUserClaimPath = tt.jwt_token_path
|
||||
}
|
||||
if len(tt.jwt_role_path) > 0 {
|
||||
cfg.Client.JWTRoleClaimPath = tt.jwt_role_path
|
||||
}
|
||||
gotUsr, gotRole := extractClaimsFromJWTHeader(tt.args.authorization)
|
||||
assert.Equal(tt.wantUsr, gotUsr, "Unexpected user ID")
|
||||
assert.Equal(tt.wantRole, gotRole, "Unexpected role")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ require (
|
||||
github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415
|
||||
github.com/lukaszraczylo/go-ratecounter v0.1.8
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.1.31
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/telegram-bot-app/libpack v0.0.0-20231008100411-9f7f8bf94315
|
||||
)
|
||||
|
||||
@@ -19,15 +20,18 @@ require (
|
||||
github.com/VictoriaMetrics/metrics v1.24.0 // indirect
|
||||
github.com/andybalholm/brotli v1.0.5 // indirect
|
||||
github.com/avast/retry-go/v4 v4.5.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.3.1 // indirect
|
||||
github.com/gookit/color v1.5.4 // indirect
|
||||
github.com/klauspost/compress v1.17.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lukaszraczylo/pandati v0.0.29 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/rs/zerolog v1.31.0 // indirect
|
||||
github.com/telegram-bot-app/lib-logging v0.0.19 // indirect
|
||||
@@ -43,4 +47,5 @@ require (
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/term v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5A
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -30,6 +31,10 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM=
|
||||
github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415 h1:lvI8Wlbg4PxkRcg2f10wgoaRpfN19v+YdRek3+dLtlM=
|
||||
github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415/go.mod h1:M+UVdyqZs++xtEPrascaVmZdOMhCnxjZ2SgH+xHpR0c=
|
||||
github.com/lukaszraczylo/go-ratecounter v0.1.8 h1:ZYm6Wkn58ZAlFWRmC7PaD4oAYHWcu8/0MUDWGe3PnJQ=
|
||||
@@ -56,6 +61,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
|
||||
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
@@ -97,5 +104,7 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
assertions "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type Tests struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
var (
|
||||
assert *assertions.Assertions
|
||||
)
|
||||
|
||||
func (suite *Tests) SetupTest() {
|
||||
assert = assertions.New(suite.T())
|
||||
}
|
||||
|
||||
func (suite *Tests) BeforeTest(suiteName, testName string) {
|
||||
fmt.Println("BeforeTest")
|
||||
cfg = &config{}
|
||||
parseConfig()
|
||||
StartMonitoringServer()
|
||||
}
|
||||
|
||||
// func (suite *Tests) AfterTest(suiteName, testName string) {)
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
suite.Run(t, new(Tests))
|
||||
}
|
||||
+63
-34
@@ -25,60 +25,89 @@ var ratelimit_intervals = map[string]time.Duration{
|
||||
}
|
||||
|
||||
func loadRatelimitConfig() error {
|
||||
paths := [3]string{"/app/ratelimit.json", "./ratelimit.json", "./static/default-ratelimit.json"}
|
||||
paths := []string{"/app/ratelimit.json", "./ratelimit.json", "./static/default-ratelimit.json"}
|
||||
|
||||
for _, path := range paths {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
continue
|
||||
err := loadConfigFromPath(path)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
defer file.Close()
|
||||
decoder := json.NewDecoder(file)
|
||||
config := struct {
|
||||
RateLimit map[string]RateLimitConfig `json:"ratelimit"`
|
||||
}{}
|
||||
err = decoder.Decode(&config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for key, value := range config.RateLimit {
|
||||
value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{
|
||||
Interval: time.Duration(value.Req) * ratelimit_intervals[value.Interval],
|
||||
})
|
||||
cfg.Logger.Debug("Setting ratelimit config for role", map[string]interface{}{"role": key, "interval_provided": value.Interval, "interval_used": ratelimit_intervals[value.Interval], "ratelimit": value.Req})
|
||||
config.RateLimit[key] = value
|
||||
}
|
||||
|
||||
rateLimits = config.RateLimit
|
||||
cfg.Logger.Debug("Rate limit config loaded", map[string]interface{}{"ratelimit": rateLimits})
|
||||
return nil
|
||||
cfg.Logger.Error("Failed to load config", map[string]interface{}{"path": path, "error": err})
|
||||
}
|
||||
|
||||
cfg.Logger.Debug("Rate limit config not found")
|
||||
return os.ErrNotExist
|
||||
}
|
||||
|
||||
func rateLimitedRequest(userId string, userRole string) (shouldAllow bool) {
|
||||
func loadConfigFromPath(path string) error {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
config := struct {
|
||||
RateLimit map[string]RateLimitConfig `json:"ratelimit"`
|
||||
}{}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for key, value := range config.RateLimit {
|
||||
value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{
|
||||
Interval: time.Duration(value.Req) * ratelimit_intervals[value.Interval],
|
||||
})
|
||||
cfg.Logger.Debug("Setting ratelimit config for role", map[string]interface{}{
|
||||
"role": key,
|
||||
"interval_provided": value.Interval,
|
||||
"interval_used": ratelimit_intervals[value.Interval],
|
||||
"ratelimit": value.Req,
|
||||
})
|
||||
config.RateLimit[key] = value
|
||||
}
|
||||
|
||||
rateLimits = config.RateLimit
|
||||
cfg.Logger.Debug("Rate limit config loaded", map[string]interface{}{"ratelimit": rateLimits})
|
||||
return nil
|
||||
}
|
||||
|
||||
func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) {
|
||||
if rateLimits == nil {
|
||||
cfg.Logger.Debug("Rate limit config not found", map[string]interface{}{"user_role": userRole})
|
||||
return true
|
||||
}
|
||||
// check if userRole is in rateLimits
|
||||
if _, ok := rateLimits[userRole]; !ok {
|
||||
|
||||
// Fetch role config once to avoid multiple map lookups
|
||||
roleConfig, ok := rateLimits[userRole]
|
||||
if !ok {
|
||||
cfg.Logger.Warning("Rate limit role not found", map[string]interface{}{"user_role": userRole})
|
||||
return true
|
||||
}
|
||||
|
||||
if rateLimits[userRole].RateCounterTicker == nil {
|
||||
if roleConfig.RateCounterTicker == nil {
|
||||
cfg.Logger.Warning("Rate limit ticker not found", map[string]interface{}{"user_role": userRole})
|
||||
return true
|
||||
}
|
||||
|
||||
rateLimits[userRole].RateCounterTicker.Incr(1)
|
||||
ticker_rate := rateLimits[userRole].RateCounterTicker.GetRate()
|
||||
cfg.Logger.Debug("Rate limit ticker", map[string]interface{}{"user_role": userRole, "user_id": userId, "rate": ticker_rate, "config_rate": rateLimits[userRole].Req, "interval": rateLimits[userRole].Interval, "interval_duration": rateLimits[userRole].Interval})
|
||||
if ticker_rate > float64(rateLimits[userRole].Req) {
|
||||
cfg.Logger.Debug("Rate limit exceeded", map[string]interface{}{"user_role": userRole, "user_id": userId, "rate": ticker_rate, "config_rate": rateLimits[userRole].Req, "interval": rateLimits[userRole].Interval, "interval_duration": rateLimits[userRole].Interval})
|
||||
roleConfig.RateCounterTicker.Incr(1)
|
||||
tickerRate := roleConfig.RateCounterTicker.GetRate()
|
||||
|
||||
logDetails := map[string]interface{}{
|
||||
"user_role": userRole,
|
||||
"user_id": userID,
|
||||
"rate": tickerRate,
|
||||
"config_rate": roleConfig.Req,
|
||||
"interval": roleConfig.Interval,
|
||||
}
|
||||
|
||||
cfg.Logger.Debug("Rate limit ticker", logDetails)
|
||||
|
||||
if tickerRate > float64(roleConfig.Req) {
|
||||
cfg.Logger.Debug("Rate limit exceeded", logDetails)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -42,70 +42,90 @@ func healthCheck(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
t := time.Now()
|
||||
startTime := time.Now()
|
||||
|
||||
var extracted_user_id string = "-"
|
||||
var extracted_role_name string = "-"
|
||||
var query_cache_hash string
|
||||
// Initialize variables with default values
|
||||
extractedUserID := "-"
|
||||
extractedRoleName := "-"
|
||||
var queryCacheHash string
|
||||
|
||||
authorization := c.Request().Header.Peek("Authorization")
|
||||
if authorization != nil && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
|
||||
extracted_user_id, extracted_role_name = extractClaimsFromJWTHeader(string(authorization))
|
||||
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(string(authorization))
|
||||
}
|
||||
|
||||
// Implementing rate limiting if enabled
|
||||
if cfg.Client.JWTRoleRateLimit {
|
||||
cfg.Logger.Debug("Rate limiting enabled", map[string]interface{}{"user_id": extracted_user_id, "role_name": extracted_role_name})
|
||||
if !rateLimitedRequest(extracted_user_id, extracted_role_name) {
|
||||
cfg.Logger.Debug("Rate limiting enabled", map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName})
|
||||
if !rateLimitedRequest(extractedUserID, extractedRoleName) {
|
||||
c.Status(429).SendString("Rate limit exceeded, try again later")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
opType, opName, cache_from_query, should_block := parseGraphQLQuery(c)
|
||||
|
||||
if should_block {
|
||||
opType, opName, cacheFromQuery, shouldBlock := parseGraphQLQuery(c)
|
||||
if shouldBlock {
|
||||
return nil
|
||||
}
|
||||
|
||||
was_cached := false
|
||||
wasCached := false
|
||||
|
||||
if cache_from_query || cfg.Cache.CacheEnable {
|
||||
cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cache_from_query, "via_env": cfg.Cache.CacheEnable})
|
||||
query_cache_hash = calculateHash(c)
|
||||
cachedResponse := cacheLookup(query_cache_hash)
|
||||
if cachedResponse != nil {
|
||||
cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": query_cache_hash, "user_id": extracted_user_id})
|
||||
// Handling Cache Logic
|
||||
if cacheFromQuery || cfg.Cache.CacheEnable {
|
||||
cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cacheFromQuery, "via_env": cfg.Cache.CacheEnable})
|
||||
queryCacheHash = calculateHash(c)
|
||||
|
||||
if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil {
|
||||
cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID})
|
||||
c.Send(cachedResponse)
|
||||
was_cached = true
|
||||
wasCached = true
|
||||
} else {
|
||||
cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": query_cache_hash, "user_id": extracted_user_id})
|
||||
proxyTheRequest(c)
|
||||
cfg.Cache.CacheClient.Set(query_cache_hash, c.Response().Body(), time.Duration(cfg.Cache.CacheTTL)*time.Second)
|
||||
c.Send(c.Response().Body())
|
||||
cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID})
|
||||
proxyAndCacheTheRequest(c, queryCacheHash)
|
||||
}
|
||||
} else {
|
||||
proxyTheRequest(c)
|
||||
}
|
||||
time_taken := time.Since(t)
|
||||
|
||||
if cfg.Server.AccessLog {
|
||||
cfg.Logger.Info("Request processed", map[string]interface{}{"ip": c.IP(), "user_id": extracted_user_id, "op_type": opType, "op_name": opName, "time": time_taken, "cache": was_cached})
|
||||
}
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil)
|
||||
timeTaken := time.Since(startTime)
|
||||
|
||||
// Logging & Monitoring
|
||||
logAndMonitorRequest(c, extractedUserID, opType, opName, wasCached, timeTaken, startTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Additional helper function to avoid code repetition
|
||||
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string) {
|
||||
proxyTheRequest(c)
|
||||
cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cfg.Cache.CacheTTL)*time.Second)
|
||||
c.Send(c.Response().Body())
|
||||
}
|
||||
|
||||
func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) {
|
||||
labels := map[string]string{
|
||||
"op_type": opType,
|
||||
"op_name": opName,
|
||||
"cached": fmt.Sprintf("%t", was_cached),
|
||||
"user_id": extracted_user_id,
|
||||
"cached": fmt.Sprintf("%t", wasCached),
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
if cfg.Server.AccessLog {
|
||||
cfg.Logger.Info("Request processed", map[string]interface{}{
|
||||
"ip": c.IP(),
|
||||
"user_id": userID,
|
||||
"op_type": opType,
|
||||
"op_name": opName,
|
||||
"time": duration,
|
||||
"cache": wasCached,
|
||||
})
|
||||
}
|
||||
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil)
|
||||
cfg.Monitoring.Increment("executed_query", labels)
|
||||
|
||||
if !was_cached {
|
||||
cfg.Monitoring.UpdateDuration("timed_query", labels, t)
|
||||
cfg.Monitoring.Update("timed_query", labels, float64(time_taken.Milliseconds()))
|
||||
if !wasCached {
|
||||
cfg.Monitoring.UpdateDuration("timed_query", labels, startTime)
|
||||
cfg.Monitoring.Update("timed_query", labels, float64(duration.Milliseconds()))
|
||||
}
|
||||
// // cfg.Monitoring.Set("timed_query", time_taken.Milliseconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user