From 794cb1ddf4a026c5bec8960c60bee52c6c93592c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Mon, 5 Feb 2024 14:24:17 +0000 Subject: [PATCH] Add the prefixed environment variables to avoid potential conflicts. --- README.md | 4 ++++ main.go | 63 ++++++++++++++++++++++++++++++++---------------- main_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 111 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 24b8dca..9616dca 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,10 @@ In this case, both proxy and websockets will be available under the `/v1/graphql ### Configuration +All the environment variables **should** be prefixed with `GMP_` to avoid conflicts with other applications. +If `GMP_` prefixed environment variable is present - it will take precedence over the non-prefixed one. +You can still use the non-prefixed environment variables in the spirit of the backward compatibility, but it's not recommended. + | Parameter | Description | Default Value | |---------------------------|------------------------------------------|----------------------------| | `MONITORING_PORT` | The port to expose the metrics endpoint | `9393` | diff --git a/main.go b/main.go index b4b65bc..51e8892 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "os" "strings" "github.com/gookit/goutil/envutil" @@ -11,47 +12,67 @@ import ( var cfg *config +// function get value from the env where the value can be anything +func getDetailsFromEnv[T any](key string, defaultValue T) T { + var result any + if _, ok := os.LookupEnv("GMP_" + key); ok { + key = "GMP_" + key + } + switch v := any(defaultValue).(type) { + case string: + result = envutil.Getenv(key, v) + case int: + result = envutil.GetInt(key, v) + case bool: + result = envutil.GetBool(key, v) + default: + result = defaultValue + } + return result.(T) +} + func parseConfig() { libpack_config.PKG_NAME = "graphql_proxy" c := config{} - c.Server.PortGraphQL = envutil.GetInt("PORT_GRAPHQL", 8080) - c.Server.PortMonitoring = envutil.GetInt("MONITORING_PORT", 9393) - c.Server.HostGraphQL = envutil.Getenv("HOST_GRAPHQL", "http://localhost/") - c.Client.JWTUserClaimPath = envutil.Getenv("JWT_USER_CLAIM_PATH", "") - c.Client.JWTRoleClaimPath = envutil.Getenv("JWT_ROLE_CLAIM_PATH", "") - c.Client.RoleFromHeader = envutil.Getenv("ROLE_FROM_HEADER", "") - c.Client.RoleRateLimit = envutil.GetBool("ROLE_RATE_LIMIT", false) - c.Cache.CacheEnable = envutil.GetBool("ENABLE_GLOBAL_CACHE", false) - c.Cache.CacheTTL = envutil.GetInt("CACHE_TTL", 60) - c.Security.BlockIntrospection = envutil.GetBool("BLOCK_SCHEMA_INTROSPECTION", false) + c.Server.PortGraphQL = getDetailsFromEnv("PORT_GRAPHQL", 8080) + c.Server.PortMonitoring = getDetailsFromEnv("MONITORING_PORT", 9393) + c.Server.HostGraphQL = getDetailsFromEnv("HOST_GRAPHQL", "http://localhost/") + c.Client.JWTUserClaimPath = getDetailsFromEnv("JWT_USER_CLAIM_PATH", "") + c.Client.JWTRoleClaimPath = getDetailsFromEnv("JWT_ROLE_CLAIM_PATH", "") + c.Client.RoleFromHeader = getDetailsFromEnv("ROLE_FROM_HEADER", "") + c.Client.RoleRateLimit = getDetailsFromEnv("ROLE_RATE_LIMIT", false) + c.Cache.CacheEnable = getDetailsFromEnv("ENABLE_GLOBAL_CACHE", false) + c.Cache.CacheTTL = getDetailsFromEnv("CACHE_TTL", 60) + c.Security.BlockIntrospection = getDetailsFromEnv("BLOCK_SCHEMA_INTROSPECTION", false) c.Security.IntrospectionAllowed = func() []string { - urls := envutil.Getenv("ALLOWED_INTROSPECTION", "") + urls := getDetailsFromEnv("ALLOWED_INTROSPECTION", "") if urls == "" { return nil } return strings.Split(urls, ",") }() c.Logger = libpack_logging.NewLogger() - c.Server.HealthcheckGraphQL = envutil.Getenv("HEALTHCHECK_GRAPHQL_URL", "") + c.Server.HealthcheckGraphQL = getDetailsFromEnv("HEALTHCHECK_GRAPHQL_URL", "") c.Client.GQLClient = graphql.NewConnection() c.Client.GQLClient.SetEndpoint(c.Server.HealthcheckGraphQL) - c.Server.AccessLog = envutil.GetBool("ENABLE_ACCESS_LOG", false) - c.Server.ReadOnlyMode = envutil.GetBool("READ_ONLY_MODE", false) + c.Server.AccessLog = getDetailsFromEnv("ENABLE_ACCESS_LOG", false) + c.Server.ReadOnlyMode = getDetailsFromEnv("READ_ONLY_MODE", false) c.Server.AllowURLs = func() []string { - urls := envutil.Getenv("ALLOWED_URLS", "") + urls := getDetailsFromEnv("ALLOWED_URLS", "") if urls == "" { return nil } return strings.Split(urls, ",") }() - c.Client.ClientTimeout = envutil.GetInt("PROXIED_CLIENT_TIMEOUT", 120) + c.Client.ClientTimeout = getDetailsFromEnv("PROXIED_CLIENT_TIMEOUT", 120) c.Client.FastProxyClient = createFasthttpClient(c.Client.ClientTimeout) - c.Server.EnableApi = envutil.GetBool("ENABLE_API", false) - c.Server.ApiPort = envutil.GetInt("API_PORT", 9090) - c.Api.BannedUsersFile = envutil.Getenv("BANNED_USERS_FILE", "/go/src/app/banned_users.json") - c.Server.PurgeOnCrawl = envutil.GetBool("PURGE_METRICS_ON_CRAWL", false) - c.Server.PurgeEvery = envutil.GetInt("PURGE_METRICS_ON_TIMER", 0) + c.Server.EnableApi = getDetailsFromEnv("ENABLE_API", false) + c.Server.ApiPort = getDetailsFromEnv("API_PORT", 9090) + c.Api.BannedUsersFile = getDetailsFromEnv("BANNED_USERS_FILE", "/go/src/app/banned_users.json") + c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false) + c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 0) cfg = &c + enableCache() // takes close to no resources, but can be used with dynamic query cache loadRatelimitConfig() enableApi() diff --git a/main_test.go b/main_test.go index 44ad6d7..2e89a78 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package main import ( + "os" "testing" assertions "github.com/stretchr/testify/assert" @@ -15,11 +16,25 @@ var ( assert *assertions.Assertions ) -func (suite *Tests) SetupTest() { - assert = assertions.New(suite.T()) +func (suite *Tests) BeforeTest(suiteName, testName string) { } -func (suite *Tests) BeforeTest(suiteName, testName string) { +func (suite *Tests) SetupTest() { + assert = assertions.New(suite.T()) + // Setup environment variables here if needed + os.Setenv("GMP_TEST_STRING", "testValue") + os.Setenv("GMP_TEST_INT", "123") + os.Setenv("GMP_TEST_BOOL", "true") + os.Setenv("NON_GMP_TEST_INT", "31337") +} + +// TearDownTest is run after each test to clean up +func (suite *Tests) TearDownTest() { + // Clean up environment variables here if needed + os.Unsetenv("GMP_TEST_STRING") + os.Unsetenv("GMP_TEST_INT") + os.Unsetenv("GMP_TEST_BOOL") + os.Unsetenv("NON_GMP_TEST_INT") } // func (suite *Tests) AfterTest(suiteName, testName string) {) @@ -30,3 +45,50 @@ func TestSuite(t *testing.T) { StartMonitoringServer() suite.Run(t, new(Tests)) } + +func (suite *Tests) Test_envVariableSetting() { + tests := []struct { + name string + envKey string + defaultValue any + expected any + }{ + { + name: "test_string", + envKey: "TEST_STRING", + defaultValue: "default", + expected: "testValue", + }, + { + name: "test_int", + envKey: "TEST_INT", + defaultValue: 0, + expected: 123, + }, + { + name: "test_bool", + envKey: "TEST_BOOL", + defaultValue: false, + expected: true, + }, + { + name: "test_non_prefixed", + envKey: "NON_GMP_TEST_INT", + defaultValue: 0, + expected: 31337, + }, + { + name: "test_non_existing", + envKey: "NON_EXISTING", + defaultValue: "default_val", + expected: "default_val", + }, + } + + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + result := getDetailsFromEnv(tt.envKey, tt.defaultValue) + assert.Equal(tt.expected, result) + }) + } +}