diff --git a/ratelimit.go b/ratelimit.go index 3b9950d..0034615 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -2,6 +2,7 @@ package main import ( "os" + "sync" "time" "github.com/goccy/go-json" @@ -16,58 +17,61 @@ type RateLimitConfig struct { Req int `json:"req"` } -var rateLimits map[string]RateLimitConfig -var ratelimit_intervals = map[string]time.Duration{ - "milli": time.Millisecond, - "micro": time.Microsecond, - "nano": time.Nanosecond, - "second": time.Second, - "minute": time.Minute, - "hour": time.Hour, - "day": time.Hour * 24, -} +var ( + rateLimits map[string]RateLimitConfig + ratelimitIntervals = map[string]time.Duration{ + "milli": time.Millisecond, + "micro": time.Microsecond, + "nano": time.Nanosecond, + "second": time.Second, + "minute": time.Minute, + "hour": time.Hour, + "day": 24 * time.Hour, + } + configPaths = []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"} + mu sync.RWMutex +) func loadRatelimitConfig() error { - paths := []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"} - - for _, path := range paths { - err := loadConfigFromPath(path) - if err == nil { + for _, path := range configPaths { + if err := loadConfigFromPath(path); err == nil { + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Failed to load config", + Pairs: map[string]interface{}{"path": path, "error": err}, + }) return nil } - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Failed to load config", - Pairs: map[string]interface{}{"path": path, "error": err}, - }) } cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Rate limit config not found", - Pairs: map[string]interface{}{"paths": paths}, + Pairs: map[string]interface{}{"paths": configPaths}, }) return os.ErrNotExist } func loadConfigFromPath(path string) error { - file, err := os.Open(path) + file, err := os.ReadFile(path) if err != nil { return err } - defer file.Close() - config := struct { + var config struct { RateLimit map[string]RateLimitConfig `json:"ratelimit"` - }{} + } - decoder := json.NewDecoder(file) - if err := decoder.Decode(&config); err != nil { + if err := json.Unmarshal(file, &config); err != nil { return err } + mu.Lock() + defer mu.Unlock() + + rateLimits = make(map[string]RateLimitConfig, len(config.RateLimit)) for key, value := range config.RateLimit { value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{ - Interval: time.Duration(value.Req) * ratelimit_intervals[value.Interval], + Interval: time.Duration(value.Req) * ratelimitIntervals[value.Interval], }) if cfg.LogLevel == "debug" { @@ -76,15 +80,14 @@ func loadConfigFromPath(path string) error { Pairs: map[string]interface{}{ "role": key, "interval_provided": value.Interval, - "interval_used": ratelimit_intervals[value.Interval], + "interval_used": ratelimitIntervals[value.Interval], "ratelimit": value.Req, }, }) } - config.RateLimit[key] = value + rateLimits[key] = value } - rateLimits = config.RateLimit cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limit config loaded", Pairs: map[string]interface{}{"ratelimit": rateLimits}, @@ -92,7 +95,10 @@ func loadConfigFromPath(path string) error { return nil } -func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) { +func rateLimitedRequest(userID, userRole string) bool { + mu.RLock() + defer mu.RUnlock() + if rateLimits == nil { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limit config not found", @@ -101,19 +107,10 @@ func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) { return true } - // Fetch role config once to avoid multiple map lookups roleConfig, ok := rateLimits[userRole] - if !ok { + if !ok || roleConfig.RateCounterTicker == nil { cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Rate limit role not found", - Pairs: map[string]interface{}{"user_role": userRole}, - }) - return true - } - - if roleConfig.RateCounterTicker == nil { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Rate limit ticker not found", + Message: "Rate limit role or ticker not found", Pairs: map[string]interface{}{"user_role": userRole}, }) return true @@ -122,23 +119,29 @@ func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) { roleConfig.RateCounterTicker.Incr(1) tickerRate := roleConfig.RateCounterTicker.GetRate() - logDetails := map[string]interface{}{ - "user_role": userRole, - "user_id": userID, - "rate": tickerRate, - "config_rate": roleConfig.Req, - "interval": roleConfig.Interval, + if cfg.LogLevel == "debug" { + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Rate limit ticker", + Pairs: map[string]interface{}{ + "user_role": userRole, + "user_id": userID, + "rate": tickerRate, + "config_rate": roleConfig.Req, + "interval": roleConfig.Interval, + }, + }) } - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Rate limit ticker", - Pairs: map[string]interface{}{"log_details": logDetails}, - }) - if tickerRate > float64(roleConfig.Req) { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limit exceeded", - Pairs: map[string]interface{}{"log_details": logDetails}, + Pairs: map[string]interface{}{ + "user_role": userRole, + "user_id": userID, + "rate": tickerRate, + "config_rate": roleConfig.Req, + "interval": roleConfig.Interval, + }, }) return false }