mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e74153b107 | |||
| 025107fe3e | |||
| dfb9c0771e | |||
| 1107df40e7 | |||
| bf294569eb | |||
| 482c346840 | |||
| a462e44896 | |||
| 5eff0dc866 | |||
| dfc534a400 | |||
| 061c12d0a3 | |||
| 4c4fff3613 | |||
| 0dcb44c187 | |||
| cbe773d96a | |||
| 40254888d7 | |||
| ef41870c81 | |||
| 081c32925a | |||
| 17dea67229 | |||
| 8512ad6d68 | |||
| 5aa838c669 | |||
| 6f359e5ef1 | |||
| bd18d6041c | |||
| 74c620ad51 | |||
| 7e3dc46b6e | |||
| 147aa0b169 | |||
| eecb7dfc92 | |||
| a8d65688c4 | |||
| bef4212c57 | |||
| 1fee2f9e9a | |||
| 11bc6f3e31 | |||
| 2b7af88ff9 | |||
| 01ee7c4dc8 | |||
| a6fa4d8789 | |||
| 8101fb2bf6 | |||
| 8ca669105b | |||
| 555164160d | |||
| 3fe537d38f | |||
| 31de2c63b2 | |||
| 7dd9205277 | |||
| f3598e4ab8 | |||
| 218165d365 | |||
| dc4c4824cd | |||
| 345c0c4a11 | |||
| da4f97de04 | |||
| ce916f3ca3 | |||
| 6f2cf65d49 | |||
| 78b9d611f0 | |||
| 2bb1debeb3 | |||
| 93b49b6d17 | |||
| 7a53da6080 | |||
| 66e08755c1 | |||
| d6fd3467c3 | |||
| 6196a72a8e |
+4
-1
@@ -13,13 +13,16 @@ testData:
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
sessionEncryptionKey: potato-secret
|
||||
allowedRolesAndGroups:
|
||||
- guest-endpoints
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
|
||||
@@ -4,6 +4,10 @@ This middleware is supposed to replace the need for the forward-auth and oauth2-
|
||||
|
||||
Middleware has been tested with Auth0 and Logto.
|
||||
|
||||
### Traefik version compatibility
|
||||
|
||||
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
|
||||
|
||||
### Configuration options
|
||||
|
||||
Middleware currently supports following scenarios:
|
||||
@@ -13,6 +17,127 @@ Middleware currently supports following scenarios:
|
||||
* Using excluded URLs which do **NOT** require the OIDC authentication
|
||||
* Rate limiting requests to prevent the bruteforce attacks
|
||||
|
||||
#### How to configure...
|
||||
|
||||
* `sessionEncryptionKey` should be at least 32 bytes long.
|
||||
|
||||
##### Keeping secrets secret
|
||||
|
||||
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
|
||||
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```
|
||||
|
||||
##### Excluded URLs with open access
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```
|
||||
|
||||
|
||||
##### Allowed email domains
|
||||
|
||||
Assuming that your OIDC provider allows anyone to log in, you may want to limit the access to people using emains in specific domain.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-only-my-users
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /new-oidc/callback
|
||||
logoutURL: /new-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
##### Allowed groups and roles
|
||||
|
||||
In case of multiple roles / groups and access separation for various endpoints you will need to create multiple traefik middlewares.
|
||||
Following example allows access for users who have additional role `guest-endpoints` assigned.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-guest-endpoints
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /my-oidc/callback
|
||||
logoutURL: /my-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # This line queries the OIDC provider for roles
|
||||
forceHTTPS: true
|
||||
allowedRolesAndGroups:
|
||||
- guest-endpoints # This line specifies the roles or groups allowed to access content
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
#### Docker compose example
|
||||
|
||||
`docker-compose.yaml`
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
// Value is the cached data of any type
|
||||
Value interface{}
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired
|
||||
// and removed from the cache during cleanup operations
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
|
||||
type Cache struct {
|
||||
// items stores the cached data with string keys
|
||||
items map[string]CacheItem
|
||||
|
||||
// mutex protects concurrent access to the items map
|
||||
// Use RLock/RUnlock for reads and Lock/Unlock for writes
|
||||
mutex sync.RWMutex
|
||||
|
||||
// maxSize is the maximum number of items allowed in the cache
|
||||
maxSize int
|
||||
|
||||
// accessList maintains the order of item access for eviction
|
||||
accessList []string
|
||||
}
|
||||
|
||||
// DefaultMaxSize is the default maximum number of items in the cache
|
||||
const DefaultMaxSize = 1000
|
||||
|
||||
// NewCache creates a new empty cache instance.
|
||||
// The cache is immediately ready for use and is thread-safe.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem),
|
||||
maxSize: DefaultMaxSize,
|
||||
accessList: make([]string, 0, DefaultMaxSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// Parameters:
|
||||
// - key: Unique identifier for the cached item
|
||||
// - value: The data to cache (can be of any type)
|
||||
// - expiration: How long the item should remain in the cache
|
||||
//
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// If key exists, update it
|
||||
if _, exists := c.items[key]; exists {
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If cache is full, remove oldest item
|
||||
if len(c.items) >= c.maxSize {
|
||||
c.evictOldest()
|
||||
}
|
||||
|
||||
// Add new item
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
}
|
||||
c.accessList = append(c.accessList, key)
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Parameters:
|
||||
// - key: The identifier of the item to retrieve
|
||||
//
|
||||
// Returns:
|
||||
// - value: The cached data (nil if not found or expired)
|
||||
// - found: true if the item was found and is valid, false otherwise
|
||||
//
|
||||
// Thread-safe: Uses read locking to ensure safe concurrent access.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
item, found := c.items[key]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
c.mutex.Lock()
|
||||
c.removeItem(key)
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Update access order
|
||||
c.mutex.Lock()
|
||||
c.updateAccessOrder(key)
|
||||
c.mutex.Unlock()
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache if it exists.
|
||||
// If the item doesn't exist, this operation is a no-op.
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// Cleanup removes all expired items from the cache.
|
||||
// This should be called periodically to prevent memory leaks from
|
||||
// expired items that haven't been accessed (and thus not removed during Get operations).
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
var newAccessList []string
|
||||
|
||||
for _, key := range c.accessList {
|
||||
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
|
||||
newAccessList = append(newAccessList, key)
|
||||
} else {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
c.accessList = newAccessList
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item from the cache
|
||||
func (c *Cache) evictOldest() {
|
||||
if len(c.accessList) > 0 {
|
||||
oldest := c.accessList[0]
|
||||
c.removeItem(oldest)
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item from both the cache and access list
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateAccessOrder moves the accessed key to the end of the access list
|
||||
func (c *Cache) updateAccessOrder(key string) {
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
+306
@@ -0,0 +1,306 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
t.Run("Basic Set and Get", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Test Set
|
||||
cache.Set(key, value, expiration)
|
||||
|
||||
// Test Get
|
||||
got, found := cache.Get(key)
|
||||
if !found {
|
||||
t.Error("Expected to find key in cache")
|
||||
}
|
||||
if got != value {
|
||||
t.Errorf("Expected value %v, got %v", value, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expiration", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
expiration := 10 * time.Millisecond
|
||||
|
||||
// Set with short expiration
|
||||
cache.Set(key, value, expiration)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Should not find expired key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key to be expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Set and then delete
|
||||
cache.Set(key, value, expiration)
|
||||
cache.Delete(key)
|
||||
|
||||
// Should not find deleted key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
// Add multiple items with different expirations
|
||||
cache.Set("expired1", "value1", 10*time.Millisecond)
|
||||
cache.Set("expired2", "value2", 10*time.Millisecond)
|
||||
cache.Set("valid", "value3", 1*time.Second)
|
||||
|
||||
// Wait for some items to expire
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Run cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Check expired items are removed
|
||||
_, found1 := cache.Get("expired1")
|
||||
_, found2 := cache.Get("expired2")
|
||||
_, found3 := cache.Get("valid")
|
||||
|
||||
if found1 {
|
||||
t.Error("Expected expired1 to be cleaned up")
|
||||
}
|
||||
if found2 {
|
||||
t.Error("Expected expired2 to be cleaned up")
|
||||
}
|
||||
if !found3 {
|
||||
t.Error("Expected valid item to remain in cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Concurrent Access", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
done := make(chan bool)
|
||||
|
||||
// Start multiple goroutines to access cache concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
key := "key"
|
||||
value := "value"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Perform multiple operations
|
||||
cache.Set(key, value, expiration)
|
||||
cache.Get(key)
|
||||
cache.Delete(key)
|
||||
cache.Cleanup()
|
||||
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Zero Expiration", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
|
||||
// Set with zero expiration
|
||||
cache.Set(key, value, 0)
|
||||
|
||||
// Should not find the key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key with zero expiration to be immediately expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Negative Expiration", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
|
||||
// Set with negative expiration
|
||||
cache.Set(key, value, -1*time.Second)
|
||||
|
||||
// Should not find the key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key with negative expiration to be immediately expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Update Existing Key", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value1 := "value1"
|
||||
value2 := "value2"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Set initial value
|
||||
cache.Set(key, value1, expiration)
|
||||
|
||||
// Update value
|
||||
cache.Set(key, value2, expiration)
|
||||
|
||||
// Check updated value
|
||||
got, found := cache.Get(key)
|
||||
if !found {
|
||||
t.Error("Expected to find key in cache")
|
||||
}
|
||||
if got != value2 {
|
||||
t.Errorf("Expected updated value %v, got %v", value2, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Different Value Types", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Test with different value types
|
||||
testCases := []struct {
|
||||
key string
|
||||
value interface{}
|
||||
}{
|
||||
{"string", "test"},
|
||||
{"int", 42},
|
||||
{"float", 3.14},
|
||||
{"bool", true},
|
||||
{"slice", []string{"a", "b", "c"}},
|
||||
{"map", map[string]int{"a": 1, "b": 2}},
|
||||
{"struct", struct{ Name string }{"test"}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.key, func(t *testing.T) {
|
||||
cache.Set(tc.key, tc.value, expiration)
|
||||
got, found := cache.Get(tc.key)
|
||||
if !found {
|
||||
t.Error("Expected to find key in cache")
|
||||
}
|
||||
// Use reflect.DeepEqual for comparing complex types like slices and maps
|
||||
if !reflect.DeepEqual(got, tc.value) {
|
||||
t.Errorf("Expected value %v, got %v", tc.value, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenCache(t *testing.T) {
|
||||
t.Run("Basic Operations", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
token := "test-token"
|
||||
claims := map[string]interface{}{
|
||||
"sub": "1234567890",
|
||||
"name": "John Doe",
|
||||
"admin": true,
|
||||
}
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Test Set and Get
|
||||
tc.Set(token, claims, expiration)
|
||||
gotClaims, found := tc.Get(token)
|
||||
if !found {
|
||||
t.Error("Expected to find token in cache")
|
||||
}
|
||||
if len(gotClaims) != len(claims) {
|
||||
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
|
||||
}
|
||||
for k, v := range claims {
|
||||
if gotClaims[k] != v {
|
||||
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Test Delete
|
||||
tc.Delete(token)
|
||||
_, found = tc.Get(token)
|
||||
if found {
|
||||
t.Error("Expected token to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expiration", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
token := "test-token"
|
||||
claims := map[string]interface{}{"sub": "1234567890"}
|
||||
expiration := 10 * time.Millisecond
|
||||
|
||||
// Set with short expiration
|
||||
tc.Set(token, claims, expiration)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Should not find expired token
|
||||
_, found := tc.Get(token)
|
||||
if found {
|
||||
t.Error("Expected token to be expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
|
||||
// Add multiple tokens with different expirations
|
||||
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
|
||||
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
|
||||
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
|
||||
|
||||
// Wait for some tokens to expire
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Run cleanup
|
||||
tc.Cleanup()
|
||||
|
||||
// Check expired tokens are removed
|
||||
_, found1 := tc.Get("expired1")
|
||||
_, found2 := tc.Get("expired2")
|
||||
_, found3 := tc.Get("valid")
|
||||
|
||||
if found1 {
|
||||
t.Error("Expected expired1 to be cleaned up")
|
||||
}
|
||||
if found2 {
|
||||
t.Error("Expected expired2 to be cleaned up")
|
||||
}
|
||||
if !found3 {
|
||||
t.Error("Expected valid token to remain in cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token Prefix", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
token := "test-token"
|
||||
claims := map[string]interface{}{"sub": "1234567890"}
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Set token
|
||||
tc.Set(token, claims, expiration)
|
||||
|
||||
// Verify internal storage uses prefix
|
||||
_, found := tc.cache.Get("t-" + token)
|
||||
if !found {
|
||||
t.Error("Expected to find prefixed token in underlying cache")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,19 +1,13 @@
|
||||
module github.com/lukaszraczylo/traefikoidc
|
||||
|
||||
go 1.22.2
|
||||
go 1.23
|
||||
|
||||
toolchain go1.23.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/time v0.5.0
|
||||
golang.org/x/time v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
require github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -8,17 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
||||
+295
-143
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -15,6 +16,28 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// newSessionOptions creates secure session cookie options.
|
||||
// Parameters:
|
||||
// - isSecure: Whether to set the Secure flag on cookies
|
||||
//
|
||||
// Returns session options configured for security with:
|
||||
// - HttpOnly flag to prevent JavaScript access
|
||||
// - SameSite=Lax for CSRF protection
|
||||
// - Appropriate timeout and path settings
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
// the authentication request.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -24,14 +47,34 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
// TokenResponse represents the response from the OIDC token endpoint.
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
type TokenResponse struct {
|
||||
// IDToken is the OIDC ID token containing user claims
|
||||
IDToken string `json:"id_token"`
|
||||
|
||||
// AccessToken is the OAuth 2.0 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
||||
// ExpiresIn is the lifetime in seconds of the access token
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// TokenType is the type of token, typically "Bearer"
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
|
||||
// It supports both authorization code and refresh token grant types.
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
@@ -57,150 +100,174 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResponse TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
|
||||
// This is used to refresh access tokens before they expire.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
response := &TokenResponse{
|
||||
IDToken: result["id_token"].(string),
|
||||
AccessToken: result["access_token"].(string),
|
||||
ExpiresIn: int(result["expires_in"].(float64)),
|
||||
TokenType: result["token_type"].(string),
|
||||
}
|
||||
|
||||
// The refresh token might not be returned if it hasn't changed
|
||||
if newRefreshToken, ok := result["refresh_token"].(string); ok {
|
||||
response.RefreshToken = newRefreshToken
|
||||
} else {
|
||||
response.RefreshToken = refreshToken
|
||||
}
|
||||
|
||||
return response, nil
|
||||
t.logger.Debugf("Token response: %+v", tokenResponse)
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
t.logger.Debugf("Logging out user")
|
||||
if err != nil {
|
||||
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
|
||||
// handleExpiredToken manages token expiration by clearing the session
|
||||
// and initiating a new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear authentication data but preserve CSRF state
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
|
||||
// Save the cleared session state
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken, ok := session.Values["id_token"].(string); ok {
|
||||
err := t.RevokeTokenWithProvider(idToken)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
t.RevokeToken(idToken)
|
||||
}
|
||||
|
||||
// Clear the session
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(rw, "Logged out", http.StatusForbidden)
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
|
||||
// Clear the existing session
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
err := session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to clear session: %v", err)
|
||||
}
|
||||
|
||||
// Initiate a new authentication flow
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
// handleCallback processes the authentication callback from the OIDC provider.
|
||||
// It validates the callback parameters, exchanges the authorization code for
|
||||
// tokens, verifies the tokens, and establishes the user's session.
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return false, ""
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
callbackState := req.URL.Query().Get("state")
|
||||
sessionState, ok := session.Values["csrf"].(string)
|
||||
if !ok || callbackState != sessionState {
|
||||
t.logger.Debug("Invalid state parameter. Session might have expired.")
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return false, ""
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the callback
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF state
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath)
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
http.Error(rw, "No code in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := t.exchangeTokens(req.Context(), "authorization_code", code, redirectURL)
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to exchange token", http.StatusUnauthorized, t.logger)
|
||||
return false, ""
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token["id_token"].(string)
|
||||
if !ok {
|
||||
handleError(rw, "No id_token field in oauth2 token", http.StatusUnauthorized, t.logger)
|
||||
return false, ""
|
||||
// Verify tokens and claims
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := t.verifyToken(rawIDToken); err != nil {
|
||||
handleError(rw, "Failed to verify token", http.StatusUnauthorized, t.logger)
|
||||
return false, ""
|
||||
}
|
||||
|
||||
claims, err := extractClaims(rawIDToken)
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to extract claims", http.StatusInternalServerError, t.logger)
|
||||
return false, ""
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce to prevent replay attacks
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user's email domain
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
session.Values["refresh_token"] = oauth2Token["refresh_token"]
|
||||
session.Values["email"] = email
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
|
||||
return false, ""
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
originalPath, ok := session.Values["incoming_path"].(string)
|
||||
if !ok {
|
||||
originalPath = "/"
|
||||
// Redirect to original path or root
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
delete(session.Values, "incoming_path")
|
||||
|
||||
return true, originalPath
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims parses a JWT token and extracts its claims.
|
||||
// It handles base64url decoding and JSON parsing of the token payload.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -220,28 +287,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
type UsedTokens struct {
|
||||
tokens map[string]bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// TokenBlacklist maintains a thread-safe list of revoked tokens.
|
||||
// It stores tokens with their expiration times and automatically
|
||||
// removes expired entries during cleanup operations.
|
||||
type TokenBlacklist struct {
|
||||
// blacklist maps token IDs to their expiration times
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
|
||||
// mutex protects concurrent access to the blacklist
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
@@ -249,6 +320,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
@@ -260,51 +332,131 @@ func (tb *TokenBlacklist) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCache provides a caching mechanism for validated tokens.
|
||||
// It stores token claims to avoid repeated validation of the
|
||||
// same token, improving performance for frequently used tokens.
|
||||
type TokenCache struct {
|
||||
cache map[string]*TokenInfo
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
// cache is the underlying cache implementation
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache instance.
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: make(map[string]*TokenInfo),
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
|
||||
// Set stores a token's claims in the cache with an expiration time.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
info, exists := tc.cache[token]
|
||||
if exists && time.Now().Before(info.ExpiresAt) {
|
||||
return info, true
|
||||
// Get retrieves a token's claims from the cache.
|
||||
// Returns the claims and a boolean indicating if the token was found.
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
return nil, false
|
||||
claims, ok := value.(map[string]interface{})
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
delete(tc.cache, token)
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the cache.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for token, info := range tc.cache {
|
||||
if now.After(info.ExpiresAt) {
|
||||
delete(tc.cache, token)
|
||||
}
|
||||
}
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings.
|
||||
// Used for efficient lookups in allowed domains and roles.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
result[key] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout manages the OIDC logout process.
|
||||
// It clears the session and redirects either to the OIDC provider's
|
||||
// end session endpoint (if available) or to the configured post-logout URL.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
|
||||
// Parameters:
|
||||
// - endSessionURL: The OIDC provider's end session endpoint
|
||||
// - idToken: The ID token to be invalidated
|
||||
// - postLogoutRedirectURI: Where to redirect after logout completes
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse end session URL: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
@@ -4,39 +4,87 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It contains the cryptographic key information used for token verification.
|
||||
type JWK struct {
|
||||
// Kty is the key type (e.g., "RSA", "EC")
|
||||
Kty string `json:"kty"`
|
||||
|
||||
// Kid is the unique key identifier
|
||||
Kid string `json:"kid"`
|
||||
|
||||
// Use specifies the intended use of the key (e.g., "sig" for signature)
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
|
||||
// N is the modulus for RSA keys
|
||||
N string `json:"n"`
|
||||
|
||||
// E is the exponent for RSA keys
|
||||
E string `json:"e"`
|
||||
|
||||
// Alg is the algorithm intended for use with the key
|
||||
Alg string `json:"alg"`
|
||||
|
||||
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
|
||||
// X is the x-coordinate for EC keys
|
||||
X string `json:"x"`
|
||||
|
||||
// Y is the y-coordinate for EC keys
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
|
||||
// OIDC providers typically expose multiple keys to support key rotation.
|
||||
type JWKSet struct {
|
||||
// Keys is the array of JSON Web Keys
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache provides a thread-safe caching mechanism for JWK sets.
|
||||
// It caches the keys for a configurable duration to reduce load on the OIDC provider
|
||||
// while ensuring keys are refreshed periodically to handle key rotation.
|
||||
type JWKCache struct {
|
||||
jwks *JWKSet
|
||||
// jwks holds the cached set of JSON Web Keys
|
||||
jwks *JWKSet
|
||||
|
||||
// expiresAt is the timestamp when the cached keys should be refreshed
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
|
||||
// mutex protects concurrent access to the cache
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for JWK caching operations.
|
||||
// This interface allows for different caching implementations while
|
||||
// maintaining consistent behavior in the token verification process.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
}
|
||||
|
||||
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
|
||||
// from the OIDC provider. It implements a thread-safe double-checked locking
|
||||
// pattern to prevent multiple simultaneous fetches of the same keys.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for fetching keys
|
||||
//
|
||||
// Returns:
|
||||
// - The JSON Web Key Set
|
||||
// - An error if the keys cannot be retrieved or parsed
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
@@ -63,6 +111,15 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
|
||||
// It handles HTTP communication and JSON parsing of the response.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for the request
|
||||
//
|
||||
// Returns:
|
||||
// - The parsed JSON Web Key Set
|
||||
// - An error if the request fails or the response is invalid
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
if err != nil {
|
||||
@@ -82,66 +139,68 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
func verifyAudience(tokenAudience, expectedAudience string) error {
|
||||
if tokenAudience != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
|
||||
// cryptographic functions. It supports both RSA and EC keys, delegating
|
||||
// to the appropriate converter based on the key type.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
switch jwk.Kty {
|
||||
case "RSA":
|
||||
return rsaJWKToPEM(jwk)
|
||||
case "EC":
|
||||
return ecJWKToPEM(jwk)
|
||||
default:
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
|
||||
}
|
||||
return converter(jwk)
|
||||
}
|
||||
|
||||
type jwkToPEMConverter func(*JWK) ([]byte, error)
|
||||
|
||||
var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"RSA": rsaJWKToPEM,
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
|
||||
// It handles base64url decoding of the modulus and exponent,
|
||||
// constructs an RSA public key, and encodes it in PEM format.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
|
||||
}
|
||||
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
|
||||
}
|
||||
|
||||
publicKey := &rsa.PublicKey{
|
||||
N: new(big.Int).SetBytes(n),
|
||||
E: int(new(big.Int).SetBytes(e).Int64()),
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
e := new(big.Int).SetBytes(eBytes)
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: n,
|
||||
E: int(e.Int64()),
|
||||
}
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKeyBytes,
|
||||
})
|
||||
|
||||
return publicKeyPEM, nil
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
|
||||
// It supports the P-256, P-384, and P-521 curves as defined in the
|
||||
// OIDC specification, decoding the x and y coordinates and encoding
|
||||
// the resulting public key in PEM format.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
|
||||
}
|
||||
|
||||
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
|
||||
}
|
||||
@@ -158,21 +217,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
|
||||
}
|
||||
|
||||
publicKey := &ecdsa.PublicKey{
|
||||
pubKey := &ecdsa.PublicKey{
|
||||
Curve: curve,
|
||||
X: new(big.Int).SetBytes(x),
|
||||
Y: new(big.Int).SetBytes(y),
|
||||
X: new(big.Int).SetBytes(xBytes),
|
||||
Y: new(big.Int).SetBytes(yBytes),
|
||||
}
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
Bytes: pubKeyBytes,
|
||||
})
|
||||
|
||||
return publicKeyPEM, nil
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
@@ -4,86 +4,364 @@ import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
// It contains the three parts of a JWT: header, claims (payload),
|
||||
// and signature, along with the original token string.
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Signature string
|
||||
// Header contains the token metadata (algorithm, key ID, etc.)
|
||||
Header map[string]interface{}
|
||||
|
||||
// Claims contains the token claims (subject, expiration, etc.)
|
||||
Claims map[string]interface{}
|
||||
|
||||
// Signature contains the raw signature bytes
|
||||
Signature []byte
|
||||
|
||||
// Token is the original JWT string
|
||||
Token string
|
||||
}
|
||||
|
||||
func parseJWT(token string) (*JWT, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// It validates the token format and decodes the three parts
|
||||
// (header, claims, signature) using base64url decoding.
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT token string
|
||||
//
|
||||
// Returns:
|
||||
// - A parsed JWT struct
|
||||
// - An error if the token format is invalid or parsing fails
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
header, err := decodeSegment(parts[0])
|
||||
jwt := &JWT{
|
||||
Token: tokenString,
|
||||
}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode header: %w", err)
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
claims, err := decodeSegment(parts[1])
|
||||
// Decode and unmarshal the claims
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode claims: %w", err)
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
return &JWT{
|
||||
Header: header,
|
||||
Claims: claims,
|
||||
Signature: parts[2],
|
||||
}, nil
|
||||
// Decode the signature
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
}
|
||||
jwt.Signature = signatureBytes
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// It checks:
|
||||
// - issuer (iss) matches the expected issuer URL
|
||||
// - audience (aud) includes the client ID
|
||||
// - expiration time (exp) is in the future (with clock skew tolerance)
|
||||
// - issued at time (iat) is in the past (with clock skew tolerance)
|
||||
// - not before time (nbf) is in the past (with clock skew tolerance)
|
||||
// - subject (sub) is present and not empty
|
||||
// - algorithm matches expected value to prevent algorithm switching attacks
|
||||
//
|
||||
// Returns an error if any validation fails.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Debug logging of validation parameters
|
||||
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
|
||||
// Debug logging of token header
|
||||
fmt.Printf("Token header: %+v\n", j.Header)
|
||||
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'alg' header")
|
||||
}
|
||||
// List of supported algorithms - should match those in verifySignature
|
||||
supportedAlgs := map[string]bool{
|
||||
"RS256": true, "RS384": true, "RS512": true,
|
||||
"PS256": true, "PS384": true, "PS512": true,
|
||||
"ES256": true, "ES384": true, "ES512": true,
|
||||
}
|
||||
if !supportedAlgs[alg] {
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
claims := j.Claims
|
||||
|
||||
if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil {
|
||||
// Debug logging of all claims
|
||||
fmt.Printf("Token claims: %+v\n", claims)
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'iss' claim")
|
||||
}
|
||||
if err := verifyIssuer(iss, issuerURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyAudience(claims["aud"].(string), clientID); err != nil {
|
||||
aud, ok := claims["aud"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'aud' claim")
|
||||
}
|
||||
if err := verifyAudience(aud, clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyExpiration(claims["exp"].(float64)); err != nil {
|
||||
exp, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing or invalid 'exp' claim")
|
||||
}
|
||||
if err := verifyExpiration(exp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := verifyIssuedAt(claims["iat"].(float64)); err != nil {
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing or invalid 'iat' claim")
|
||||
}
|
||||
if err := verifyIssuedAt(iat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate nbf (not before) claim if present
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if err := verifyNotBefore(nbf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate jti (JWT ID) claim if present
|
||||
if jti, ok := claims["jti"].(string); ok {
|
||||
// Could add replay detection here if needed
|
||||
_ = jti
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
if !ok || sub == "" {
|
||||
return fmt.Errorf("missing or empty 'sub' claim")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience validates the token's audience claim.
|
||||
// The audience can be either a single string or an array of strings.
|
||||
// For array audiences, the expected audience must match any one value.
|
||||
// Parameters:
|
||||
// - tokenAudience: The audience claim from the token
|
||||
// - expectedAudience: The expected audience value
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
|
||||
tokenAudience, expectedAudience)
|
||||
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
case []interface{}:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid 'aud' claim type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer validates the token's issuer claim.
|
||||
// The issuer URL must exactly match the expected issuer.
|
||||
// Parameters:
|
||||
// - tokenIssuer: The issuer claim from the token
|
||||
// - expectedIssuer: The expected issuer URL
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
|
||||
tokenIssuer, expectedIssuer)
|
||||
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
|
||||
tokenIssuer, expectedIssuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clock skew tolerance for time-based validations
|
||||
const clockSkewTolerance = 2 * time.Minute
|
||||
|
||||
// verifyExpiration checks if the token's expiration time has passed.
|
||||
// The expiration time is compared against the current time with clock skew tolerance.
|
||||
// Parameters:
|
||||
// - expiration: The expiration timestamp from the token
|
||||
//
|
||||
// Returns an error if the token has expired.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
if time.Now().After(expirationTime) {
|
||||
return fmt.Errorf("token has expired")
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
expirationTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens that expire exactly now
|
||||
if expirationTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.After(expirationTime) {
|
||||
return fmt.Errorf("token has expired (exp: %v, now: %v)",
|
||||
expirationTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
|
||||
// verifyIssuedAt validates the token's issued-at time.
|
||||
// Ensures the token wasn't issued in the future, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - issuedAt: The issued-at timestamp from the token
|
||||
//
|
||||
// Returns an error if the token was issued in the future.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
issuedAtTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens issued in the same second as current time
|
||||
if issuedAtTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
|
||||
issuedAtTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyNotBefore validates the token's not-before time if present.
|
||||
// Ensures the token is not used before its valid time period, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - notBefore: The not-before timestamp from the token
|
||||
//
|
||||
// Returns an error if the token is not yet valid.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
notBeforeTime := time.Unix(int64(notBefore), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
notBeforeTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens that become valid exactly now
|
||||
if notBeforeTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.Before(notBeforeTime) {
|
||||
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
|
||||
notBeforeTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifySignature validates the token's cryptographic signature.
|
||||
// Supports multiple signature algorithms:
|
||||
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
|
||||
// - RSA-PSS: PS256, PS384, PS512
|
||||
// - ECDSA: ES256, ES384, ES512
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The complete JWT token string
|
||||
// - publicKeyPEM: The PEM-encoded public key for verification
|
||||
// - alg: The signature algorithm identifier
|
||||
//
|
||||
// Returns an error if signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
|
||||
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
@@ -97,60 +375,42 @@ func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
switch pub := pubKey.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature handling
|
||||
keyBytes := (pub.Params().BitSize + 7) / 8
|
||||
if len(signature) != 2*keyBytes {
|
||||
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
|
||||
}
|
||||
r := new(big.Int).SetBytes(signature[:keyBytes])
|
||||
s := new(big.Int).SetBytes(signature[keyBytes:])
|
||||
|
||||
if ecdsa.Verify(pub, hashed, r, s) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid ECDSA signature")
|
||||
}
|
||||
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("RSA signature verification failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
return fmt.Errorf("invalid ECDSA signature length")
|
||||
}
|
||||
r.SetBytes(signature[:sigLen/2])
|
||||
s.SetBytes(signature[sigLen/2:])
|
||||
if ecdsa.Verify(pubKey, hashed, &r, &s) {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("invalid ECDSA signature")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
|
||||
default:
|
||||
return fmt.Errorf("unsupported public key type: %T", pub)
|
||||
return fmt.Errorf("unsupported public key type: %T", pubKey)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeSegment(seg string) (map[string]interface{}, error) {
|
||||
data, err := base64.RawURLEncoding.DecodeString(seg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode segment: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(data, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal segment: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -2,122 +2,149 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"runtime"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const ConstSessionTimeout = 86400
|
||||
const ConstSessionTimeout = 86400 // Session timeout in seconds
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
// JWTVerifier interface for JWT verification
|
||||
type JWTVerifier interface {
|
||||
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
}
|
||||
|
||||
// TraefikOidc is the main struct for the OIDC middleware
|
||||
type TraefikOidc struct {
|
||||
next http.Handler
|
||||
name string
|
||||
store sessions.Store
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
issuerURL string
|
||||
revocationURL string
|
||||
jwkCache *JWKCache
|
||||
tokenBlacklist *TokenBlacklist
|
||||
jwksURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
authURL string
|
||||
tokenURL string
|
||||
scopes []string
|
||||
limiter *rate.Limiter
|
||||
forceHTTPS bool
|
||||
scheme string
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
redirectURL string
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
excludedURLs map[string]struct{}
|
||||
allowedUserDomains map[string]struct{}
|
||||
next http.Handler
|
||||
name string
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
issuerURL string
|
||||
revocationURL string
|
||||
jwkCache JWKCacheInterface
|
||||
tokenBlacklist *TokenBlacklist
|
||||
jwksURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
authURL string
|
||||
tokenURL string
|
||||
scopes []string
|
||||
limiter *rate.Limiter
|
||||
forceHTTPS bool
|
||||
scheme string
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
excludedURLs map[string]struct{}
|
||||
allowedUserDomains map[string]struct{}
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
baseURL string
|
||||
postLogoutRedirectURI string
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
|
||||
// ProviderMetadata holds OIDC provider metadata
|
||||
type ProviderMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
EndSessionURL string `json:"end_session_endpoint"`
|
||||
}
|
||||
|
||||
// defaultExcludedURLs are the paths that are excluded from authentication
|
||||
var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
var newTicker = time.NewTicker
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
t.logger.Debugf("Verifying token")
|
||||
|
||||
// Rate limiting
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
// Check if token is blacklisted
|
||||
if t.tokenBlacklist.IsBlacklisted(token) {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
// Check if token is cached
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the token until it expires
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
t.tokenCache.Set(token, expirationTime)
|
||||
now := time.Now()
|
||||
duration := expirationTime.Sub(now)
|
||||
t.tokenCache.Set(token, jwt.Claims, duration)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
|
||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
|
||||
// Get JWKS
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
// Retrieve key ID and algorithm from JWT header
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
t.logger.Debugf("Token kid: %s", kid)
|
||||
|
||||
alg, ok := jwt.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing algorithm in token header")
|
||||
}
|
||||
t.logger.Debugf("Token alg: %s", alg)
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
@@ -125,85 +152,88 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchingKey == nil {
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
|
||||
|
||||
// Convert JWK to PEM format
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
}
|
||||
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
|
||||
t.logger.Errorf("Signature verification failed: %v", err)
|
||||
// Verify the signature
|
||||
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("Signature verified successfully")
|
||||
|
||||
// Verify standard claims
|
||||
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
||||
return fmt.Errorf("standard claim verification failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("Standard claims verified successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new instance of the OIDC middleware
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: ConstSessionTimeout,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
if config == nil {
|
||||
config = CreateConfig()
|
||||
}
|
||||
|
||||
metadata, err := discoverProviderMetadata(config.ProviderURL, http.Client{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
|
||||
// Generate default session encryption key if not provided
|
||||
if config.SessionEncryptionKey == "" {
|
||||
// Generate a fixed key for Traefik Hub testing
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
||||
} else {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup HTTP client
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: time.Second * 30,
|
||||
Transport: transport,
|
||||
var httpClient *http.Client
|
||||
if config.HTTPClient != nil {
|
||||
httpClient = config.HTTPClient
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
t := &TraefikOidc{
|
||||
next: next,
|
||||
name: name,
|
||||
store: store,
|
||||
redirURLPath: config.CallbackURL,
|
||||
logoutURLPath: func() string {
|
||||
if config.LogoutURL == "" {
|
||||
@@ -211,38 +241,37 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
return config.LogoutURL
|
||||
}(),
|
||||
issuerURL: metadata.Issuer,
|
||||
revocationURL: metadata.RevokeURL,
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
jwksURL: metadata.JWKSURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
authURL: metadata.AuthURL,
|
||||
tokenURL: metadata.TokenURL,
|
||||
scopes: config.Scopes,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
logger: NewLogger(config.LogLevel),
|
||||
excludedURLs: func() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, url := range config.ExcludedURLs {
|
||||
m[url] = struct{}{}
|
||||
postLogoutRedirectURI: func() string {
|
||||
if config.PostLogoutRedirectURI == "" {
|
||||
return "/"
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
redirectURL: "",
|
||||
allowedUserDomains: func() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, domain := range config.AllowedUserDomains {
|
||||
m[domain] = struct{}{}
|
||||
}
|
||||
return m
|
||||
return config.PostLogoutRedirectURI
|
||||
}(),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
scopes: config.Scopes,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
}
|
||||
// add defaultExcludedURLs to excludedURLs
|
||||
// Assign the initialized logger
|
||||
t.logger = logger
|
||||
|
||||
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// Add default excluded URLs
|
||||
for k, v := range defaultExcludedURLs {
|
||||
t.excludedURLs[k] = v
|
||||
}
|
||||
@@ -250,11 +279,93 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
t.tokenVerifier = t
|
||||
t.jwtVerifier = t
|
||||
t.startTokenCleanup()
|
||||
go t.initializeMetadata(config.ProviderURL)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func discoverProviderMetadata(providerURL string, httpClient http.Client) (*ProviderMetadata, error) {
|
||||
// initializeMetadata discovers and initializes the provider metadata
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Debug("Starting provider metadata discovery")
|
||||
|
||||
// Keep retrying until successful
|
||||
backoff := time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
for {
|
||||
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
|
||||
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
|
||||
time.Sleep(backoff)
|
||||
|
||||
// Exponential backoff with max
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
|
||||
// Only close channel on success
|
||||
close(t.initComplete)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Error("Received nil metadata, retrying")
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// discoverProviderMetadata fetches the OIDC provider metadata
|
||||
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
|
||||
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
maxRetries := 5
|
||||
baseDelay := 1 * time.Second
|
||||
maxDelay := 30 * time.Second
|
||||
totalTimeout := 5 * time.Minute
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if time.Since(start) > totalTimeout {
|
||||
l.Errorf("Timeout exceeded while fetching provider metadata")
|
||||
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
|
||||
}
|
||||
|
||||
metadata, err := fetchMetadata(wellKnownURL, httpClient)
|
||||
if err == nil {
|
||||
l.Debug("Provider metadata fetched successfully")
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Exponential backoff
|
||||
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
l.Errorf("Max retries exceeded while fetching provider metadata")
|
||||
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
|
||||
}
|
||||
|
||||
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
|
||||
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
|
||||
resp, err := httpClient.Get(wellKnownURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
@@ -276,117 +387,149 @@ func discoverProviderMetadata(providerURL string, httpClient http.Client) (*Prov
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the main handler for the middleware
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Check if the URL is excluded first
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
if t.issuerURL == "" {
|
||||
t.logger.Error("OIDC provider metadata initialization failed")
|
||||
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
t.logger.Debug("Request cancelled")
|
||||
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
t.logger.Error("Timeout waiting for OIDC initialization")
|
||||
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if URL is excluded
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
t.scheme = t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
// Get session
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
|
||||
// Obtain a new session and clear any residual session cookies
|
||||
session, _ = t.sessionManager.GetSession(req)
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Build redirect URL
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Initiate authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Build redirect URL
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Handle special URLs
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if t.redirectURL == "" {
|
||||
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
|
||||
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
|
||||
}
|
||||
|
||||
// Only get or create a session if the URL is not excluded
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
authSuccess, originalPath := t.handleCallback(rw, req)
|
||||
if authSuccess {
|
||||
http.Redirect(rw, req, originalPath, http.StatusFound)
|
||||
return
|
||||
}
|
||||
if !authSuccess && originalPath == "invalid-state-param" {
|
||||
// redirect to the root path so that the user can try again
|
||||
// this usually happens when user was previously authenticated
|
||||
// and the session was cleared, but user tries to refresh the page
|
||||
// and different traefik instance is used.
|
||||
http.Redirect(rw, req, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
http.Error(rw, "Authentication failed", http.StatusUnauthorized)
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Check authentication status
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
|
||||
if expired || !authenticated {
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
if expired {
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if !refreshed {
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if authenticated {
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Errorf("No id_token found in session")
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := extractClaims(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Debugf("No email found in token claims")
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
// Process authenticated request
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Debug("No email found in session")
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// If the user is not authenticated, initiate authentication
|
||||
t.initiateAuthentication(rw, req, session, t.redirectURL)
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed roles and groups
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
// Process the request
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// determineExcludedURL checks if the current request URL is in the excluded list
|
||||
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
for excludedURL := range t.excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
}
|
||||
t.logger.Debug("URL is not excluded - got %s", currentRequest)
|
||||
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
|
||||
return false
|
||||
}
|
||||
|
||||
// determineScheme determines the scheme (http or https) of the request
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
if t.forceHTTPS {
|
||||
return "https"
|
||||
@@ -400,6 +543,7 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
return "http"
|
||||
}
|
||||
|
||||
// determineHost determines the host of the request
|
||||
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
@@ -407,101 +551,103 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
return req.Host
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
|
||||
authenticated, _ := session.Values["authenticated"].(bool)
|
||||
if !authenticated {
|
||||
// isUserAuthenticated checks if the user is authenticated
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
return false, false, false
|
||||
}
|
||||
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
accessToken := session.GetAccessToken()
|
||||
if accessToken == "" {
|
||||
t.logger.Debug("No access token found in session")
|
||||
return false, false, true // Session is invalid, consider it expired
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
t.logger.Errorf("Token verification failed: %v", err)
|
||||
return false, false, true // Token is invalid, consider it expired
|
||||
}
|
||||
|
||||
claims, err := extractClaims(idToken)
|
||||
claims, err := extractClaims(accessToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
return false, false, true // Can't read claims, consider it expired
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
exp, ok := claims["exp"].(float64)
|
||||
expClaim, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Errorf("Failed to get expiration time from claims")
|
||||
return false, false, true // No expiration, consider it expired
|
||||
t.logger.Error("Failed to get expiration time from claims")
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
expTime := int64(exp)
|
||||
expTime := int64(expClaim)
|
||||
|
||||
if now > expTime {
|
||||
return false, false, true // Token has expired
|
||||
t.logger.Debug("Token has expired")
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
gracePeriod := time.Minute * 5
|
||||
if time.Now().Add(gracePeriod).Unix() > expTime {
|
||||
if now+int64(gracePeriod.Seconds()) > expTime {
|
||||
t.logger.Debug("Token will expire soon")
|
||||
return true, true, false // Token will expire soon, needs refresh
|
||||
}
|
||||
|
||||
return true, false, false // Token is valid and not expiring soon
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
// defaultInitiateAuthentication initiates the authentication process
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.New().String()
|
||||
session.Values["csrf"] = csrfToken
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build and redirect to auth URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// verifyToken verifies the token using the token verifier
|
||||
func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
var authURLBuilder strings.Builder
|
||||
|
||||
// buildAuthURL constructs the authentication URL
|
||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
authURLBuilder.Reset()
|
||||
authURLBuilder.Grow(256) // Pre-allocate some space
|
||||
authURLBuilder.WriteString(t.authURL)
|
||||
authURLBuilder.WriteString("?client_id=")
|
||||
authURLBuilder.WriteString(t.clientID)
|
||||
authURLBuilder.WriteString("&response_type=code&redirect_uri=")
|
||||
authURLBuilder.WriteString(url.QueryEscape(redirectURL))
|
||||
authURLBuilder.WriteString("&state=")
|
||||
authURLBuilder.WriteString(state)
|
||||
authURLBuilder.WriteString("&nonce=")
|
||||
authURLBuilder.WriteString(nonce)
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", t.clientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("redirect_uri", redirectURL)
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
if len(t.scopes) > 0 {
|
||||
authURLBuilder.WriteString("&scope=")
|
||||
authURLBuilder.WriteString(strings.Join(t.scopes, "+"))
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
|
||||
return authURLBuilder.String()
|
||||
return t.authURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
ticker := newTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
t.logger.Debug("Cleaning up token cache")
|
||||
@@ -511,26 +657,23 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}()
|
||||
}
|
||||
|
||||
// RevokeToken adds the token to the blacklist
|
||||
func (t *TraefikOidc) RevokeToken(token string) {
|
||||
// Remove from cache
|
||||
t.tokenCache.Delete(token)
|
||||
|
||||
// Add to blacklist
|
||||
claims, err := extractClaims(token)
|
||||
if err == nil {
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
t.tokenBlacklist.Add(token, expTime)
|
||||
}
|
||||
}
|
||||
// Add to blacklist with default expiration
|
||||
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
|
||||
t.tokenBlacklist.Add(token, expiry)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
|
||||
// RevokeTokenWithProvider revokes the token with the provider
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
t.logger.Debugf("Revoking token with provider")
|
||||
|
||||
data := url.Values{
|
||||
"token": {token},
|
||||
"token_type_hint": {"access_token", "refresh_token"},
|
||||
"token_type_hint": {tokenType},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
}
|
||||
@@ -561,9 +704,12 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
|
||||
refreshToken, ok := session.Values["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
// refreshToken refreshes the user's token
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
t.logger.Debug("Refreshing token")
|
||||
refreshToken := session.GetRefreshToken()
|
||||
if refreshToken == "" {
|
||||
t.logger.Debug("No refresh token found in session")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -573,8 +719,17 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return false
|
||||
}
|
||||
|
||||
session.Values["id_token"] = newToken.IDToken
|
||||
session.Values["refresh_token"] = newToken.RefreshToken
|
||||
// Verify the new access token
|
||||
if err := t.verifyToken(newToken.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify new access token: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Update session with new tokens
|
||||
session.SetAccessToken(newToken.IDToken)
|
||||
session.SetRefreshToken(newToken.RefreshToken)
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save refreshed session: %v", err)
|
||||
return false
|
||||
@@ -583,6 +738,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// isAllowedDomain checks if the user's email domain is allowed
|
||||
func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
if len(t.allowedUserDomains) == 0 {
|
||||
return true // If no domains are specified, all are allowed
|
||||
@@ -597,3 +753,59 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
_, ok := t.allowedUserDomains[domain]
|
||||
return ok
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts groups and roles from the id_token
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
||||
}
|
||||
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups with type checking
|
||||
if groupsClaim, exists := claims["groups"]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("groups claim is not an array")
|
||||
}
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles with type checking
|
||||
if rolesClaim, exists := claims["roles"]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("roles claim is not an array")
|
||||
}
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groups, roles, nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
// If the path is already a full URL, return it as-is
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
|
||||
// Ensure the path starts with a forward slash
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
+40
-161
@@ -1,178 +1,57 @@
|
||||
// main_bench_test.go
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func BenchmarkServeHTTP_AuthenticatedUser(b *testing.B) {
|
||||
suite := new(TraefikOidcTestSuite)
|
||||
suite.SetupTest()
|
||||
// BenchmarkOIDCMiddleware benchmarks the OIDC middleware's ability to handle concurrent requests.
|
||||
func BenchmarkOIDCMiddleware(b *testing.B) {
|
||||
// Setup test environment
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["authenticated"] = true
|
||||
ts := &TestSuite{}
|
||||
ts.Setup()
|
||||
ts.token = "valid.jwt.token"
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
}
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
encodedClaims := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
mockToken := fmt.Sprintf("header.%s.signature", encodedClaims)
|
||||
session.Values["id_token"] = mockToken
|
||||
|
||||
suite.mockStore.On("Get", req, cookieName).Return(session, nil)
|
||||
suite.mockStore.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
suite.mockTokenVerifier.On("VerifyToken", mockToken).Return(nil)
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Define the handler with OIDC middleware
|
||||
ts.tOidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
suite.oidc.next = nextHandler
|
||||
|
||||
// Create test server
|
||||
server := httptest.NewServer(ts.tOidc.next)
|
||||
defer server.Close()
|
||||
|
||||
// Prepare HTTP client
|
||||
client := &http.Client{}
|
||||
|
||||
// Reset timer to exclude setup time
|
||||
b.ResetTimer()
|
||||
|
||||
// Run benchmark
|
||||
for i := 0; i < b.N; i++ {
|
||||
rw := httptest.NewRecorder()
|
||||
suite.oidc.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkVerifyToken(b *testing.B) {
|
||||
suite := new(TraefikOidcTestSuite)
|
||||
suite.SetupTest()
|
||||
|
||||
token := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Rfa2lkIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE1MTYyMzkxMjJ9.ZmFrZV9zaWduYXR1cmU"
|
||||
suite.mockTokenVerifier.On("VerifyToken", token).Return(nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
suite.oidc.verifyToken(token)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBuildAuthURL(b *testing.B) {
|
||||
suite := new(TraefikOidcTestSuite)
|
||||
suite.SetupTest()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
suite.oidc.buildAuthURL("http://example.com/callback", "test_state", "test_nonce")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJWKToPEM(b *testing.B) {
|
||||
jwk := &JWK{
|
||||
N: base64.RawURLEncoding.EncodeToString(big.NewInt(12345).Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
jwkToPEM(jwk)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenBlacklist_Add(b *testing.B) {
|
||||
tb := NewTokenBlacklist()
|
||||
token := "test_token"
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tb.Add(token, expiration)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenBlacklist_IsBlacklisted(b *testing.B) {
|
||||
tb := NewTokenBlacklist()
|
||||
token := "test_token"
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
tb.Add(token, expiration)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tb.IsBlacklisted(token)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenCache_Set(b *testing.B) {
|
||||
tc := NewTokenCache()
|
||||
token := "test_token"
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tc.Set(token, expiration)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenCache_Get(b *testing.B) {
|
||||
tc := NewTokenCache()
|
||||
token := "test_token"
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
tc.Set(token, expiration)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tc.Get(token)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractClaims(b *testing.B) {
|
||||
tokenString := "header.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.signature"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
extractClaims(tokenString)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDetermineScheme(b *testing.B) {
|
||||
suite := new(TraefikOidcTestSuite)
|
||||
suite.SetupTest()
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
suite.oidc.determineScheme(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDetermineHost(b *testing.B) {
|
||||
suite := new(TraefikOidcTestSuite)
|
||||
suite.SetupTest()
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Host", "forwarded.example.com")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
suite.oidc.determineHost(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsUserAuthenticated(b *testing.B) {
|
||||
suite := new(TraefikOidcTestSuite)
|
||||
suite.SetupTest()
|
||||
session := sessions.NewSession(suite.mockStore, cookieName)
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature"
|
||||
|
||||
suite.mockTokenVerifier.On("VerifyToken", "valid.eyJleHAiOjk5OTk5OTk5OTl9.signature").Return(nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
suite.oidc.isUserAuthenticated(session)
|
||||
// Create new request
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Set necessary headers or cookies
|
||||
req.Header.Set("Authorization", "Bearer "+ts.token)
|
||||
|
||||
// Send the request
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Close response body
|
||||
resp.Body.Close()
|
||||
|
||||
// Check response status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b.Errorf("Unexpected status code: got %v, want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+1620
-967
File diff suppressed because it is too large
Load Diff
+639
@@ -0,0 +1,639 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length
|
||||
func generateSecureRandomString(length int) string {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
panic("failed to generate random string")
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// Cookie names and configuration constants used for session management
|
||||
const (
|
||||
// Using fixed prefixes for consistent cookie naming across restarts
|
||||
mainCookieName = "_oidc_raczylo_m"
|
||||
accessTokenCookie = "_oidc_raczylo_a"
|
||||
refreshTokenCookie = "_oidc_raczylo_r"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxCookieSize is the maximum size for each cookie chunk.
|
||||
// This value is calculated to ensure the final cookie size stays within browser limits:
|
||||
// 1. Browser cookie size limit is typically 4096 bytes
|
||||
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
|
||||
// 3. Calculation:
|
||||
// - Let x be the chunk size
|
||||
// - After encryption: x + 28 bytes
|
||||
// - After base64: ((x + 28) * 4/3) bytes
|
||||
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
|
||||
// - Solving for x: x ≤ 3044
|
||||
// 4. We use 2000 as a conservative limit to account for cookie metadata
|
||||
maxCookieSize = 2000
|
||||
|
||||
// absoluteSessionTimeout defines the maximum lifetime of a session
|
||||
// regardless of activity (24 hours)
|
||||
absoluteSessionTimeout = 24 * time.Hour
|
||||
|
||||
// minEncryptionKeyLength defines the minimum length for the encryption key
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// compressToken compresses a token using gzip and base64 encodes it
|
||||
func compressToken(token string) string {
|
||||
var b bytes.Buffer
|
||||
gz := gzip.NewWriter(&b)
|
||||
if _, err := gz.Write([]byte(token)); err != nil {
|
||||
return token // fallback to uncompressed on error
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
return token
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||
}
|
||||
|
||||
// decompressToken decompresses a base64 encoded gzipped token
|
||||
func decompressToken(compressed string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
return compressed // return as-is if not base64
|
||||
}
|
||||
|
||||
gz, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return compressed
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gz)
|
||||
if err != nil {
|
||||
return compressed
|
||||
}
|
||||
|
||||
return string(decompressed)
|
||||
}
|
||||
|
||||
// SessionManager handles the management of multiple session cookies for OIDC authentication.
|
||||
// It provides functionality for storing and retrieving authentication state, tokens,
|
||||
// and other session-related data across multiple cookies to handle large tokens.
|
||||
type SessionManager struct {
|
||||
// store is the underlying session store for cookie management
|
||||
store sessions.Store
|
||||
|
||||
// forceHTTPS enforces secure cookie attributes regardless of request scheme
|
||||
forceHTTPS bool
|
||||
|
||||
// logger provides structured logging capabilities
|
||||
logger *Logger
|
||||
|
||||
// sessionPool is a sync.Pool for reusing SessionData objects
|
||||
sessionPool sync.Pool
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager with the specified configuration.
|
||||
// Parameters:
|
||||
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
|
||||
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
|
||||
// - logger: Logger instance for recording session-related events
|
||||
//
|
||||
// The manager handles session creation, storage, and cookie security settings.
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
||||
// Validate encryption key length
|
||||
if len(encryptionKey) < minEncryptionKeyLength {
|
||||
panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength))
|
||||
}
|
||||
|
||||
sm := &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
forceHTTPS: forceHTTPS,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize session pool
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
return &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS
|
||||
//
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
// - Using SameSite=Lax for CSRF protection
|
||||
// - Set with appropriate timeout and path settings
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure || sm.forceHTTPS,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(absoluteSessionTimeout.Seconds()),
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession retrieves all session data for the current request.
|
||||
// It loads the main session and token sessions, including any chunked token data,
|
||||
// and combines them into a single SessionData structure for easy access.
|
||||
// Returns an error if any session component cannot be loaded.
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
// Get session from pool
|
||||
sessionData := sm.sessionPool.Get().(*SessionData)
|
||||
sessionData.request = r
|
||||
|
||||
var err error
|
||||
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
|
||||
if err != nil {
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("failed to get main session: %w", err)
|
||||
}
|
||||
|
||||
// Check for absolute session timeout
|
||||
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
||||
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
|
||||
// Session has expired
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
}
|
||||
|
||||
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
|
||||
if err != nil {
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("failed to get access token session: %w", err)
|
||||
}
|
||||
|
||||
sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie)
|
||||
if err != nil {
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Clear and reuse chunk maps
|
||||
for k := range sessionData.accessTokenChunks {
|
||||
delete(sessionData.accessTokenChunks, k)
|
||||
}
|
||||
for k := range sessionData.refreshTokenChunks {
|
||||
delete(sessionData.refreshTokenChunks, k)
|
||||
}
|
||||
|
||||
// Retrieve chunked token sessions
|
||||
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves all session chunks for a given token type.
|
||||
// Parameters:
|
||||
// - r: The HTTP request
|
||||
// - baseName: The base name for the token's session cookies
|
||||
// - chunks: Map to store the chunks in
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
session, err := sm.store.Get(r, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
// No more sessions
|
||||
break
|
||||
}
|
||||
chunks[i] = session
|
||||
}
|
||||
}
|
||||
|
||||
// SessionData holds all session information for an authenticated user.
|
||||
// It manages multiple session cookies to handle the main session state
|
||||
// and potentially large access and refresh tokens that may need to be
|
||||
// split across multiple cookies due to browser size limitations.
|
||||
type SessionData struct {
|
||||
// manager is the SessionManager that created this SessionData
|
||||
manager *SessionManager
|
||||
|
||||
// request is the current HTTP request associated with this session
|
||||
request *http.Request
|
||||
|
||||
// mainSession stores authentication state and basic user info
|
||||
mainSession *sessions.Session
|
||||
|
||||
// accessSession stores the primary access token cookie
|
||||
accessSession *sessions.Session
|
||||
|
||||
// refreshSession stores the primary refresh token cookie
|
||||
refreshSession *sessions.Session
|
||||
|
||||
// accessTokenChunks stores additional chunks of the access token
|
||||
// when it exceeds the maximum cookie size
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
}
|
||||
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
// It saves the main session, token sessions, and any token chunks,
|
||||
// applying appropriate security options to each cookie. All cookies
|
||||
// are saved with consistent security settings based on the request scheme.
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
// Set options for all sessions
|
||||
options := sd.manager.getSessionOptions(isSecure)
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
|
||||
// Save main session
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token session
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
|
||||
// Save refresh token session
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token chunks
|
||||
for _, session := range sd.accessTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save refresh token chunks
|
||||
for _, session := range sd.refreshTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
// This is typically used during logout to ensure all session data is properly cleaned up.
|
||||
// It handles both main session data and any token chunks that may exist.
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
|
||||
for k := range sd.mainSession.Values {
|
||||
delete(sd.mainSession.Values, k)
|
||||
}
|
||||
for k := range sd.accessSession.Values {
|
||||
delete(sd.accessSession.Values, k)
|
||||
}
|
||||
for k := range sd.refreshSession.Values {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
|
||||
// Clear chunk sessions
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
|
||||
var err error
|
||||
if w != nil {
|
||||
err = sd.Save(r, w)
|
||||
}
|
||||
|
||||
// Return session to pool
|
||||
sd.manager.sessionPool.Put(sd)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
// It expires the cookies and removes all stored values to ensure
|
||||
// no token data remains after logout or token invalidation.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
// Returns true if the user has successfully completed OIDC authentication
|
||||
// and the session hasn't expired, false otherwise.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
if !auth {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check session expiration
|
||||
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
|
||||
}
|
||||
|
||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||
// This should be called after successful OIDC authentication or during logout.
|
||||
// Session ID rotation helps prevent session fixation attacks.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
if value {
|
||||
// Generate new session ID and set creation time
|
||||
sd.mainSession.ID = generateSecureRandomString(32)
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
}
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
// If the token was split into chunks due to size limitations, it will
|
||||
// automatically reassemble the complete token from all chunks.
|
||||
// Returns an empty string if no token is found.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.accessSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
if len(sd.accessTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.accessTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
token = strings.Join(chunks, "")
|
||||
compressed, _ := sd.accessSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// SetAccessToken stores the access token in the session.
|
||||
// If the token exceeds maxCookieSize, it is automatically compressed and split into
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
// Expire the cookie
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
// Save expired cookie
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Expire any existing chunk cookies first
|
||||
if sd.request != nil {
|
||||
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token
|
||||
compressed := compressToken(token)
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
sd.accessSession.Values["token"] = compressed
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks
|
||||
sd.accessSession.Values["token"] = ""
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.accessTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
// If the token was split into chunks due to size limitations, it will
|
||||
// automatically reassemble the complete token from all chunks.
|
||||
// Returns an empty string if no token is found.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
if len(sd.refreshTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.refreshTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
token = strings.Join(chunks, "")
|
||||
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
// If the token exceeds maxCookieSize, it is automatically compressed and split into
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
// Expire the cookie
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
// Save expired cookie
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Expire any existing chunk cookies first
|
||||
if sd.request != nil {
|
||||
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token
|
||||
compressed := compressToken(token)
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
sd.refreshSession.Values["token"] = compressed
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.refreshTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
// This is used internally to handle large tokens that exceed cookie size limits.
|
||||
// Parameters:
|
||||
// - s: The string to split
|
||||
// - chunkSize: Maximum size of each chunk
|
||||
//
|
||||
// Returns an array of string chunks, each no larger than chunkSize.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
if len(s) > chunkSize {
|
||||
chunks = append(chunks, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
} else {
|
||||
chunks = append(chunks, s)
|
||||
break
|
||||
}
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
// This token is used to prevent cross-site request forgery attacks
|
||||
// by ensuring requests originate from the authenticated user.
|
||||
// Returns an empty string if no CSRF token is found.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
// This should be called when initiating authentication to generate
|
||||
// a new token for the authentication flow.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
// The nonce is used to prevent replay attacks in the OIDC flow
|
||||
// by ensuring the token received matches the authentication request.
|
||||
// Returns an empty string if no nonce is found.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
// This should be called when initiating authentication to generate
|
||||
// a new nonce for the OIDC authentication flow.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
// The email is typically extracted from the OIDC ID token claims.
|
||||
// Returns an empty string if no email is found.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail stores the user's email address in the session.
|
||||
// This should be called after successful authentication when
|
||||
// processing the OIDC ID token claims.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath retrieves the original request path that triggered
|
||||
// the authentication flow. This is used to redirect the user back
|
||||
// to their intended destination after successful authentication.
|
||||
// Returns an empty string if no path was stored.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath stores the original request path that triggered
|
||||
// the authentication flow. This should be called before redirecting
|
||||
// to the OIDC provider to remember where to send the user afterward.
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
+388
@@ -0,0 +1,388 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize random seed
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// generateRandomString creates a random string of specified length
|
||||
func generateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// TestTokenCompression tests the token compression functionality
|
||||
func TestTokenCompression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantSize int // Expected size after compression (approximate)
|
||||
}{
|
||||
{
|
||||
name: "Short token",
|
||||
token: "shorttoken",
|
||||
wantSize: 50, // Base64 encoded gzip has overhead for small content
|
||||
},
|
||||
{
|
||||
name: "Repeating content",
|
||||
token: strings.Repeat("abcdef", 1000),
|
||||
wantSize: 100, // Should compress well due to repetition
|
||||
},
|
||||
{
|
||||
name: "Random content",
|
||||
token: generateRandomString(1000),
|
||||
wantSize: 2000, // Random content won't compress much
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
compressed := compressToken(tt.token)
|
||||
decompressed := decompressToken(compressed)
|
||||
|
||||
// Only verify compression ratio for non-short tokens
|
||||
if len(tt.token) > 100 {
|
||||
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
|
||||
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
|
||||
|
||||
if compressionRatio > 1.1 { // Allow up to 10% size increase
|
||||
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
|
||||
len(tt.token), len(compressed), compressionRatio)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify decompression restores original
|
||||
if decompressed != tt.token {
|
||||
t.Error("Decompression failed to restore original token")
|
||||
}
|
||||
|
||||
// Verify approximate compression ratio
|
||||
if len(compressed) > tt.wantSize*2 {
|
||||
t.Errorf("Compression ratio worse than expected: got=%d, want<%d", len(compressed), tt.wantSize*2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionManager tests the SessionManager functionality
|
||||
|
||||
func TestCookiePrefix(t *testing.T) {
|
||||
// Create a session and verify cookie names
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set some data to ensure cookies are created
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken("test_token")
|
||||
session.SetRefreshToken("test_refresh_token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Check cookie prefixes
|
||||
cookies := rr.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
|
||||
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRefreshCleanup(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set a large token that will be split into chunks
|
||||
largeToken := strings.Repeat("x", 5000)
|
||||
session.SetAccessToken(largeToken)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get initial cookies
|
||||
initialCookies := rr.Result().Cookies()
|
||||
|
||||
// Create a new request with the initial cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range initialCookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
newRr := httptest.NewRecorder()
|
||||
|
||||
// Get session with cookies and set a new token
|
||||
newSession, err := sm.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
// Create a response recorder for expired cookies
|
||||
expiredRr := httptest.NewRecorder()
|
||||
|
||||
// Expire old chunk cookies
|
||||
newSession.expireAccessTokenChunks(expiredRr)
|
||||
|
||||
// Set a smaller token that won't need chunks
|
||||
newSession.SetAccessToken("small_token")
|
||||
|
||||
// Save session with new token
|
||||
if err := newSession.Save(newReq, newRr); err != nil {
|
||||
t.Fatalf("Failed to save new session: %v", err)
|
||||
}
|
||||
|
||||
// Check cookies in response where old cookies are expired
|
||||
intermediateResponse := expiredRr.Result()
|
||||
intermediateCount := 0
|
||||
chunkCount := 0
|
||||
expiredCount := 0
|
||||
|
||||
for _, cookie := range intermediateResponse.Cookies() {
|
||||
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
|
||||
chunkCount++
|
||||
if cookie.MaxAge < 0 {
|
||||
expiredCount++
|
||||
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
} else if cookie.MaxAge >= 0 {
|
||||
intermediateCount++
|
||||
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
|
||||
// All chunk cookies should be expired
|
||||
if chunkCount > 0 && chunkCount != expiredCount {
|
||||
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
|
||||
}
|
||||
|
||||
// Should have fewer active cookies after setting smaller token
|
||||
if intermediateCount >= len(initialCookies) {
|
||||
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticated bool
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
expectedCookieCount int
|
||||
wantCompressed bool // Whether tokens should be compressed
|
||||
}{
|
||||
{
|
||||
name: "Short tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: "shortaccesstoken",
|
||||
refreshToken: "shortrefreshtoken",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "Long tokens exceeding 4096 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 5000),
|
||||
refreshToken: strings.Repeat("y", 6000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "REALLY long tokens, exceeding 25000 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 25000),
|
||||
refreshToken: strings.Repeat("y", 25000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated session",
|
||||
authenticated: false,
|
||||
email: "",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
wantCompressed: false,
|
||||
},
|
||||
{
|
||||
name: "Random content tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: generateRandomString(5000),
|
||||
refreshToken: generateRandomString(5000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
|
||||
wantCompressed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set and compression is used when appropriate
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != tc.expectedCookieCount {
|
||||
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
|
||||
}
|
||||
|
||||
// Verify compression is working by checking token sizes
|
||||
for _, cookie := range cookies {
|
||||
if strings.Contains(cookie.Name, accessTokenCookie) {
|
||||
// Get original and stored sizes
|
||||
originalSize := len(tc.accessToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
// For large tokens, verify some compression occurred
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
|
||||
originalSize := len(tc.refreshToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 {
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := ts.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
// Verify session values
|
||||
if newSession.GetAuthenticated() != tc.authenticated {
|
||||
t.Errorf("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != tc.email {
|
||||
t.Errorf("Expected email %s, got %s", tc.email, email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != tc.accessToken {
|
||||
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
|
||||
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
|
||||
}
|
||||
|
||||
// Verify session pooling by checking if the session is reused
|
||||
session2, _ := ts.sessionManager.GetSession(newReq)
|
||||
if session2 == newSession {
|
||||
t.Error("Session not properly pooled")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
|
||||
count := 3 // main, access, refresh
|
||||
|
||||
// Helper to calculate chunks for compressed token
|
||||
calculateChunks := func(token string) int {
|
||||
// Compress token (matching the actual implementation)
|
||||
compressed := compressToken(token)
|
||||
|
||||
// If compressed token fits in one cookie, no additional chunks needed
|
||||
if len(compressed) <= maxCookieSize {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate chunks needed for compressed token
|
||||
return len(splitIntoChunks(compressed, maxCookieSize))
|
||||
}
|
||||
|
||||
// Add chunks for access token if needed
|
||||
accessChunks := calculateChunks(accessToken)
|
||||
if accessChunks > 0 {
|
||||
count += accessChunks
|
||||
}
|
||||
|
||||
// Add chunks for refresh token if needed
|
||||
refreshChunks := calculateChunks(refreshToken)
|
||||
if refreshChunks > 0 {
|
||||
count += refreshChunks
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
+208
-36
@@ -5,84 +5,235 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "_raczylo_oidc"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC middleware.
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Scopes []string `json:"scopes"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
// ProviderURL is the base URL of the OIDC provider (required)
|
||||
// Example: https://accounts.google.com
|
||||
ProviderURL string `json:"providerURL"`
|
||||
|
||||
// RevocationURL is the endpoint for revoking tokens (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
|
||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||
// Example: /oauth2/callback
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
|
||||
// LogoutURL is the path for handling logout requests (optional)
|
||||
// If not provided, it will be set to CallbackURL + "/logout"
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
|
||||
// ClientID is the OAuth 2.0 client identifier (required)
|
||||
ClientID string `json:"clientID"`
|
||||
|
||||
// ClientSecret is the OAuth 2.0 client secret (required)
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
|
||||
// Scopes defines the OAuth 2.0 scopes to request (optional)
|
||||
// Defaults to ["openid", "profile", "email"] if not provided
|
||||
Scopes []string `json:"scopes"`
|
||||
|
||||
// LogLevel sets the logging verbosity (optional)
|
||||
// Valid values: "debug", "info", "error"
|
||||
// Default: "info"
|
||||
LogLevel string `json:"logLevel"`
|
||||
|
||||
// SessionEncryptionKey is used to encrypt session data (required)
|
||||
// Must be a secure random string
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
|
||||
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
|
||||
// Default: false
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
|
||||
// RateLimit sets the maximum number of requests per second (optional)
|
||||
// Default: 100
|
||||
RateLimit int `json:"rateLimit"`
|
||||
|
||||
// ExcludedURLs lists paths that bypass authentication (optional)
|
||||
// Example: ["/health", "/metrics"]
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
|
||||
// AllowedUserDomains restricts access to specific email domains (optional)
|
||||
// Example: ["company.com", "subsidiary.com"]
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
|
||||
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
|
||||
// Example: ["admin", "developer"]
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
|
||||
// OIDCEndSessionURL is the provider's end session endpoint (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
|
||||
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
|
||||
// Default: "/"
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultRateLimit defines the default rate limit for requests per second
|
||||
DefaultRateLimit = 100
|
||||
|
||||
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
|
||||
MinRateLimit = 10
|
||||
|
||||
// DefaultLogLevel defines the default logging level
|
||||
DefaultLogLevel = "info"
|
||||
|
||||
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
|
||||
MinSessionEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// CreateConfig creates a new Config with secure default values.
|
||||
// Default values are set for optional fields:
|
||||
// - Scopes: ["openid", "profile", "email"]
|
||||
// - LogLevel: "info"
|
||||
// - LogoutURL: CallbackURL + "/logout"
|
||||
// - RateLimit: 100 requests per second
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{}
|
||||
|
||||
if c.Scopes == nil {
|
||||
c.Scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
if c.LogLevel == "" {
|
||||
c.LogLevel = "info"
|
||||
}
|
||||
|
||||
if c.LogoutURL == "" {
|
||||
c.LogoutURL = c.CallbackURL + "/logout"
|
||||
}
|
||||
|
||||
if c.RateLimit == 0 {
|
||||
c.RateLimit = 100
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate performs validation checks on the Config.
|
||||
// It ensures all required fields are set and have valid values.
|
||||
// Returns an error if any validation check fails.
|
||||
func (c *Config) Validate() error {
|
||||
// Validate provider URL
|
||||
if c.ProviderURL == "" {
|
||||
return fmt.Errorf("providerURL is required")
|
||||
}
|
||||
if !isValidSecureURL(c.ProviderURL) {
|
||||
return fmt.Errorf("providerURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate callback URL
|
||||
if c.CallbackURL == "" {
|
||||
return fmt.Errorf("callbackURL is required")
|
||||
}
|
||||
if !strings.HasPrefix(c.CallbackURL, "/") {
|
||||
return fmt.Errorf("callbackURL must start with /")
|
||||
}
|
||||
|
||||
// Validate client credentials
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("clientID is required")
|
||||
}
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required")
|
||||
}
|
||||
|
||||
// Validate session encryption key
|
||||
if c.SessionEncryptionKey == "" {
|
||||
return fmt.Errorf("sessionEncryptionKey is required")
|
||||
}
|
||||
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
|
||||
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Validate log level
|
||||
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
|
||||
return fmt.Errorf("logLevel must be one of: debug, info, error")
|
||||
}
|
||||
|
||||
// Validate excluded URLs
|
||||
for _, url := range c.ExcludedURLs {
|
||||
if !strings.HasPrefix(url, "/") {
|
||||
return fmt.Errorf("excluded URL must start with /: %s", url)
|
||||
}
|
||||
if strings.Contains(url, "..") {
|
||||
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
|
||||
}
|
||||
if strings.Contains(url, "*") {
|
||||
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate revocation URL if set
|
||||
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
|
||||
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate end session URL if set
|
||||
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
|
||||
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate post-logout redirect URI if set
|
||||
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
|
||||
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
|
||||
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate rate limit
|
||||
if c.RateLimit < MinRateLimit {
|
||||
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidSecureURL checks if the provided string is a valid HTTPS URL
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level is valid
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
|
||||
// Logger provides structured logging capabilities with different severity levels.
|
||||
// It supports error, info, and debug levels with appropriate output streams
|
||||
// and formatting for each level.
|
||||
type Logger struct {
|
||||
// logError handles error-level messages, writing to stderr
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
// logInfo handles informational messages, writing to stdout
|
||||
logInfo *log.Logger
|
||||
// logDebug handles debug-level messages, writing to stdout when debug is enabled
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger with the specified log level.
|
||||
// The log level determines which messages are output:
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
//
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
|
||||
logError.SetOutput(os.Stderr)
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
|
||||
if logLevel == "debug" || logLevel == "info" {
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
}
|
||||
if logLevel == "debug" {
|
||||
logDebug.SetOutput(os.Stdout)
|
||||
}
|
||||
@@ -94,31 +245,52 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an informational message.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an informational message using Printf formatting.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message using Printf formatting.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message using Printf formatting.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to both the HTTP response and the error log.
|
||||
// It ensures consistent error handling across the middleware by logging the error
|
||||
// and sending an appropriate HTTP response to the client.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Errorf(message)
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,397 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("Default Values", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
|
||||
// Check default scopes
|
||||
expectedScopes := []string{"openid", "profile", "email"}
|
||||
if len(config.Scopes) != len(expectedScopes) {
|
||||
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
|
||||
}
|
||||
for i, scope := range expectedScopes {
|
||||
if config.Scopes[i] != scope {
|
||||
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Check default log level
|
||||
if config.LogLevel != DefaultLogLevel {
|
||||
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
|
||||
}
|
||||
|
||||
// Check default rate limit
|
||||
if config.RateLimit != DefaultRateLimit {
|
||||
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
|
||||
}
|
||||
|
||||
// Check ForceHTTPS default
|
||||
if !config.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true by default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Custom Values Preserved", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.Scopes = []string{"custom_scope"}
|
||||
config.LogLevel = "debug"
|
||||
config.RateLimit = 50
|
||||
config.ForceHTTPS = false
|
||||
|
||||
// Verify custom values are not overwritten
|
||||
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
|
||||
t.Error("Custom scopes were overwritten")
|
||||
}
|
||||
if config.LogLevel != "debug" {
|
||||
t.Error("Custom log level was overwritten")
|
||||
}
|
||||
if config.RateLimit != 50 {
|
||||
t.Error("Custom rate limit was overwritten")
|
||||
}
|
||||
if config.ForceHTTPS {
|
||||
t.Error("Custom ForceHTTPS value was overwritten")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Empty Config",
|
||||
config: &Config{},
|
||||
expectedError: "providerURL is required",
|
||||
},
|
||||
{
|
||||
name: "Missing CallbackURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
},
|
||||
expectedError: "callbackURL is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientID",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
},
|
||||
expectedError: "clientID is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientSecret",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedError: "clientSecret is required",
|
||||
},
|
||||
{
|
||||
name: "Missing SessionEncryptionKey",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
},
|
||||
expectedError: "sessionEncryptionKey is required",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS ProviderURL",
|
||||
config: &Config{
|
||||
ProviderURL: "http://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "encryption-key",
|
||||
},
|
||||
expectedError: "providerURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid CallbackURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "callback", // Missing leading slash
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "encryption-key",
|
||||
},
|
||||
expectedError: "callbackURL must start with /",
|
||||
},
|
||||
{
|
||||
name: "Short SessionEncryptionKey",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "short",
|
||||
},
|
||||
expectedError: "sessionEncryptionKey must be at least 32 characters long",
|
||||
},
|
||||
{
|
||||
name: "Low RateLimit",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RateLimit: 5,
|
||||
},
|
||||
expectedError: "rateLimit must be at least 10",
|
||||
},
|
||||
{
|
||||
name: "Invalid LogLevel",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "invalid",
|
||||
},
|
||||
expectedError: "logLevel must be one of: debug, info, error",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS RevocationURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RevocationURL: "http://revoke.com",
|
||||
},
|
||||
expectedError: "revocationURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS OIDCEndSessionURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
OIDCEndSessionURL: "http://endsession.com",
|
||||
},
|
||||
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Valid Config",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "debug",
|
||||
RateLimit: 100,
|
||||
RevocationURL: "https://revoke.com",
|
||||
OIDCEndSessionURL: "https://endsession.com",
|
||||
},
|
||||
expectedError: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.config.Validate()
|
||||
if tc.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
|
||||
} else if err.Error() != tc.expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
// Capture log output
|
||||
var debugBuf, infoBuf, errorBuf bytes.Buffer
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logLevel string
|
||||
testFunc func(*Logger)
|
||||
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
|
||||
}{
|
||||
{
|
||||
name: "Debug Level",
|
||||
logLevel: "debug",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut == "" {
|
||||
t.Error("Expected debug message in output")
|
||||
}
|
||||
if infoOut == "" {
|
||||
t.Error("Expected info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Info Level",
|
||||
logLevel: "info",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut != "" {
|
||||
t.Error("Did not expect debug message in output")
|
||||
}
|
||||
if infoOut == "" {
|
||||
t.Error("Expected info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error Level",
|
||||
logLevel: "error",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut != "" {
|
||||
t.Error("Did not expect debug message in output")
|
||||
}
|
||||
if infoOut != "" {
|
||||
t.Error("Did not expect info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Printf Methods",
|
||||
logLevel: "debug",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debugf("debug %s", "formatted")
|
||||
l.Infof("info %s", "formatted")
|
||||
l.Errorf("error %s", "formatted")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
|
||||
t.Error("Expected formatted debug message")
|
||||
}
|
||||
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
|
||||
t.Error("Expected formatted info message")
|
||||
}
|
||||
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
|
||||
t.Error("Expected formatted error message")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset buffers
|
||||
debugBuf.Reset()
|
||||
infoBuf.Reset()
|
||||
errorBuf.Reset()
|
||||
|
||||
// Create logger with test buffers
|
||||
logger := NewLogger(tc.logLevel)
|
||||
logger.logError.SetOutput(&errorBuf)
|
||||
|
||||
if tc.logLevel == "debug" || tc.logLevel == "info" {
|
||||
logger.logInfo.SetOutput(&infoBuf)
|
||||
}
|
||||
if tc.logLevel == "debug" {
|
||||
logger.logDebug.SetOutput(&debugBuf)
|
||||
}
|
||||
|
||||
// Run test
|
||||
tc.testFunc(logger)
|
||||
|
||||
// Check results
|
||||
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleError(t *testing.T) {
|
||||
// Create a test logger with captured output
|
||||
var errorBuf bytes.Buffer
|
||||
logger := &Logger{
|
||||
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
|
||||
}
|
||||
logger.logError.SetOutput(&errorBuf)
|
||||
|
||||
// Create a test response recorder
|
||||
rr := &testResponseRecorder{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
message := "test error message"
|
||||
code := 400
|
||||
handleError(rr, message, code, logger)
|
||||
|
||||
// Check response code
|
||||
if rr.statusCode != code {
|
||||
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
|
||||
}
|
||||
|
||||
// Check response body
|
||||
expectedBody := message + "\n"
|
||||
if rr.body != expectedBody {
|
||||
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
|
||||
}
|
||||
|
||||
// Check error was logged
|
||||
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
|
||||
t.Error("Error message was not logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper types
|
||||
type testResponseRecorder struct {
|
||||
statusCode int
|
||||
body string
|
||||
headers map[string][]string
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Header() http.Header {
|
||||
return r.headers
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Write(b []byte) (int, error) {
|
||||
r.body = string(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) WriteHeader(code int) {
|
||||
r.statusCode = code
|
||||
}
|
||||
-15
@@ -1,15 +0,0 @@
|
||||
ISC License
|
||||
|
||||
Copyright (c) 2012-2016 Dave Collins <dave@davec.name>
|
||||
|
||||
Permission to use, copy, modify, and/or distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
-145
@@ -1,145 +0,0 @@
|
||||
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when the code is not running on Google App Engine, compiled by GopherJS, and
|
||||
// "-tags safe" is not added to the go build command line. The "disableunsafe"
|
||||
// tag is deprecated and thus should not be used.
|
||||
// Go versions prior to 1.4 are disabled because they use a different layout
|
||||
// for interfaces which make the implementation of unsafeReflectValue more complex.
|
||||
// +build !js,!appengine,!safe,!disableunsafe,go1.4
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnsafeDisabled is a build-time constant which specifies whether or
|
||||
// not access to the unsafe package is available.
|
||||
UnsafeDisabled = false
|
||||
|
||||
// ptrSize is the size of a pointer on the current arch.
|
||||
ptrSize = unsafe.Sizeof((*byte)(nil))
|
||||
)
|
||||
|
||||
type flag uintptr
|
||||
|
||||
var (
|
||||
// flagRO indicates whether the value field of a reflect.Value
|
||||
// is read-only.
|
||||
flagRO flag
|
||||
|
||||
// flagAddr indicates whether the address of the reflect.Value's
|
||||
// value may be taken.
|
||||
flagAddr flag
|
||||
)
|
||||
|
||||
// flagKindMask holds the bits that make up the kind
|
||||
// part of the flags field. In all the supported versions,
|
||||
// it is in the lower 5 bits.
|
||||
const flagKindMask = flag(0x1f)
|
||||
|
||||
// Different versions of Go have used different
|
||||
// bit layouts for the flags type. This table
|
||||
// records the known combinations.
|
||||
var okFlags = []struct {
|
||||
ro, addr flag
|
||||
}{{
|
||||
// From Go 1.4 to 1.5
|
||||
ro: 1 << 5,
|
||||
addr: 1 << 7,
|
||||
}, {
|
||||
// Up to Go tip.
|
||||
ro: 1<<5 | 1<<6,
|
||||
addr: 1 << 8,
|
||||
}}
|
||||
|
||||
var flagValOffset = func() uintptr {
|
||||
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
|
||||
if !ok {
|
||||
panic("reflect.Value has no flag field")
|
||||
}
|
||||
return field.Offset
|
||||
}()
|
||||
|
||||
// flagField returns a pointer to the flag field of a reflect.Value.
|
||||
func flagField(v *reflect.Value) *flag {
|
||||
return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset))
|
||||
}
|
||||
|
||||
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
|
||||
// the typical safety restrictions preventing access to unaddressable and
|
||||
// unexported data. It works by digging the raw pointer to the underlying
|
||||
// value out of the protected value and generating a new unprotected (unsafe)
|
||||
// reflect.Value to it.
|
||||
//
|
||||
// This allows us to check for implementations of the Stringer and error
|
||||
// interfaces to be used for pretty printing ordinarily unaddressable and
|
||||
// inaccessible values such as unexported struct fields.
|
||||
func unsafeReflectValue(v reflect.Value) reflect.Value {
|
||||
if !v.IsValid() || (v.CanInterface() && v.CanAddr()) {
|
||||
return v
|
||||
}
|
||||
flagFieldPtr := flagField(&v)
|
||||
*flagFieldPtr &^= flagRO
|
||||
*flagFieldPtr |= flagAddr
|
||||
return v
|
||||
}
|
||||
|
||||
// Sanity checks against future reflect package changes
|
||||
// to the type or semantics of the Value.flag field.
|
||||
func init() {
|
||||
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
|
||||
if !ok {
|
||||
panic("reflect.Value has no flag field")
|
||||
}
|
||||
if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() {
|
||||
panic("reflect.Value flag field has changed kind")
|
||||
}
|
||||
type t0 int
|
||||
var t struct {
|
||||
A t0
|
||||
// t0 will have flagEmbedRO set.
|
||||
t0
|
||||
// a will have flagStickyRO set
|
||||
a t0
|
||||
}
|
||||
vA := reflect.ValueOf(t).FieldByName("A")
|
||||
va := reflect.ValueOf(t).FieldByName("a")
|
||||
vt0 := reflect.ValueOf(t).FieldByName("t0")
|
||||
|
||||
// Infer flagRO from the difference between the flags
|
||||
// for the (otherwise identical) fields in t.
|
||||
flagPublic := *flagField(&vA)
|
||||
flagWithRO := *flagField(&va) | *flagField(&vt0)
|
||||
flagRO = flagPublic ^ flagWithRO
|
||||
|
||||
// Infer flagAddr from the difference between a value
|
||||
// taken from a pointer and not.
|
||||
vPtrA := reflect.ValueOf(&t).Elem().FieldByName("A")
|
||||
flagNoPtr := *flagField(&vA)
|
||||
flagPtr := *flagField(&vPtrA)
|
||||
flagAddr = flagNoPtr ^ flagPtr
|
||||
|
||||
// Check that the inferred flags tally with one of the known versions.
|
||||
for _, f := range okFlags {
|
||||
if flagRO == f.ro && flagAddr == f.addr {
|
||||
return
|
||||
}
|
||||
}
|
||||
panic("reflect.Value read-only flag has changed semantics")
|
||||
}
|
||||
-38
@@ -1,38 +0,0 @@
|
||||
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when the code is running on Google App Engine, compiled by GopherJS, or
|
||||
// "-tags safe" is added to the go build command line. The "disableunsafe"
|
||||
// tag is deprecated and thus should not be used.
|
||||
// +build js appengine safe disableunsafe !go1.4
|
||||
|
||||
package spew
|
||||
|
||||
import "reflect"
|
||||
|
||||
const (
|
||||
// UnsafeDisabled is a build-time constant which specifies whether or
|
||||
// not access to the unsafe package is available.
|
||||
UnsafeDisabled = true
|
||||
)
|
||||
|
||||
// unsafeReflectValue typically converts the passed reflect.Value into a one
|
||||
// that bypasses the typical safety restrictions preventing access to
|
||||
// unaddressable and unexported data. However, doing this relies on access to
|
||||
// the unsafe package. This is a stub version which simply returns the passed
|
||||
// reflect.Value when the unsafe package is not available.
|
||||
func unsafeReflectValue(v reflect.Value) reflect.Value {
|
||||
return v
|
||||
}
|
||||
-341
@@ -1,341 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Some constants in the form of bytes to avoid string overhead. This mirrors
|
||||
// the technique used in the fmt package.
|
||||
var (
|
||||
panicBytes = []byte("(PANIC=")
|
||||
plusBytes = []byte("+")
|
||||
iBytes = []byte("i")
|
||||
trueBytes = []byte("true")
|
||||
falseBytes = []byte("false")
|
||||
interfaceBytes = []byte("(interface {})")
|
||||
commaNewlineBytes = []byte(",\n")
|
||||
newlineBytes = []byte("\n")
|
||||
openBraceBytes = []byte("{")
|
||||
openBraceNewlineBytes = []byte("{\n")
|
||||
closeBraceBytes = []byte("}")
|
||||
asteriskBytes = []byte("*")
|
||||
colonBytes = []byte(":")
|
||||
colonSpaceBytes = []byte(": ")
|
||||
openParenBytes = []byte("(")
|
||||
closeParenBytes = []byte(")")
|
||||
spaceBytes = []byte(" ")
|
||||
pointerChainBytes = []byte("->")
|
||||
nilAngleBytes = []byte("<nil>")
|
||||
maxNewlineBytes = []byte("<max depth reached>\n")
|
||||
maxShortBytes = []byte("<max>")
|
||||
circularBytes = []byte("<already shown>")
|
||||
circularShortBytes = []byte("<shown>")
|
||||
invalidAngleBytes = []byte("<invalid>")
|
||||
openBracketBytes = []byte("[")
|
||||
closeBracketBytes = []byte("]")
|
||||
percentBytes = []byte("%")
|
||||
precisionBytes = []byte(".")
|
||||
openAngleBytes = []byte("<")
|
||||
closeAngleBytes = []byte(">")
|
||||
openMapBytes = []byte("map[")
|
||||
closeMapBytes = []byte("]")
|
||||
lenEqualsBytes = []byte("len=")
|
||||
capEqualsBytes = []byte("cap=")
|
||||
)
|
||||
|
||||
// hexDigits is used to map a decimal value to a hex digit.
|
||||
var hexDigits = "0123456789abcdef"
|
||||
|
||||
// catchPanic handles any panics that might occur during the handleMethods
|
||||
// calls.
|
||||
func catchPanic(w io.Writer, v reflect.Value) {
|
||||
if err := recover(); err != nil {
|
||||
w.Write(panicBytes)
|
||||
fmt.Fprintf(w, "%v", err)
|
||||
w.Write(closeParenBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// handleMethods attempts to call the Error and String methods on the underlying
|
||||
// type the passed reflect.Value represents and outputes the result to Writer w.
|
||||
//
|
||||
// It handles panics in any called methods by catching and displaying the error
|
||||
// as the formatted value.
|
||||
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
|
||||
// We need an interface to check if the type implements the error or
|
||||
// Stringer interface. However, the reflect package won't give us an
|
||||
// interface on certain things like unexported struct fields in order
|
||||
// to enforce visibility rules. We use unsafe, when it's available,
|
||||
// to bypass these restrictions since this package does not mutate the
|
||||
// values.
|
||||
if !v.CanInterface() {
|
||||
if UnsafeDisabled {
|
||||
return false
|
||||
}
|
||||
|
||||
v = unsafeReflectValue(v)
|
||||
}
|
||||
|
||||
// Choose whether or not to do error and Stringer interface lookups against
|
||||
// the base type or a pointer to the base type depending on settings.
|
||||
// Technically calling one of these methods with a pointer receiver can
|
||||
// mutate the value, however, types which choose to satisify an error or
|
||||
// Stringer interface with a pointer receiver should not be mutating their
|
||||
// state inside these interface methods.
|
||||
if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() {
|
||||
v = unsafeReflectValue(v)
|
||||
}
|
||||
if v.CanAddr() {
|
||||
v = v.Addr()
|
||||
}
|
||||
|
||||
// Is it an error or Stringer?
|
||||
switch iface := v.Interface().(type) {
|
||||
case error:
|
||||
defer catchPanic(w, v)
|
||||
if cs.ContinueOnMethod {
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(iface.Error()))
|
||||
w.Write(closeParenBytes)
|
||||
w.Write(spaceBytes)
|
||||
return false
|
||||
}
|
||||
|
||||
w.Write([]byte(iface.Error()))
|
||||
return true
|
||||
|
||||
case fmt.Stringer:
|
||||
defer catchPanic(w, v)
|
||||
if cs.ContinueOnMethod {
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(iface.String()))
|
||||
w.Write(closeParenBytes)
|
||||
w.Write(spaceBytes)
|
||||
return false
|
||||
}
|
||||
w.Write([]byte(iface.String()))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// printBool outputs a boolean value as true or false to Writer w.
|
||||
func printBool(w io.Writer, val bool) {
|
||||
if val {
|
||||
w.Write(trueBytes)
|
||||
} else {
|
||||
w.Write(falseBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// printInt outputs a signed integer value to Writer w.
|
||||
func printInt(w io.Writer, val int64, base int) {
|
||||
w.Write([]byte(strconv.FormatInt(val, base)))
|
||||
}
|
||||
|
||||
// printUint outputs an unsigned integer value to Writer w.
|
||||
func printUint(w io.Writer, val uint64, base int) {
|
||||
w.Write([]byte(strconv.FormatUint(val, base)))
|
||||
}
|
||||
|
||||
// printFloat outputs a floating point value using the specified precision,
|
||||
// which is expected to be 32 or 64bit, to Writer w.
|
||||
func printFloat(w io.Writer, val float64, precision int) {
|
||||
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
|
||||
}
|
||||
|
||||
// printComplex outputs a complex value using the specified float precision
|
||||
// for the real and imaginary parts to Writer w.
|
||||
func printComplex(w io.Writer, c complex128, floatPrecision int) {
|
||||
r := real(c)
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
|
||||
i := imag(c)
|
||||
if i >= 0 {
|
||||
w.Write(plusBytes)
|
||||
}
|
||||
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
|
||||
w.Write(iBytes)
|
||||
w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// printHexPtr outputs a uintptr formatted as hexadecimal with a leading '0x'
|
||||
// prefix to Writer w.
|
||||
func printHexPtr(w io.Writer, p uintptr) {
|
||||
// Null pointer.
|
||||
num := uint64(p)
|
||||
if num == 0 {
|
||||
w.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
|
||||
buf := make([]byte, 18)
|
||||
|
||||
// It's simpler to construct the hex string right to left.
|
||||
base := uint64(16)
|
||||
i := len(buf) - 1
|
||||
for num >= base {
|
||||
buf[i] = hexDigits[num%base]
|
||||
num /= base
|
||||
i--
|
||||
}
|
||||
buf[i] = hexDigits[num]
|
||||
|
||||
// Add '0x' prefix.
|
||||
i--
|
||||
buf[i] = 'x'
|
||||
i--
|
||||
buf[i] = '0'
|
||||
|
||||
// Strip unused leading bytes.
|
||||
buf = buf[i:]
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
|
||||
// elements to be sorted.
|
||||
type valuesSorter struct {
|
||||
values []reflect.Value
|
||||
strings []string // either nil or same len and values
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// newValuesSorter initializes a valuesSorter instance, which holds a set of
|
||||
// surrogate keys on which the data should be sorted. It uses flags in
|
||||
// ConfigState to decide if and how to populate those surrogate keys.
|
||||
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
|
||||
vs := &valuesSorter{values: values, cs: cs}
|
||||
if canSortSimply(vs.values[0].Kind()) {
|
||||
return vs
|
||||
}
|
||||
if !cs.DisableMethods {
|
||||
vs.strings = make([]string, len(values))
|
||||
for i := range vs.values {
|
||||
b := bytes.Buffer{}
|
||||
if !handleMethods(cs, &b, vs.values[i]) {
|
||||
vs.strings = nil
|
||||
break
|
||||
}
|
||||
vs.strings[i] = b.String()
|
||||
}
|
||||
}
|
||||
if vs.strings == nil && cs.SpewKeys {
|
||||
vs.strings = make([]string, len(values))
|
||||
for i := range vs.values {
|
||||
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
|
||||
}
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
|
||||
// directly, or whether it should be considered for sorting by surrogate keys
|
||||
// (if the ConfigState allows it).
|
||||
func canSortSimply(kind reflect.Kind) bool {
|
||||
// This switch parallels valueSortLess, except for the default case.
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
return true
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return true
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
case reflect.String:
|
||||
return true
|
||||
case reflect.Uintptr:
|
||||
return true
|
||||
case reflect.Array:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Len returns the number of values in the slice. It is part of the
|
||||
// sort.Interface implementation.
|
||||
func (s *valuesSorter) Len() int {
|
||||
return len(s.values)
|
||||
}
|
||||
|
||||
// Swap swaps the values at the passed indices. It is part of the
|
||||
// sort.Interface implementation.
|
||||
func (s *valuesSorter) Swap(i, j int) {
|
||||
s.values[i], s.values[j] = s.values[j], s.values[i]
|
||||
if s.strings != nil {
|
||||
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
|
||||
}
|
||||
}
|
||||
|
||||
// valueSortLess returns whether the first value should sort before the second
|
||||
// value. It is used by valueSorter.Less as part of the sort.Interface
|
||||
// implementation.
|
||||
func valueSortLess(a, b reflect.Value) bool {
|
||||
switch a.Kind() {
|
||||
case reflect.Bool:
|
||||
return !a.Bool() && b.Bool()
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return a.Int() < b.Int()
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return a.Uint() < b.Uint()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return a.Float() < b.Float()
|
||||
case reflect.String:
|
||||
return a.String() < b.String()
|
||||
case reflect.Uintptr:
|
||||
return a.Uint() < b.Uint()
|
||||
case reflect.Array:
|
||||
// Compare the contents of both arrays.
|
||||
l := a.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
av := a.Index(i)
|
||||
bv := b.Index(i)
|
||||
if av.Interface() == bv.Interface() {
|
||||
continue
|
||||
}
|
||||
return valueSortLess(av, bv)
|
||||
}
|
||||
}
|
||||
return a.String() < b.String()
|
||||
}
|
||||
|
||||
// Less returns whether the value at index i should sort before the
|
||||
// value at index j. It is part of the sort.Interface implementation.
|
||||
func (s *valuesSorter) Less(i, j int) bool {
|
||||
if s.strings == nil {
|
||||
return valueSortLess(s.values[i], s.values[j])
|
||||
}
|
||||
return s.strings[i] < s.strings[j]
|
||||
}
|
||||
|
||||
// sortValues is a sort function that handles both native types and any type that
|
||||
// can be converted to error or Stringer. Other inputs are sorted according to
|
||||
// their Value.String() value to ensure display stability.
|
||||
func sortValues(values []reflect.Value, cs *ConfigState) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
sort.Sort(newValuesSorter(values, cs))
|
||||
}
|
||||
-306
@@ -1,306 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ConfigState houses the configuration options used by spew to format and
|
||||
// display values. There is a global instance, Config, that is used to control
|
||||
// all top-level Formatter and Dump functionality. Each ConfigState instance
|
||||
// provides methods equivalent to the top-level functions.
|
||||
//
|
||||
// The zero value for ConfigState provides no indentation. You would typically
|
||||
// want to set it to a space or a tab.
|
||||
//
|
||||
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
|
||||
// with default settings. See the documentation of NewDefaultConfig for default
|
||||
// values.
|
||||
type ConfigState struct {
|
||||
// Indent specifies the string to use for each indentation level. The
|
||||
// global config instance that all top-level functions use set this to a
|
||||
// single space by default. If you would like more indentation, you might
|
||||
// set this to a tab with "\t" or perhaps two spaces with " ".
|
||||
Indent string
|
||||
|
||||
// MaxDepth controls the maximum number of levels to descend into nested
|
||||
// data structures. The default, 0, means there is no limit.
|
||||
//
|
||||
// NOTE: Circular data structures are properly detected, so it is not
|
||||
// necessary to set this value unless you specifically want to limit deeply
|
||||
// nested data structures.
|
||||
MaxDepth int
|
||||
|
||||
// DisableMethods specifies whether or not error and Stringer interfaces are
|
||||
// invoked for types that implement them.
|
||||
DisableMethods bool
|
||||
|
||||
// DisablePointerMethods specifies whether or not to check for and invoke
|
||||
// error and Stringer interfaces on types which only accept a pointer
|
||||
// receiver when the current type is not a pointer.
|
||||
//
|
||||
// NOTE: This might be an unsafe action since calling one of these methods
|
||||
// with a pointer receiver could technically mutate the value, however,
|
||||
// in practice, types which choose to satisify an error or Stringer
|
||||
// interface with a pointer receiver should not be mutating their state
|
||||
// inside these interface methods. As a result, this option relies on
|
||||
// access to the unsafe package, so it will not have any effect when
|
||||
// running in environments without access to the unsafe package such as
|
||||
// Google App Engine or with the "safe" build tag specified.
|
||||
DisablePointerMethods bool
|
||||
|
||||
// DisablePointerAddresses specifies whether to disable the printing of
|
||||
// pointer addresses. This is useful when diffing data structures in tests.
|
||||
DisablePointerAddresses bool
|
||||
|
||||
// DisableCapacities specifies whether to disable the printing of capacities
|
||||
// for arrays, slices, maps and channels. This is useful when diffing
|
||||
// data structures in tests.
|
||||
DisableCapacities bool
|
||||
|
||||
// ContinueOnMethod specifies whether or not recursion should continue once
|
||||
// a custom error or Stringer interface is invoked. The default, false,
|
||||
// means it will print the results of invoking the custom error or Stringer
|
||||
// interface and return immediately instead of continuing to recurse into
|
||||
// the internals of the data type.
|
||||
//
|
||||
// NOTE: This flag does not have any effect if method invocation is disabled
|
||||
// via the DisableMethods or DisablePointerMethods options.
|
||||
ContinueOnMethod bool
|
||||
|
||||
// SortKeys specifies map keys should be sorted before being printed. Use
|
||||
// this to have a more deterministic, diffable output. Note that only
|
||||
// native types (bool, int, uint, floats, uintptr and string) and types
|
||||
// that support the error or Stringer interfaces (if methods are
|
||||
// enabled) are supported, with other types sorted according to the
|
||||
// reflect.Value.String() output which guarantees display stability.
|
||||
SortKeys bool
|
||||
|
||||
// SpewKeys specifies that, as a last resort attempt, map keys should
|
||||
// be spewed to strings and sorted by those strings. This is only
|
||||
// considered if SortKeys is true.
|
||||
SpewKeys bool
|
||||
}
|
||||
|
||||
// Config is the active configuration of the top-level functions.
|
||||
// The configuration can be changed by modifying the contents of spew.Config.
|
||||
var Config = ConfigState{Indent: " "}
|
||||
|
||||
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the formatted string as a value that satisfies error. See NewFormatter
|
||||
// for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
|
||||
return fmt.Errorf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprint(w, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintf(w, format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||
// passed with a Formatter interface returned by c.NewFormatter. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintln(w, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
|
||||
return fmt.Print(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Printf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
|
||||
return fmt.Println(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprint(a ...interface{}) string {
|
||||
return fmt.Sprint(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
|
||||
return fmt.Sprintf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||
// were passed with a Formatter interface returned by c.NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprintln(a ...interface{}) string {
|
||||
return fmt.Sprintln(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
/*
|
||||
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||
interface. As a result, it integrates cleanly with standard fmt package
|
||||
printing functions. The formatter is useful for inline printing of smaller data
|
||||
types similar to the standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Typically this function shouldn't be called directly. It is much easier to make
|
||||
use of the custom formatter by calling one of the convenience functions such as
|
||||
c.Printf, c.Println, or c.Printf.
|
||||
*/
|
||||
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
|
||||
return newFormatter(c, v)
|
||||
}
|
||||
|
||||
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||
// exactly the same as Dump.
|
||||
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
|
||||
fdump(c, w, a...)
|
||||
}
|
||||
|
||||
/*
|
||||
Dump displays the passed parameters to standard out with newlines, customizable
|
||||
indentation, and additional debug information such as complete types and all
|
||||
pointer addresses used to indirect to the final value. It provides the
|
||||
following features over the built-in printing facilities provided by the fmt
|
||||
package:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output
|
||||
|
||||
The configuration options are controlled by modifying the public members
|
||||
of c. See ConfigState for options documentation.
|
||||
|
||||
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||
get the formatted result as a string.
|
||||
*/
|
||||
func (c *ConfigState) Dump(a ...interface{}) {
|
||||
fdump(c, os.Stdout, a...)
|
||||
}
|
||||
|
||||
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||
// as Dump.
|
||||
func (c *ConfigState) Sdump(a ...interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
fdump(c, &buf, a...)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||
// length with each argument converted to a spew Formatter interface using
|
||||
// the ConfigState associated with s.
|
||||
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
|
||||
formatters = make([]interface{}, len(args))
|
||||
for index, arg := range args {
|
||||
formatters[index] = newFormatter(c, arg)
|
||||
}
|
||||
return formatters
|
||||
}
|
||||
|
||||
// NewDefaultConfig returns a ConfigState with the following default settings.
|
||||
//
|
||||
// Indent: " "
|
||||
// MaxDepth: 0
|
||||
// DisableMethods: false
|
||||
// DisablePointerMethods: false
|
||||
// ContinueOnMethod: false
|
||||
// SortKeys: false
|
||||
func NewDefaultConfig() *ConfigState {
|
||||
return &ConfigState{Indent: " "}
|
||||
}
|
||||
-211
@@ -1,211 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
/*
|
||||
Package spew implements a deep pretty printer for Go data structures to aid in
|
||||
debugging.
|
||||
|
||||
A quick overview of the additional features spew provides over the built-in
|
||||
printing facilities for Go data types are as follows:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output (only when using
|
||||
Dump style)
|
||||
|
||||
There are two different approaches spew allows for dumping Go data structures:
|
||||
|
||||
* Dump style which prints with newlines, customizable indentation,
|
||||
and additional debug information such as types and all pointer addresses
|
||||
used to indirect to the final value
|
||||
* A custom Formatter interface that integrates cleanly with the standard fmt
|
||||
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
|
||||
similar to the default %v while providing the additional functionality
|
||||
outlined above and passing unsupported format verbs such as %x and %q
|
||||
along to fmt
|
||||
|
||||
Quick Start
|
||||
|
||||
This section demonstrates how to quickly get started with spew. See the
|
||||
sections below for further details on formatting and configuration options.
|
||||
|
||||
To dump a variable with full newlines, indentation, type, and pointer
|
||||
information use Dump, Fdump, or Sdump:
|
||||
spew.Dump(myVar1, myVar2, ...)
|
||||
spew.Fdump(someWriter, myVar1, myVar2, ...)
|
||||
str := spew.Sdump(myVar1, myVar2, ...)
|
||||
|
||||
Alternatively, if you would prefer to use format strings with a compacted inline
|
||||
printing style, use the convenience wrappers Printf, Fprintf, etc with
|
||||
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
|
||||
%#+v (adds types and pointer addresses):
|
||||
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
|
||||
Configuration Options
|
||||
|
||||
Configuration of spew is handled by fields in the ConfigState type. For
|
||||
convenience, all of the top-level functions use a global state available
|
||||
via the spew.Config global.
|
||||
|
||||
It is also possible to create a ConfigState instance that provides methods
|
||||
equivalent to the top-level functions. This allows concurrent configuration
|
||||
options. See the ConfigState documentation for more details.
|
||||
|
||||
The following configuration options are available:
|
||||
* Indent
|
||||
String to use for each indentation level for Dump functions.
|
||||
It is a single space by default. A popular alternative is "\t".
|
||||
|
||||
* MaxDepth
|
||||
Maximum number of levels to descend into nested data structures.
|
||||
There is no limit by default.
|
||||
|
||||
* DisableMethods
|
||||
Disables invocation of error and Stringer interface methods.
|
||||
Method invocation is enabled by default.
|
||||
|
||||
* DisablePointerMethods
|
||||
Disables invocation of error and Stringer interface methods on types
|
||||
which only accept pointer receivers from non-pointer variables.
|
||||
Pointer method invocation is enabled by default.
|
||||
|
||||
* DisablePointerAddresses
|
||||
DisablePointerAddresses specifies whether to disable the printing of
|
||||
pointer addresses. This is useful when diffing data structures in tests.
|
||||
|
||||
* DisableCapacities
|
||||
DisableCapacities specifies whether to disable the printing of
|
||||
capacities for arrays, slices, maps and channels. This is useful when
|
||||
diffing data structures in tests.
|
||||
|
||||
* ContinueOnMethod
|
||||
Enables recursion into types after invoking error and Stringer interface
|
||||
methods. Recursion after method invocation is disabled by default.
|
||||
|
||||
* SortKeys
|
||||
Specifies map keys should be sorted before being printed. Use
|
||||
this to have a more deterministic, diffable output. Note that
|
||||
only native types (bool, int, uint, floats, uintptr and string)
|
||||
and types which implement error or Stringer interfaces are
|
||||
supported with other types sorted according to the
|
||||
reflect.Value.String() output which guarantees display
|
||||
stability. Natural map order is used by default.
|
||||
|
||||
* SpewKeys
|
||||
Specifies that, as a last resort attempt, map keys should be
|
||||
spewed to strings and sorted by those strings. This is only
|
||||
considered if SortKeys is true.
|
||||
|
||||
Dump Usage
|
||||
|
||||
Simply call spew.Dump with a list of variables you want to dump:
|
||||
|
||||
spew.Dump(myVar1, myVar2, ...)
|
||||
|
||||
You may also call spew.Fdump if you would prefer to output to an arbitrary
|
||||
io.Writer. For example, to dump to standard error:
|
||||
|
||||
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
|
||||
|
||||
A third option is to call spew.Sdump to get the formatted output as a string:
|
||||
|
||||
str := spew.Sdump(myVar1, myVar2, ...)
|
||||
|
||||
Sample Dump Output
|
||||
|
||||
See the Dump example for details on the setup of the types and variables being
|
||||
shown here.
|
||||
|
||||
(main.Foo) {
|
||||
unexportedField: (*main.Bar)(0xf84002e210)({
|
||||
flag: (main.Flag) flagTwo,
|
||||
data: (uintptr) <nil>
|
||||
}),
|
||||
ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
(string) (len=3) "one": (bool) true
|
||||
}
|
||||
}
|
||||
|
||||
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
|
||||
command as shown.
|
||||
([]uint8) (len=32 cap=32) {
|
||||
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||
00000020 31 32 |12|
|
||||
}
|
||||
|
||||
Custom Formatter
|
||||
|
||||
Spew provides a custom formatter that implements the fmt.Formatter interface
|
||||
so that it integrates cleanly with standard fmt package printing functions. The
|
||||
formatter is useful for inline printing of smaller data types similar to the
|
||||
standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Custom Formatter Usage
|
||||
|
||||
The simplest way to make use of the spew custom formatter is to call one of the
|
||||
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
|
||||
functions have syntax you are most likely already familiar with:
|
||||
|
||||
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
spew.Println(myVar, myVar2)
|
||||
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
|
||||
See the Index for the full list convenience functions.
|
||||
|
||||
Sample Formatter Output
|
||||
|
||||
Double pointer to a uint8:
|
||||
%v: <**>5
|
||||
%+v: <**>(0xf8400420d0->0xf8400420c8)5
|
||||
%#v: (**uint8)5
|
||||
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
|
||||
|
||||
Pointer to circular struct with a uint8 field and a pointer to itself:
|
||||
%v: <*>{1 <*><shown>}
|
||||
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
|
||||
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
|
||||
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
|
||||
|
||||
See the Printf example for details on the setup of variables being shown
|
||||
here.
|
||||
|
||||
Errors
|
||||
|
||||
Since it is possible for custom Stringer/error interfaces to panic, spew
|
||||
detects them and handles them internally by printing the panic information
|
||||
inline with the output. Since spew is intended to provide deep pretty printing
|
||||
capabilities on structures, it intentionally does not return any errors.
|
||||
*/
|
||||
package spew
|
||||
-509
@@ -1,509 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// uint8Type is a reflect.Type representing a uint8. It is used to
|
||||
// convert cgo types to uint8 slices for hexdumping.
|
||||
uint8Type = reflect.TypeOf(uint8(0))
|
||||
|
||||
// cCharRE is a regular expression that matches a cgo char.
|
||||
// It is used to detect character arrays to hexdump them.
|
||||
cCharRE = regexp.MustCompile(`^.*\._Ctype_char$`)
|
||||
|
||||
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
|
||||
// char. It is used to detect unsigned character arrays to hexdump
|
||||
// them.
|
||||
cUnsignedCharRE = regexp.MustCompile(`^.*\._Ctype_unsignedchar$`)
|
||||
|
||||
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
|
||||
// It is used to detect uint8_t arrays to hexdump them.
|
||||
cUint8tCharRE = regexp.MustCompile(`^.*\._Ctype_uint8_t$`)
|
||||
)
|
||||
|
||||
// dumpState contains information about the state of a dump operation.
|
||||
type dumpState struct {
|
||||
w io.Writer
|
||||
depth int
|
||||
pointers map[uintptr]int
|
||||
ignoreNextType bool
|
||||
ignoreNextIndent bool
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// indent performs indentation according to the depth level and cs.Indent
|
||||
// option.
|
||||
func (d *dumpState) indent() {
|
||||
if d.ignoreNextIndent {
|
||||
d.ignoreNextIndent = false
|
||||
return
|
||||
}
|
||||
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
|
||||
}
|
||||
|
||||
// unpackValue returns values inside of non-nil interfaces when possible.
|
||||
// This is useful for data types like structs, arrays, slices, and maps which
|
||||
// can contain varying types packed inside an interface.
|
||||
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Interface && !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// dumpPtr handles formatting of pointers by indirecting them as necessary.
|
||||
func (d *dumpState) dumpPtr(v reflect.Value) {
|
||||
// Remove pointers at or below the current depth from map used to detect
|
||||
// circular refs.
|
||||
for k, depth := range d.pointers {
|
||||
if depth >= d.depth {
|
||||
delete(d.pointers, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep list of all dereferenced pointers to show later.
|
||||
pointerChain := make([]uintptr, 0)
|
||||
|
||||
// Figure out how many levels of indirection there are by dereferencing
|
||||
// pointers and unpacking interfaces down the chain while detecting circular
|
||||
// references.
|
||||
nilFound := false
|
||||
cycleFound := false
|
||||
indirects := 0
|
||||
ve := v
|
||||
for ve.Kind() == reflect.Ptr {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
indirects++
|
||||
addr := ve.Pointer()
|
||||
pointerChain = append(pointerChain, addr)
|
||||
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
|
||||
cycleFound = true
|
||||
indirects--
|
||||
break
|
||||
}
|
||||
d.pointers[addr] = d.depth
|
||||
|
||||
ve = ve.Elem()
|
||||
if ve.Kind() == reflect.Interface {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
ve = ve.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// Display type information.
|
||||
d.w.Write(openParenBytes)
|
||||
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||
d.w.Write([]byte(ve.Type().String()))
|
||||
d.w.Write(closeParenBytes)
|
||||
|
||||
// Display pointer information.
|
||||
if !d.cs.DisablePointerAddresses && len(pointerChain) > 0 {
|
||||
d.w.Write(openParenBytes)
|
||||
for i, addr := range pointerChain {
|
||||
if i > 0 {
|
||||
d.w.Write(pointerChainBytes)
|
||||
}
|
||||
printHexPtr(d.w, addr)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// Display dereferenced value.
|
||||
d.w.Write(openParenBytes)
|
||||
switch {
|
||||
case nilFound:
|
||||
d.w.Write(nilAngleBytes)
|
||||
|
||||
case cycleFound:
|
||||
d.w.Write(circularBytes)
|
||||
|
||||
default:
|
||||
d.ignoreNextType = true
|
||||
d.dump(ve)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
|
||||
// reflection) arrays and slices are dumped in hexdump -C fashion.
|
||||
func (d *dumpState) dumpSlice(v reflect.Value) {
|
||||
// Determine whether this type should be hex dumped or not. Also,
|
||||
// for types which should be hexdumped, try to use the underlying data
|
||||
// first, then fall back to trying to convert them to a uint8 slice.
|
||||
var buf []uint8
|
||||
doConvert := false
|
||||
doHexDump := false
|
||||
numEntries := v.Len()
|
||||
if numEntries > 0 {
|
||||
vt := v.Index(0).Type()
|
||||
vts := vt.String()
|
||||
switch {
|
||||
// C types that need to be converted.
|
||||
case cCharRE.MatchString(vts):
|
||||
fallthrough
|
||||
case cUnsignedCharRE.MatchString(vts):
|
||||
fallthrough
|
||||
case cUint8tCharRE.MatchString(vts):
|
||||
doConvert = true
|
||||
|
||||
// Try to use existing uint8 slices and fall back to converting
|
||||
// and copying if that fails.
|
||||
case vt.Kind() == reflect.Uint8:
|
||||
// We need an addressable interface to convert the type
|
||||
// to a byte slice. However, the reflect package won't
|
||||
// give us an interface on certain things like
|
||||
// unexported struct fields in order to enforce
|
||||
// visibility rules. We use unsafe, when available, to
|
||||
// bypass these restrictions since this package does not
|
||||
// mutate the values.
|
||||
vs := v
|
||||
if !vs.CanInterface() || !vs.CanAddr() {
|
||||
vs = unsafeReflectValue(vs)
|
||||
}
|
||||
if !UnsafeDisabled {
|
||||
vs = vs.Slice(0, numEntries)
|
||||
|
||||
// Use the existing uint8 slice if it can be
|
||||
// type asserted.
|
||||
iface := vs.Interface()
|
||||
if slice, ok := iface.([]uint8); ok {
|
||||
buf = slice
|
||||
doHexDump = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// The underlying data needs to be converted if it can't
|
||||
// be type asserted to a uint8 slice.
|
||||
doConvert = true
|
||||
}
|
||||
|
||||
// Copy and convert the underlying type if needed.
|
||||
if doConvert && vt.ConvertibleTo(uint8Type) {
|
||||
// Convert and copy each element into a uint8 byte
|
||||
// slice.
|
||||
buf = make([]uint8, numEntries)
|
||||
for i := 0; i < numEntries; i++ {
|
||||
vv := v.Index(i)
|
||||
buf[i] = uint8(vv.Convert(uint8Type).Uint())
|
||||
}
|
||||
doHexDump = true
|
||||
}
|
||||
}
|
||||
|
||||
// Hexdump the entire slice as needed.
|
||||
if doHexDump {
|
||||
indent := strings.Repeat(d.cs.Indent, d.depth)
|
||||
str := indent + hex.Dump(buf)
|
||||
str = strings.Replace(str, "\n", "\n"+indent, -1)
|
||||
str = strings.TrimRight(str, d.cs.Indent)
|
||||
d.w.Write([]byte(str))
|
||||
return
|
||||
}
|
||||
|
||||
// Recursively call dump for each item.
|
||||
for i := 0; i < numEntries; i++ {
|
||||
d.dump(d.unpackValue(v.Index(i)))
|
||||
if i < (numEntries - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dump is the main workhorse for dumping a value. It uses the passed reflect
|
||||
// value to figure out what kind of object we are dealing with and formats it
|
||||
// appropriately. It is a recursive function, however circular data structures
|
||||
// are detected and handled properly.
|
||||
func (d *dumpState) dump(v reflect.Value) {
|
||||
// Handle invalid reflect values immediately.
|
||||
kind := v.Kind()
|
||||
if kind == reflect.Invalid {
|
||||
d.w.Write(invalidAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle pointers specially.
|
||||
if kind == reflect.Ptr {
|
||||
d.indent()
|
||||
d.dumpPtr(v)
|
||||
return
|
||||
}
|
||||
|
||||
// Print type information unless already handled elsewhere.
|
||||
if !d.ignoreNextType {
|
||||
d.indent()
|
||||
d.w.Write(openParenBytes)
|
||||
d.w.Write([]byte(v.Type().String()))
|
||||
d.w.Write(closeParenBytes)
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
d.ignoreNextType = false
|
||||
|
||||
// Display length and capacity if the built-in len and cap functions
|
||||
// work with the value's kind and the len/cap itself is non-zero.
|
||||
valueLen, valueCap := 0, 0
|
||||
switch v.Kind() {
|
||||
case reflect.Array, reflect.Slice, reflect.Chan:
|
||||
valueLen, valueCap = v.Len(), v.Cap()
|
||||
case reflect.Map, reflect.String:
|
||||
valueLen = v.Len()
|
||||
}
|
||||
if valueLen != 0 || !d.cs.DisableCapacities && valueCap != 0 {
|
||||
d.w.Write(openParenBytes)
|
||||
if valueLen != 0 {
|
||||
d.w.Write(lenEqualsBytes)
|
||||
printInt(d.w, int64(valueLen), 10)
|
||||
}
|
||||
if !d.cs.DisableCapacities && valueCap != 0 {
|
||||
if valueLen != 0 {
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
d.w.Write(capEqualsBytes)
|
||||
printInt(d.w, int64(valueCap), 10)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
|
||||
// Call Stringer/error interfaces if they exist and the handle methods flag
|
||||
// is enabled
|
||||
if !d.cs.DisableMethods {
|
||||
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||
if handled := handleMethods(d.cs, d.w, v); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Invalid:
|
||||
// Do nothing. We should never get here since invalid has already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Bool:
|
||||
printBool(d.w, v.Bool())
|
||||
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
printInt(d.w, v.Int(), 10)
|
||||
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
printUint(d.w, v.Uint(), 10)
|
||||
|
||||
case reflect.Float32:
|
||||
printFloat(d.w, v.Float(), 32)
|
||||
|
||||
case reflect.Float64:
|
||||
printFloat(d.w, v.Float(), 64)
|
||||
|
||||
case reflect.Complex64:
|
||||
printComplex(d.w, v.Complex(), 32)
|
||||
|
||||
case reflect.Complex128:
|
||||
printComplex(d.w, v.Complex(), 64)
|
||||
|
||||
case reflect.Slice:
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Array:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
d.dumpSlice(v)
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.String:
|
||||
d.w.Write([]byte(strconv.Quote(v.String())))
|
||||
|
||||
case reflect.Interface:
|
||||
// The only time we should get here is for nil interfaces due to
|
||||
// unpackValue calls.
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// Do nothing. We should never get here since pointers have already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Map:
|
||||
// nil maps should be indicated as different than empty maps
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
numEntries := v.Len()
|
||||
keys := v.MapKeys()
|
||||
if d.cs.SortKeys {
|
||||
sortValues(keys, d.cs)
|
||||
}
|
||||
for i, key := range keys {
|
||||
d.dump(d.unpackValue(key))
|
||||
d.w.Write(colonSpaceBytes)
|
||||
d.ignoreNextIndent = true
|
||||
d.dump(d.unpackValue(v.MapIndex(key)))
|
||||
if i < (numEntries - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Struct:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
vt := v.Type()
|
||||
numFields := v.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
d.indent()
|
||||
vtf := vt.Field(i)
|
||||
d.w.Write([]byte(vtf.Name))
|
||||
d.w.Write(colonSpaceBytes)
|
||||
d.ignoreNextIndent = true
|
||||
d.dump(d.unpackValue(v.Field(i)))
|
||||
if i < (numFields - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Uintptr:
|
||||
printHexPtr(d.w, uintptr(v.Uint()))
|
||||
|
||||
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||
printHexPtr(d.w, v.Pointer())
|
||||
|
||||
// There were not any other types at the time this code was written, but
|
||||
// fall back to letting the default fmt package handle it in case any new
|
||||
// types are added.
|
||||
default:
|
||||
if v.CanInterface() {
|
||||
fmt.Fprintf(d.w, "%v", v.Interface())
|
||||
} else {
|
||||
fmt.Fprintf(d.w, "%v", v.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fdump is a helper function to consolidate the logic from the various public
|
||||
// methods which take varying writers and config states.
|
||||
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
|
||||
for _, arg := range a {
|
||||
if arg == nil {
|
||||
w.Write(interfaceBytes)
|
||||
w.Write(spaceBytes)
|
||||
w.Write(nilAngleBytes)
|
||||
w.Write(newlineBytes)
|
||||
continue
|
||||
}
|
||||
|
||||
d := dumpState{w: w, cs: cs}
|
||||
d.pointers = make(map[uintptr]int)
|
||||
d.dump(reflect.ValueOf(arg))
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||
// exactly the same as Dump.
|
||||
func Fdump(w io.Writer, a ...interface{}) {
|
||||
fdump(&Config, w, a...)
|
||||
}
|
||||
|
||||
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||
// as Dump.
|
||||
func Sdump(a ...interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
fdump(&Config, &buf, a...)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
/*
|
||||
Dump displays the passed parameters to standard out with newlines, customizable
|
||||
indentation, and additional debug information such as complete types and all
|
||||
pointer addresses used to indirect to the final value. It provides the
|
||||
following features over the built-in printing facilities provided by the fmt
|
||||
package:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output
|
||||
|
||||
The configuration options are controlled by an exported package global,
|
||||
spew.Config. See ConfigState for options documentation.
|
||||
|
||||
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||
get the formatted result as a string.
|
||||
*/
|
||||
func Dump(a ...interface{}) {
|
||||
fdump(&Config, os.Stdout, a...)
|
||||
}
|
||||
-419
@@ -1,419 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// supportedFlags is a list of all the character flags supported by fmt package.
|
||||
const supportedFlags = "0-+# "
|
||||
|
||||
// formatState implements the fmt.Formatter interface and contains information
|
||||
// about the state of a formatting operation. The NewFormatter function can
|
||||
// be used to get a new Formatter which can be used directly as arguments
|
||||
// in standard fmt package printing calls.
|
||||
type formatState struct {
|
||||
value interface{}
|
||||
fs fmt.State
|
||||
depth int
|
||||
pointers map[uintptr]int
|
||||
ignoreNextType bool
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// buildDefaultFormat recreates the original format string without precision
|
||||
// and width information to pass in to fmt.Sprintf in the case of an
|
||||
// unrecognized type. Unless new types are added to the language, this
|
||||
// function won't ever be called.
|
||||
func (f *formatState) buildDefaultFormat() (format string) {
|
||||
buf := bytes.NewBuffer(percentBytes)
|
||||
|
||||
for _, flag := range supportedFlags {
|
||||
if f.fs.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteRune('v')
|
||||
|
||||
format = buf.String()
|
||||
return format
|
||||
}
|
||||
|
||||
// constructOrigFormat recreates the original format string including precision
|
||||
// and width information to pass along to the standard fmt package. This allows
|
||||
// automatic deferral of all format strings this package doesn't support.
|
||||
func (f *formatState) constructOrigFormat(verb rune) (format string) {
|
||||
buf := bytes.NewBuffer(percentBytes)
|
||||
|
||||
for _, flag := range supportedFlags {
|
||||
if f.fs.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
if width, ok := f.fs.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(width))
|
||||
}
|
||||
|
||||
if precision, ok := f.fs.Precision(); ok {
|
||||
buf.Write(precisionBytes)
|
||||
buf.WriteString(strconv.Itoa(precision))
|
||||
}
|
||||
|
||||
buf.WriteRune(verb)
|
||||
|
||||
format = buf.String()
|
||||
return format
|
||||
}
|
||||
|
||||
// unpackValue returns values inside of non-nil interfaces when possible and
|
||||
// ensures that types for values which have been unpacked from an interface
|
||||
// are displayed when the show types flag is also set.
|
||||
// This is useful for data types like structs, arrays, slices, and maps which
|
||||
// can contain varying types packed inside an interface.
|
||||
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Interface {
|
||||
f.ignoreNextType = false
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// formatPtr handles formatting of pointers by indirecting them as necessary.
|
||||
func (f *formatState) formatPtr(v reflect.Value) {
|
||||
// Display nil if top level pointer is nil.
|
||||
showTypes := f.fs.Flag('#')
|
||||
if v.IsNil() && (!showTypes || f.ignoreNextType) {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove pointers at or below the current depth from map used to detect
|
||||
// circular refs.
|
||||
for k, depth := range f.pointers {
|
||||
if depth >= f.depth {
|
||||
delete(f.pointers, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep list of all dereferenced pointers to possibly show later.
|
||||
pointerChain := make([]uintptr, 0)
|
||||
|
||||
// Figure out how many levels of indirection there are by derferencing
|
||||
// pointers and unpacking interfaces down the chain while detecting circular
|
||||
// references.
|
||||
nilFound := false
|
||||
cycleFound := false
|
||||
indirects := 0
|
||||
ve := v
|
||||
for ve.Kind() == reflect.Ptr {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
indirects++
|
||||
addr := ve.Pointer()
|
||||
pointerChain = append(pointerChain, addr)
|
||||
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
|
||||
cycleFound = true
|
||||
indirects--
|
||||
break
|
||||
}
|
||||
f.pointers[addr] = f.depth
|
||||
|
||||
ve = ve.Elem()
|
||||
if ve.Kind() == reflect.Interface {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
ve = ve.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// Display type or indirection level depending on flags.
|
||||
if showTypes && !f.ignoreNextType {
|
||||
f.fs.Write(openParenBytes)
|
||||
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||
f.fs.Write([]byte(ve.Type().String()))
|
||||
f.fs.Write(closeParenBytes)
|
||||
} else {
|
||||
if nilFound || cycleFound {
|
||||
indirects += strings.Count(ve.Type().String(), "*")
|
||||
}
|
||||
f.fs.Write(openAngleBytes)
|
||||
f.fs.Write([]byte(strings.Repeat("*", indirects)))
|
||||
f.fs.Write(closeAngleBytes)
|
||||
}
|
||||
|
||||
// Display pointer information depending on flags.
|
||||
if f.fs.Flag('+') && (len(pointerChain) > 0) {
|
||||
f.fs.Write(openParenBytes)
|
||||
for i, addr := range pointerChain {
|
||||
if i > 0 {
|
||||
f.fs.Write(pointerChainBytes)
|
||||
}
|
||||
printHexPtr(f.fs, addr)
|
||||
}
|
||||
f.fs.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// Display dereferenced value.
|
||||
switch {
|
||||
case nilFound:
|
||||
f.fs.Write(nilAngleBytes)
|
||||
|
||||
case cycleFound:
|
||||
f.fs.Write(circularShortBytes)
|
||||
|
||||
default:
|
||||
f.ignoreNextType = true
|
||||
f.format(ve)
|
||||
}
|
||||
}
|
||||
|
||||
// format is the main workhorse for providing the Formatter interface. It
|
||||
// uses the passed reflect value to figure out what kind of object we are
|
||||
// dealing with and formats it appropriately. It is a recursive function,
|
||||
// however circular data structures are detected and handled properly.
|
||||
func (f *formatState) format(v reflect.Value) {
|
||||
// Handle invalid reflect values immediately.
|
||||
kind := v.Kind()
|
||||
if kind == reflect.Invalid {
|
||||
f.fs.Write(invalidAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle pointers specially.
|
||||
if kind == reflect.Ptr {
|
||||
f.formatPtr(v)
|
||||
return
|
||||
}
|
||||
|
||||
// Print type information unless already handled elsewhere.
|
||||
if !f.ignoreNextType && f.fs.Flag('#') {
|
||||
f.fs.Write(openParenBytes)
|
||||
f.fs.Write([]byte(v.Type().String()))
|
||||
f.fs.Write(closeParenBytes)
|
||||
}
|
||||
f.ignoreNextType = false
|
||||
|
||||
// Call Stringer/error interfaces if they exist and the handle methods
|
||||
// flag is enabled.
|
||||
if !f.cs.DisableMethods {
|
||||
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||
if handled := handleMethods(f.cs, f.fs, v); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Invalid:
|
||||
// Do nothing. We should never get here since invalid has already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Bool:
|
||||
printBool(f.fs, v.Bool())
|
||||
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
printInt(f.fs, v.Int(), 10)
|
||||
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
printUint(f.fs, v.Uint(), 10)
|
||||
|
||||
case reflect.Float32:
|
||||
printFloat(f.fs, v.Float(), 32)
|
||||
|
||||
case reflect.Float64:
|
||||
printFloat(f.fs, v.Float(), 64)
|
||||
|
||||
case reflect.Complex64:
|
||||
printComplex(f.fs, v.Complex(), 32)
|
||||
|
||||
case reflect.Complex128:
|
||||
printComplex(f.fs, v.Complex(), 64)
|
||||
|
||||
case reflect.Slice:
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Array:
|
||||
f.fs.Write(openBracketBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
numEntries := v.Len()
|
||||
for i := 0; i < numEntries; i++ {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(v.Index(i)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeBracketBytes)
|
||||
|
||||
case reflect.String:
|
||||
f.fs.Write([]byte(v.String()))
|
||||
|
||||
case reflect.Interface:
|
||||
// The only time we should get here is for nil interfaces due to
|
||||
// unpackValue calls.
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// Do nothing. We should never get here since pointers have already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Map:
|
||||
// nil maps should be indicated as different than empty maps
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
|
||||
f.fs.Write(openMapBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
keys := v.MapKeys()
|
||||
if f.cs.SortKeys {
|
||||
sortValues(keys, f.cs)
|
||||
}
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(key))
|
||||
f.fs.Write(colonBytes)
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(v.MapIndex(key)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeMapBytes)
|
||||
|
||||
case reflect.Struct:
|
||||
numFields := v.NumField()
|
||||
f.fs.Write(openBraceBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
vt := v.Type()
|
||||
for i := 0; i < numFields; i++ {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
vtf := vt.Field(i)
|
||||
if f.fs.Flag('+') || f.fs.Flag('#') {
|
||||
f.fs.Write([]byte(vtf.Name))
|
||||
f.fs.Write(colonBytes)
|
||||
}
|
||||
f.format(f.unpackValue(v.Field(i)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Uintptr:
|
||||
printHexPtr(f.fs, uintptr(v.Uint()))
|
||||
|
||||
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||
printHexPtr(f.fs, v.Pointer())
|
||||
|
||||
// There were not any other types at the time this code was written, but
|
||||
// fall back to letting the default fmt package handle it if any get added.
|
||||
default:
|
||||
format := f.buildDefaultFormat()
|
||||
if v.CanInterface() {
|
||||
fmt.Fprintf(f.fs, format, v.Interface())
|
||||
} else {
|
||||
fmt.Fprintf(f.fs, format, v.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
|
||||
// details.
|
||||
func (f *formatState) Format(fs fmt.State, verb rune) {
|
||||
f.fs = fs
|
||||
|
||||
// Use standard formatting for verbs that are not v.
|
||||
if verb != 'v' {
|
||||
format := f.constructOrigFormat(verb)
|
||||
fmt.Fprintf(fs, format, f.value)
|
||||
return
|
||||
}
|
||||
|
||||
if f.value == nil {
|
||||
if fs.Flag('#') {
|
||||
fs.Write(interfaceBytes)
|
||||
}
|
||||
fs.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
f.format(reflect.ValueOf(f.value))
|
||||
}
|
||||
|
||||
// newFormatter is a helper function to consolidate the logic from the various
|
||||
// public methods which take varying config states.
|
||||
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
|
||||
fs := &formatState{value: v, cs: cs}
|
||||
fs.pointers = make(map[uintptr]int)
|
||||
return fs
|
||||
}
|
||||
|
||||
/*
|
||||
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||
interface. As a result, it integrates cleanly with standard fmt package
|
||||
printing functions. The formatter is useful for inline printing of smaller data
|
||||
types similar to the standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Typically this function shouldn't be called directly. It is much easier to make
|
||||
use of the custom formatter by calling one of the convenience functions such as
|
||||
Printf, Println, or Fprintf.
|
||||
*/
|
||||
func NewFormatter(v interface{}) fmt.Formatter {
|
||||
return newFormatter(&Config, v)
|
||||
}
|
||||
-148
@@ -1,148 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the formatted string as a value that satisfies error. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Errorf(format string, a ...interface{}) (err error) {
|
||||
return fmt.Errorf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprint(w, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintf(w, format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||
// passed with a default Formatter interface returned by NewFormatter. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintln(w, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Print(a ...interface{}) (n int, err error) {
|
||||
return fmt.Print(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Printf(format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Printf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Println(a ...interface{}) (n int, err error) {
|
||||
return fmt.Println(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprint(a ...interface{}) string {
|
||||
return fmt.Sprint(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprintf(format string, a ...interface{}) string {
|
||||
return fmt.Sprintf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||
// were passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprintln(a ...interface{}) string {
|
||||
return fmt.Sprintln(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||
// length with each argument converted to a default spew Formatter interface.
|
||||
func convertArgs(args []interface{}) (formatters []interface{}) {
|
||||
formatters = make([]interface{}, len(args))
|
||||
for index, arg := range args {
|
||||
formatters[index] = NewFormatter(arg)
|
||||
}
|
||||
return formatters
|
||||
}
|
||||
-27
@@ -1,27 +0,0 @@
|
||||
Copyright (c) 2013, Patrick Mezard
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
The names of its contributors may not be used to endorse or promote
|
||||
products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
|
||||
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
|
||||
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
|
||||
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
-772
@@ -1,772 +0,0 @@
|
||||
// Package difflib is a partial port of Python difflib module.
|
||||
//
|
||||
// It provides tools to compare sequences of strings and generate textual diffs.
|
||||
//
|
||||
// The following class and functions have been ported:
|
||||
//
|
||||
// - SequenceMatcher
|
||||
//
|
||||
// - unified_diff
|
||||
//
|
||||
// - context_diff
|
||||
//
|
||||
// Getting unified diffs was the main goal of the port. Keep in mind this code
|
||||
// is mostly suitable to output text differences in a human friendly way, there
|
||||
// are no guarantees generated diffs are consumable by patch(1).
|
||||
package difflib
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func calculateRatio(matches, length int) float64 {
|
||||
if length > 0 {
|
||||
return 2.0 * float64(matches) / float64(length)
|
||||
}
|
||||
return 1.0
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
A int
|
||||
B int
|
||||
Size int
|
||||
}
|
||||
|
||||
type OpCode struct {
|
||||
Tag byte
|
||||
I1 int
|
||||
I2 int
|
||||
J1 int
|
||||
J2 int
|
||||
}
|
||||
|
||||
// SequenceMatcher compares sequence of strings. The basic
|
||||
// algorithm predates, and is a little fancier than, an algorithm
|
||||
// published in the late 1980's by Ratcliff and Obershelp under the
|
||||
// hyperbolic name "gestalt pattern matching". The basic idea is to find
|
||||
// the longest contiguous matching subsequence that contains no "junk"
|
||||
// elements (R-O doesn't address junk). The same idea is then applied
|
||||
// recursively to the pieces of the sequences to the left and to the right
|
||||
// of the matching subsequence. This does not yield minimal edit
|
||||
// sequences, but does tend to yield matches that "look right" to people.
|
||||
//
|
||||
// SequenceMatcher tries to compute a "human-friendly diff" between two
|
||||
// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the
|
||||
// longest *contiguous* & junk-free matching subsequence. That's what
|
||||
// catches peoples' eyes. The Windows(tm) windiff has another interesting
|
||||
// notion, pairing up elements that appear uniquely in each sequence.
|
||||
// That, and the method here, appear to yield more intuitive difference
|
||||
// reports than does diff. This method appears to be the least vulnerable
|
||||
// to synching up on blocks of "junk lines", though (like blank lines in
|
||||
// ordinary text files, or maybe "<P>" lines in HTML files). That may be
|
||||
// because this is the only method of the 3 that has a *concept* of
|
||||
// "junk" <wink>.
|
||||
//
|
||||
// Timing: Basic R-O is cubic time worst case and quadratic time expected
|
||||
// case. SequenceMatcher is quadratic time for the worst case and has
|
||||
// expected-case behavior dependent in a complicated way on how many
|
||||
// elements the sequences have in common; best case time is linear.
|
||||
type SequenceMatcher struct {
|
||||
a []string
|
||||
b []string
|
||||
b2j map[string][]int
|
||||
IsJunk func(string) bool
|
||||
autoJunk bool
|
||||
bJunk map[string]struct{}
|
||||
matchingBlocks []Match
|
||||
fullBCount map[string]int
|
||||
bPopular map[string]struct{}
|
||||
opCodes []OpCode
|
||||
}
|
||||
|
||||
func NewMatcher(a, b []string) *SequenceMatcher {
|
||||
m := SequenceMatcher{autoJunk: true}
|
||||
m.SetSeqs(a, b)
|
||||
return &m
|
||||
}
|
||||
|
||||
func NewMatcherWithJunk(a, b []string, autoJunk bool,
|
||||
isJunk func(string) bool) *SequenceMatcher {
|
||||
|
||||
m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk}
|
||||
m.SetSeqs(a, b)
|
||||
return &m
|
||||
}
|
||||
|
||||
// Set two sequences to be compared.
|
||||
func (m *SequenceMatcher) SetSeqs(a, b []string) {
|
||||
m.SetSeq1(a)
|
||||
m.SetSeq2(b)
|
||||
}
|
||||
|
||||
// Set the first sequence to be compared. The second sequence to be compared is
|
||||
// not changed.
|
||||
//
|
||||
// SequenceMatcher computes and caches detailed information about the second
|
||||
// sequence, so if you want to compare one sequence S against many sequences,
|
||||
// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other
|
||||
// sequences.
|
||||
//
|
||||
// See also SetSeqs() and SetSeq2().
|
||||
func (m *SequenceMatcher) SetSeq1(a []string) {
|
||||
if &a == &m.a {
|
||||
return
|
||||
}
|
||||
m.a = a
|
||||
m.matchingBlocks = nil
|
||||
m.opCodes = nil
|
||||
}
|
||||
|
||||
// Set the second sequence to be compared. The first sequence to be compared is
|
||||
// not changed.
|
||||
func (m *SequenceMatcher) SetSeq2(b []string) {
|
||||
if &b == &m.b {
|
||||
return
|
||||
}
|
||||
m.b = b
|
||||
m.matchingBlocks = nil
|
||||
m.opCodes = nil
|
||||
m.fullBCount = nil
|
||||
m.chainB()
|
||||
}
|
||||
|
||||
func (m *SequenceMatcher) chainB() {
|
||||
// Populate line -> index mapping
|
||||
b2j := map[string][]int{}
|
||||
for i, s := range m.b {
|
||||
indices := b2j[s]
|
||||
indices = append(indices, i)
|
||||
b2j[s] = indices
|
||||
}
|
||||
|
||||
// Purge junk elements
|
||||
m.bJunk = map[string]struct{}{}
|
||||
if m.IsJunk != nil {
|
||||
junk := m.bJunk
|
||||
for s, _ := range b2j {
|
||||
if m.IsJunk(s) {
|
||||
junk[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
for s, _ := range junk {
|
||||
delete(b2j, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Purge remaining popular elements
|
||||
popular := map[string]struct{}{}
|
||||
n := len(m.b)
|
||||
if m.autoJunk && n >= 200 {
|
||||
ntest := n/100 + 1
|
||||
for s, indices := range b2j {
|
||||
if len(indices) > ntest {
|
||||
popular[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
for s, _ := range popular {
|
||||
delete(b2j, s)
|
||||
}
|
||||
}
|
||||
m.bPopular = popular
|
||||
m.b2j = b2j
|
||||
}
|
||||
|
||||
func (m *SequenceMatcher) isBJunk(s string) bool {
|
||||
_, ok := m.bJunk[s]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Find longest matching block in a[alo:ahi] and b[blo:bhi].
|
||||
//
|
||||
// If IsJunk is not defined:
|
||||
//
|
||||
// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
|
||||
// alo <= i <= i+k <= ahi
|
||||
// blo <= j <= j+k <= bhi
|
||||
// and for all (i',j',k') meeting those conditions,
|
||||
// k >= k'
|
||||
// i <= i'
|
||||
// and if i == i', j <= j'
|
||||
//
|
||||
// In other words, of all maximal matching blocks, return one that
|
||||
// starts earliest in a, and of all those maximal matching blocks that
|
||||
// start earliest in a, return the one that starts earliest in b.
|
||||
//
|
||||
// If IsJunk is defined, first the longest matching block is
|
||||
// determined as above, but with the additional restriction that no
|
||||
// junk element appears in the block. Then that block is extended as
|
||||
// far as possible by matching (only) junk elements on both sides. So
|
||||
// the resulting block never matches on junk except as identical junk
|
||||
// happens to be adjacent to an "interesting" match.
|
||||
//
|
||||
// If no blocks match, return (alo, blo, 0).
|
||||
func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match {
|
||||
// CAUTION: stripping common prefix or suffix would be incorrect.
|
||||
// E.g.,
|
||||
// ab
|
||||
// acab
|
||||
// Longest matching block is "ab", but if common prefix is
|
||||
// stripped, it's "a" (tied with "b"). UNIX(tm) diff does so
|
||||
// strip, so ends up claiming that ab is changed to acab by
|
||||
// inserting "ca" in the middle. That's minimal but unintuitive:
|
||||
// "it's obvious" that someone inserted "ac" at the front.
|
||||
// Windiff ends up at the same place as diff, but by pairing up
|
||||
// the unique 'b's and then matching the first two 'a's.
|
||||
besti, bestj, bestsize := alo, blo, 0
|
||||
|
||||
// find longest junk-free match
|
||||
// during an iteration of the loop, j2len[j] = length of longest
|
||||
// junk-free match ending with a[i-1] and b[j]
|
||||
j2len := map[int]int{}
|
||||
for i := alo; i != ahi; i++ {
|
||||
// look at all instances of a[i] in b; note that because
|
||||
// b2j has no junk keys, the loop is skipped if a[i] is junk
|
||||
newj2len := map[int]int{}
|
||||
for _, j := range m.b2j[m.a[i]] {
|
||||
// a[i] matches b[j]
|
||||
if j < blo {
|
||||
continue
|
||||
}
|
||||
if j >= bhi {
|
||||
break
|
||||
}
|
||||
k := j2len[j-1] + 1
|
||||
newj2len[j] = k
|
||||
if k > bestsize {
|
||||
besti, bestj, bestsize = i-k+1, j-k+1, k
|
||||
}
|
||||
}
|
||||
j2len = newj2len
|
||||
}
|
||||
|
||||
// Extend the best by non-junk elements on each end. In particular,
|
||||
// "popular" non-junk elements aren't in b2j, which greatly speeds
|
||||
// the inner loop above, but also means "the best" match so far
|
||||
// doesn't contain any junk *or* popular non-junk elements.
|
||||
for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) &&
|
||||
m.a[besti-1] == m.b[bestj-1] {
|
||||
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
||||
}
|
||||
for besti+bestsize < ahi && bestj+bestsize < bhi &&
|
||||
!m.isBJunk(m.b[bestj+bestsize]) &&
|
||||
m.a[besti+bestsize] == m.b[bestj+bestsize] {
|
||||
bestsize += 1
|
||||
}
|
||||
|
||||
// Now that we have a wholly interesting match (albeit possibly
|
||||
// empty!), we may as well suck up the matching junk on each
|
||||
// side of it too. Can't think of a good reason not to, and it
|
||||
// saves post-processing the (possibly considerable) expense of
|
||||
// figuring out what to do with it. In the case of an empty
|
||||
// interesting match, this is clearly the right thing to do,
|
||||
// because no other kind of match is possible in the regions.
|
||||
for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) &&
|
||||
m.a[besti-1] == m.b[bestj-1] {
|
||||
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
||||
}
|
||||
for besti+bestsize < ahi && bestj+bestsize < bhi &&
|
||||
m.isBJunk(m.b[bestj+bestsize]) &&
|
||||
m.a[besti+bestsize] == m.b[bestj+bestsize] {
|
||||
bestsize += 1
|
||||
}
|
||||
|
||||
return Match{A: besti, B: bestj, Size: bestsize}
|
||||
}
|
||||
|
||||
// Return list of triples describing matching subsequences.
|
||||
//
|
||||
// Each triple is of the form (i, j, n), and means that
|
||||
// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in
|
||||
// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are
|
||||
// adjacent triples in the list, and the second is not the last triple in the
|
||||
// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe
|
||||
// adjacent equal blocks.
|
||||
//
|
||||
// The last triple is a dummy, (len(a), len(b), 0), and is the only
|
||||
// triple with n==0.
|
||||
func (m *SequenceMatcher) GetMatchingBlocks() []Match {
|
||||
if m.matchingBlocks != nil {
|
||||
return m.matchingBlocks
|
||||
}
|
||||
|
||||
var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match
|
||||
matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match {
|
||||
match := m.findLongestMatch(alo, ahi, blo, bhi)
|
||||
i, j, k := match.A, match.B, match.Size
|
||||
if match.Size > 0 {
|
||||
if alo < i && blo < j {
|
||||
matched = matchBlocks(alo, i, blo, j, matched)
|
||||
}
|
||||
matched = append(matched, match)
|
||||
if i+k < ahi && j+k < bhi {
|
||||
matched = matchBlocks(i+k, ahi, j+k, bhi, matched)
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
matched := matchBlocks(0, len(m.a), 0, len(m.b), nil)
|
||||
|
||||
// It's possible that we have adjacent equal blocks in the
|
||||
// matching_blocks list now.
|
||||
nonAdjacent := []Match{}
|
||||
i1, j1, k1 := 0, 0, 0
|
||||
for _, b := range matched {
|
||||
// Is this block adjacent to i1, j1, k1?
|
||||
i2, j2, k2 := b.A, b.B, b.Size
|
||||
if i1+k1 == i2 && j1+k1 == j2 {
|
||||
// Yes, so collapse them -- this just increases the length of
|
||||
// the first block by the length of the second, and the first
|
||||
// block so lengthened remains the block to compare against.
|
||||
k1 += k2
|
||||
} else {
|
||||
// Not adjacent. Remember the first block (k1==0 means it's
|
||||
// the dummy we started with), and make the second block the
|
||||
// new block to compare against.
|
||||
if k1 > 0 {
|
||||
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
|
||||
}
|
||||
i1, j1, k1 = i2, j2, k2
|
||||
}
|
||||
}
|
||||
if k1 > 0 {
|
||||
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
|
||||
}
|
||||
|
||||
nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0})
|
||||
m.matchingBlocks = nonAdjacent
|
||||
return m.matchingBlocks
|
||||
}
|
||||
|
||||
// Return list of 5-tuples describing how to turn a into b.
|
||||
//
|
||||
// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple
|
||||
// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the
|
||||
// tuple preceding it, and likewise for j1 == the previous j2.
|
||||
//
|
||||
// The tags are characters, with these meanings:
|
||||
//
|
||||
// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2]
|
||||
//
|
||||
// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case.
|
||||
//
|
||||
// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case.
|
||||
//
|
||||
// 'e' (equal): a[i1:i2] == b[j1:j2]
|
||||
func (m *SequenceMatcher) GetOpCodes() []OpCode {
|
||||
if m.opCodes != nil {
|
||||
return m.opCodes
|
||||
}
|
||||
i, j := 0, 0
|
||||
matching := m.GetMatchingBlocks()
|
||||
opCodes := make([]OpCode, 0, len(matching))
|
||||
for _, m := range matching {
|
||||
// invariant: we've pumped out correct diffs to change
|
||||
// a[:i] into b[:j], and the next matching block is
|
||||
// a[ai:ai+size] == b[bj:bj+size]. So we need to pump
|
||||
// out a diff to change a[i:ai] into b[j:bj], pump out
|
||||
// the matching block, and move (i,j) beyond the match
|
||||
ai, bj, size := m.A, m.B, m.Size
|
||||
tag := byte(0)
|
||||
if i < ai && j < bj {
|
||||
tag = 'r'
|
||||
} else if i < ai {
|
||||
tag = 'd'
|
||||
} else if j < bj {
|
||||
tag = 'i'
|
||||
}
|
||||
if tag > 0 {
|
||||
opCodes = append(opCodes, OpCode{tag, i, ai, j, bj})
|
||||
}
|
||||
i, j = ai+size, bj+size
|
||||
// the list of matching blocks is terminated by a
|
||||
// sentinel with size 0
|
||||
if size > 0 {
|
||||
opCodes = append(opCodes, OpCode{'e', ai, i, bj, j})
|
||||
}
|
||||
}
|
||||
m.opCodes = opCodes
|
||||
return m.opCodes
|
||||
}
|
||||
|
||||
// Isolate change clusters by eliminating ranges with no changes.
|
||||
//
|
||||
// Return a generator of groups with up to n lines of context.
|
||||
// Each group is in the same format as returned by GetOpCodes().
|
||||
func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
|
||||
if n < 0 {
|
||||
n = 3
|
||||
}
|
||||
codes := m.GetOpCodes()
|
||||
if len(codes) == 0 {
|
||||
codes = []OpCode{OpCode{'e', 0, 1, 0, 1}}
|
||||
}
|
||||
// Fixup leading and trailing groups if they show no changes.
|
||||
if codes[0].Tag == 'e' {
|
||||
c := codes[0]
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2}
|
||||
}
|
||||
if codes[len(codes)-1].Tag == 'e' {
|
||||
c := codes[len(codes)-1]
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}
|
||||
}
|
||||
nn := n + n
|
||||
groups := [][]OpCode{}
|
||||
group := []OpCode{}
|
||||
for _, c := range codes {
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
// End the current group and start a new one whenever
|
||||
// there is a large range with no changes.
|
||||
if c.Tag == 'e' && i2-i1 > nn {
|
||||
group = append(group, OpCode{c.Tag, i1, min(i2, i1+n),
|
||||
j1, min(j2, j1+n)})
|
||||
groups = append(groups, group)
|
||||
group = []OpCode{}
|
||||
i1, j1 = max(i1, i2-n), max(j1, j2-n)
|
||||
}
|
||||
group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
|
||||
}
|
||||
if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// Return a measure of the sequences' similarity (float in [0,1]).
|
||||
//
|
||||
// Where T is the total number of elements in both sequences, and
|
||||
// M is the number of matches, this is 2.0*M / T.
|
||||
// Note that this is 1 if the sequences are identical, and 0 if
|
||||
// they have nothing in common.
|
||||
//
|
||||
// .Ratio() is expensive to compute if you haven't already computed
|
||||
// .GetMatchingBlocks() or .GetOpCodes(), in which case you may
|
||||
// want to try .QuickRatio() or .RealQuickRation() first to get an
|
||||
// upper bound.
|
||||
func (m *SequenceMatcher) Ratio() float64 {
|
||||
matches := 0
|
||||
for _, m := range m.GetMatchingBlocks() {
|
||||
matches += m.Size
|
||||
}
|
||||
return calculateRatio(matches, len(m.a)+len(m.b))
|
||||
}
|
||||
|
||||
// Return an upper bound on ratio() relatively quickly.
|
||||
//
|
||||
// This isn't defined beyond that it is an upper bound on .Ratio(), and
|
||||
// is faster to compute.
|
||||
func (m *SequenceMatcher) QuickRatio() float64 {
|
||||
// viewing a and b as multisets, set matches to the cardinality
|
||||
// of their intersection; this counts the number of matches
|
||||
// without regard to order, so is clearly an upper bound
|
||||
if m.fullBCount == nil {
|
||||
m.fullBCount = map[string]int{}
|
||||
for _, s := range m.b {
|
||||
m.fullBCount[s] = m.fullBCount[s] + 1
|
||||
}
|
||||
}
|
||||
|
||||
// avail[x] is the number of times x appears in 'b' less the
|
||||
// number of times we've seen it in 'a' so far ... kinda
|
||||
avail := map[string]int{}
|
||||
matches := 0
|
||||
for _, s := range m.a {
|
||||
n, ok := avail[s]
|
||||
if !ok {
|
||||
n = m.fullBCount[s]
|
||||
}
|
||||
avail[s] = n - 1
|
||||
if n > 0 {
|
||||
matches += 1
|
||||
}
|
||||
}
|
||||
return calculateRatio(matches, len(m.a)+len(m.b))
|
||||
}
|
||||
|
||||
// Return an upper bound on ratio() very quickly.
|
||||
//
|
||||
// This isn't defined beyond that it is an upper bound on .Ratio(), and
|
||||
// is faster to compute than either .Ratio() or .QuickRatio().
|
||||
func (m *SequenceMatcher) RealQuickRatio() float64 {
|
||||
la, lb := len(m.a), len(m.b)
|
||||
return calculateRatio(min(la, lb), la+lb)
|
||||
}
|
||||
|
||||
// Convert range to the "ed" format
|
||||
func formatRangeUnified(start, stop int) string {
|
||||
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||
beginning := start + 1 // lines start numbering with one
|
||||
length := stop - start
|
||||
if length == 1 {
|
||||
return fmt.Sprintf("%d", beginning)
|
||||
}
|
||||
if length == 0 {
|
||||
beginning -= 1 // empty ranges begin at line just before the range
|
||||
}
|
||||
return fmt.Sprintf("%d,%d", beginning, length)
|
||||
}
|
||||
|
||||
// Unified diff parameters
|
||||
type UnifiedDiff struct {
|
||||
A []string // First sequence lines
|
||||
FromFile string // First file name
|
||||
FromDate string // First file time
|
||||
B []string // Second sequence lines
|
||||
ToFile string // Second file name
|
||||
ToDate string // Second file time
|
||||
Eol string // Headers end of line, defaults to LF
|
||||
Context int // Number of context lines
|
||||
}
|
||||
|
||||
// Compare two sequences of lines; generate the delta as a unified diff.
|
||||
//
|
||||
// Unified diffs are a compact way of showing line changes and a few
|
||||
// lines of context. The number of context lines is set by 'n' which
|
||||
// defaults to three.
|
||||
//
|
||||
// By default, the diff control lines (those with ---, +++, or @@) are
|
||||
// created with a trailing newline. This is helpful so that inputs
|
||||
// created from file.readlines() result in diffs that are suitable for
|
||||
// file.writelines() since both the inputs and outputs have trailing
|
||||
// newlines.
|
||||
//
|
||||
// For inputs that do not have trailing newlines, set the lineterm
|
||||
// argument to "" so that the output will be uniformly newline free.
|
||||
//
|
||||
// The unidiff format normally has a header for filenames and modification
|
||||
// times. Any or all of these may be specified using strings for
|
||||
// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
|
||||
// The modification times are normally expressed in the ISO 8601 format.
|
||||
func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error {
|
||||
buf := bufio.NewWriter(writer)
|
||||
defer buf.Flush()
|
||||
wf := func(format string, args ...interface{}) error {
|
||||
_, err := buf.WriteString(fmt.Sprintf(format, args...))
|
||||
return err
|
||||
}
|
||||
ws := func(s string) error {
|
||||
_, err := buf.WriteString(s)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(diff.Eol) == 0 {
|
||||
diff.Eol = "\n"
|
||||
}
|
||||
|
||||
started := false
|
||||
m := NewMatcher(diff.A, diff.B)
|
||||
for _, g := range m.GetGroupedOpCodes(diff.Context) {
|
||||
if !started {
|
||||
started = true
|
||||
fromDate := ""
|
||||
if len(diff.FromDate) > 0 {
|
||||
fromDate = "\t" + diff.FromDate
|
||||
}
|
||||
toDate := ""
|
||||
if len(diff.ToDate) > 0 {
|
||||
toDate = "\t" + diff.ToDate
|
||||
}
|
||||
if diff.FromFile != "" || diff.ToFile != "" {
|
||||
err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
first, last := g[0], g[len(g)-1]
|
||||
range1 := formatRangeUnified(first.I1, last.I2)
|
||||
range2 := formatRangeUnified(first.J1, last.J2)
|
||||
if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range g {
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
if c.Tag == 'e' {
|
||||
for _, line := range diff.A[i1:i2] {
|
||||
if err := ws(" " + line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c.Tag == 'r' || c.Tag == 'd' {
|
||||
for _, line := range diff.A[i1:i2] {
|
||||
if err := ws("-" + line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.Tag == 'r' || c.Tag == 'i' {
|
||||
for _, line := range diff.B[j1:j2] {
|
||||
if err := ws("+" + line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Like WriteUnifiedDiff but returns the diff a string.
|
||||
func GetUnifiedDiffString(diff UnifiedDiff) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
err := WriteUnifiedDiff(w, diff)
|
||||
return string(w.Bytes()), err
|
||||
}
|
||||
|
||||
// Convert range to the "ed" format.
|
||||
func formatRangeContext(start, stop int) string {
|
||||
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||
beginning := start + 1 // lines start numbering with one
|
||||
length := stop - start
|
||||
if length == 0 {
|
||||
beginning -= 1 // empty ranges begin at line just before the range
|
||||
}
|
||||
if length <= 1 {
|
||||
return fmt.Sprintf("%d", beginning)
|
||||
}
|
||||
return fmt.Sprintf("%d,%d", beginning, beginning+length-1)
|
||||
}
|
||||
|
||||
type ContextDiff UnifiedDiff
|
||||
|
||||
// Compare two sequences of lines; generate the delta as a context diff.
|
||||
//
|
||||
// Context diffs are a compact way of showing line changes and a few
|
||||
// lines of context. The number of context lines is set by diff.Context
|
||||
// which defaults to three.
|
||||
//
|
||||
// By default, the diff control lines (those with *** or ---) are
|
||||
// created with a trailing newline.
|
||||
//
|
||||
// For inputs that do not have trailing newlines, set the diff.Eol
|
||||
// argument to "" so that the output will be uniformly newline free.
|
||||
//
|
||||
// The context diff format normally has a header for filenames and
|
||||
// modification times. Any or all of these may be specified using
|
||||
// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate.
|
||||
// The modification times are normally expressed in the ISO 8601 format.
|
||||
// If not specified, the strings default to blanks.
|
||||
func WriteContextDiff(writer io.Writer, diff ContextDiff) error {
|
||||
buf := bufio.NewWriter(writer)
|
||||
defer buf.Flush()
|
||||
var diffErr error
|
||||
wf := func(format string, args ...interface{}) {
|
||||
_, err := buf.WriteString(fmt.Sprintf(format, args...))
|
||||
if diffErr == nil && err != nil {
|
||||
diffErr = err
|
||||
}
|
||||
}
|
||||
ws := func(s string) {
|
||||
_, err := buf.WriteString(s)
|
||||
if diffErr == nil && err != nil {
|
||||
diffErr = err
|
||||
}
|
||||
}
|
||||
|
||||
if len(diff.Eol) == 0 {
|
||||
diff.Eol = "\n"
|
||||
}
|
||||
|
||||
prefix := map[byte]string{
|
||||
'i': "+ ",
|
||||
'd': "- ",
|
||||
'r': "! ",
|
||||
'e': " ",
|
||||
}
|
||||
|
||||
started := false
|
||||
m := NewMatcher(diff.A, diff.B)
|
||||
for _, g := range m.GetGroupedOpCodes(diff.Context) {
|
||||
if !started {
|
||||
started = true
|
||||
fromDate := ""
|
||||
if len(diff.FromDate) > 0 {
|
||||
fromDate = "\t" + diff.FromDate
|
||||
}
|
||||
toDate := ""
|
||||
if len(diff.ToDate) > 0 {
|
||||
toDate = "\t" + diff.ToDate
|
||||
}
|
||||
if diff.FromFile != "" || diff.ToFile != "" {
|
||||
wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol)
|
||||
wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol)
|
||||
}
|
||||
}
|
||||
|
||||
first, last := g[0], g[len(g)-1]
|
||||
ws("***************" + diff.Eol)
|
||||
|
||||
range1 := formatRangeContext(first.I1, last.I2)
|
||||
wf("*** %s ****%s", range1, diff.Eol)
|
||||
for _, c := range g {
|
||||
if c.Tag == 'r' || c.Tag == 'd' {
|
||||
for _, cc := range g {
|
||||
if cc.Tag == 'i' {
|
||||
continue
|
||||
}
|
||||
for _, line := range diff.A[cc.I1:cc.I2] {
|
||||
ws(prefix[cc.Tag] + line)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
range2 := formatRangeContext(first.J1, last.J2)
|
||||
wf("--- %s ----%s", range2, diff.Eol)
|
||||
for _, c := range g {
|
||||
if c.Tag == 'r' || c.Tag == 'i' {
|
||||
for _, cc := range g {
|
||||
if cc.Tag == 'd' {
|
||||
continue
|
||||
}
|
||||
for _, line := range diff.B[cc.J1:cc.J2] {
|
||||
ws(prefix[cc.Tag] + line)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return diffErr
|
||||
}
|
||||
|
||||
// Like WriteContextDiff but returns the diff a string.
|
||||
func GetContextDiffString(diff ContextDiff) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
err := WriteContextDiff(w, diff)
|
||||
return string(w.Bytes()), err
|
||||
}
|
||||
|
||||
// Split a string on "\n" while preserving them. The output can be used
|
||||
// as input for UnifiedDiff and ContextDiff structures.
|
||||
func SplitLines(s string) []string {
|
||||
lines := strings.SplitAfter(s, "\n")
|
||||
lines[len(lines)-1] += "\n"
|
||||
return lines
|
||||
}
|
||||
-21
@@ -1,21 +0,0 @@
|
||||
engines:
|
||||
gofmt:
|
||||
enabled: true
|
||||
golint:
|
||||
enabled: true
|
||||
govet:
|
||||
enabled: true
|
||||
|
||||
exclude_patterns:
|
||||
- ".github/"
|
||||
- "vendor/"
|
||||
- "codegen/"
|
||||
- "*.yml"
|
||||
- ".*.yml"
|
||||
- "*.md"
|
||||
- "Gopkg.*"
|
||||
- "doc.go"
|
||||
- "type_specific_codegen_test.go"
|
||||
- "type_specific_codegen.go"
|
||||
- ".gitignore"
|
||||
- "LICENSE"
|
||||
-11
@@ -1,11 +0,0 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, build with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
-22
@@ -1,22 +0,0 @@
|
||||
The MIT License
|
||||
|
||||
Copyright (c) 2014 Stretchr, Inc.
|
||||
Copyright (c) 2017-2018 objx contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
-80
@@ -1,80 +0,0 @@
|
||||
# Objx
|
||||
[](https://travis-ci.org/stretchr/objx)
|
||||
[](https://goreportcard.com/report/github.com/stretchr/objx)
|
||||
[](https://codeclimate.com/github/stretchr/objx/maintainability)
|
||||
[](https://codeclimate.com/github/stretchr/objx/test_coverage)
|
||||
[](https://sourcegraph.com/github.com/stretchr/objx)
|
||||
[](https://pkg.go.dev/github.com/stretchr/objx)
|
||||
|
||||
Objx - Go package for dealing with maps, slices, JSON and other data.
|
||||
|
||||
Get started:
|
||||
|
||||
- Install Objx with [one line of code](#installation), or [update it with another](#staying-up-to-date)
|
||||
- Check out the API Documentation http://pkg.go.dev/github.com/stretchr/objx
|
||||
|
||||
## Overview
|
||||
Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes a powerful `Get` method (among others) that allows you to easily and quickly get access to data within the map, without having to worry too much about type assertions, missing data, default values etc.
|
||||
|
||||
### Pattern
|
||||
Objx uses a predictable pattern to make access data from within `map[string]interface{}` easy. Call one of the `objx.` functions to create your `objx.Map` to get going:
|
||||
|
||||
m, err := objx.FromJSON(json)
|
||||
|
||||
NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, the rest will be optimistic and try to figure things out without panicking.
|
||||
|
||||
Use `Get` to access the value you're interested in. You can use dot and array
|
||||
notation too:
|
||||
|
||||
m.Get("places[0].latlng")
|
||||
|
||||
Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type.
|
||||
|
||||
if m.Get("code").IsStr() { // Your code... }
|
||||
|
||||
Or you can just assume the type, and use one of the strong type methods to extract the real value:
|
||||
|
||||
m.Get("code").Int()
|
||||
|
||||
If there's no value there (or if it's the wrong type) then a default value will be returned, or you can be explicit about the default value.
|
||||
|
||||
Get("code").Int(-1)
|
||||
|
||||
If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating, manipulating and selecting that data. You can find out more by exploring the index below.
|
||||
|
||||
### Reading data
|
||||
A simple example of how to use Objx:
|
||||
|
||||
// Use MustFromJSON to make an objx.Map from some JSON
|
||||
m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`)
|
||||
|
||||
// Get the details
|
||||
name := m.Get("name").Str()
|
||||
age := m.Get("age").Int()
|
||||
|
||||
// Get their nickname (or use their name if they don't have one)
|
||||
nickname := m.Get("nickname").Str(name)
|
||||
|
||||
### Ranging
|
||||
Since `objx.Map` is a `map[string]interface{}` you can treat it as such. For example, to `range` the data, do what you would expect:
|
||||
|
||||
m := objx.MustFromJSON(json)
|
||||
for key, value := range m {
|
||||
// Your code...
|
||||
}
|
||||
|
||||
## Installation
|
||||
To install Objx, use go get:
|
||||
|
||||
go get github.com/stretchr/objx
|
||||
|
||||
### Staying up to date
|
||||
To update Objx to the latest version, run:
|
||||
|
||||
go get -u github.com/stretchr/objx
|
||||
|
||||
### Supported go versions
|
||||
We currently support the three recent major Go versions.
|
||||
|
||||
## Contributing
|
||||
Please feel free to submit issues, fork the repository and send pull requests!
|
||||
-27
@@ -1,27 +0,0 @@
|
||||
version: '3'
|
||||
|
||||
tasks:
|
||||
default:
|
||||
deps: [test]
|
||||
|
||||
lint:
|
||||
desc: Checks code style
|
||||
cmds:
|
||||
- gofmt -d -s *.go
|
||||
- go vet ./...
|
||||
silent: true
|
||||
|
||||
lint-fix:
|
||||
desc: Fixes code style
|
||||
cmds:
|
||||
- gofmt -w -s *.go
|
||||
|
||||
test:
|
||||
desc: Runs go tests
|
||||
cmds:
|
||||
- go test -race ./...
|
||||
|
||||
test-coverage:
|
||||
desc: Runs go tests and calculates test coverage
|
||||
cmds:
|
||||
- go test -race -coverprofile=c.out ./...
|
||||
-197
@@ -1,197 +0,0 @@
|
||||
package objx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// PathSeparator is the character used to separate the elements
|
||||
// of the keypath.
|
||||
//
|
||||
// For example, `location.address.city`
|
||||
PathSeparator string = "."
|
||||
|
||||
// arrayAccessRegexString is the regex used to extract the array number
|
||||
// from the access path
|
||||
arrayAccessRegexString = `^(.+)\[([0-9]+)\]$`
|
||||
|
||||
// mapAccessRegexString is the regex used to extract the map key
|
||||
// from the access path
|
||||
mapAccessRegexString = `^([^\[]*)\[([^\]]+)\](.*)$`
|
||||
)
|
||||
|
||||
// arrayAccessRegex is the compiled arrayAccessRegexString
|
||||
var arrayAccessRegex = regexp.MustCompile(arrayAccessRegexString)
|
||||
|
||||
// mapAccessRegex is the compiled mapAccessRegexString
|
||||
var mapAccessRegex = regexp.MustCompile(mapAccessRegexString)
|
||||
|
||||
// Get gets the value using the specified selector and
|
||||
// returns it inside a new Obj object.
|
||||
//
|
||||
// If it cannot find the value, Get will return a nil
|
||||
// value inside an instance of Obj.
|
||||
//
|
||||
// Get can only operate directly on map[string]interface{} and []interface.
|
||||
//
|
||||
// # Example
|
||||
//
|
||||
// To access the title of the third chapter of the second book, do:
|
||||
//
|
||||
// o.Get("books[1].chapters[2].title")
|
||||
func (m Map) Get(selector string) *Value {
|
||||
rawObj := access(m, selector, nil, false)
|
||||
return &Value{data: rawObj}
|
||||
}
|
||||
|
||||
// Set sets the value using the specified selector and
|
||||
// returns the object on which Set was called.
|
||||
//
|
||||
// Set can only operate directly on map[string]interface{} and []interface
|
||||
//
|
||||
// # Example
|
||||
//
|
||||
// To set the title of the third chapter of the second book, do:
|
||||
//
|
||||
// o.Set("books[1].chapters[2].title","Time to Go")
|
||||
func (m Map) Set(selector string, value interface{}) Map {
|
||||
access(m, selector, value, true)
|
||||
return m
|
||||
}
|
||||
|
||||
// getIndex returns the index, which is hold in s by two branches.
|
||||
// It also returns s without the index part, e.g. name[1] will return (1, name).
|
||||
// If no index is found, -1 is returned
|
||||
func getIndex(s string) (int, string) {
|
||||
arrayMatches := arrayAccessRegex.FindStringSubmatch(s)
|
||||
if len(arrayMatches) > 0 {
|
||||
// Get the key into the map
|
||||
selector := arrayMatches[1]
|
||||
// Get the index into the array at the key
|
||||
// We know this can't fail because arrayMatches[2] is an int for sure
|
||||
index, _ := strconv.Atoi(arrayMatches[2])
|
||||
return index, selector
|
||||
}
|
||||
return -1, s
|
||||
}
|
||||
|
||||
// getKey returns the key which is held in s by two brackets.
|
||||
// It also returns the next selector.
|
||||
func getKey(s string) (string, string) {
|
||||
selSegs := strings.SplitN(s, PathSeparator, 2)
|
||||
thisSel := selSegs[0]
|
||||
nextSel := ""
|
||||
|
||||
if len(selSegs) > 1 {
|
||||
nextSel = selSegs[1]
|
||||
}
|
||||
|
||||
mapMatches := mapAccessRegex.FindStringSubmatch(s)
|
||||
if len(mapMatches) > 0 {
|
||||
if _, err := strconv.Atoi(mapMatches[2]); err != nil {
|
||||
thisSel = mapMatches[1]
|
||||
nextSel = "[" + mapMatches[2] + "]" + mapMatches[3]
|
||||
|
||||
if thisSel == "" {
|
||||
thisSel = mapMatches[2]
|
||||
nextSel = mapMatches[3]
|
||||
}
|
||||
|
||||
if nextSel == "" {
|
||||
selSegs = []string{"", ""}
|
||||
} else if nextSel[0] == '.' {
|
||||
nextSel = nextSel[1:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return thisSel, nextSel
|
||||
}
|
||||
|
||||
// access accesses the object using the selector and performs the
|
||||
// appropriate action.
|
||||
func access(current interface{}, selector string, value interface{}, isSet bool) interface{} {
|
||||
thisSel, nextSel := getKey(selector)
|
||||
|
||||
indexes := []int{}
|
||||
for strings.Contains(thisSel, "[") {
|
||||
prevSel := thisSel
|
||||
index := -1
|
||||
index, thisSel = getIndex(thisSel)
|
||||
indexes = append(indexes, index)
|
||||
if prevSel == thisSel {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if curMap, ok := current.(Map); ok {
|
||||
current = map[string]interface{}(curMap)
|
||||
}
|
||||
// get the object in question
|
||||
switch current.(type) {
|
||||
case map[string]interface{}:
|
||||
curMSI := current.(map[string]interface{})
|
||||
if nextSel == "" && isSet {
|
||||
curMSI[thisSel] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
_, ok := curMSI[thisSel].(map[string]interface{})
|
||||
if !ok {
|
||||
_, ok = curMSI[thisSel].(Map)
|
||||
}
|
||||
|
||||
if (curMSI[thisSel] == nil || !ok) && len(indexes) == 0 && isSet {
|
||||
curMSI[thisSel] = map[string]interface{}{}
|
||||
}
|
||||
|
||||
current = curMSI[thisSel]
|
||||
default:
|
||||
current = nil
|
||||
}
|
||||
|
||||
// do we need to access the item of an array?
|
||||
if len(indexes) > 0 {
|
||||
num := len(indexes)
|
||||
for num > 0 {
|
||||
num--
|
||||
index := indexes[num]
|
||||
indexes = indexes[:num]
|
||||
if array, ok := interSlice(current); ok {
|
||||
if index < len(array) {
|
||||
current = array[index]
|
||||
} else {
|
||||
current = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if nextSel != "" {
|
||||
current = access(current, nextSel, value, isSet)
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func interSlice(slice interface{}) ([]interface{}, bool) {
|
||||
if array, ok := slice.([]interface{}); ok {
|
||||
return array, ok
|
||||
}
|
||||
|
||||
s := reflect.ValueOf(slice)
|
||||
if s.Kind() != reflect.Slice {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ret := make([]interface{}, s.Len())
|
||||
|
||||
for i := 0; i < s.Len(); i++ {
|
||||
ret[i] = s.Index(i).Interface()
|
||||
}
|
||||
|
||||
return ret, true
|
||||
}
|
||||
-280
@@ -1,280 +0,0 @@
|
||||
package objx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SignatureSeparator is the character that is used to
|
||||
// separate the Base64 string from the security signature.
|
||||
const SignatureSeparator = "_"
|
||||
|
||||
// URLValuesSliceKeySuffix is the character that is used to
|
||||
// specify a suffix for slices parsed by URLValues.
|
||||
// If the suffix is set to "[i]", then the index of the slice
|
||||
// is used in place of i
|
||||
// Ex: Suffix "[]" would have the form a[]=b&a[]=c
|
||||
// OR Suffix "[i]" would have the form a[0]=b&a[1]=c
|
||||
// OR Suffix "" would have the form a=b&a=c
|
||||
var urlValuesSliceKeySuffix = "[]"
|
||||
|
||||
const (
|
||||
URLValuesSliceKeySuffixEmpty = ""
|
||||
URLValuesSliceKeySuffixArray = "[]"
|
||||
URLValuesSliceKeySuffixIndex = "[i]"
|
||||
)
|
||||
|
||||
// SetURLValuesSliceKeySuffix sets the character that is used to
|
||||
// specify a suffix for slices parsed by URLValues.
|
||||
// If the suffix is set to "[i]", then the index of the slice
|
||||
// is used in place of i
|
||||
// Ex: Suffix "[]" would have the form a[]=b&a[]=c
|
||||
// OR Suffix "[i]" would have the form a[0]=b&a[1]=c
|
||||
// OR Suffix "" would have the form a=b&a=c
|
||||
func SetURLValuesSliceKeySuffix(s string) error {
|
||||
if s == URLValuesSliceKeySuffixEmpty || s == URLValuesSliceKeySuffixArray || s == URLValuesSliceKeySuffixIndex {
|
||||
urlValuesSliceKeySuffix = s
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("objx: Invalid URLValuesSliceKeySuffix provided.")
|
||||
}
|
||||
|
||||
// JSON converts the contained object to a JSON string
|
||||
// representation
|
||||
func (m Map) JSON() (string, error) {
|
||||
for k, v := range m {
|
||||
m[k] = cleanUp(v)
|
||||
}
|
||||
|
||||
result, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
err = errors.New("objx: JSON encode failed with: " + err.Error())
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
func cleanUpInterfaceArray(in []interface{}) []interface{} {
|
||||
result := make([]interface{}, len(in))
|
||||
for i, v := range in {
|
||||
result[i] = cleanUp(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func cleanUpInterfaceMap(in map[interface{}]interface{}) Map {
|
||||
result := Map{}
|
||||
for k, v := range in {
|
||||
result[fmt.Sprintf("%v", k)] = cleanUp(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func cleanUpStringMap(in map[string]interface{}) Map {
|
||||
result := Map{}
|
||||
for k, v := range in {
|
||||
result[k] = cleanUp(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func cleanUpMSIArray(in []map[string]interface{}) []Map {
|
||||
result := make([]Map, len(in))
|
||||
for i, v := range in {
|
||||
result[i] = cleanUpStringMap(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func cleanUpMapArray(in []Map) []Map {
|
||||
result := make([]Map, len(in))
|
||||
for i, v := range in {
|
||||
result[i] = cleanUpStringMap(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func cleanUp(v interface{}) interface{} {
|
||||
switch v := v.(type) {
|
||||
case []interface{}:
|
||||
return cleanUpInterfaceArray(v)
|
||||
case []map[string]interface{}:
|
||||
return cleanUpMSIArray(v)
|
||||
case map[interface{}]interface{}:
|
||||
return cleanUpInterfaceMap(v)
|
||||
case Map:
|
||||
return cleanUpStringMap(v)
|
||||
case []Map:
|
||||
return cleanUpMapArray(v)
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// MustJSON converts the contained object to a JSON string
|
||||
// representation and panics if there is an error
|
||||
func (m Map) MustJSON() string {
|
||||
result, err := m.JSON()
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Base64 converts the contained object to a Base64 string
|
||||
// representation of the JSON string representation
|
||||
func (m Map) Base64() (string, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
jsonData, err := m.JSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
|
||||
_, _ = encoder.Write([]byte(jsonData))
|
||||
_ = encoder.Close()
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// MustBase64 converts the contained object to a Base64 string
|
||||
// representation of the JSON string representation and panics
|
||||
// if there is an error
|
||||
func (m Map) MustBase64() string {
|
||||
result, err := m.Base64()
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SignedBase64 converts the contained object to a Base64 string
|
||||
// representation of the JSON string representation and signs it
|
||||
// using the provided key.
|
||||
func (m Map) SignedBase64(key string) (string, error) {
|
||||
base64, err := m.Base64()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sig := HashWithKey(base64, key)
|
||||
return base64 + SignatureSeparator + sig, nil
|
||||
}
|
||||
|
||||
// MustSignedBase64 converts the contained object to a Base64 string
|
||||
// representation of the JSON string representation and signs it
|
||||
// using the provided key and panics if there is an error
|
||||
func (m Map) MustSignedBase64(key string) string {
|
||||
result, err := m.SignedBase64(key)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/*
|
||||
URL Query
|
||||
------------------------------------------------
|
||||
*/
|
||||
|
||||
// URLValues creates a url.Values object from an Obj. This
|
||||
// function requires that the wrapped object be a map[string]interface{}
|
||||
func (m Map) URLValues() url.Values {
|
||||
vals := make(url.Values)
|
||||
|
||||
m.parseURLValues(m, vals, "")
|
||||
|
||||
return vals
|
||||
}
|
||||
|
||||
func (m Map) parseURLValues(queryMap Map, vals url.Values, key string) {
|
||||
useSliceIndex := false
|
||||
if urlValuesSliceKeySuffix == "[i]" {
|
||||
useSliceIndex = true
|
||||
}
|
||||
|
||||
for k, v := range queryMap {
|
||||
val := &Value{data: v}
|
||||
switch {
|
||||
case val.IsObjxMap():
|
||||
if key == "" {
|
||||
m.parseURLValues(val.ObjxMap(), vals, k)
|
||||
} else {
|
||||
m.parseURLValues(val.ObjxMap(), vals, key+"["+k+"]")
|
||||
}
|
||||
case val.IsObjxMapSlice():
|
||||
sliceKey := k
|
||||
if key != "" {
|
||||
sliceKey = key + "[" + k + "]"
|
||||
}
|
||||
|
||||
if useSliceIndex {
|
||||
for i, sv := range val.MustObjxMapSlice() {
|
||||
sk := sliceKey + "[" + strconv.FormatInt(int64(i), 10) + "]"
|
||||
m.parseURLValues(sv, vals, sk)
|
||||
}
|
||||
} else {
|
||||
sliceKey = sliceKey + urlValuesSliceKeySuffix
|
||||
for _, sv := range val.MustObjxMapSlice() {
|
||||
m.parseURLValues(sv, vals, sliceKey)
|
||||
}
|
||||
}
|
||||
case val.IsMSISlice():
|
||||
sliceKey := k
|
||||
if key != "" {
|
||||
sliceKey = key + "[" + k + "]"
|
||||
}
|
||||
|
||||
if useSliceIndex {
|
||||
for i, sv := range val.MustMSISlice() {
|
||||
sk := sliceKey + "[" + strconv.FormatInt(int64(i), 10) + "]"
|
||||
m.parseURLValues(New(sv), vals, sk)
|
||||
}
|
||||
} else {
|
||||
sliceKey = sliceKey + urlValuesSliceKeySuffix
|
||||
for _, sv := range val.MustMSISlice() {
|
||||
m.parseURLValues(New(sv), vals, sliceKey)
|
||||
}
|
||||
}
|
||||
case val.IsStrSlice(), val.IsBoolSlice(),
|
||||
val.IsFloat32Slice(), val.IsFloat64Slice(),
|
||||
val.IsIntSlice(), val.IsInt8Slice(), val.IsInt16Slice(), val.IsInt32Slice(), val.IsInt64Slice(),
|
||||
val.IsUintSlice(), val.IsUint8Slice(), val.IsUint16Slice(), val.IsUint32Slice(), val.IsUint64Slice():
|
||||
|
||||
sliceKey := k
|
||||
if key != "" {
|
||||
sliceKey = key + "[" + k + "]"
|
||||
}
|
||||
|
||||
if useSliceIndex {
|
||||
for i, sv := range val.StringSlice() {
|
||||
sk := sliceKey + "[" + strconv.FormatInt(int64(i), 10) + "]"
|
||||
vals.Set(sk, sv)
|
||||
}
|
||||
} else {
|
||||
sliceKey = sliceKey + urlValuesSliceKeySuffix
|
||||
vals[sliceKey] = val.StringSlice()
|
||||
}
|
||||
|
||||
default:
|
||||
if key == "" {
|
||||
vals.Set(k, val.String())
|
||||
} else {
|
||||
vals.Set(key+"["+k+"]", val.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// URLQuery gets an encoded URL query representing the given
|
||||
// Obj. This function requires that the wrapped object be a
|
||||
// map[string]interface{}
|
||||
func (m Map) URLQuery() (string, error) {
|
||||
return m.URLValues().Encode(), nil
|
||||
}
|
||||
-66
@@ -1,66 +0,0 @@
|
||||
/*
|
||||
Package objx provides utilities for dealing with maps, slices, JSON and other data.
|
||||
|
||||
# Overview
|
||||
|
||||
Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes
|
||||
a powerful `Get` method (among others) that allows you to easily and quickly get
|
||||
access to data within the map, without having to worry too much about type assertions,
|
||||
missing data, default values etc.
|
||||
|
||||
# Pattern
|
||||
|
||||
Objx uses a predictable pattern to make access data from within `map[string]interface{}` easy.
|
||||
Call one of the `objx.` functions to create your `objx.Map` to get going:
|
||||
|
||||
m, err := objx.FromJSON(json)
|
||||
|
||||
NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong,
|
||||
the rest will be optimistic and try to figure things out without panicking.
|
||||
|
||||
Use `Get` to access the value you're interested in. You can use dot and array
|
||||
notation too:
|
||||
|
||||
m.Get("places[0].latlng")
|
||||
|
||||
Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type.
|
||||
|
||||
if m.Get("code").IsStr() { // Your code... }
|
||||
|
||||
Or you can just assume the type, and use one of the strong type methods to extract the real value:
|
||||
|
||||
m.Get("code").Int()
|
||||
|
||||
If there's no value there (or if it's the wrong type) then a default value will be returned,
|
||||
or you can be explicit about the default value.
|
||||
|
||||
Get("code").Int(-1)
|
||||
|
||||
If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating,
|
||||
manipulating and selecting that data. You can find out more by exploring the index below.
|
||||
|
||||
# Reading data
|
||||
|
||||
A simple example of how to use Objx:
|
||||
|
||||
// Use MustFromJSON to make an objx.Map from some JSON
|
||||
m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`)
|
||||
|
||||
// Get the details
|
||||
name := m.Get("name").Str()
|
||||
age := m.Get("age").Int()
|
||||
|
||||
// Get their nickname (or use their name if they don't have one)
|
||||
nickname := m.Get("nickname").Str(name)
|
||||
|
||||
# Ranging
|
||||
|
||||
Since `objx.Map` is a `map[string]interface{}` you can treat it as such.
|
||||
For example, to `range` the data, do what you would expect:
|
||||
|
||||
m := objx.MustFromJSON(json)
|
||||
for key, value := range m {
|
||||
// Your code...
|
||||
}
|
||||
*/
|
||||
package objx
|
||||
-214
@@ -1,214 +0,0 @@
|
||||
package objx
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MSIConvertable is an interface that defines methods for converting your
|
||||
// custom types to a map[string]interface{} representation.
|
||||
type MSIConvertable interface {
|
||||
// MSI gets a map[string]interface{} (msi) representing the
|
||||
// object.
|
||||
MSI() map[string]interface{}
|
||||
}
|
||||
|
||||
// Map provides extended functionality for working with
|
||||
// untyped data, in particular map[string]interface (msi).
|
||||
type Map map[string]interface{}
|
||||
|
||||
// Value returns the internal value instance
|
||||
func (m Map) Value() *Value {
|
||||
return &Value{data: m}
|
||||
}
|
||||
|
||||
// Nil represents a nil Map.
|
||||
var Nil = New(nil)
|
||||
|
||||
// New creates a new Map containing the map[string]interface{} in the data argument.
|
||||
// If the data argument is not a map[string]interface, New attempts to call the
|
||||
// MSI() method on the MSIConvertable interface to create one.
|
||||
func New(data interface{}) Map {
|
||||
if _, ok := data.(map[string]interface{}); !ok {
|
||||
if converter, ok := data.(MSIConvertable); ok {
|
||||
data = converter.MSI()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return Map(data.(map[string]interface{}))
|
||||
}
|
||||
|
||||
// MSI creates a map[string]interface{} and puts it inside a new Map.
|
||||
//
|
||||
// The arguments follow a key, value pattern.
|
||||
//
|
||||
// Returns nil if any key argument is non-string or if there are an odd number of arguments.
|
||||
//
|
||||
// # Example
|
||||
//
|
||||
// To easily create Maps:
|
||||
//
|
||||
// m := objx.MSI("name", "Mat", "age", 29, "subobj", objx.MSI("active", true))
|
||||
//
|
||||
// // creates an Map equivalent to
|
||||
// m := objx.Map{"name": "Mat", "age": 29, "subobj": objx.Map{"active": true}}
|
||||
func MSI(keyAndValuePairs ...interface{}) Map {
|
||||
newMap := Map{}
|
||||
keyAndValuePairsLen := len(keyAndValuePairs)
|
||||
if keyAndValuePairsLen%2 != 0 {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i < keyAndValuePairsLen; i = i + 2 {
|
||||
key := keyAndValuePairs[i]
|
||||
value := keyAndValuePairs[i+1]
|
||||
|
||||
// make sure the key is a string
|
||||
keyString, keyStringOK := key.(string)
|
||||
if !keyStringOK {
|
||||
return nil
|
||||
}
|
||||
newMap[keyString] = value
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
// ****** Conversion Constructors
|
||||
|
||||
// MustFromJSON creates a new Map containing the data specified in the
|
||||
// jsonString.
|
||||
//
|
||||
// Panics if the JSON is invalid.
|
||||
func MustFromJSON(jsonString string) Map {
|
||||
o, err := FromJSON(jsonString)
|
||||
if err != nil {
|
||||
panic("objx: MustFromJSON failed with error: " + err.Error())
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// MustFromJSONSlice creates a new slice of Map containing the data specified in the
|
||||
// jsonString. Works with jsons with a top level array
|
||||
//
|
||||
// Panics if the JSON is invalid.
|
||||
func MustFromJSONSlice(jsonString string) []Map {
|
||||
slice, err := FromJSONSlice(jsonString)
|
||||
if err != nil {
|
||||
panic("objx: MustFromJSONSlice failed with error: " + err.Error())
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
// FromJSON creates a new Map containing the data specified in the
|
||||
// jsonString.
|
||||
//
|
||||
// Returns an error if the JSON is invalid.
|
||||
func FromJSON(jsonString string) (Map, error) {
|
||||
var m Map
|
||||
err := json.Unmarshal([]byte(jsonString), &m)
|
||||
if err != nil {
|
||||
return Nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// FromJSONSlice creates a new slice of Map containing the data specified in the
|
||||
// jsonString. Works with jsons with a top level array
|
||||
//
|
||||
// Returns an error if the JSON is invalid.
|
||||
func FromJSONSlice(jsonString string) ([]Map, error) {
|
||||
var slice []Map
|
||||
err := json.Unmarshal([]byte(jsonString), &slice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// FromBase64 creates a new Obj containing the data specified
|
||||
// in the Base64 string.
|
||||
//
|
||||
// The string is an encoded JSON string returned by Base64
|
||||
func FromBase64(base64String string) (Map, error) {
|
||||
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64String))
|
||||
decoded, err := ioutil.ReadAll(decoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return FromJSON(string(decoded))
|
||||
}
|
||||
|
||||
// MustFromBase64 creates a new Obj containing the data specified
|
||||
// in the Base64 string and panics if there is an error.
|
||||
//
|
||||
// The string is an encoded JSON string returned by Base64
|
||||
func MustFromBase64(base64String string) Map {
|
||||
result, err := FromBase64(base64String)
|
||||
if err != nil {
|
||||
panic("objx: MustFromBase64 failed with error: " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FromSignedBase64 creates a new Obj containing the data specified
|
||||
// in the Base64 string.
|
||||
//
|
||||
// The string is an encoded JSON string returned by SignedBase64
|
||||
func FromSignedBase64(base64String, key string) (Map, error) {
|
||||
parts := strings.Split(base64String, SignatureSeparator)
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("objx: Signed base64 string is malformed")
|
||||
}
|
||||
|
||||
sig := HashWithKey(parts[0], key)
|
||||
if parts[1] != sig {
|
||||
return nil, errors.New("objx: Signature for base64 data does not match")
|
||||
}
|
||||
return FromBase64(parts[0])
|
||||
}
|
||||
|
||||
// MustFromSignedBase64 creates a new Obj containing the data specified
|
||||
// in the Base64 string and panics if there is an error.
|
||||
//
|
||||
// The string is an encoded JSON string returned by Base64
|
||||
func MustFromSignedBase64(base64String, key string) Map {
|
||||
result, err := FromSignedBase64(base64String, key)
|
||||
if err != nil {
|
||||
panic("objx: MustFromSignedBase64 failed with error: " + err.Error())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FromURLQuery generates a new Obj by parsing the specified
|
||||
// query.
|
||||
//
|
||||
// For queries with multiple values, the first value is selected.
|
||||
func FromURLQuery(query string) (Map, error) {
|
||||
vals, err := url.ParseQuery(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := Map{}
|
||||
for k, vals := range vals {
|
||||
m[k] = vals[0]
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// MustFromURLQuery generates a new Obj by parsing the specified
|
||||
// query.
|
||||
//
|
||||
// For queries with multiple values, the first value is selected.
|
||||
//
|
||||
// Panics if it encounters an error
|
||||
func MustFromURLQuery(query string) Map {
|
||||
o, err := FromURLQuery(query)
|
||||
if err != nil {
|
||||
panic("objx: MustFromURLQuery failed with error: " + err.Error())
|
||||
}
|
||||
return o
|
||||
}
|
||||
-77
@@ -1,77 +0,0 @@
|
||||
package objx
|
||||
|
||||
// Exclude returns a new Map with the keys in the specified []string
|
||||
// excluded.
|
||||
func (m Map) Exclude(exclude []string) Map {
|
||||
excluded := make(Map)
|
||||
for k, v := range m {
|
||||
if !contains(exclude, k) {
|
||||
excluded[k] = v
|
||||
}
|
||||
}
|
||||
return excluded
|
||||
}
|
||||
|
||||
// Copy creates a shallow copy of the Obj.
|
||||
func (m Map) Copy() Map {
|
||||
copied := Map{}
|
||||
for k, v := range m {
|
||||
copied[k] = v
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
// Merge blends the specified map with a copy of this map and returns the result.
|
||||
//
|
||||
// Keys that appear in both will be selected from the specified map.
|
||||
// This method requires that the wrapped object be a map[string]interface{}
|
||||
func (m Map) Merge(merge Map) Map {
|
||||
return m.Copy().MergeHere(merge)
|
||||
}
|
||||
|
||||
// MergeHere blends the specified map with this map and returns the current map.
|
||||
//
|
||||
// Keys that appear in both will be selected from the specified map. The original map
|
||||
// will be modified. This method requires that
|
||||
// the wrapped object be a map[string]interface{}
|
||||
func (m Map) MergeHere(merge Map) Map {
|
||||
for k, v := range merge {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Transform builds a new Obj giving the transformer a chance
|
||||
// to change the keys and values as it goes. This method requires that
|
||||
// the wrapped object be a map[string]interface{}
|
||||
func (m Map) Transform(transformer func(key string, value interface{}) (string, interface{})) Map {
|
||||
newMap := Map{}
|
||||
for k, v := range m {
|
||||
modifiedKey, modifiedVal := transformer(k, v)
|
||||
newMap[modifiedKey] = modifiedVal
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
// TransformKeys builds a new map using the specified key mapping.
|
||||
//
|
||||
// Unspecified keys will be unaltered.
|
||||
// This method requires that the wrapped object be a map[string]interface{}
|
||||
func (m Map) TransformKeys(mapping map[string]string) Map {
|
||||
return m.Transform(func(key string, value interface{}) (string, interface{}) {
|
||||
if newKey, ok := mapping[key]; ok {
|
||||
return newKey, value
|
||||
}
|
||||
return key, value
|
||||
})
|
||||
}
|
||||
|
||||
// Checks if a string slice contains a string
|
||||
func contains(s []string, e string) bool {
|
||||
for _, a := range s {
|
||||
if a == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
-12
@@ -1,12 +0,0 @@
|
||||
package objx
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// HashWithKey hashes the specified string using the security key
|
||||
func HashWithKey(data, key string) string {
|
||||
d := sha1.Sum([]byte(data + ":" + key))
|
||||
return hex.EncodeToString(d[:])
|
||||
}
|
||||
-17
@@ -1,17 +0,0 @@
|
||||
package objx
|
||||
|
||||
// Has gets whether there is something at the specified selector
|
||||
// or not.
|
||||
//
|
||||
// If m is nil, Has will always return false.
|
||||
func (m Map) Has(selector string) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
return !m.Get(selector).IsNil()
|
||||
}
|
||||
|
||||
// IsNil gets whether the data is nil or not.
|
||||
func (v *Value) IsNil() bool {
|
||||
return v == nil || v.data == nil
|
||||
}
|
||||
-346
@@ -1,346 +0,0 @@
|
||||
package objx
|
||||
|
||||
/*
|
||||
MSI (map[string]interface{} and []map[string]interface{})
|
||||
*/
|
||||
|
||||
// MSI gets the value as a map[string]interface{}, returns the optionalDefault
|
||||
// value or a system default object if the value is the wrong type.
|
||||
func (v *Value) MSI(optionalDefault ...map[string]interface{}) map[string]interface{} {
|
||||
if s, ok := v.data.(map[string]interface{}); ok {
|
||||
return s
|
||||
}
|
||||
if s, ok := v.data.(Map); ok {
|
||||
return map[string]interface{}(s)
|
||||
}
|
||||
if len(optionalDefault) == 1 {
|
||||
return optionalDefault[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustMSI gets the value as a map[string]interface{}.
|
||||
//
|
||||
// Panics if the object is not a map[string]interface{}.
|
||||
func (v *Value) MustMSI() map[string]interface{} {
|
||||
if s, ok := v.data.(Map); ok {
|
||||
return map[string]interface{}(s)
|
||||
}
|
||||
return v.data.(map[string]interface{})
|
||||
}
|
||||
|
||||
// MSISlice gets the value as a []map[string]interface{}, returns the optionalDefault
|
||||
// value or nil if the value is not a []map[string]interface{}.
|
||||
func (v *Value) MSISlice(optionalDefault ...[]map[string]interface{}) []map[string]interface{} {
|
||||
if s, ok := v.data.([]map[string]interface{}); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
s := v.ObjxMapSlice()
|
||||
if s == nil {
|
||||
if len(optionalDefault) == 1 {
|
||||
return optionalDefault[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, len(s))
|
||||
for i := range s {
|
||||
result[i] = s[i].Value().MSI()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MustMSISlice gets the value as a []map[string]interface{}.
|
||||
//
|
||||
// Panics if the object is not a []map[string]interface{}.
|
||||
func (v *Value) MustMSISlice() []map[string]interface{} {
|
||||
if s := v.MSISlice(); s != nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return v.data.([]map[string]interface{})
|
||||
}
|
||||
|
||||
// IsMSI gets whether the object contained is a map[string]interface{} or not.
|
||||
func (v *Value) IsMSI() bool {
|
||||
_, ok := v.data.(map[string]interface{})
|
||||
if !ok {
|
||||
_, ok = v.data.(Map)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsMSISlice gets whether the object contained is a []map[string]interface{} or not.
|
||||
func (v *Value) IsMSISlice() bool {
|
||||
_, ok := v.data.([]map[string]interface{})
|
||||
if !ok {
|
||||
_, ok = v.data.([]Map)
|
||||
if !ok {
|
||||
s, ok := v.data.([]interface{})
|
||||
if ok {
|
||||
for i := range s {
|
||||
switch s[i].(type) {
|
||||
case Map:
|
||||
case map[string]interface{}:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// EachMSI calls the specified callback for each object
|
||||
// in the []map[string]interface{}.
|
||||
//
|
||||
// Panics if the object is the wrong type.
|
||||
func (v *Value) EachMSI(callback func(int, map[string]interface{}) bool) *Value {
|
||||
for index, val := range v.MustMSISlice() {
|
||||
carryon := callback(index, val)
|
||||
if !carryon {
|
||||
break
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// WhereMSI uses the specified decider function to select items
|
||||
// from the []map[string]interface{}. The object contained in the result will contain
|
||||
// only the selected items.
|
||||
func (v *Value) WhereMSI(decider func(int, map[string]interface{}) bool) *Value {
|
||||
var selected []map[string]interface{}
|
||||
v.EachMSI(func(index int, val map[string]interface{}) bool {
|
||||
shouldSelect := decider(index, val)
|
||||
if !shouldSelect {
|
||||
selected = append(selected, val)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return &Value{data: selected}
|
||||
}
|
||||
|
||||
// GroupMSI uses the specified grouper function to group the items
|
||||
// keyed by the return of the grouper. The object contained in the
|
||||
// result will contain a map[string][]map[string]interface{}.
|
||||
func (v *Value) GroupMSI(grouper func(int, map[string]interface{}) string) *Value {
|
||||
groups := make(map[string][]map[string]interface{})
|
||||
v.EachMSI(func(index int, val map[string]interface{}) bool {
|
||||
group := grouper(index, val)
|
||||
if _, ok := groups[group]; !ok {
|
||||
groups[group] = make([]map[string]interface{}, 0)
|
||||
}
|
||||
groups[group] = append(groups[group], val)
|
||||
return true
|
||||
})
|
||||
return &Value{data: groups}
|
||||
}
|
||||
|
||||
// ReplaceMSI uses the specified function to replace each map[string]interface{}s
|
||||
// by iterating each item. The data in the returned result will be a
|
||||
// []map[string]interface{} containing the replaced items.
|
||||
func (v *Value) ReplaceMSI(replacer func(int, map[string]interface{}) map[string]interface{}) *Value {
|
||||
arr := v.MustMSISlice()
|
||||
replaced := make([]map[string]interface{}, len(arr))
|
||||
v.EachMSI(func(index int, val map[string]interface{}) bool {
|
||||
replaced[index] = replacer(index, val)
|
||||
return true
|
||||
})
|
||||
return &Value{data: replaced}
|
||||
}
|
||||
|
||||
// CollectMSI uses the specified collector function to collect a value
|
||||
// for each of the map[string]interface{}s in the slice. The data returned will be a
|
||||
// []interface{}.
|
||||
func (v *Value) CollectMSI(collector func(int, map[string]interface{}) interface{}) *Value {
|
||||
arr := v.MustMSISlice()
|
||||
collected := make([]interface{}, len(arr))
|
||||
v.EachMSI(func(index int, val map[string]interface{}) bool {
|
||||
collected[index] = collector(index, val)
|
||||
return true
|
||||
})
|
||||
return &Value{data: collected}
|
||||
}
|
||||
|
||||
/*
|
||||
ObjxMap ((Map) and [](Map))
|
||||
*/
|
||||
|
||||
// ObjxMap gets the value as a (Map), returns the optionalDefault
|
||||
// value or a system default object if the value is the wrong type.
|
||||
func (v *Value) ObjxMap(optionalDefault ...(Map)) Map {
|
||||
if s, ok := v.data.((Map)); ok {
|
||||
return s
|
||||
}
|
||||
if s, ok := v.data.(map[string]interface{}); ok {
|
||||
return s
|
||||
}
|
||||
if len(optionalDefault) == 1 {
|
||||
return optionalDefault[0]
|
||||
}
|
||||
return New(nil)
|
||||
}
|
||||
|
||||
// MustObjxMap gets the value as a (Map).
|
||||
//
|
||||
// Panics if the object is not a (Map).
|
||||
func (v *Value) MustObjxMap() Map {
|
||||
if s, ok := v.data.(map[string]interface{}); ok {
|
||||
return s
|
||||
}
|
||||
return v.data.((Map))
|
||||
}
|
||||
|
||||
// ObjxMapSlice gets the value as a [](Map), returns the optionalDefault
|
||||
// value or nil if the value is not a [](Map).
|
||||
func (v *Value) ObjxMapSlice(optionalDefault ...[](Map)) [](Map) {
|
||||
if s, ok := v.data.([]Map); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
if s, ok := v.data.([]map[string]interface{}); ok {
|
||||
result := make([]Map, len(s))
|
||||
for i := range s {
|
||||
result[i] = s[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
s, ok := v.data.([]interface{})
|
||||
if !ok {
|
||||
if len(optionalDefault) == 1 {
|
||||
return optionalDefault[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]Map, len(s))
|
||||
for i := range s {
|
||||
switch s[i].(type) {
|
||||
case Map:
|
||||
result[i] = s[i].(Map)
|
||||
case map[string]interface{}:
|
||||
result[i] = New(s[i])
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MustObjxMapSlice gets the value as a [](Map).
|
||||
//
|
||||
// Panics if the object is not a [](Map).
|
||||
func (v *Value) MustObjxMapSlice() [](Map) {
|
||||
if s := v.ObjxMapSlice(); s != nil {
|
||||
return s
|
||||
}
|
||||
return v.data.([](Map))
|
||||
}
|
||||
|
||||
// IsObjxMap gets whether the object contained is a (Map) or not.
|
||||
func (v *Value) IsObjxMap() bool {
|
||||
_, ok := v.data.((Map))
|
||||
if !ok {
|
||||
_, ok = v.data.(map[string]interface{})
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsObjxMapSlice gets whether the object contained is a [](Map) or not.
|
||||
func (v *Value) IsObjxMapSlice() bool {
|
||||
_, ok := v.data.([](Map))
|
||||
if !ok {
|
||||
_, ok = v.data.([]map[string]interface{})
|
||||
if !ok {
|
||||
s, ok := v.data.([]interface{})
|
||||
if ok {
|
||||
for i := range s {
|
||||
switch s[i].(type) {
|
||||
case Map:
|
||||
case map[string]interface{}:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// EachObjxMap calls the specified callback for each object
|
||||
// in the [](Map).
|
||||
//
|
||||
// Panics if the object is the wrong type.
|
||||
func (v *Value) EachObjxMap(callback func(int, Map) bool) *Value {
|
||||
for index, val := range v.MustObjxMapSlice() {
|
||||
carryon := callback(index, val)
|
||||
if !carryon {
|
||||
break
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// WhereObjxMap uses the specified decider function to select items
|
||||
// from the [](Map). The object contained in the result will contain
|
||||
// only the selected items.
|
||||
func (v *Value) WhereObjxMap(decider func(int, Map) bool) *Value {
|
||||
var selected [](Map)
|
||||
v.EachObjxMap(func(index int, val Map) bool {
|
||||
shouldSelect := decider(index, val)
|
||||
if !shouldSelect {
|
||||
selected = append(selected, val)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return &Value{data: selected}
|
||||
}
|
||||
|
||||
// GroupObjxMap uses the specified grouper function to group the items
|
||||
// keyed by the return of the grouper. The object contained in the
|
||||
// result will contain a map[string][](Map).
|
||||
func (v *Value) GroupObjxMap(grouper func(int, Map) string) *Value {
|
||||
groups := make(map[string][](Map))
|
||||
v.EachObjxMap(func(index int, val Map) bool {
|
||||
group := grouper(index, val)
|
||||
if _, ok := groups[group]; !ok {
|
||||
groups[group] = make([](Map), 0)
|
||||
}
|
||||
groups[group] = append(groups[group], val)
|
||||
return true
|
||||
})
|
||||
return &Value{data: groups}
|
||||
}
|
||||
|
||||
// ReplaceObjxMap uses the specified function to replace each (Map)s
|
||||
// by iterating each item. The data in the returned result will be a
|
||||
// [](Map) containing the replaced items.
|
||||
func (v *Value) ReplaceObjxMap(replacer func(int, Map) Map) *Value {
|
||||
arr := v.MustObjxMapSlice()
|
||||
replaced := make([](Map), len(arr))
|
||||
v.EachObjxMap(func(index int, val Map) bool {
|
||||
replaced[index] = replacer(index, val)
|
||||
return true
|
||||
})
|
||||
return &Value{data: replaced}
|
||||
}
|
||||
|
||||
// CollectObjxMap uses the specified collector function to collect a value
|
||||
// for each of the (Map)s in the slice. The data returned will be a
|
||||
// []interface{}.
|
||||
func (v *Value) CollectObjxMap(collector func(int, Map) interface{}) *Value {
|
||||
arr := v.MustObjxMapSlice()
|
||||
collected := make([]interface{}, len(arr))
|
||||
v.EachObjxMap(func(index int, val Map) bool {
|
||||
collected[index] = collector(index, val)
|
||||
return true
|
||||
})
|
||||
return &Value{data: collected}
|
||||
}
|
||||
-2261
File diff suppressed because it is too large
Load Diff
-159
@@ -1,159 +0,0 @@
|
||||
package objx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Value provides methods for extracting interface{} data in various
|
||||
// types.
|
||||
type Value struct {
|
||||
// data contains the raw data being managed by this Value
|
||||
data interface{}
|
||||
}
|
||||
|
||||
// Data returns the raw data contained by this Value
|
||||
func (v *Value) Data() interface{} {
|
||||
return v.data
|
||||
}
|
||||
|
||||
// String returns the value always as a string
|
||||
func (v *Value) String() string {
|
||||
switch {
|
||||
case v.IsNil():
|
||||
return ""
|
||||
case v.IsStr():
|
||||
return v.Str()
|
||||
case v.IsBool():
|
||||
return strconv.FormatBool(v.Bool())
|
||||
case v.IsFloat32():
|
||||
return strconv.FormatFloat(float64(v.Float32()), 'f', -1, 32)
|
||||
case v.IsFloat64():
|
||||
return strconv.FormatFloat(v.Float64(), 'f', -1, 64)
|
||||
case v.IsInt():
|
||||
return strconv.FormatInt(int64(v.Int()), 10)
|
||||
case v.IsInt8():
|
||||
return strconv.FormatInt(int64(v.Int8()), 10)
|
||||
case v.IsInt16():
|
||||
return strconv.FormatInt(int64(v.Int16()), 10)
|
||||
case v.IsInt32():
|
||||
return strconv.FormatInt(int64(v.Int32()), 10)
|
||||
case v.IsInt64():
|
||||
return strconv.FormatInt(v.Int64(), 10)
|
||||
case v.IsUint():
|
||||
return strconv.FormatUint(uint64(v.Uint()), 10)
|
||||
case v.IsUint8():
|
||||
return strconv.FormatUint(uint64(v.Uint8()), 10)
|
||||
case v.IsUint16():
|
||||
return strconv.FormatUint(uint64(v.Uint16()), 10)
|
||||
case v.IsUint32():
|
||||
return strconv.FormatUint(uint64(v.Uint32()), 10)
|
||||
case v.IsUint64():
|
||||
return strconv.FormatUint(v.Uint64(), 10)
|
||||
}
|
||||
return fmt.Sprintf("%#v", v.Data())
|
||||
}
|
||||
|
||||
// StringSlice returns the value always as a []string
|
||||
func (v *Value) StringSlice(optionalDefault ...[]string) []string {
|
||||
switch {
|
||||
case v.IsStrSlice():
|
||||
return v.MustStrSlice()
|
||||
case v.IsBoolSlice():
|
||||
slice := v.MustBoolSlice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatBool(iv)
|
||||
}
|
||||
return vals
|
||||
case v.IsFloat32Slice():
|
||||
slice := v.MustFloat32Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatFloat(float64(iv), 'f', -1, 32)
|
||||
}
|
||||
return vals
|
||||
case v.IsFloat64Slice():
|
||||
slice := v.MustFloat64Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatFloat(iv, 'f', -1, 64)
|
||||
}
|
||||
return vals
|
||||
case v.IsIntSlice():
|
||||
slice := v.MustIntSlice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatInt(int64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsInt8Slice():
|
||||
slice := v.MustInt8Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatInt(int64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsInt16Slice():
|
||||
slice := v.MustInt16Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatInt(int64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsInt32Slice():
|
||||
slice := v.MustInt32Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatInt(int64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsInt64Slice():
|
||||
slice := v.MustInt64Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatInt(iv, 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsUintSlice():
|
||||
slice := v.MustUintSlice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatUint(uint64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsUint8Slice():
|
||||
slice := v.MustUint8Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatUint(uint64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsUint16Slice():
|
||||
slice := v.MustUint16Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatUint(uint64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsUint32Slice():
|
||||
slice := v.MustUint32Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatUint(uint64(iv), 10)
|
||||
}
|
||||
return vals
|
||||
case v.IsUint64Slice():
|
||||
slice := v.MustUint64Slice()
|
||||
vals := make([]string, len(slice))
|
||||
for i, iv := range slice {
|
||||
vals[i] = strconv.FormatUint(iv, 10)
|
||||
}
|
||||
return vals
|
||||
}
|
||||
if len(optionalDefault) == 1 {
|
||||
return optionalDefault[0]
|
||||
}
|
||||
|
||||
return []string{}
|
||||
}
|
||||
-21
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
-480
@@ -1,480 +0,0 @@
|
||||
package assert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CompareType int
|
||||
|
||||
const (
|
||||
compareLess CompareType = iota - 1
|
||||
compareEqual
|
||||
compareGreater
|
||||
)
|
||||
|
||||
var (
|
||||
intType = reflect.TypeOf(int(1))
|
||||
int8Type = reflect.TypeOf(int8(1))
|
||||
int16Type = reflect.TypeOf(int16(1))
|
||||
int32Type = reflect.TypeOf(int32(1))
|
||||
int64Type = reflect.TypeOf(int64(1))
|
||||
|
||||
uintType = reflect.TypeOf(uint(1))
|
||||
uint8Type = reflect.TypeOf(uint8(1))
|
||||
uint16Type = reflect.TypeOf(uint16(1))
|
||||
uint32Type = reflect.TypeOf(uint32(1))
|
||||
uint64Type = reflect.TypeOf(uint64(1))
|
||||
|
||||
uintptrType = reflect.TypeOf(uintptr(1))
|
||||
|
||||
float32Type = reflect.TypeOf(float32(1))
|
||||
float64Type = reflect.TypeOf(float64(1))
|
||||
|
||||
stringType = reflect.TypeOf("")
|
||||
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
bytesType = reflect.TypeOf([]byte{})
|
||||
)
|
||||
|
||||
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
|
||||
obj1Value := reflect.ValueOf(obj1)
|
||||
obj2Value := reflect.ValueOf(obj2)
|
||||
|
||||
// throughout this switch we try and avoid calling .Convert() if possible,
|
||||
// as this has a pretty big performance impact
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
{
|
||||
intobj1, ok := obj1.(int)
|
||||
if !ok {
|
||||
intobj1 = obj1Value.Convert(intType).Interface().(int)
|
||||
}
|
||||
intobj2, ok := obj2.(int)
|
||||
if !ok {
|
||||
intobj2 = obj2Value.Convert(intType).Interface().(int)
|
||||
}
|
||||
if intobj1 > intobj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if intobj1 == intobj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if intobj1 < intobj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int8:
|
||||
{
|
||||
int8obj1, ok := obj1.(int8)
|
||||
if !ok {
|
||||
int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
|
||||
}
|
||||
int8obj2, ok := obj2.(int8)
|
||||
if !ok {
|
||||
int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
|
||||
}
|
||||
if int8obj1 > int8obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int8obj1 == int8obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int8obj1 < int8obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int16:
|
||||
{
|
||||
int16obj1, ok := obj1.(int16)
|
||||
if !ok {
|
||||
int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
|
||||
}
|
||||
int16obj2, ok := obj2.(int16)
|
||||
if !ok {
|
||||
int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
|
||||
}
|
||||
if int16obj1 > int16obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int16obj1 == int16obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int16obj1 < int16obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int32:
|
||||
{
|
||||
int32obj1, ok := obj1.(int32)
|
||||
if !ok {
|
||||
int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
|
||||
}
|
||||
int32obj2, ok := obj2.(int32)
|
||||
if !ok {
|
||||
int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
|
||||
}
|
||||
if int32obj1 > int32obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int32obj1 == int32obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int32obj1 < int32obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int64:
|
||||
{
|
||||
int64obj1, ok := obj1.(int64)
|
||||
if !ok {
|
||||
int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
|
||||
}
|
||||
int64obj2, ok := obj2.(int64)
|
||||
if !ok {
|
||||
int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
|
||||
}
|
||||
if int64obj1 > int64obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int64obj1 == int64obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int64obj1 < int64obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint:
|
||||
{
|
||||
uintobj1, ok := obj1.(uint)
|
||||
if !ok {
|
||||
uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
|
||||
}
|
||||
uintobj2, ok := obj2.(uint)
|
||||
if !ok {
|
||||
uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
|
||||
}
|
||||
if uintobj1 > uintobj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uintobj1 == uintobj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uintobj1 < uintobj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint8:
|
||||
{
|
||||
uint8obj1, ok := obj1.(uint8)
|
||||
if !ok {
|
||||
uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
|
||||
}
|
||||
uint8obj2, ok := obj2.(uint8)
|
||||
if !ok {
|
||||
uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
|
||||
}
|
||||
if uint8obj1 > uint8obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint8obj1 == uint8obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint8obj1 < uint8obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint16:
|
||||
{
|
||||
uint16obj1, ok := obj1.(uint16)
|
||||
if !ok {
|
||||
uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
|
||||
}
|
||||
uint16obj2, ok := obj2.(uint16)
|
||||
if !ok {
|
||||
uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
|
||||
}
|
||||
if uint16obj1 > uint16obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint16obj1 == uint16obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint16obj1 < uint16obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint32:
|
||||
{
|
||||
uint32obj1, ok := obj1.(uint32)
|
||||
if !ok {
|
||||
uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
|
||||
}
|
||||
uint32obj2, ok := obj2.(uint32)
|
||||
if !ok {
|
||||
uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
|
||||
}
|
||||
if uint32obj1 > uint32obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint32obj1 == uint32obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint32obj1 < uint32obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint64:
|
||||
{
|
||||
uint64obj1, ok := obj1.(uint64)
|
||||
if !ok {
|
||||
uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
|
||||
}
|
||||
uint64obj2, ok := obj2.(uint64)
|
||||
if !ok {
|
||||
uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
|
||||
}
|
||||
if uint64obj1 > uint64obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint64obj1 == uint64obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint64obj1 < uint64obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Float32:
|
||||
{
|
||||
float32obj1, ok := obj1.(float32)
|
||||
if !ok {
|
||||
float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
|
||||
}
|
||||
float32obj2, ok := obj2.(float32)
|
||||
if !ok {
|
||||
float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
|
||||
}
|
||||
if float32obj1 > float32obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if float32obj1 == float32obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if float32obj1 < float32obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Float64:
|
||||
{
|
||||
float64obj1, ok := obj1.(float64)
|
||||
if !ok {
|
||||
float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
|
||||
}
|
||||
float64obj2, ok := obj2.(float64)
|
||||
if !ok {
|
||||
float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
|
||||
}
|
||||
if float64obj1 > float64obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if float64obj1 == float64obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if float64obj1 < float64obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
{
|
||||
stringobj1, ok := obj1.(string)
|
||||
if !ok {
|
||||
stringobj1 = obj1Value.Convert(stringType).Interface().(string)
|
||||
}
|
||||
stringobj2, ok := obj2.(string)
|
||||
if !ok {
|
||||
stringobj2 = obj2Value.Convert(stringType).Interface().(string)
|
||||
}
|
||||
if stringobj1 > stringobj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if stringobj1 == stringobj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if stringobj1 < stringobj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
// Check for known struct types we can check for compare results.
|
||||
case reflect.Struct:
|
||||
{
|
||||
// All structs enter here. We're not interested in most types.
|
||||
if !obj1Value.CanConvert(timeType) {
|
||||
break
|
||||
}
|
||||
|
||||
// time.Time can be compared!
|
||||
timeObj1, ok := obj1.(time.Time)
|
||||
if !ok {
|
||||
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
|
||||
}
|
||||
|
||||
timeObj2, ok := obj2.(time.Time)
|
||||
if !ok {
|
||||
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
|
||||
}
|
||||
|
||||
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
|
||||
}
|
||||
case reflect.Slice:
|
||||
{
|
||||
// We only care about the []byte type.
|
||||
if !obj1Value.CanConvert(bytesType) {
|
||||
break
|
||||
}
|
||||
|
||||
// []byte can be compared!
|
||||
bytesObj1, ok := obj1.([]byte)
|
||||
if !ok {
|
||||
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
|
||||
|
||||
}
|
||||
bytesObj2, ok := obj2.([]byte)
|
||||
if !ok {
|
||||
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
|
||||
}
|
||||
|
||||
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
|
||||
}
|
||||
case reflect.Uintptr:
|
||||
{
|
||||
uintptrObj1, ok := obj1.(uintptr)
|
||||
if !ok {
|
||||
uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
|
||||
}
|
||||
uintptrObj2, ok := obj2.(uintptr)
|
||||
if !ok {
|
||||
uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
|
||||
}
|
||||
if uintptrObj1 > uintptrObj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uintptrObj1 == uintptrObj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uintptrObj1 < uintptrObj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return compareEqual, false
|
||||
}
|
||||
|
||||
// Greater asserts that the first element is greater than the second
|
||||
//
|
||||
// assert.Greater(t, 2, 1)
|
||||
// assert.Greater(t, float64(2), float64(1))
|
||||
// assert.Greater(t, "b", "a")
|
||||
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// GreaterOrEqual asserts that the first element is greater than or equal to the second
|
||||
//
|
||||
// assert.GreaterOrEqual(t, 2, 1)
|
||||
// assert.GreaterOrEqual(t, 2, 2)
|
||||
// assert.GreaterOrEqual(t, "b", "a")
|
||||
// assert.GreaterOrEqual(t, "b", "b")
|
||||
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Less asserts that the first element is less than the second
|
||||
//
|
||||
// assert.Less(t, 1, 2)
|
||||
// assert.Less(t, float64(1), float64(2))
|
||||
// assert.Less(t, "a", "b")
|
||||
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// LessOrEqual asserts that the first element is less than or equal to the second
|
||||
//
|
||||
// assert.LessOrEqual(t, 1, 2)
|
||||
// assert.LessOrEqual(t, 2, 2)
|
||||
// assert.LessOrEqual(t, "a", "b")
|
||||
// assert.LessOrEqual(t, "b", "b")
|
||||
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Positive asserts that the specified element is positive
|
||||
//
|
||||
// assert.Positive(t, 1)
|
||||
// assert.Positive(t, 1.23)
|
||||
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
zero := reflect.Zero(reflect.TypeOf(e))
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Negative asserts that the specified element is negative
|
||||
//
|
||||
// assert.Negative(t, -1)
|
||||
// assert.Negative(t, -1.23)
|
||||
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
zero := reflect.Zero(reflect.TypeOf(e))
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
|
||||
}
|
||||
|
||||
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
e1Kind := reflect.ValueOf(e1).Kind()
|
||||
e2Kind := reflect.ValueOf(e2).Kind()
|
||||
if e1Kind != e2Kind {
|
||||
return Fail(t, "Elements should be the same type", msgAndArgs...)
|
||||
}
|
||||
|
||||
compareResult, isComparable := compare(e1, e2, e1Kind)
|
||||
if !isComparable {
|
||||
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
|
||||
}
|
||||
|
||||
if !containsValue(allowedComparesResults, compareResult) {
|
||||
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func containsValue(values []CompareType, value CompareType) bool {
|
||||
for _, v := range values {
|
||||
if v == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
-815
@@ -1,815 +0,0 @@
|
||||
// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
http "net/http"
|
||||
url "net/url"
|
||||
time "time"
|
||||
)
|
||||
|
||||
// Conditionf uses a Comparison to assert a complex condition.
|
||||
func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Condition(t, comp, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Containsf asserts that the specified string, list(array, slice...) or map contains the
|
||||
// specified substring or element.
|
||||
//
|
||||
// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted")
|
||||
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
|
||||
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
|
||||
func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Contains(t, s, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// DirExistsf checks whether a directory exists in the given path. It also fails
|
||||
// if the path is a file rather a directory or there is an error checking whether it exists.
|
||||
func DirExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return DirExists(t, path, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified
|
||||
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
|
||||
// the number of appearances of each of them in both lists should match.
|
||||
//
|
||||
// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
|
||||
func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either
|
||||
// a slice or a channel with len == 0.
|
||||
//
|
||||
// assert.Emptyf(t, obj, "error message %s", "formatted")
|
||||
func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Empty(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Equalf asserts that two objects are equal.
|
||||
//
|
||||
// assert.Equalf(t, 123, 123, "error message %s", "formatted")
|
||||
//
|
||||
// Pointer variable equality is determined based on the equality of the
|
||||
// referenced values (as opposed to the memory addresses). Function equality
|
||||
// cannot be determined and will always fail.
|
||||
func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Equal(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// EqualErrorf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that it is equal to the provided error.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
|
||||
func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// EqualExportedValuesf asserts that the types of two objects are equal and their public
|
||||
// fields are also equal. This is useful for comparing structs that have private fields
|
||||
// that could potentially differ.
|
||||
//
|
||||
// type S struct {
|
||||
// Exported int
|
||||
// notExported int
|
||||
// }
|
||||
// assert.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true
|
||||
// assert.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false
|
||||
func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// EqualValuesf asserts that two objects are equal or convertible to the same types
|
||||
// and equal.
|
||||
//
|
||||
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
|
||||
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Errorf asserts that a function returned an error (i.e. not `nil`).
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// if assert.Errorf(t, err, "error message %s", "formatted") {
|
||||
// assert.Equal(t, expectedErrorf, err)
|
||||
// }
|
||||
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Error(t, err, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ErrorAsf asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value.
|
||||
// This is a wrapper for errors.As.
|
||||
func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
|
||||
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
|
||||
// This is a wrapper for errors.Is.
|
||||
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Eventuallyf asserts that given condition will be met in waitFor time,
|
||||
// periodically checking target function each tick.
|
||||
//
|
||||
// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
|
||||
func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// EventuallyWithTf asserts that given condition will be met in waitFor time,
|
||||
// periodically checking target function each tick. In contrast to Eventually,
|
||||
// it supplies a CollectT to the condition function, so that the condition
|
||||
// function can use the CollectT to call other assertions.
|
||||
// The condition is considered "met" if no errors are raised in a tick.
|
||||
// The supplied CollectT collects all errors from one tick (if there are any).
|
||||
// If the condition is not met before waitFor, the collected errors of
|
||||
// the last tick are copied to t.
|
||||
//
|
||||
// externalValue := false
|
||||
// go func() {
|
||||
// time.Sleep(8*time.Second)
|
||||
// externalValue = true
|
||||
// }()
|
||||
// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") {
|
||||
// // add assertions as needed; any assertion failure will fail the current tick
|
||||
// assert.True(c, externalValue, "expected 'externalValue' to be true")
|
||||
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
|
||||
func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return EventuallyWithT(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Exactlyf asserts that two objects are equal in value and type.
|
||||
//
|
||||
// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted")
|
||||
func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Exactly(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Failf reports a failure through
|
||||
func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Fail(t, failureMessage, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// FailNowf fails test
|
||||
func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Falsef asserts that the specified value is false.
|
||||
//
|
||||
// assert.Falsef(t, myBool, "error message %s", "formatted")
|
||||
func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return False(t, value, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// FileExistsf checks whether a file exists in the given path. It also fails if
|
||||
// the path points to a directory or there is an error when trying to check the file.
|
||||
func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return FileExists(t, path, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Greaterf asserts that the first element is greater than the second
|
||||
//
|
||||
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
|
||||
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
|
||||
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
|
||||
func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Greater(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// GreaterOrEqualf asserts that the first element is greater than or equal to the second
|
||||
//
|
||||
// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted")
|
||||
// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted")
|
||||
// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted")
|
||||
// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted")
|
||||
func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return GreaterOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// HTTPBodyContainsf asserts that a specified handler returns a
|
||||
// body that contains a string.
|
||||
//
|
||||
// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return HTTPBodyContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// HTTPBodyNotContainsf asserts that a specified handler returns a
|
||||
// body that does not contain a string.
|
||||
//
|
||||
// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return HTTPBodyNotContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// HTTPErrorf asserts that a specified handler returns an error status code.
|
||||
//
|
||||
// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return HTTPError(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// HTTPRedirectf asserts that a specified handler returns a redirect status code.
|
||||
//
|
||||
// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return HTTPRedirect(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// HTTPStatusCodef asserts that a specified handler returns a specified status code.
|
||||
//
|
||||
// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return HTTPStatusCode(t, handler, method, url, values, statuscode, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// HTTPSuccessf asserts that a specified handler returns a success status code.
|
||||
//
|
||||
// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return HTTPSuccess(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Implementsf asserts that an object is implemented by the specified interface.
|
||||
//
|
||||
// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
|
||||
func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Implements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// InDeltaf asserts that the two numerals are within delta of each other.
|
||||
//
|
||||
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
|
||||
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
|
||||
func InDeltaMapValuesf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return InDeltaMapValues(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// InDeltaSlicef is the same as InDelta, except it compares two slices.
|
||||
func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return InDeltaSlice(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// InEpsilonf asserts that expected and actual have a relative error less than epsilon
|
||||
func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return InEpsilon(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices.
|
||||
func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return InEpsilonSlice(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// IsDecreasingf asserts that the collection is decreasing
|
||||
//
|
||||
// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted")
|
||||
// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted")
|
||||
// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
|
||||
func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return IsDecreasing(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// IsIncreasingf asserts that the collection is increasing
|
||||
//
|
||||
// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted")
|
||||
// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted")
|
||||
// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
|
||||
func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return IsIncreasing(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// IsNonDecreasingf asserts that the collection is not decreasing
|
||||
//
|
||||
// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted")
|
||||
// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted")
|
||||
// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
|
||||
func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return IsNonDecreasing(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// IsNonIncreasingf asserts that the collection is not increasing
|
||||
//
|
||||
// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted")
|
||||
// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted")
|
||||
// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
|
||||
func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// IsTypef asserts that the specified objects are of the same type.
|
||||
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// JSONEqf asserts that two JSON strings are equivalent.
|
||||
//
|
||||
// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted")
|
||||
func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return JSONEq(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Lenf asserts that the specified object has specific length.
|
||||
// Lenf also fails if the object has a type that len() not accept.
|
||||
//
|
||||
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
|
||||
func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Len(t, object, length, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Lessf asserts that the first element is less than the second
|
||||
//
|
||||
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
|
||||
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
|
||||
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
|
||||
func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Less(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// LessOrEqualf asserts that the first element is less than or equal to the second
|
||||
//
|
||||
// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted")
|
||||
// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted")
|
||||
// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted")
|
||||
// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted")
|
||||
func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Negativef asserts that the specified element is negative
|
||||
//
|
||||
// assert.Negativef(t, -1, "error message %s", "formatted")
|
||||
// assert.Negativef(t, -1.23, "error message %s", "formatted")
|
||||
func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Negative(t, e, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Neverf asserts that the given condition doesn't satisfy in waitFor time,
|
||||
// periodically checking the target function each tick.
|
||||
//
|
||||
// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
|
||||
func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Never(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Nilf asserts that the specified object is nil.
|
||||
//
|
||||
// assert.Nilf(t, err, "error message %s", "formatted")
|
||||
func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Nil(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NoDirExistsf checks whether a directory does not exist in the given path.
|
||||
// It fails if the path points to an existing _directory_ only.
|
||||
func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NoDirExists(t, path, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NoErrorf asserts that a function returned no error (i.e. `nil`).
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// if assert.NoErrorf(t, err, "error message %s", "formatted") {
|
||||
// assert.Equal(t, expectedObj, actualObj)
|
||||
// }
|
||||
func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NoError(t, err, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NoFileExistsf checks whether a file does not exist in a given path. It fails
|
||||
// if the path points to an existing _file_ only.
|
||||
func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NoFileExists(t, path, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
|
||||
// specified substring or element.
|
||||
//
|
||||
// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted")
|
||||
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
|
||||
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
|
||||
func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
|
||||
// a slice or a channel with len == 0.
|
||||
//
|
||||
// if assert.NotEmptyf(t, obj, "error message %s", "formatted") {
|
||||
// assert.Equal(t, "two", obj[1])
|
||||
// }
|
||||
func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotEmpty(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotEqualf asserts that the specified values are NOT equal.
|
||||
//
|
||||
// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted")
|
||||
//
|
||||
// Pointer variable equality is determined based on the equality of the
|
||||
// referenced values (as opposed to the memory addresses).
|
||||
func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
|
||||
//
|
||||
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
|
||||
func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotErrorIsf asserts that at none of the errors in err's chain matches target.
|
||||
// This is a wrapper for errors.Is.
|
||||
func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotImplementsf asserts that an object does not implement the specified interface.
|
||||
//
|
||||
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
|
||||
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotImplements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotNilf asserts that the specified object is not nil.
|
||||
//
|
||||
// assert.NotNilf(t, err, "error message %s", "formatted")
|
||||
func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotNil(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic.
|
||||
//
|
||||
// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted")
|
||||
func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotPanics(t, f, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotRegexpf asserts that a specified regexp does not match a string.
|
||||
//
|
||||
// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted")
|
||||
// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted")
|
||||
func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotRegexp(t, rx, str, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotSamef asserts that two pointers do not reference the same object.
|
||||
//
|
||||
// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted")
|
||||
//
|
||||
// Both arguments must be pointer variables. Pointer variable sameness is
|
||||
// determined based on the equality of both type and value.
|
||||
func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotSubsetf asserts that the specified list(array, slice...) or map does NOT
|
||||
// contain all elements given in the specified subset list(array, slice...) or
|
||||
// map.
|
||||
//
|
||||
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
|
||||
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
|
||||
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotSubset(t, list, subset, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotZerof asserts that i is not the zero value for its type.
|
||||
func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotZero(t, i, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Panicsf asserts that the code inside the specified PanicTestFunc panics.
|
||||
//
|
||||
// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted")
|
||||
func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Panics(t, f, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// PanicsWithErrorf asserts that the code inside the specified PanicTestFunc
|
||||
// panics, and that the recovered panic value is an error that satisfies the
|
||||
// EqualError comparison.
|
||||
//
|
||||
// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
|
||||
func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return PanicsWithError(t, errString, f, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that
|
||||
// the recovered panic value equals the expected panic value.
|
||||
//
|
||||
// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
|
||||
func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return PanicsWithValue(t, expected, f, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Positivef asserts that the specified element is positive
|
||||
//
|
||||
// assert.Positivef(t, 1, "error message %s", "formatted")
|
||||
// assert.Positivef(t, 1.23, "error message %s", "formatted")
|
||||
func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Positive(t, e, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Regexpf asserts that a specified regexp matches a string.
|
||||
//
|
||||
// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted")
|
||||
// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted")
|
||||
func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Regexp(t, rx, str, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Samef asserts that two pointers reference the same object.
|
||||
//
|
||||
// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted")
|
||||
//
|
||||
// Both arguments must be pointer variables. Pointer variable sameness is
|
||||
// determined based on the equality of both type and value.
|
||||
func Samef(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Subsetf asserts that the specified list(array, slice...) or map contains all
|
||||
// elements given in the specified subset list(array, slice...) or map.
|
||||
//
|
||||
// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
|
||||
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
|
||||
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Subset(t, list, subset, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Truef asserts that the specified value is true.
|
||||
//
|
||||
// assert.Truef(t, myBool, "error message %s", "formatted")
|
||||
func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return True(t, value, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// WithinDurationf asserts that the two times are within duration delta of each other.
|
||||
//
|
||||
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
|
||||
func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// WithinRangef asserts that a time is within a time range (inclusive).
|
||||
//
|
||||
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
|
||||
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return WithinRange(t, actual, start, end, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// YAMLEqf asserts that two YAML strings are equivalent.
|
||||
func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return YAMLEq(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Zerof asserts that i is the zero value for its type.
|
||||
func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Zero(t, i, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
-5
@@ -1,5 +0,0 @@
|
||||
{{.CommentFormat}}
|
||||
func {{.DocInfo.Name}}f(t TestingT, {{.ParamsFormat}}) bool {
|
||||
if h, ok := t.(tHelper); ok { h.Helper() }
|
||||
return {{.DocInfo.Name}}(t, {{.ForwardedParamsFormat}})
|
||||
}
|
||||
-1621
File diff suppressed because it is too large
Load Diff
-5
@@ -1,5 +0,0 @@
|
||||
{{.CommentWithoutT "a"}}
|
||||
func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) bool {
|
||||
if h, ok := a.t.(tHelper); ok { h.Helper() }
|
||||
return {{.DocInfo.Name}}(a.t, {{.ForwardedParams}})
|
||||
}
|
||||
-81
@@ -1,81 +0,0 @@
|
||||
package assert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// isOrdered checks that collection contains orderable elements.
|
||||
func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
|
||||
objKind := reflect.TypeOf(object).Kind()
|
||||
if objKind != reflect.Slice && objKind != reflect.Array {
|
||||
return false
|
||||
}
|
||||
|
||||
objValue := reflect.ValueOf(object)
|
||||
objLen := objValue.Len()
|
||||
|
||||
if objLen <= 1 {
|
||||
return true
|
||||
}
|
||||
|
||||
value := objValue.Index(0)
|
||||
valueInterface := value.Interface()
|
||||
firstValueKind := value.Kind()
|
||||
|
||||
for i := 1; i < objLen; i++ {
|
||||
prevValue := value
|
||||
prevValueInterface := valueInterface
|
||||
|
||||
value = objValue.Index(i)
|
||||
valueInterface = value.Interface()
|
||||
|
||||
compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind)
|
||||
|
||||
if !isComparable {
|
||||
return Fail(t, fmt.Sprintf("Can not compare type \"%s\" and \"%s\"", reflect.TypeOf(value), reflect.TypeOf(prevValue)), msgAndArgs...)
|
||||
}
|
||||
|
||||
if !containsValue(allowedComparesResults, compareResult) {
|
||||
return Fail(t, fmt.Sprintf(failMessage, prevValue, value), msgAndArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsIncreasing asserts that the collection is increasing
|
||||
//
|
||||
// assert.IsIncreasing(t, []int{1, 2, 3})
|
||||
// assert.IsIncreasing(t, []float{1, 2})
|
||||
// assert.IsIncreasing(t, []string{"a", "b"})
|
||||
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// IsNonIncreasing asserts that the collection is not increasing
|
||||
//
|
||||
// assert.IsNonIncreasing(t, []int{2, 1, 1})
|
||||
// assert.IsNonIncreasing(t, []float{2, 1})
|
||||
// assert.IsNonIncreasing(t, []string{"b", "a"})
|
||||
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// IsDecreasing asserts that the collection is decreasing
|
||||
//
|
||||
// assert.IsDecreasing(t, []int{2, 1, 0})
|
||||
// assert.IsDecreasing(t, []float{2, 1})
|
||||
// assert.IsDecreasing(t, []string{"b", "a"})
|
||||
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// IsNonDecreasing asserts that the collection is not decreasing
|
||||
//
|
||||
// assert.IsNonDecreasing(t, []int{1, 1, 2})
|
||||
// assert.IsNonDecreasing(t, []float{1, 2})
|
||||
// assert.IsNonDecreasing(t, []string{"a", "b"})
|
||||
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
-2105
File diff suppressed because it is too large
Load Diff
-46
@@ -1,46 +0,0 @@
|
||||
// Package assert provides a set of comprehensive testing tools for use with the normal Go testing system.
|
||||
//
|
||||
// # Example Usage
|
||||
//
|
||||
// The following is a complete example using assert in a standard test function:
|
||||
//
|
||||
// import (
|
||||
// "testing"
|
||||
// "github.com/stretchr/testify/assert"
|
||||
// )
|
||||
//
|
||||
// func TestSomething(t *testing.T) {
|
||||
//
|
||||
// var a string = "Hello"
|
||||
// var b string = "Hello"
|
||||
//
|
||||
// assert.Equal(t, a, b, "The two words should be the same.")
|
||||
//
|
||||
// }
|
||||
//
|
||||
// if you assert many times, use the format below:
|
||||
//
|
||||
// import (
|
||||
// "testing"
|
||||
// "github.com/stretchr/testify/assert"
|
||||
// )
|
||||
//
|
||||
// func TestSomething(t *testing.T) {
|
||||
// assert := assert.New(t)
|
||||
//
|
||||
// var a string = "Hello"
|
||||
// var b string = "Hello"
|
||||
//
|
||||
// assert.Equal(a, b, "The two words should be the same.")
|
||||
// }
|
||||
//
|
||||
// # Assertions
|
||||
//
|
||||
// Assertions allow you to easily write test code, and are global funcs in the `assert` package.
|
||||
// All assertion functions take, as the first argument, the `*testing.T` object provided by the
|
||||
// testing framework. This allows the assertion funcs to write the failings and other details to
|
||||
// the correct place.
|
||||
//
|
||||
// Every assertion function also takes an optional string message as the final argument,
|
||||
// allowing custom error messages to be appended to the message the assertion method outputs.
|
||||
package assert
|
||||
-10
@@ -1,10 +0,0 @@
|
||||
package assert
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// AnError is an error instance useful for testing. If the code does not care
|
||||
// about error specifics, and only needs to return the error for example, this
|
||||
// error should be used to make the test code more readable.
|
||||
var AnError = errors.New("assert.AnError general error for testing")
|
||||
-16
@@ -1,16 +0,0 @@
|
||||
package assert
|
||||
|
||||
// Assertions provides assertion methods around the
|
||||
// TestingT interface.
|
||||
type Assertions struct {
|
||||
t TestingT
|
||||
}
|
||||
|
||||
// New makes a new Assertions object for the specified TestingT.
|
||||
func New(t TestingT) *Assertions {
|
||||
return &Assertions{
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_forward.go.tmpl -include-format-funcs"
|
||||
-165
@@ -1,165 +0,0 @@
|
||||
package assert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// httpCode is a helper that returns HTTP code of the response. It returns -1 and
|
||||
// an error if building a new request fails.
|
||||
func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) {
|
||||
w := httptest.NewRecorder()
|
||||
req, err := http.NewRequest(method, url, http.NoBody)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
req.URL.RawQuery = values.Encode()
|
||||
handler(w, req)
|
||||
return w.Code, nil
|
||||
}
|
||||
|
||||
// HTTPSuccess asserts that a specified handler returns a success status code.
|
||||
//
|
||||
// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil)
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
code, err := httpCode(handler, method, url, values)
|
||||
if err != nil {
|
||||
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
|
||||
}
|
||||
|
||||
isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent
|
||||
if !isSuccessCode {
|
||||
Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
|
||||
}
|
||||
|
||||
return isSuccessCode
|
||||
}
|
||||
|
||||
// HTTPRedirect asserts that a specified handler returns a redirect status code.
|
||||
//
|
||||
// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
code, err := httpCode(handler, method, url, values)
|
||||
if err != nil {
|
||||
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
|
||||
}
|
||||
|
||||
isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect
|
||||
if !isRedirectCode {
|
||||
Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
|
||||
}
|
||||
|
||||
return isRedirectCode
|
||||
}
|
||||
|
||||
// HTTPError asserts that a specified handler returns an error status code.
|
||||
//
|
||||
// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
code, err := httpCode(handler, method, url, values)
|
||||
if err != nil {
|
||||
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
|
||||
}
|
||||
|
||||
isErrorCode := code >= http.StatusBadRequest
|
||||
if !isErrorCode {
|
||||
Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
|
||||
}
|
||||
|
||||
return isErrorCode
|
||||
}
|
||||
|
||||
// HTTPStatusCode asserts that a specified handler returns a specified status code.
|
||||
//
|
||||
// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501)
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
code, err := httpCode(handler, method, url, values)
|
||||
if err != nil {
|
||||
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
|
||||
}
|
||||
|
||||
successful := code == statuscode
|
||||
if !successful {
|
||||
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code), msgAndArgs...)
|
||||
}
|
||||
|
||||
return successful
|
||||
}
|
||||
|
||||
// HTTPBody is a helper that returns HTTP body of the response. It returns
|
||||
// empty string if building a new request fails.
|
||||
func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string {
|
||||
w := httptest.NewRecorder()
|
||||
if len(values) > 0 {
|
||||
url += "?" + values.Encode()
|
||||
}
|
||||
req, err := http.NewRequest(method, url, http.NoBody)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
handler(w, req)
|
||||
return w.Body.String()
|
||||
}
|
||||
|
||||
// HTTPBodyContains asserts that a specified handler returns a
|
||||
// body that contains a string.
|
||||
//
|
||||
// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
body := HTTPBody(handler, method, url, values)
|
||||
|
||||
contains := strings.Contains(body, fmt.Sprint(str))
|
||||
if !contains {
|
||||
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
|
||||
}
|
||||
|
||||
return contains
|
||||
}
|
||||
|
||||
// HTTPBodyNotContains asserts that a specified handler returns a
|
||||
// body that does not contain a string.
|
||||
//
|
||||
// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
body := HTTPBody(handler, method, url, values)
|
||||
|
||||
contains := strings.Contains(body, fmt.Sprint(str))
|
||||
if contains {
|
||||
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
|
||||
}
|
||||
|
||||
return !contains
|
||||
}
|
||||
-44
@@ -1,44 +0,0 @@
|
||||
// Package mock provides a system by which it is possible to mock your objects
|
||||
// and verify calls are happening as expected.
|
||||
//
|
||||
// # Example Usage
|
||||
//
|
||||
// The mock package provides an object, Mock, that tracks activity on another object. It is usually
|
||||
// embedded into a test object as shown below:
|
||||
//
|
||||
// type MyTestObject struct {
|
||||
// // add a Mock object instance
|
||||
// mock.Mock
|
||||
//
|
||||
// // other fields go here as normal
|
||||
// }
|
||||
//
|
||||
// When implementing the methods of an interface, you wire your functions up
|
||||
// to call the Mock.Called(args...) method, and return the appropriate values.
|
||||
//
|
||||
// For example, to mock a method that saves the name and age of a person and returns
|
||||
// the year of their birth or an error, you might write this:
|
||||
//
|
||||
// func (o *MyTestObject) SavePersonDetails(firstname, lastname string, age int) (int, error) {
|
||||
// args := o.Called(firstname, lastname, age)
|
||||
// return args.Int(0), args.Error(1)
|
||||
// }
|
||||
//
|
||||
// The Int, Error and Bool methods are examples of strongly typed getters that take the argument
|
||||
// index position. Given this argument list:
|
||||
//
|
||||
// (12, true, "Something")
|
||||
//
|
||||
// You could read them out strongly typed like this:
|
||||
//
|
||||
// args.Int(0)
|
||||
// args.Bool(1)
|
||||
// args.String(2)
|
||||
//
|
||||
// For objects of your own type, use the generic Arguments.Get(index) method and make a type assertion:
|
||||
//
|
||||
// return args.Get(0).(*MyObject), args.Get(1).(*AnotherObjectOfMine)
|
||||
//
|
||||
// This may cause a panic if the object you are getting is nil (the type assertion will fail), in those
|
||||
// cases you should check for nil first.
|
||||
package mock
|
||||
-1241
File diff suppressed because it is too large
Load Diff
-29
@@ -1,29 +0,0 @@
|
||||
// Package require implements the same assertions as the `assert` package but
|
||||
// stops test execution when a test fails.
|
||||
//
|
||||
// # Example Usage
|
||||
//
|
||||
// The following is a complete example using require in a standard test function:
|
||||
//
|
||||
// import (
|
||||
// "testing"
|
||||
// "github.com/stretchr/testify/require"
|
||||
// )
|
||||
//
|
||||
// func TestSomething(t *testing.T) {
|
||||
//
|
||||
// var a string = "Hello"
|
||||
// var b string = "Hello"
|
||||
//
|
||||
// require.Equal(t, a, b, "The two words should be the same.")
|
||||
//
|
||||
// }
|
||||
//
|
||||
// # Assertions
|
||||
//
|
||||
// The `require` package have same global functions as in the `assert` package,
|
||||
// but instead of returning a boolean result they call `t.FailNow()`.
|
||||
//
|
||||
// Every assertion function also takes an optional string message as the final argument,
|
||||
// allowing custom error messages to be appended to the message the assertion method outputs.
|
||||
package require
|
||||
-16
@@ -1,16 +0,0 @@
|
||||
package require
|
||||
|
||||
// Assertions provides assertion methods around the
|
||||
// TestingT interface.
|
||||
type Assertions struct {
|
||||
t TestingT
|
||||
}
|
||||
|
||||
// New makes a new Assertions object for the specified TestingT.
|
||||
func New(t TestingT) *Assertions {
|
||||
return &Assertions{
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=require -template=require_forward.go.tmpl -include-format-funcs"
|
||||
-2060
File diff suppressed because it is too large
Load Diff
-6
@@ -1,6 +0,0 @@
|
||||
{{.Comment}}
|
||||
func {{.DocInfo.Name}}(t TestingT, {{.Params}}) {
|
||||
if h, ok := t.(tHelper); ok { h.Helper() }
|
||||
if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return }
|
||||
t.FailNow()
|
||||
}
|
||||
-1622
File diff suppressed because it is too large
Load Diff
-5
@@ -1,5 +0,0 @@
|
||||
{{.CommentWithoutT "a"}}
|
||||
func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) {
|
||||
if h, ok := a.t.(tHelper); ok { h.Helper() }
|
||||
{{.DocInfo.Name}}(a.t, {{.ForwardedParams}})
|
||||
}
|
||||
-29
@@ -1,29 +0,0 @@
|
||||
package require
|
||||
|
||||
// TestingT is an interface wrapper around *testing.T
|
||||
type TestingT interface {
|
||||
Errorf(format string, args ...interface{})
|
||||
FailNow()
|
||||
}
|
||||
|
||||
type tHelper interface {
|
||||
Helper()
|
||||
}
|
||||
|
||||
// ComparisonAssertionFunc is a common function prototype when comparing two values. Can be useful
|
||||
// for table driven tests.
|
||||
type ComparisonAssertionFunc func(TestingT, interface{}, interface{}, ...interface{})
|
||||
|
||||
// ValueAssertionFunc is a common function prototype when validating a single value. Can be useful
|
||||
// for table driven tests.
|
||||
type ValueAssertionFunc func(TestingT, interface{}, ...interface{})
|
||||
|
||||
// BoolAssertionFunc is a common function prototype when validating a bool value. Can be useful
|
||||
// for table driven tests.
|
||||
type BoolAssertionFunc func(TestingT, bool, ...interface{})
|
||||
|
||||
// ErrorAssertionFunc is a common function prototype when validating an error value. Can be useful
|
||||
// for table driven tests.
|
||||
type ErrorAssertionFunc func(TestingT, error, ...interface{})
|
||||
|
||||
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=require -template=require.go.tmpl -include-format-funcs"
|
||||
-66
@@ -1,66 +0,0 @@
|
||||
// Package suite contains logic for creating testing suite structs
|
||||
// and running the methods on those structs as tests. The most useful
|
||||
// piece of this package is that you can create setup/teardown methods
|
||||
// on your testing suites, which will run before/after the whole suite
|
||||
// or individual tests (depending on which interface(s) you
|
||||
// implement).
|
||||
//
|
||||
// A testing suite is usually built by first extending the built-in
|
||||
// suite functionality from suite.Suite in testify. Alternatively,
|
||||
// you could reproduce that logic on your own if you wanted (you
|
||||
// just need to implement the TestingSuite interface from
|
||||
// suite/interfaces.go).
|
||||
//
|
||||
// After that, you can implement any of the interfaces in
|
||||
// suite/interfaces.go to add setup/teardown functionality to your
|
||||
// suite, and add any methods that start with "Test" to add tests.
|
||||
// Methods that do not match any suite interfaces and do not begin
|
||||
// with "Test" will not be run by testify, and can safely be used as
|
||||
// helper methods.
|
||||
//
|
||||
// Once you've built your testing suite, you need to run the suite
|
||||
// (using suite.Run from testify) inside any function that matches the
|
||||
// identity that "go test" is already looking for (i.e.
|
||||
// func(*testing.T)).
|
||||
//
|
||||
// Regular expression to select test suites specified command-line
|
||||
// argument "-run". Regular expression to select the methods
|
||||
// of test suites specified command-line argument "-m".
|
||||
// Suite object has assertion methods.
|
||||
//
|
||||
// A crude example:
|
||||
//
|
||||
// // Basic imports
|
||||
// import (
|
||||
// "testing"
|
||||
// "github.com/stretchr/testify/assert"
|
||||
// "github.com/stretchr/testify/suite"
|
||||
// )
|
||||
//
|
||||
// // Define the suite, and absorb the built-in basic suite
|
||||
// // functionality from testify - including a T() method which
|
||||
// // returns the current testing context
|
||||
// type ExampleTestSuite struct {
|
||||
// suite.Suite
|
||||
// VariableThatShouldStartAtFive int
|
||||
// }
|
||||
//
|
||||
// // Make sure that VariableThatShouldStartAtFive is set to five
|
||||
// // before each test
|
||||
// func (suite *ExampleTestSuite) SetupTest() {
|
||||
// suite.VariableThatShouldStartAtFive = 5
|
||||
// }
|
||||
//
|
||||
// // All methods that begin with "Test" are run as tests within a
|
||||
// // suite.
|
||||
// func (suite *ExampleTestSuite) TestExample() {
|
||||
// assert.Equal(suite.T(), 5, suite.VariableThatShouldStartAtFive)
|
||||
// suite.Equal(5, suite.VariableThatShouldStartAtFive)
|
||||
// }
|
||||
//
|
||||
// // In order for 'go test' to run this suite, we need to create
|
||||
// // a normal test function and pass our suite to suite.Run
|
||||
// func TestExampleTestSuite(t *testing.T) {
|
||||
// suite.Run(t, new(ExampleTestSuite))
|
||||
// }
|
||||
package suite
|
||||
-66
@@ -1,66 +0,0 @@
|
||||
package suite
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestingSuite can store and return the current *testing.T context
|
||||
// generated by 'go test'.
|
||||
type TestingSuite interface {
|
||||
T() *testing.T
|
||||
SetT(*testing.T)
|
||||
SetS(suite TestingSuite)
|
||||
}
|
||||
|
||||
// SetupAllSuite has a SetupSuite method, which will run before the
|
||||
// tests in the suite are run.
|
||||
type SetupAllSuite interface {
|
||||
SetupSuite()
|
||||
}
|
||||
|
||||
// SetupTestSuite has a SetupTest method, which will run before each
|
||||
// test in the suite.
|
||||
type SetupTestSuite interface {
|
||||
SetupTest()
|
||||
}
|
||||
|
||||
// TearDownAllSuite has a TearDownSuite method, which will run after
|
||||
// all the tests in the suite have been run.
|
||||
type TearDownAllSuite interface {
|
||||
TearDownSuite()
|
||||
}
|
||||
|
||||
// TearDownTestSuite has a TearDownTest method, which will run after
|
||||
// each test in the suite.
|
||||
type TearDownTestSuite interface {
|
||||
TearDownTest()
|
||||
}
|
||||
|
||||
// BeforeTest has a function to be executed right before the test
|
||||
// starts and receives the suite and test names as input
|
||||
type BeforeTest interface {
|
||||
BeforeTest(suiteName, testName string)
|
||||
}
|
||||
|
||||
// AfterTest has a function to be executed right after the test
|
||||
// finishes and receives the suite and test names as input
|
||||
type AfterTest interface {
|
||||
AfterTest(suiteName, testName string)
|
||||
}
|
||||
|
||||
// WithStats implements HandleStats, a function that will be executed
|
||||
// when a test suite is finished. The stats contain information about
|
||||
// the execution of that suite and its tests.
|
||||
type WithStats interface {
|
||||
HandleStats(suiteName string, stats *SuiteInformation)
|
||||
}
|
||||
|
||||
// SetupSubTest has a SetupSubTest method, which will run before each
|
||||
// subtest in the suite.
|
||||
type SetupSubTest interface {
|
||||
SetupSubTest()
|
||||
}
|
||||
|
||||
// TearDownSubTest has a TearDownSubTest method, which will run after
|
||||
// each subtest in the suite have been run.
|
||||
type TearDownSubTest interface {
|
||||
TearDownSubTest()
|
||||
}
|
||||
-46
@@ -1,46 +0,0 @@
|
||||
package suite
|
||||
|
||||
import "time"
|
||||
|
||||
// SuiteInformation stats stores stats for the whole suite execution.
|
||||
type SuiteInformation struct {
|
||||
Start, End time.Time
|
||||
TestStats map[string]*TestInformation
|
||||
}
|
||||
|
||||
// TestInformation stores information about the execution of each test.
|
||||
type TestInformation struct {
|
||||
TestName string
|
||||
Start, End time.Time
|
||||
Passed bool
|
||||
}
|
||||
|
||||
func newSuiteInformation() *SuiteInformation {
|
||||
testStats := make(map[string]*TestInformation)
|
||||
|
||||
return &SuiteInformation{
|
||||
TestStats: testStats,
|
||||
}
|
||||
}
|
||||
|
||||
func (s SuiteInformation) start(testName string) {
|
||||
s.TestStats[testName] = &TestInformation{
|
||||
TestName: testName,
|
||||
Start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s SuiteInformation) end(testName string, passed bool) {
|
||||
s.TestStats[testName].End = time.Now()
|
||||
s.TestStats[testName].Passed = passed
|
||||
}
|
||||
|
||||
func (s SuiteInformation) Passed() bool {
|
||||
for _, stats := range s.TestStats {
|
||||
if !stats.Passed {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
-253
@@ -1,253 +0,0 @@
|
||||
package suite
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var allTestsFilter = func(_, _ string) (bool, error) { return true, nil }
|
||||
var matchMethod = flag.String("testify.m", "", "regular expression to select tests of the testify suite to run")
|
||||
|
||||
// Suite is a basic testing suite with methods for storing and
|
||||
// retrieving the current *testing.T context.
|
||||
type Suite struct {
|
||||
*assert.Assertions
|
||||
|
||||
mu sync.RWMutex
|
||||
require *require.Assertions
|
||||
t *testing.T
|
||||
|
||||
// Parent suite to have access to the implemented methods of parent struct
|
||||
s TestingSuite
|
||||
}
|
||||
|
||||
// T retrieves the current *testing.T context.
|
||||
func (suite *Suite) T() *testing.T {
|
||||
suite.mu.RLock()
|
||||
defer suite.mu.RUnlock()
|
||||
return suite.t
|
||||
}
|
||||
|
||||
// SetT sets the current *testing.T context.
|
||||
func (suite *Suite) SetT(t *testing.T) {
|
||||
suite.mu.Lock()
|
||||
defer suite.mu.Unlock()
|
||||
suite.t = t
|
||||
suite.Assertions = assert.New(t)
|
||||
suite.require = require.New(t)
|
||||
}
|
||||
|
||||
// SetS needs to set the current test suite as parent
|
||||
// to get access to the parent methods
|
||||
func (suite *Suite) SetS(s TestingSuite) {
|
||||
suite.s = s
|
||||
}
|
||||
|
||||
// Require returns a require context for suite.
|
||||
func (suite *Suite) Require() *require.Assertions {
|
||||
suite.mu.Lock()
|
||||
defer suite.mu.Unlock()
|
||||
if suite.require == nil {
|
||||
panic("'Require' must not be called before 'Run' or 'SetT'")
|
||||
}
|
||||
return suite.require
|
||||
}
|
||||
|
||||
// Assert returns an assert context for suite. Normally, you can call
|
||||
// `suite.NoError(expected, actual)`, but for situations where the embedded
|
||||
// methods are overridden (for example, you might want to override
|
||||
// assert.Assertions with require.Assertions), this method is provided so you
|
||||
// can call `suite.Assert().NoError()`.
|
||||
func (suite *Suite) Assert() *assert.Assertions {
|
||||
suite.mu.Lock()
|
||||
defer suite.mu.Unlock()
|
||||
if suite.Assertions == nil {
|
||||
panic("'Assert' must not be called before 'Run' or 'SetT'")
|
||||
}
|
||||
return suite.Assertions
|
||||
}
|
||||
|
||||
func recoverAndFailOnPanic(t *testing.T) {
|
||||
t.Helper()
|
||||
r := recover()
|
||||
failOnPanic(t, r)
|
||||
}
|
||||
|
||||
func failOnPanic(t *testing.T, r interface{}) {
|
||||
t.Helper()
|
||||
if r != nil {
|
||||
t.Errorf("test panicked: %v\n%s", r, debug.Stack())
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Run provides suite functionality around golang subtests. It should be
|
||||
// called in place of t.Run(name, func(t *testing.T)) in test suite code.
|
||||
// The passed-in func will be executed as a subtest with a fresh instance of t.
|
||||
// Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName.
|
||||
func (suite *Suite) Run(name string, subtest func()) bool {
|
||||
oldT := suite.T()
|
||||
|
||||
return oldT.Run(name, func(t *testing.T) {
|
||||
suite.SetT(t)
|
||||
defer suite.SetT(oldT)
|
||||
|
||||
defer recoverAndFailOnPanic(t)
|
||||
|
||||
if setupSubTest, ok := suite.s.(SetupSubTest); ok {
|
||||
setupSubTest.SetupSubTest()
|
||||
}
|
||||
|
||||
if tearDownSubTest, ok := suite.s.(TearDownSubTest); ok {
|
||||
defer tearDownSubTest.TearDownSubTest()
|
||||
}
|
||||
|
||||
subtest()
|
||||
})
|
||||
}
|
||||
|
||||
// Run takes a testing suite and runs all of the tests attached
|
||||
// to it.
|
||||
func Run(t *testing.T, suite TestingSuite) {
|
||||
defer recoverAndFailOnPanic(t)
|
||||
|
||||
suite.SetT(t)
|
||||
suite.SetS(suite)
|
||||
|
||||
var suiteSetupDone bool
|
||||
|
||||
var stats *SuiteInformation
|
||||
if _, ok := suite.(WithStats); ok {
|
||||
stats = newSuiteInformation()
|
||||
}
|
||||
|
||||
tests := []testing.InternalTest{}
|
||||
methodFinder := reflect.TypeOf(suite)
|
||||
suiteName := methodFinder.Elem().Name()
|
||||
|
||||
for i := 0; i < methodFinder.NumMethod(); i++ {
|
||||
method := methodFinder.Method(i)
|
||||
|
||||
ok, err := methodFilter(method.Name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if !suiteSetupDone {
|
||||
if stats != nil {
|
||||
stats.Start = time.Now()
|
||||
}
|
||||
|
||||
if setupAllSuite, ok := suite.(SetupAllSuite); ok {
|
||||
setupAllSuite.SetupSuite()
|
||||
}
|
||||
|
||||
suiteSetupDone = true
|
||||
}
|
||||
|
||||
test := testing.InternalTest{
|
||||
Name: method.Name,
|
||||
F: func(t *testing.T) {
|
||||
parentT := suite.T()
|
||||
suite.SetT(t)
|
||||
defer recoverAndFailOnPanic(t)
|
||||
defer func() {
|
||||
t.Helper()
|
||||
|
||||
r := recover()
|
||||
|
||||
if stats != nil {
|
||||
passed := !t.Failed() && r == nil
|
||||
stats.end(method.Name, passed)
|
||||
}
|
||||
|
||||
if afterTestSuite, ok := suite.(AfterTest); ok {
|
||||
afterTestSuite.AfterTest(suiteName, method.Name)
|
||||
}
|
||||
|
||||
if tearDownTestSuite, ok := suite.(TearDownTestSuite); ok {
|
||||
tearDownTestSuite.TearDownTest()
|
||||
}
|
||||
|
||||
suite.SetT(parentT)
|
||||
failOnPanic(t, r)
|
||||
}()
|
||||
|
||||
if setupTestSuite, ok := suite.(SetupTestSuite); ok {
|
||||
setupTestSuite.SetupTest()
|
||||
}
|
||||
if beforeTestSuite, ok := suite.(BeforeTest); ok {
|
||||
beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name)
|
||||
}
|
||||
|
||||
if stats != nil {
|
||||
stats.start(method.Name)
|
||||
}
|
||||
|
||||
method.Func.Call([]reflect.Value{reflect.ValueOf(suite)})
|
||||
},
|
||||
}
|
||||
tests = append(tests, test)
|
||||
}
|
||||
if suiteSetupDone {
|
||||
defer func() {
|
||||
if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok {
|
||||
tearDownAllSuite.TearDownSuite()
|
||||
}
|
||||
|
||||
if suiteWithStats, measureStats := suite.(WithStats); measureStats {
|
||||
stats.End = time.Now()
|
||||
suiteWithStats.HandleStats(suiteName, stats)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
runTests(t, tests)
|
||||
}
|
||||
|
||||
// Filtering method according to set regular expression
|
||||
// specified command-line argument -m
|
||||
func methodFilter(name string) (bool, error) {
|
||||
if ok, _ := regexp.MatchString("^Test", name); !ok {
|
||||
return false, nil
|
||||
}
|
||||
return regexp.MatchString(*matchMethod, name)
|
||||
}
|
||||
|
||||
func runTests(t testing.TB, tests []testing.InternalTest) {
|
||||
if len(tests) == 0 {
|
||||
t.Log("warning: no tests to run")
|
||||
return
|
||||
}
|
||||
|
||||
r, ok := t.(runner)
|
||||
if !ok { // backwards compatibility with Go 1.6 and below
|
||||
if !testing.RunTests(allTestsFilter, tests) {
|
||||
t.Fail()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
r.Run(test.Name, test.F)
|
||||
}
|
||||
}
|
||||
|
||||
type runner interface {
|
||||
Run(name string, f func(t *testing.T)) bool
|
||||
}
|
||||
-27
@@ -1,27 +0,0 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
-22
@@ -1,22 +0,0 @@
|
||||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the Go project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of Go, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of Go. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of Go or any code incorporated within this
|
||||
implementation of Go constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of Go
|
||||
shall terminate as of the date such litigation is filed.
|
||||
-135
@@ -1,135 +0,0 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package errgroup provides synchronization, error propagation, and Context
|
||||
// cancelation for groups of goroutines working on subtasks of a common task.
|
||||
//
|
||||
// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks
|
||||
// returning errors.
|
||||
package errgroup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type token struct{}
|
||||
|
||||
// A Group is a collection of goroutines working on subtasks that are part of
|
||||
// the same overall task.
|
||||
//
|
||||
// A zero Group is valid, has no limit on the number of active goroutines,
|
||||
// and does not cancel on error.
|
||||
type Group struct {
|
||||
cancel func(error)
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
sem chan token
|
||||
|
||||
errOnce sync.Once
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *Group) done() {
|
||||
if g.sem != nil {
|
||||
<-g.sem
|
||||
}
|
||||
g.wg.Done()
|
||||
}
|
||||
|
||||
// WithContext returns a new Group and an associated Context derived from ctx.
|
||||
//
|
||||
// The derived Context is canceled the first time a function passed to Go
|
||||
// returns a non-nil error or the first time Wait returns, whichever occurs
|
||||
// first.
|
||||
func WithContext(ctx context.Context) (*Group, context.Context) {
|
||||
ctx, cancel := withCancelCause(ctx)
|
||||
return &Group{cancel: cancel}, ctx
|
||||
}
|
||||
|
||||
// Wait blocks until all function calls from the Go method have returned, then
|
||||
// returns the first non-nil error (if any) from them.
|
||||
func (g *Group) Wait() error {
|
||||
g.wg.Wait()
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
return g.err
|
||||
}
|
||||
|
||||
// Go calls the given function in a new goroutine.
|
||||
// It blocks until the new goroutine can be added without the number of
|
||||
// active goroutines in the group exceeding the configured limit.
|
||||
//
|
||||
// The first call to return a non-nil error cancels the group's context, if the
|
||||
// group was created by calling WithContext. The error will be returned by Wait.
|
||||
func (g *Group) Go(f func() error) {
|
||||
if g.sem != nil {
|
||||
g.sem <- token{}
|
||||
}
|
||||
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.done()
|
||||
|
||||
if err := f(); err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// TryGo calls the given function in a new goroutine only if the number of
|
||||
// active goroutines in the group is currently below the configured limit.
|
||||
//
|
||||
// The return value reports whether the goroutine was started.
|
||||
func (g *Group) TryGo(f func() error) bool {
|
||||
if g.sem != nil {
|
||||
select {
|
||||
case g.sem <- token{}:
|
||||
// Note: this allows barging iff channels in general allow barging.
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.done()
|
||||
|
||||
if err := f(); err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// SetLimit limits the number of active goroutines in this group to at most n.
|
||||
// A negative value indicates no limit.
|
||||
//
|
||||
// Any subsequent call to the Go method will block until it can add an active
|
||||
// goroutine without exceeding the configured limit.
|
||||
//
|
||||
// The limit must not be modified while any goroutines in the group are active.
|
||||
func (g *Group) SetLimit(n int) {
|
||||
if n < 0 {
|
||||
g.sem = nil
|
||||
return
|
||||
}
|
||||
if len(g.sem) != 0 {
|
||||
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
|
||||
}
|
||||
g.sem = make(chan token, n)
|
||||
}
|
||||
-13
@@ -1,13 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build go1.20
|
||||
|
||||
package errgroup
|
||||
|
||||
import "context"
|
||||
|
||||
func withCancelCause(parent context.Context) (context.Context, func(error)) {
|
||||
return context.WithCancelCause(parent)
|
||||
}
|
||||
-14
@@ -1,14 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !go1.20
|
||||
|
||||
package errgroup
|
||||
|
||||
import "context"
|
||||
|
||||
func withCancelCause(parent context.Context) (context.Context, func(error)) {
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
return ctx, func(error) { cancel() }
|
||||
}
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
Copyright 2009 The Go Authors.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
* Neither the name of Google LLC nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
|
||||
+3
-14
@@ -99,8 +99,9 @@ func (lim *Limiter) Tokens() float64 {
|
||||
// bursts of at most b tokens.
|
||||
func NewLimiter(r Limit, b int) *Limiter {
|
||||
return &Limiter{
|
||||
limit: r,
|
||||
burst: b,
|
||||
limit: r,
|
||||
burst: b,
|
||||
tokens: float64(b),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,18 +345,6 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
|
||||
tokens: n,
|
||||
timeToAct: t,
|
||||
}
|
||||
} else if lim.limit == 0 {
|
||||
var ok bool
|
||||
if lim.burst >= n {
|
||||
ok = true
|
||||
lim.burst -= n
|
||||
}
|
||||
return Reservation{
|
||||
ok: ok,
|
||||
lim: lim,
|
||||
tokens: lim.burst,
|
||||
timeToAct: t,
|
||||
}
|
||||
}
|
||||
|
||||
t, tokens := lim.advance(t)
|
||||
|
||||
-50
@@ -1,50 +0,0 @@
|
||||
|
||||
This project is covered by two different licenses: MIT and Apache.
|
||||
|
||||
#### MIT License ####
|
||||
|
||||
The following files were ported to Go from C files of libyaml, and thus
|
||||
are still covered by their original MIT license, with the additional
|
||||
copyright staring in 2011 when the project was ported over:
|
||||
|
||||
apic.go emitterc.go parserc.go readerc.go scannerc.go
|
||||
writerc.go yamlh.go yamlprivateh.go
|
||||
|
||||
Copyright (c) 2006-2010 Kirill Simonov
|
||||
Copyright (c) 2006-2011 Kirill Simonov
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Software, and to permit persons to whom the Software is furnished to do
|
||||
so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
### Apache License ###
|
||||
|
||||
All the remaining project files are covered by the Apache license:
|
||||
|
||||
Copyright (c) 2011-2019 Canonical Ltd
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-13
@@ -1,13 +0,0 @@
|
||||
Copyright 2011-2016 Canonical Ltd.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-150
@@ -1,150 +0,0 @@
|
||||
# YAML support for the Go language
|
||||
|
||||
Introduction
|
||||
------------
|
||||
|
||||
The yaml package enables Go programs to comfortably encode and decode YAML
|
||||
values. It was developed within [Canonical](https://www.canonical.com) as
|
||||
part of the [juju](https://juju.ubuntu.com) project, and is based on a
|
||||
pure Go port of the well-known [libyaml](http://pyyaml.org/wiki/LibYAML)
|
||||
C library to parse and generate YAML data quickly and reliably.
|
||||
|
||||
Compatibility
|
||||
-------------
|
||||
|
||||
The yaml package supports most of YAML 1.2, but preserves some behavior
|
||||
from 1.1 for backwards compatibility.
|
||||
|
||||
Specifically, as of v3 of the yaml package:
|
||||
|
||||
- YAML 1.1 bools (_yes/no, on/off_) are supported as long as they are being
|
||||
decoded into a typed bool value. Otherwise they behave as a string. Booleans
|
||||
in YAML 1.2 are _true/false_ only.
|
||||
- Octals encode and decode as _0777_ per YAML 1.1, rather than _0o777_
|
||||
as specified in YAML 1.2, because most parsers still use the old format.
|
||||
Octals in the _0o777_ format are supported though, so new files work.
|
||||
- Does not support base-60 floats. These are gone from YAML 1.2, and were
|
||||
actually never supported by this package as it's clearly a poor choice.
|
||||
|
||||
and offers backwards
|
||||
compatibility with YAML 1.1 in some cases.
|
||||
1.2, including support for
|
||||
anchors, tags, map merging, etc. Multi-document unmarshalling is not yet
|
||||
implemented, and base-60 floats from YAML 1.1 are purposefully not
|
||||
supported since they're a poor design and are gone in YAML 1.2.
|
||||
|
||||
Installation and usage
|
||||
----------------------
|
||||
|
||||
The import path for the package is *gopkg.in/yaml.v3*.
|
||||
|
||||
To install it, run:
|
||||
|
||||
go get gopkg.in/yaml.v3
|
||||
|
||||
API documentation
|
||||
-----------------
|
||||
|
||||
If opened in a browser, the import path itself leads to the API documentation:
|
||||
|
||||
- [https://gopkg.in/yaml.v3](https://gopkg.in/yaml.v3)
|
||||
|
||||
API stability
|
||||
-------------
|
||||
|
||||
The package API for yaml v3 will remain stable as described in [gopkg.in](https://gopkg.in).
|
||||
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
The yaml package is licensed under the MIT and Apache License 2.0 licenses.
|
||||
Please see the LICENSE file for details.
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
```Go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var data = `
|
||||
a: Easy!
|
||||
b:
|
||||
c: 2
|
||||
d: [3, 4]
|
||||
`
|
||||
|
||||
// Note: struct fields must be public in order for unmarshal to
|
||||
// correctly populate the data.
|
||||
type T struct {
|
||||
A string
|
||||
B struct {
|
||||
RenamedC int `yaml:"c"`
|
||||
D []int `yaml:",flow"`
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
t := T{}
|
||||
|
||||
err := yaml.Unmarshal([]byte(data), &t)
|
||||
if err != nil {
|
||||
log.Fatalf("error: %v", err)
|
||||
}
|
||||
fmt.Printf("--- t:\n%v\n\n", t)
|
||||
|
||||
d, err := yaml.Marshal(&t)
|
||||
if err != nil {
|
||||
log.Fatalf("error: %v", err)
|
||||
}
|
||||
fmt.Printf("--- t dump:\n%s\n\n", string(d))
|
||||
|
||||
m := make(map[interface{}]interface{})
|
||||
|
||||
err = yaml.Unmarshal([]byte(data), &m)
|
||||
if err != nil {
|
||||
log.Fatalf("error: %v", err)
|
||||
}
|
||||
fmt.Printf("--- m:\n%v\n\n", m)
|
||||
|
||||
d, err = yaml.Marshal(&m)
|
||||
if err != nil {
|
||||
log.Fatalf("error: %v", err)
|
||||
}
|
||||
fmt.Printf("--- m dump:\n%s\n\n", string(d))
|
||||
}
|
||||
```
|
||||
|
||||
This example will generate the following output:
|
||||
|
||||
```
|
||||
--- t:
|
||||
{Easy! {2 [3 4]}}
|
||||
|
||||
--- t dump:
|
||||
a: Easy!
|
||||
b:
|
||||
c: 2
|
||||
d: [3, 4]
|
||||
|
||||
|
||||
--- m:
|
||||
map[a:Easy! b:map[c:2 d:[3 4]]]
|
||||
|
||||
--- m dump:
|
||||
a: Easy!
|
||||
b:
|
||||
c: 2
|
||||
d:
|
||||
- 3
|
||||
- 4
|
||||
```
|
||||
|
||||
-747
@@ -1,747 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
// Copyright (c) 2006-2010 Kirill Simonov
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
// of the Software, and to permit persons to whom the Software is furnished to do
|
||||
// so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
func yaml_insert_token(parser *yaml_parser_t, pos int, token *yaml_token_t) {
|
||||
//fmt.Println("yaml_insert_token", "pos:", pos, "typ:", token.typ, "head:", parser.tokens_head, "len:", len(parser.tokens))
|
||||
|
||||
// Check if we can move the queue at the beginning of the buffer.
|
||||
if parser.tokens_head > 0 && len(parser.tokens) == cap(parser.tokens) {
|
||||
if parser.tokens_head != len(parser.tokens) {
|
||||
copy(parser.tokens, parser.tokens[parser.tokens_head:])
|
||||
}
|
||||
parser.tokens = parser.tokens[:len(parser.tokens)-parser.tokens_head]
|
||||
parser.tokens_head = 0
|
||||
}
|
||||
parser.tokens = append(parser.tokens, *token)
|
||||
if pos < 0 {
|
||||
return
|
||||
}
|
||||
copy(parser.tokens[parser.tokens_head+pos+1:], parser.tokens[parser.tokens_head+pos:])
|
||||
parser.tokens[parser.tokens_head+pos] = *token
|
||||
}
|
||||
|
||||
// Create a new parser object.
|
||||
func yaml_parser_initialize(parser *yaml_parser_t) bool {
|
||||
*parser = yaml_parser_t{
|
||||
raw_buffer: make([]byte, 0, input_raw_buffer_size),
|
||||
buffer: make([]byte, 0, input_buffer_size),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Destroy a parser object.
|
||||
func yaml_parser_delete(parser *yaml_parser_t) {
|
||||
*parser = yaml_parser_t{}
|
||||
}
|
||||
|
||||
// String read handler.
|
||||
func yaml_string_read_handler(parser *yaml_parser_t, buffer []byte) (n int, err error) {
|
||||
if parser.input_pos == len(parser.input) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(buffer, parser.input[parser.input_pos:])
|
||||
parser.input_pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Reader read handler.
|
||||
func yaml_reader_read_handler(parser *yaml_parser_t, buffer []byte) (n int, err error) {
|
||||
return parser.input_reader.Read(buffer)
|
||||
}
|
||||
|
||||
// Set a string input.
|
||||
func yaml_parser_set_input_string(parser *yaml_parser_t, input []byte) {
|
||||
if parser.read_handler != nil {
|
||||
panic("must set the input source only once")
|
||||
}
|
||||
parser.read_handler = yaml_string_read_handler
|
||||
parser.input = input
|
||||
parser.input_pos = 0
|
||||
}
|
||||
|
||||
// Set a file input.
|
||||
func yaml_parser_set_input_reader(parser *yaml_parser_t, r io.Reader) {
|
||||
if parser.read_handler != nil {
|
||||
panic("must set the input source only once")
|
||||
}
|
||||
parser.read_handler = yaml_reader_read_handler
|
||||
parser.input_reader = r
|
||||
}
|
||||
|
||||
// Set the source encoding.
|
||||
func yaml_parser_set_encoding(parser *yaml_parser_t, encoding yaml_encoding_t) {
|
||||
if parser.encoding != yaml_ANY_ENCODING {
|
||||
panic("must set the encoding only once")
|
||||
}
|
||||
parser.encoding = encoding
|
||||
}
|
||||
|
||||
// Create a new emitter object.
|
||||
func yaml_emitter_initialize(emitter *yaml_emitter_t) {
|
||||
*emitter = yaml_emitter_t{
|
||||
buffer: make([]byte, output_buffer_size),
|
||||
raw_buffer: make([]byte, 0, output_raw_buffer_size),
|
||||
states: make([]yaml_emitter_state_t, 0, initial_stack_size),
|
||||
events: make([]yaml_event_t, 0, initial_queue_size),
|
||||
best_width: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy an emitter object.
|
||||
func yaml_emitter_delete(emitter *yaml_emitter_t) {
|
||||
*emitter = yaml_emitter_t{}
|
||||
}
|
||||
|
||||
// String write handler.
|
||||
func yaml_string_write_handler(emitter *yaml_emitter_t, buffer []byte) error {
|
||||
*emitter.output_buffer = append(*emitter.output_buffer, buffer...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// yaml_writer_write_handler uses emitter.output_writer to write the
|
||||
// emitted text.
|
||||
func yaml_writer_write_handler(emitter *yaml_emitter_t, buffer []byte) error {
|
||||
_, err := emitter.output_writer.Write(buffer)
|
||||
return err
|
||||
}
|
||||
|
||||
// Set a string output.
|
||||
func yaml_emitter_set_output_string(emitter *yaml_emitter_t, output_buffer *[]byte) {
|
||||
if emitter.write_handler != nil {
|
||||
panic("must set the output target only once")
|
||||
}
|
||||
emitter.write_handler = yaml_string_write_handler
|
||||
emitter.output_buffer = output_buffer
|
||||
}
|
||||
|
||||
// Set a file output.
|
||||
func yaml_emitter_set_output_writer(emitter *yaml_emitter_t, w io.Writer) {
|
||||
if emitter.write_handler != nil {
|
||||
panic("must set the output target only once")
|
||||
}
|
||||
emitter.write_handler = yaml_writer_write_handler
|
||||
emitter.output_writer = w
|
||||
}
|
||||
|
||||
// Set the output encoding.
|
||||
func yaml_emitter_set_encoding(emitter *yaml_emitter_t, encoding yaml_encoding_t) {
|
||||
if emitter.encoding != yaml_ANY_ENCODING {
|
||||
panic("must set the output encoding only once")
|
||||
}
|
||||
emitter.encoding = encoding
|
||||
}
|
||||
|
||||
// Set the canonical output style.
|
||||
func yaml_emitter_set_canonical(emitter *yaml_emitter_t, canonical bool) {
|
||||
emitter.canonical = canonical
|
||||
}
|
||||
|
||||
// Set the indentation increment.
|
||||
func yaml_emitter_set_indent(emitter *yaml_emitter_t, indent int) {
|
||||
if indent < 2 || indent > 9 {
|
||||
indent = 2
|
||||
}
|
||||
emitter.best_indent = indent
|
||||
}
|
||||
|
||||
// Set the preferred line width.
|
||||
func yaml_emitter_set_width(emitter *yaml_emitter_t, width int) {
|
||||
if width < 0 {
|
||||
width = -1
|
||||
}
|
||||
emitter.best_width = width
|
||||
}
|
||||
|
||||
// Set if unescaped non-ASCII characters are allowed.
|
||||
func yaml_emitter_set_unicode(emitter *yaml_emitter_t, unicode bool) {
|
||||
emitter.unicode = unicode
|
||||
}
|
||||
|
||||
// Set the preferred line break character.
|
||||
func yaml_emitter_set_break(emitter *yaml_emitter_t, line_break yaml_break_t) {
|
||||
emitter.line_break = line_break
|
||||
}
|
||||
|
||||
///*
|
||||
// * Destroy a token object.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(void)
|
||||
//yaml_token_delete(yaml_token_t *token)
|
||||
//{
|
||||
// assert(token); // Non-NULL token object expected.
|
||||
//
|
||||
// switch (token.type)
|
||||
// {
|
||||
// case YAML_TAG_DIRECTIVE_TOKEN:
|
||||
// yaml_free(token.data.tag_directive.handle);
|
||||
// yaml_free(token.data.tag_directive.prefix);
|
||||
// break;
|
||||
//
|
||||
// case YAML_ALIAS_TOKEN:
|
||||
// yaml_free(token.data.alias.value);
|
||||
// break;
|
||||
//
|
||||
// case YAML_ANCHOR_TOKEN:
|
||||
// yaml_free(token.data.anchor.value);
|
||||
// break;
|
||||
//
|
||||
// case YAML_TAG_TOKEN:
|
||||
// yaml_free(token.data.tag.handle);
|
||||
// yaml_free(token.data.tag.suffix);
|
||||
// break;
|
||||
//
|
||||
// case YAML_SCALAR_TOKEN:
|
||||
// yaml_free(token.data.scalar.value);
|
||||
// break;
|
||||
//
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// memset(token, 0, sizeof(yaml_token_t));
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Check if a string is a valid UTF-8 sequence.
|
||||
// *
|
||||
// * Check 'reader.c' for more details on UTF-8 encoding.
|
||||
// */
|
||||
//
|
||||
//static int
|
||||
//yaml_check_utf8(yaml_char_t *start, size_t length)
|
||||
//{
|
||||
// yaml_char_t *end = start+length;
|
||||
// yaml_char_t *pointer = start;
|
||||
//
|
||||
// while (pointer < end) {
|
||||
// unsigned char octet;
|
||||
// unsigned int width;
|
||||
// unsigned int value;
|
||||
// size_t k;
|
||||
//
|
||||
// octet = pointer[0];
|
||||
// width = (octet & 0x80) == 0x00 ? 1 :
|
||||
// (octet & 0xE0) == 0xC0 ? 2 :
|
||||
// (octet & 0xF0) == 0xE0 ? 3 :
|
||||
// (octet & 0xF8) == 0xF0 ? 4 : 0;
|
||||
// value = (octet & 0x80) == 0x00 ? octet & 0x7F :
|
||||
// (octet & 0xE0) == 0xC0 ? octet & 0x1F :
|
||||
// (octet & 0xF0) == 0xE0 ? octet & 0x0F :
|
||||
// (octet & 0xF8) == 0xF0 ? octet & 0x07 : 0;
|
||||
// if (!width) return 0;
|
||||
// if (pointer+width > end) return 0;
|
||||
// for (k = 1; k < width; k ++) {
|
||||
// octet = pointer[k];
|
||||
// if ((octet & 0xC0) != 0x80) return 0;
|
||||
// value = (value << 6) + (octet & 0x3F);
|
||||
// }
|
||||
// if (!((width == 1) ||
|
||||
// (width == 2 && value >= 0x80) ||
|
||||
// (width == 3 && value >= 0x800) ||
|
||||
// (width == 4 && value >= 0x10000))) return 0;
|
||||
//
|
||||
// pointer += width;
|
||||
// }
|
||||
//
|
||||
// return 1;
|
||||
//}
|
||||
//
|
||||
|
||||
// Create STREAM-START.
|
||||
func yaml_stream_start_event_initialize(event *yaml_event_t, encoding yaml_encoding_t) {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_STREAM_START_EVENT,
|
||||
encoding: encoding,
|
||||
}
|
||||
}
|
||||
|
||||
// Create STREAM-END.
|
||||
func yaml_stream_end_event_initialize(event *yaml_event_t) {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_STREAM_END_EVENT,
|
||||
}
|
||||
}
|
||||
|
||||
// Create DOCUMENT-START.
|
||||
func yaml_document_start_event_initialize(
|
||||
event *yaml_event_t,
|
||||
version_directive *yaml_version_directive_t,
|
||||
tag_directives []yaml_tag_directive_t,
|
||||
implicit bool,
|
||||
) {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_DOCUMENT_START_EVENT,
|
||||
version_directive: version_directive,
|
||||
tag_directives: tag_directives,
|
||||
implicit: implicit,
|
||||
}
|
||||
}
|
||||
|
||||
// Create DOCUMENT-END.
|
||||
func yaml_document_end_event_initialize(event *yaml_event_t, implicit bool) {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_DOCUMENT_END_EVENT,
|
||||
implicit: implicit,
|
||||
}
|
||||
}
|
||||
|
||||
// Create ALIAS.
|
||||
func yaml_alias_event_initialize(event *yaml_event_t, anchor []byte) bool {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_ALIAS_EVENT,
|
||||
anchor: anchor,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Create SCALAR.
|
||||
func yaml_scalar_event_initialize(event *yaml_event_t, anchor, tag, value []byte, plain_implicit, quoted_implicit bool, style yaml_scalar_style_t) bool {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_SCALAR_EVENT,
|
||||
anchor: anchor,
|
||||
tag: tag,
|
||||
value: value,
|
||||
implicit: plain_implicit,
|
||||
quoted_implicit: quoted_implicit,
|
||||
style: yaml_style_t(style),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Create SEQUENCE-START.
|
||||
func yaml_sequence_start_event_initialize(event *yaml_event_t, anchor, tag []byte, implicit bool, style yaml_sequence_style_t) bool {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_SEQUENCE_START_EVENT,
|
||||
anchor: anchor,
|
||||
tag: tag,
|
||||
implicit: implicit,
|
||||
style: yaml_style_t(style),
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Create SEQUENCE-END.
|
||||
func yaml_sequence_end_event_initialize(event *yaml_event_t) bool {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_SEQUENCE_END_EVENT,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Create MAPPING-START.
|
||||
func yaml_mapping_start_event_initialize(event *yaml_event_t, anchor, tag []byte, implicit bool, style yaml_mapping_style_t) {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_MAPPING_START_EVENT,
|
||||
anchor: anchor,
|
||||
tag: tag,
|
||||
implicit: implicit,
|
||||
style: yaml_style_t(style),
|
||||
}
|
||||
}
|
||||
|
||||
// Create MAPPING-END.
|
||||
func yaml_mapping_end_event_initialize(event *yaml_event_t) {
|
||||
*event = yaml_event_t{
|
||||
typ: yaml_MAPPING_END_EVENT,
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy an event object.
|
||||
func yaml_event_delete(event *yaml_event_t) {
|
||||
*event = yaml_event_t{}
|
||||
}
|
||||
|
||||
///*
|
||||
// * Create a document object.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(int)
|
||||
//yaml_document_initialize(document *yaml_document_t,
|
||||
// version_directive *yaml_version_directive_t,
|
||||
// tag_directives_start *yaml_tag_directive_t,
|
||||
// tag_directives_end *yaml_tag_directive_t,
|
||||
// start_implicit int, end_implicit int)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
// struct {
|
||||
// start *yaml_node_t
|
||||
// end *yaml_node_t
|
||||
// top *yaml_node_t
|
||||
// } nodes = { NULL, NULL, NULL }
|
||||
// version_directive_copy *yaml_version_directive_t = NULL
|
||||
// struct {
|
||||
// start *yaml_tag_directive_t
|
||||
// end *yaml_tag_directive_t
|
||||
// top *yaml_tag_directive_t
|
||||
// } tag_directives_copy = { NULL, NULL, NULL }
|
||||
// value yaml_tag_directive_t = { NULL, NULL }
|
||||
// mark yaml_mark_t = { 0, 0, 0 }
|
||||
//
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
// assert((tag_directives_start && tag_directives_end) ||
|
||||
// (tag_directives_start == tag_directives_end))
|
||||
// // Valid tag directives are expected.
|
||||
//
|
||||
// if (!STACK_INIT(&context, nodes, INITIAL_STACK_SIZE)) goto error
|
||||
//
|
||||
// if (version_directive) {
|
||||
// version_directive_copy = yaml_malloc(sizeof(yaml_version_directive_t))
|
||||
// if (!version_directive_copy) goto error
|
||||
// version_directive_copy.major = version_directive.major
|
||||
// version_directive_copy.minor = version_directive.minor
|
||||
// }
|
||||
//
|
||||
// if (tag_directives_start != tag_directives_end) {
|
||||
// tag_directive *yaml_tag_directive_t
|
||||
// if (!STACK_INIT(&context, tag_directives_copy, INITIAL_STACK_SIZE))
|
||||
// goto error
|
||||
// for (tag_directive = tag_directives_start
|
||||
// tag_directive != tag_directives_end; tag_directive ++) {
|
||||
// assert(tag_directive.handle)
|
||||
// assert(tag_directive.prefix)
|
||||
// if (!yaml_check_utf8(tag_directive.handle,
|
||||
// strlen((char *)tag_directive.handle)))
|
||||
// goto error
|
||||
// if (!yaml_check_utf8(tag_directive.prefix,
|
||||
// strlen((char *)tag_directive.prefix)))
|
||||
// goto error
|
||||
// value.handle = yaml_strdup(tag_directive.handle)
|
||||
// value.prefix = yaml_strdup(tag_directive.prefix)
|
||||
// if (!value.handle || !value.prefix) goto error
|
||||
// if (!PUSH(&context, tag_directives_copy, value))
|
||||
// goto error
|
||||
// value.handle = NULL
|
||||
// value.prefix = NULL
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// DOCUMENT_INIT(*document, nodes.start, nodes.end, version_directive_copy,
|
||||
// tag_directives_copy.start, tag_directives_copy.top,
|
||||
// start_implicit, end_implicit, mark, mark)
|
||||
//
|
||||
// return 1
|
||||
//
|
||||
//error:
|
||||
// STACK_DEL(&context, nodes)
|
||||
// yaml_free(version_directive_copy)
|
||||
// while (!STACK_EMPTY(&context, tag_directives_copy)) {
|
||||
// value yaml_tag_directive_t = POP(&context, tag_directives_copy)
|
||||
// yaml_free(value.handle)
|
||||
// yaml_free(value.prefix)
|
||||
// }
|
||||
// STACK_DEL(&context, tag_directives_copy)
|
||||
// yaml_free(value.handle)
|
||||
// yaml_free(value.prefix)
|
||||
//
|
||||
// return 0
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Destroy a document object.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(void)
|
||||
//yaml_document_delete(document *yaml_document_t)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
// tag_directive *yaml_tag_directive_t
|
||||
//
|
||||
// context.error = YAML_NO_ERROR // Eliminate a compiler warning.
|
||||
//
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
//
|
||||
// while (!STACK_EMPTY(&context, document.nodes)) {
|
||||
// node yaml_node_t = POP(&context, document.nodes)
|
||||
// yaml_free(node.tag)
|
||||
// switch (node.type) {
|
||||
// case YAML_SCALAR_NODE:
|
||||
// yaml_free(node.data.scalar.value)
|
||||
// break
|
||||
// case YAML_SEQUENCE_NODE:
|
||||
// STACK_DEL(&context, node.data.sequence.items)
|
||||
// break
|
||||
// case YAML_MAPPING_NODE:
|
||||
// STACK_DEL(&context, node.data.mapping.pairs)
|
||||
// break
|
||||
// default:
|
||||
// assert(0) // Should not happen.
|
||||
// }
|
||||
// }
|
||||
// STACK_DEL(&context, document.nodes)
|
||||
//
|
||||
// yaml_free(document.version_directive)
|
||||
// for (tag_directive = document.tag_directives.start
|
||||
// tag_directive != document.tag_directives.end
|
||||
// tag_directive++) {
|
||||
// yaml_free(tag_directive.handle)
|
||||
// yaml_free(tag_directive.prefix)
|
||||
// }
|
||||
// yaml_free(document.tag_directives.start)
|
||||
//
|
||||
// memset(document, 0, sizeof(yaml_document_t))
|
||||
//}
|
||||
//
|
||||
///**
|
||||
// * Get a document node.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(yaml_node_t *)
|
||||
//yaml_document_get_node(document *yaml_document_t, index int)
|
||||
//{
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
//
|
||||
// if (index > 0 && document.nodes.start + index <= document.nodes.top) {
|
||||
// return document.nodes.start + index - 1
|
||||
// }
|
||||
// return NULL
|
||||
//}
|
||||
//
|
||||
///**
|
||||
// * Get the root object.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(yaml_node_t *)
|
||||
//yaml_document_get_root_node(document *yaml_document_t)
|
||||
//{
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
//
|
||||
// if (document.nodes.top != document.nodes.start) {
|
||||
// return document.nodes.start
|
||||
// }
|
||||
// return NULL
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Add a scalar node to a document.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(int)
|
||||
//yaml_document_add_scalar(document *yaml_document_t,
|
||||
// tag *yaml_char_t, value *yaml_char_t, length int,
|
||||
// style yaml_scalar_style_t)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
// mark yaml_mark_t = { 0, 0, 0 }
|
||||
// tag_copy *yaml_char_t = NULL
|
||||
// value_copy *yaml_char_t = NULL
|
||||
// node yaml_node_t
|
||||
//
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
// assert(value) // Non-NULL value is expected.
|
||||
//
|
||||
// if (!tag) {
|
||||
// tag = (yaml_char_t *)YAML_DEFAULT_SCALAR_TAG
|
||||
// }
|
||||
//
|
||||
// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error
|
||||
// tag_copy = yaml_strdup(tag)
|
||||
// if (!tag_copy) goto error
|
||||
//
|
||||
// if (length < 0) {
|
||||
// length = strlen((char *)value)
|
||||
// }
|
||||
//
|
||||
// if (!yaml_check_utf8(value, length)) goto error
|
||||
// value_copy = yaml_malloc(length+1)
|
||||
// if (!value_copy) goto error
|
||||
// memcpy(value_copy, value, length)
|
||||
// value_copy[length] = '\0'
|
||||
//
|
||||
// SCALAR_NODE_INIT(node, tag_copy, value_copy, length, style, mark, mark)
|
||||
// if (!PUSH(&context, document.nodes, node)) goto error
|
||||
//
|
||||
// return document.nodes.top - document.nodes.start
|
||||
//
|
||||
//error:
|
||||
// yaml_free(tag_copy)
|
||||
// yaml_free(value_copy)
|
||||
//
|
||||
// return 0
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Add a sequence node to a document.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(int)
|
||||
//yaml_document_add_sequence(document *yaml_document_t,
|
||||
// tag *yaml_char_t, style yaml_sequence_style_t)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
// mark yaml_mark_t = { 0, 0, 0 }
|
||||
// tag_copy *yaml_char_t = NULL
|
||||
// struct {
|
||||
// start *yaml_node_item_t
|
||||
// end *yaml_node_item_t
|
||||
// top *yaml_node_item_t
|
||||
// } items = { NULL, NULL, NULL }
|
||||
// node yaml_node_t
|
||||
//
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
//
|
||||
// if (!tag) {
|
||||
// tag = (yaml_char_t *)YAML_DEFAULT_SEQUENCE_TAG
|
||||
// }
|
||||
//
|
||||
// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error
|
||||
// tag_copy = yaml_strdup(tag)
|
||||
// if (!tag_copy) goto error
|
||||
//
|
||||
// if (!STACK_INIT(&context, items, INITIAL_STACK_SIZE)) goto error
|
||||
//
|
||||
// SEQUENCE_NODE_INIT(node, tag_copy, items.start, items.end,
|
||||
// style, mark, mark)
|
||||
// if (!PUSH(&context, document.nodes, node)) goto error
|
||||
//
|
||||
// return document.nodes.top - document.nodes.start
|
||||
//
|
||||
//error:
|
||||
// STACK_DEL(&context, items)
|
||||
// yaml_free(tag_copy)
|
||||
//
|
||||
// return 0
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Add a mapping node to a document.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(int)
|
||||
//yaml_document_add_mapping(document *yaml_document_t,
|
||||
// tag *yaml_char_t, style yaml_mapping_style_t)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
// mark yaml_mark_t = { 0, 0, 0 }
|
||||
// tag_copy *yaml_char_t = NULL
|
||||
// struct {
|
||||
// start *yaml_node_pair_t
|
||||
// end *yaml_node_pair_t
|
||||
// top *yaml_node_pair_t
|
||||
// } pairs = { NULL, NULL, NULL }
|
||||
// node yaml_node_t
|
||||
//
|
||||
// assert(document) // Non-NULL document object is expected.
|
||||
//
|
||||
// if (!tag) {
|
||||
// tag = (yaml_char_t *)YAML_DEFAULT_MAPPING_TAG
|
||||
// }
|
||||
//
|
||||
// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error
|
||||
// tag_copy = yaml_strdup(tag)
|
||||
// if (!tag_copy) goto error
|
||||
//
|
||||
// if (!STACK_INIT(&context, pairs, INITIAL_STACK_SIZE)) goto error
|
||||
//
|
||||
// MAPPING_NODE_INIT(node, tag_copy, pairs.start, pairs.end,
|
||||
// style, mark, mark)
|
||||
// if (!PUSH(&context, document.nodes, node)) goto error
|
||||
//
|
||||
// return document.nodes.top - document.nodes.start
|
||||
//
|
||||
//error:
|
||||
// STACK_DEL(&context, pairs)
|
||||
// yaml_free(tag_copy)
|
||||
//
|
||||
// return 0
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Append an item to a sequence node.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(int)
|
||||
//yaml_document_append_sequence_item(document *yaml_document_t,
|
||||
// sequence int, item int)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
//
|
||||
// assert(document) // Non-NULL document is required.
|
||||
// assert(sequence > 0
|
||||
// && document.nodes.start + sequence <= document.nodes.top)
|
||||
// // Valid sequence id is required.
|
||||
// assert(document.nodes.start[sequence-1].type == YAML_SEQUENCE_NODE)
|
||||
// // A sequence node is required.
|
||||
// assert(item > 0 && document.nodes.start + item <= document.nodes.top)
|
||||
// // Valid item id is required.
|
||||
//
|
||||
// if (!PUSH(&context,
|
||||
// document.nodes.start[sequence-1].data.sequence.items, item))
|
||||
// return 0
|
||||
//
|
||||
// return 1
|
||||
//}
|
||||
//
|
||||
///*
|
||||
// * Append a pair of a key and a value to a mapping node.
|
||||
// */
|
||||
//
|
||||
//YAML_DECLARE(int)
|
||||
//yaml_document_append_mapping_pair(document *yaml_document_t,
|
||||
// mapping int, key int, value int)
|
||||
//{
|
||||
// struct {
|
||||
// error yaml_error_type_t
|
||||
// } context
|
||||
//
|
||||
// pair yaml_node_pair_t
|
||||
//
|
||||
// assert(document) // Non-NULL document is required.
|
||||
// assert(mapping > 0
|
||||
// && document.nodes.start + mapping <= document.nodes.top)
|
||||
// // Valid mapping id is required.
|
||||
// assert(document.nodes.start[mapping-1].type == YAML_MAPPING_NODE)
|
||||
// // A mapping node is required.
|
||||
// assert(key > 0 && document.nodes.start + key <= document.nodes.top)
|
||||
// // Valid key id is required.
|
||||
// assert(value > 0 && document.nodes.start + value <= document.nodes.top)
|
||||
// // Valid value id is required.
|
||||
//
|
||||
// pair.key = key
|
||||
// pair.value = value
|
||||
//
|
||||
// if (!PUSH(&context,
|
||||
// document.nodes.start[mapping-1].data.mapping.pairs, pair))
|
||||
// return 0
|
||||
//
|
||||
// return 1
|
||||
//}
|
||||
//
|
||||
//
|
||||
-1000
File diff suppressed because it is too large
Load Diff
-2020
File diff suppressed because it is too large
Load Diff
-577
@@ -1,577 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type encoder struct {
|
||||
emitter yaml_emitter_t
|
||||
event yaml_event_t
|
||||
out []byte
|
||||
flow bool
|
||||
indent int
|
||||
doneInit bool
|
||||
}
|
||||
|
||||
func newEncoder() *encoder {
|
||||
e := &encoder{}
|
||||
yaml_emitter_initialize(&e.emitter)
|
||||
yaml_emitter_set_output_string(&e.emitter, &e.out)
|
||||
yaml_emitter_set_unicode(&e.emitter, true)
|
||||
return e
|
||||
}
|
||||
|
||||
func newEncoderWithWriter(w io.Writer) *encoder {
|
||||
e := &encoder{}
|
||||
yaml_emitter_initialize(&e.emitter)
|
||||
yaml_emitter_set_output_writer(&e.emitter, w)
|
||||
yaml_emitter_set_unicode(&e.emitter, true)
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *encoder) init() {
|
||||
if e.doneInit {
|
||||
return
|
||||
}
|
||||
if e.indent == 0 {
|
||||
e.indent = 4
|
||||
}
|
||||
e.emitter.best_indent = e.indent
|
||||
yaml_stream_start_event_initialize(&e.event, yaml_UTF8_ENCODING)
|
||||
e.emit()
|
||||
e.doneInit = true
|
||||
}
|
||||
|
||||
func (e *encoder) finish() {
|
||||
e.emitter.open_ended = false
|
||||
yaml_stream_end_event_initialize(&e.event)
|
||||
e.emit()
|
||||
}
|
||||
|
||||
func (e *encoder) destroy() {
|
||||
yaml_emitter_delete(&e.emitter)
|
||||
}
|
||||
|
||||
func (e *encoder) emit() {
|
||||
// This will internally delete the e.event value.
|
||||
e.must(yaml_emitter_emit(&e.emitter, &e.event))
|
||||
}
|
||||
|
||||
func (e *encoder) must(ok bool) {
|
||||
if !ok {
|
||||
msg := e.emitter.problem
|
||||
if msg == "" {
|
||||
msg = "unknown problem generating YAML content"
|
||||
}
|
||||
failf("%s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) marshalDoc(tag string, in reflect.Value) {
|
||||
e.init()
|
||||
var node *Node
|
||||
if in.IsValid() {
|
||||
node, _ = in.Interface().(*Node)
|
||||
}
|
||||
if node != nil && node.Kind == DocumentNode {
|
||||
e.nodev(in)
|
||||
} else {
|
||||
yaml_document_start_event_initialize(&e.event, nil, nil, true)
|
||||
e.emit()
|
||||
e.marshal(tag, in)
|
||||
yaml_document_end_event_initialize(&e.event, true)
|
||||
e.emit()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) marshal(tag string, in reflect.Value) {
|
||||
tag = shortTag(tag)
|
||||
if !in.IsValid() || in.Kind() == reflect.Ptr && in.IsNil() {
|
||||
e.nilv()
|
||||
return
|
||||
}
|
||||
iface := in.Interface()
|
||||
switch value := iface.(type) {
|
||||
case *Node:
|
||||
e.nodev(in)
|
||||
return
|
||||
case Node:
|
||||
if !in.CanAddr() {
|
||||
var n = reflect.New(in.Type()).Elem()
|
||||
n.Set(in)
|
||||
in = n
|
||||
}
|
||||
e.nodev(in.Addr())
|
||||
return
|
||||
case time.Time:
|
||||
e.timev(tag, in)
|
||||
return
|
||||
case *time.Time:
|
||||
e.timev(tag, in.Elem())
|
||||
return
|
||||
case time.Duration:
|
||||
e.stringv(tag, reflect.ValueOf(value.String()))
|
||||
return
|
||||
case Marshaler:
|
||||
v, err := value.MarshalYAML()
|
||||
if err != nil {
|
||||
fail(err)
|
||||
}
|
||||
if v == nil {
|
||||
e.nilv()
|
||||
return
|
||||
}
|
||||
e.marshal(tag, reflect.ValueOf(v))
|
||||
return
|
||||
case encoding.TextMarshaler:
|
||||
text, err := value.MarshalText()
|
||||
if err != nil {
|
||||
fail(err)
|
||||
}
|
||||
in = reflect.ValueOf(string(text))
|
||||
case nil:
|
||||
e.nilv()
|
||||
return
|
||||
}
|
||||
switch in.Kind() {
|
||||
case reflect.Interface:
|
||||
e.marshal(tag, in.Elem())
|
||||
case reflect.Map:
|
||||
e.mapv(tag, in)
|
||||
case reflect.Ptr:
|
||||
e.marshal(tag, in.Elem())
|
||||
case reflect.Struct:
|
||||
e.structv(tag, in)
|
||||
case reflect.Slice, reflect.Array:
|
||||
e.slicev(tag, in)
|
||||
case reflect.String:
|
||||
e.stringv(tag, in)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
e.intv(tag, in)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
e.uintv(tag, in)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
e.floatv(tag, in)
|
||||
case reflect.Bool:
|
||||
e.boolv(tag, in)
|
||||
default:
|
||||
panic("cannot marshal type: " + in.Type().String())
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) mapv(tag string, in reflect.Value) {
|
||||
e.mappingv(tag, func() {
|
||||
keys := keyList(in.MapKeys())
|
||||
sort.Sort(keys)
|
||||
for _, k := range keys {
|
||||
e.marshal("", k)
|
||||
e.marshal("", in.MapIndex(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (e *encoder) fieldByIndex(v reflect.Value, index []int) (field reflect.Value) {
|
||||
for _, num := range index {
|
||||
for {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return reflect.Value{}
|
||||
}
|
||||
v = v.Elem()
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
v = v.Field(num)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (e *encoder) structv(tag string, in reflect.Value) {
|
||||
sinfo, err := getStructInfo(in.Type())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
e.mappingv(tag, func() {
|
||||
for _, info := range sinfo.FieldsList {
|
||||
var value reflect.Value
|
||||
if info.Inline == nil {
|
||||
value = in.Field(info.Num)
|
||||
} else {
|
||||
value = e.fieldByIndex(in, info.Inline)
|
||||
if !value.IsValid() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if info.OmitEmpty && isZero(value) {
|
||||
continue
|
||||
}
|
||||
e.marshal("", reflect.ValueOf(info.Key))
|
||||
e.flow = info.Flow
|
||||
e.marshal("", value)
|
||||
}
|
||||
if sinfo.InlineMap >= 0 {
|
||||
m := in.Field(sinfo.InlineMap)
|
||||
if m.Len() > 0 {
|
||||
e.flow = false
|
||||
keys := keyList(m.MapKeys())
|
||||
sort.Sort(keys)
|
||||
for _, k := range keys {
|
||||
if _, found := sinfo.FieldsMap[k.String()]; found {
|
||||
panic(fmt.Sprintf("cannot have key %q in inlined map: conflicts with struct field", k.String()))
|
||||
}
|
||||
e.marshal("", k)
|
||||
e.flow = false
|
||||
e.marshal("", m.MapIndex(k))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (e *encoder) mappingv(tag string, f func()) {
|
||||
implicit := tag == ""
|
||||
style := yaml_BLOCK_MAPPING_STYLE
|
||||
if e.flow {
|
||||
e.flow = false
|
||||
style = yaml_FLOW_MAPPING_STYLE
|
||||
}
|
||||
yaml_mapping_start_event_initialize(&e.event, nil, []byte(tag), implicit, style)
|
||||
e.emit()
|
||||
f()
|
||||
yaml_mapping_end_event_initialize(&e.event)
|
||||
e.emit()
|
||||
}
|
||||
|
||||
func (e *encoder) slicev(tag string, in reflect.Value) {
|
||||
implicit := tag == ""
|
||||
style := yaml_BLOCK_SEQUENCE_STYLE
|
||||
if e.flow {
|
||||
e.flow = false
|
||||
style = yaml_FLOW_SEQUENCE_STYLE
|
||||
}
|
||||
e.must(yaml_sequence_start_event_initialize(&e.event, nil, []byte(tag), implicit, style))
|
||||
e.emit()
|
||||
n := in.Len()
|
||||
for i := 0; i < n; i++ {
|
||||
e.marshal("", in.Index(i))
|
||||
}
|
||||
e.must(yaml_sequence_end_event_initialize(&e.event))
|
||||
e.emit()
|
||||
}
|
||||
|
||||
// isBase60 returns whether s is in base 60 notation as defined in YAML 1.1.
|
||||
//
|
||||
// The base 60 float notation in YAML 1.1 is a terrible idea and is unsupported
|
||||
// in YAML 1.2 and by this package, but these should be marshalled quoted for
|
||||
// the time being for compatibility with other parsers.
|
||||
func isBase60Float(s string) (result bool) {
|
||||
// Fast path.
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
c := s[0]
|
||||
if !(c == '+' || c == '-' || c >= '0' && c <= '9') || strings.IndexByte(s, ':') < 0 {
|
||||
return false
|
||||
}
|
||||
// Do the full match.
|
||||
return base60float.MatchString(s)
|
||||
}
|
||||
|
||||
// From http://yaml.org/type/float.html, except the regular expression there
|
||||
// is bogus. In practice parsers do not enforce the "\.[0-9_]*" suffix.
|
||||
var base60float = regexp.MustCompile(`^[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+(?:\.[0-9_]*)?$`)
|
||||
|
||||
// isOldBool returns whether s is bool notation as defined in YAML 1.1.
|
||||
//
|
||||
// We continue to force strings that YAML 1.1 would interpret as booleans to be
|
||||
// rendered as quotes strings so that the marshalled output valid for YAML 1.1
|
||||
// parsing.
|
||||
func isOldBool(s string) (result bool) {
|
||||
switch s {
|
||||
case "y", "Y", "yes", "Yes", "YES", "on", "On", "ON",
|
||||
"n", "N", "no", "No", "NO", "off", "Off", "OFF":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (e *encoder) stringv(tag string, in reflect.Value) {
|
||||
var style yaml_scalar_style_t
|
||||
s := in.String()
|
||||
canUsePlain := true
|
||||
switch {
|
||||
case !utf8.ValidString(s):
|
||||
if tag == binaryTag {
|
||||
failf("explicitly tagged !!binary data must be base64-encoded")
|
||||
}
|
||||
if tag != "" {
|
||||
failf("cannot marshal invalid UTF-8 data as %s", shortTag(tag))
|
||||
}
|
||||
// It can't be encoded directly as YAML so use a binary tag
|
||||
// and encode it as base64.
|
||||
tag = binaryTag
|
||||
s = encodeBase64(s)
|
||||
case tag == "":
|
||||
// Check to see if it would resolve to a specific
|
||||
// tag when encoded unquoted. If it doesn't,
|
||||
// there's no need to quote it.
|
||||
rtag, _ := resolve("", s)
|
||||
canUsePlain = rtag == strTag && !(isBase60Float(s) || isOldBool(s))
|
||||
}
|
||||
// Note: it's possible for user code to emit invalid YAML
|
||||
// if they explicitly specify a tag and a string containing
|
||||
// text that's incompatible with that tag.
|
||||
switch {
|
||||
case strings.Contains(s, "\n"):
|
||||
if e.flow {
|
||||
style = yaml_DOUBLE_QUOTED_SCALAR_STYLE
|
||||
} else {
|
||||
style = yaml_LITERAL_SCALAR_STYLE
|
||||
}
|
||||
case canUsePlain:
|
||||
style = yaml_PLAIN_SCALAR_STYLE
|
||||
default:
|
||||
style = yaml_DOUBLE_QUOTED_SCALAR_STYLE
|
||||
}
|
||||
e.emitScalar(s, "", tag, style, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) boolv(tag string, in reflect.Value) {
|
||||
var s string
|
||||
if in.Bool() {
|
||||
s = "true"
|
||||
} else {
|
||||
s = "false"
|
||||
}
|
||||
e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) intv(tag string, in reflect.Value) {
|
||||
s := strconv.FormatInt(in.Int(), 10)
|
||||
e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) uintv(tag string, in reflect.Value) {
|
||||
s := strconv.FormatUint(in.Uint(), 10)
|
||||
e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) timev(tag string, in reflect.Value) {
|
||||
t := in.Interface().(time.Time)
|
||||
s := t.Format(time.RFC3339Nano)
|
||||
e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) floatv(tag string, in reflect.Value) {
|
||||
// Issue #352: When formatting, use the precision of the underlying value
|
||||
precision := 64
|
||||
if in.Kind() == reflect.Float32 {
|
||||
precision = 32
|
||||
}
|
||||
|
||||
s := strconv.FormatFloat(in.Float(), 'g', -1, precision)
|
||||
switch s {
|
||||
case "+Inf":
|
||||
s = ".inf"
|
||||
case "-Inf":
|
||||
s = "-.inf"
|
||||
case "NaN":
|
||||
s = ".nan"
|
||||
}
|
||||
e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) nilv() {
|
||||
e.emitScalar("null", "", "", yaml_PLAIN_SCALAR_STYLE, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (e *encoder) emitScalar(value, anchor, tag string, style yaml_scalar_style_t, head, line, foot, tail []byte) {
|
||||
// TODO Kill this function. Replace all initialize calls by their underlining Go literals.
|
||||
implicit := tag == ""
|
||||
if !implicit {
|
||||
tag = longTag(tag)
|
||||
}
|
||||
e.must(yaml_scalar_event_initialize(&e.event, []byte(anchor), []byte(tag), []byte(value), implicit, implicit, style))
|
||||
e.event.head_comment = head
|
||||
e.event.line_comment = line
|
||||
e.event.foot_comment = foot
|
||||
e.event.tail_comment = tail
|
||||
e.emit()
|
||||
}
|
||||
|
||||
func (e *encoder) nodev(in reflect.Value) {
|
||||
e.node(in.Interface().(*Node), "")
|
||||
}
|
||||
|
||||
func (e *encoder) node(node *Node, tail string) {
|
||||
// Zero nodes behave as nil.
|
||||
if node.Kind == 0 && node.IsZero() {
|
||||
e.nilv()
|
||||
return
|
||||
}
|
||||
|
||||
// If the tag was not explicitly requested, and dropping it won't change the
|
||||
// implicit tag of the value, don't include it in the presentation.
|
||||
var tag = node.Tag
|
||||
var stag = shortTag(tag)
|
||||
var forceQuoting bool
|
||||
if tag != "" && node.Style&TaggedStyle == 0 {
|
||||
if node.Kind == ScalarNode {
|
||||
if stag == strTag && node.Style&(SingleQuotedStyle|DoubleQuotedStyle|LiteralStyle|FoldedStyle) != 0 {
|
||||
tag = ""
|
||||
} else {
|
||||
rtag, _ := resolve("", node.Value)
|
||||
if rtag == stag {
|
||||
tag = ""
|
||||
} else if stag == strTag {
|
||||
tag = ""
|
||||
forceQuoting = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var rtag string
|
||||
switch node.Kind {
|
||||
case MappingNode:
|
||||
rtag = mapTag
|
||||
case SequenceNode:
|
||||
rtag = seqTag
|
||||
}
|
||||
if rtag == stag {
|
||||
tag = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case DocumentNode:
|
||||
yaml_document_start_event_initialize(&e.event, nil, nil, true)
|
||||
e.event.head_comment = []byte(node.HeadComment)
|
||||
e.emit()
|
||||
for _, node := range node.Content {
|
||||
e.node(node, "")
|
||||
}
|
||||
yaml_document_end_event_initialize(&e.event, true)
|
||||
e.event.foot_comment = []byte(node.FootComment)
|
||||
e.emit()
|
||||
|
||||
case SequenceNode:
|
||||
style := yaml_BLOCK_SEQUENCE_STYLE
|
||||
if node.Style&FlowStyle != 0 {
|
||||
style = yaml_FLOW_SEQUENCE_STYLE
|
||||
}
|
||||
e.must(yaml_sequence_start_event_initialize(&e.event, []byte(node.Anchor), []byte(longTag(tag)), tag == "", style))
|
||||
e.event.head_comment = []byte(node.HeadComment)
|
||||
e.emit()
|
||||
for _, node := range node.Content {
|
||||
e.node(node, "")
|
||||
}
|
||||
e.must(yaml_sequence_end_event_initialize(&e.event))
|
||||
e.event.line_comment = []byte(node.LineComment)
|
||||
e.event.foot_comment = []byte(node.FootComment)
|
||||
e.emit()
|
||||
|
||||
case MappingNode:
|
||||
style := yaml_BLOCK_MAPPING_STYLE
|
||||
if node.Style&FlowStyle != 0 {
|
||||
style = yaml_FLOW_MAPPING_STYLE
|
||||
}
|
||||
yaml_mapping_start_event_initialize(&e.event, []byte(node.Anchor), []byte(longTag(tag)), tag == "", style)
|
||||
e.event.tail_comment = []byte(tail)
|
||||
e.event.head_comment = []byte(node.HeadComment)
|
||||
e.emit()
|
||||
|
||||
// The tail logic below moves the foot comment of prior keys to the following key,
|
||||
// since the value for each key may be a nested structure and the foot needs to be
|
||||
// processed only the entirety of the value is streamed. The last tail is processed
|
||||
// with the mapping end event.
|
||||
var tail string
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
k := node.Content[i]
|
||||
foot := k.FootComment
|
||||
if foot != "" {
|
||||
kopy := *k
|
||||
kopy.FootComment = ""
|
||||
k = &kopy
|
||||
}
|
||||
e.node(k, tail)
|
||||
tail = foot
|
||||
|
||||
v := node.Content[i+1]
|
||||
e.node(v, "")
|
||||
}
|
||||
|
||||
yaml_mapping_end_event_initialize(&e.event)
|
||||
e.event.tail_comment = []byte(tail)
|
||||
e.event.line_comment = []byte(node.LineComment)
|
||||
e.event.foot_comment = []byte(node.FootComment)
|
||||
e.emit()
|
||||
|
||||
case AliasNode:
|
||||
yaml_alias_event_initialize(&e.event, []byte(node.Value))
|
||||
e.event.head_comment = []byte(node.HeadComment)
|
||||
e.event.line_comment = []byte(node.LineComment)
|
||||
e.event.foot_comment = []byte(node.FootComment)
|
||||
e.emit()
|
||||
|
||||
case ScalarNode:
|
||||
value := node.Value
|
||||
if !utf8.ValidString(value) {
|
||||
if stag == binaryTag {
|
||||
failf("explicitly tagged !!binary data must be base64-encoded")
|
||||
}
|
||||
if stag != "" {
|
||||
failf("cannot marshal invalid UTF-8 data as %s", stag)
|
||||
}
|
||||
// It can't be encoded directly as YAML so use a binary tag
|
||||
// and encode it as base64.
|
||||
tag = binaryTag
|
||||
value = encodeBase64(value)
|
||||
}
|
||||
|
||||
style := yaml_PLAIN_SCALAR_STYLE
|
||||
switch {
|
||||
case node.Style&DoubleQuotedStyle != 0:
|
||||
style = yaml_DOUBLE_QUOTED_SCALAR_STYLE
|
||||
case node.Style&SingleQuotedStyle != 0:
|
||||
style = yaml_SINGLE_QUOTED_SCALAR_STYLE
|
||||
case node.Style&LiteralStyle != 0:
|
||||
style = yaml_LITERAL_SCALAR_STYLE
|
||||
case node.Style&FoldedStyle != 0:
|
||||
style = yaml_FOLDED_SCALAR_STYLE
|
||||
case strings.Contains(value, "\n"):
|
||||
style = yaml_LITERAL_SCALAR_STYLE
|
||||
case forceQuoting:
|
||||
style = yaml_DOUBLE_QUOTED_SCALAR_STYLE
|
||||
}
|
||||
|
||||
e.emitScalar(value, node.Anchor, tag, style, []byte(node.HeadComment), []byte(node.LineComment), []byte(node.FootComment), []byte(tail))
|
||||
default:
|
||||
failf("cannot encode node with unknown kind %d", node.Kind)
|
||||
}
|
||||
}
|
||||
-1258
File diff suppressed because it is too large
Load Diff
-434
@@ -1,434 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
// Copyright (c) 2006-2010 Kirill Simonov
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
// of the Software, and to permit persons to whom the Software is furnished to do
|
||||
// so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// Set the reader error and return 0.
|
||||
func yaml_parser_set_reader_error(parser *yaml_parser_t, problem string, offset int, value int) bool {
|
||||
parser.error = yaml_READER_ERROR
|
||||
parser.problem = problem
|
||||
parser.problem_offset = offset
|
||||
parser.problem_value = value
|
||||
return false
|
||||
}
|
||||
|
||||
// Byte order marks.
|
||||
const (
|
||||
bom_UTF8 = "\xef\xbb\xbf"
|
||||
bom_UTF16LE = "\xff\xfe"
|
||||
bom_UTF16BE = "\xfe\xff"
|
||||
)
|
||||
|
||||
// Determine the input stream encoding by checking the BOM symbol. If no BOM is
|
||||
// found, the UTF-8 encoding is assumed. Return 1 on success, 0 on failure.
|
||||
func yaml_parser_determine_encoding(parser *yaml_parser_t) bool {
|
||||
// Ensure that we had enough bytes in the raw buffer.
|
||||
for !parser.eof && len(parser.raw_buffer)-parser.raw_buffer_pos < 3 {
|
||||
if !yaml_parser_update_raw_buffer(parser) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the encoding.
|
||||
buf := parser.raw_buffer
|
||||
pos := parser.raw_buffer_pos
|
||||
avail := len(buf) - pos
|
||||
if avail >= 2 && buf[pos] == bom_UTF16LE[0] && buf[pos+1] == bom_UTF16LE[1] {
|
||||
parser.encoding = yaml_UTF16LE_ENCODING
|
||||
parser.raw_buffer_pos += 2
|
||||
parser.offset += 2
|
||||
} else if avail >= 2 && buf[pos] == bom_UTF16BE[0] && buf[pos+1] == bom_UTF16BE[1] {
|
||||
parser.encoding = yaml_UTF16BE_ENCODING
|
||||
parser.raw_buffer_pos += 2
|
||||
parser.offset += 2
|
||||
} else if avail >= 3 && buf[pos] == bom_UTF8[0] && buf[pos+1] == bom_UTF8[1] && buf[pos+2] == bom_UTF8[2] {
|
||||
parser.encoding = yaml_UTF8_ENCODING
|
||||
parser.raw_buffer_pos += 3
|
||||
parser.offset += 3
|
||||
} else {
|
||||
parser.encoding = yaml_UTF8_ENCODING
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Update the raw buffer.
|
||||
func yaml_parser_update_raw_buffer(parser *yaml_parser_t) bool {
|
||||
size_read := 0
|
||||
|
||||
// Return if the raw buffer is full.
|
||||
if parser.raw_buffer_pos == 0 && len(parser.raw_buffer) == cap(parser.raw_buffer) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Return on EOF.
|
||||
if parser.eof {
|
||||
return true
|
||||
}
|
||||
|
||||
// Move the remaining bytes in the raw buffer to the beginning.
|
||||
if parser.raw_buffer_pos > 0 && parser.raw_buffer_pos < len(parser.raw_buffer) {
|
||||
copy(parser.raw_buffer, parser.raw_buffer[parser.raw_buffer_pos:])
|
||||
}
|
||||
parser.raw_buffer = parser.raw_buffer[:len(parser.raw_buffer)-parser.raw_buffer_pos]
|
||||
parser.raw_buffer_pos = 0
|
||||
|
||||
// Call the read handler to fill the buffer.
|
||||
size_read, err := parser.read_handler(parser, parser.raw_buffer[len(parser.raw_buffer):cap(parser.raw_buffer)])
|
||||
parser.raw_buffer = parser.raw_buffer[:len(parser.raw_buffer)+size_read]
|
||||
if err == io.EOF {
|
||||
parser.eof = true
|
||||
} else if err != nil {
|
||||
return yaml_parser_set_reader_error(parser, "input error: "+err.Error(), parser.offset, -1)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure that the buffer contains at least `length` characters.
|
||||
// Return true on success, false on failure.
|
||||
//
|
||||
// The length is supposed to be significantly less that the buffer size.
|
||||
func yaml_parser_update_buffer(parser *yaml_parser_t, length int) bool {
|
||||
if parser.read_handler == nil {
|
||||
panic("read handler must be set")
|
||||
}
|
||||
|
||||
// [Go] This function was changed to guarantee the requested length size at EOF.
|
||||
// The fact we need to do this is pretty awful, but the description above implies
|
||||
// for that to be the case, and there are tests
|
||||
|
||||
// If the EOF flag is set and the raw buffer is empty, do nothing.
|
||||
if parser.eof && parser.raw_buffer_pos == len(parser.raw_buffer) {
|
||||
// [Go] ACTUALLY! Read the documentation of this function above.
|
||||
// This is just broken. To return true, we need to have the
|
||||
// given length in the buffer. Not doing that means every single
|
||||
// check that calls this function to make sure the buffer has a
|
||||
// given length is Go) panicking; or C) accessing invalid memory.
|
||||
//return true
|
||||
}
|
||||
|
||||
// Return if the buffer contains enough characters.
|
||||
if parser.unread >= length {
|
||||
return true
|
||||
}
|
||||
|
||||
// Determine the input encoding if it is not known yet.
|
||||
if parser.encoding == yaml_ANY_ENCODING {
|
||||
if !yaml_parser_determine_encoding(parser) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Move the unread characters to the beginning of the buffer.
|
||||
buffer_len := len(parser.buffer)
|
||||
if parser.buffer_pos > 0 && parser.buffer_pos < buffer_len {
|
||||
copy(parser.buffer, parser.buffer[parser.buffer_pos:])
|
||||
buffer_len -= parser.buffer_pos
|
||||
parser.buffer_pos = 0
|
||||
} else if parser.buffer_pos == buffer_len {
|
||||
buffer_len = 0
|
||||
parser.buffer_pos = 0
|
||||
}
|
||||
|
||||
// Open the whole buffer for writing, and cut it before returning.
|
||||
parser.buffer = parser.buffer[:cap(parser.buffer)]
|
||||
|
||||
// Fill the buffer until it has enough characters.
|
||||
first := true
|
||||
for parser.unread < length {
|
||||
|
||||
// Fill the raw buffer if necessary.
|
||||
if !first || parser.raw_buffer_pos == len(parser.raw_buffer) {
|
||||
if !yaml_parser_update_raw_buffer(parser) {
|
||||
parser.buffer = parser.buffer[:buffer_len]
|
||||
return false
|
||||
}
|
||||
}
|
||||
first = false
|
||||
|
||||
// Decode the raw buffer.
|
||||
inner:
|
||||
for parser.raw_buffer_pos != len(parser.raw_buffer) {
|
||||
var value rune
|
||||
var width int
|
||||
|
||||
raw_unread := len(parser.raw_buffer) - parser.raw_buffer_pos
|
||||
|
||||
// Decode the next character.
|
||||
switch parser.encoding {
|
||||
case yaml_UTF8_ENCODING:
|
||||
// Decode a UTF-8 character. Check RFC 3629
|
||||
// (http://www.ietf.org/rfc/rfc3629.txt) for more details.
|
||||
//
|
||||
// The following table (taken from the RFC) is used for
|
||||
// decoding.
|
||||
//
|
||||
// Char. number range | UTF-8 octet sequence
|
||||
// (hexadecimal) | (binary)
|
||||
// --------------------+------------------------------------
|
||||
// 0000 0000-0000 007F | 0xxxxxxx
|
||||
// 0000 0080-0000 07FF | 110xxxxx 10xxxxxx
|
||||
// 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx
|
||||
// 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
//
|
||||
// Additionally, the characters in the range 0xD800-0xDFFF
|
||||
// are prohibited as they are reserved for use with UTF-16
|
||||
// surrogate pairs.
|
||||
|
||||
// Determine the length of the UTF-8 sequence.
|
||||
octet := parser.raw_buffer[parser.raw_buffer_pos]
|
||||
switch {
|
||||
case octet&0x80 == 0x00:
|
||||
width = 1
|
||||
case octet&0xE0 == 0xC0:
|
||||
width = 2
|
||||
case octet&0xF0 == 0xE0:
|
||||
width = 3
|
||||
case octet&0xF8 == 0xF0:
|
||||
width = 4
|
||||
default:
|
||||
// The leading octet is invalid.
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"invalid leading UTF-8 octet",
|
||||
parser.offset, int(octet))
|
||||
}
|
||||
|
||||
// Check if the raw buffer contains an incomplete character.
|
||||
if width > raw_unread {
|
||||
if parser.eof {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"incomplete UTF-8 octet sequence",
|
||||
parser.offset, -1)
|
||||
}
|
||||
break inner
|
||||
}
|
||||
|
||||
// Decode the leading octet.
|
||||
switch {
|
||||
case octet&0x80 == 0x00:
|
||||
value = rune(octet & 0x7F)
|
||||
case octet&0xE0 == 0xC0:
|
||||
value = rune(octet & 0x1F)
|
||||
case octet&0xF0 == 0xE0:
|
||||
value = rune(octet & 0x0F)
|
||||
case octet&0xF8 == 0xF0:
|
||||
value = rune(octet & 0x07)
|
||||
default:
|
||||
value = 0
|
||||
}
|
||||
|
||||
// Check and decode the trailing octets.
|
||||
for k := 1; k < width; k++ {
|
||||
octet = parser.raw_buffer[parser.raw_buffer_pos+k]
|
||||
|
||||
// Check if the octet is valid.
|
||||
if (octet & 0xC0) != 0x80 {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"invalid trailing UTF-8 octet",
|
||||
parser.offset+k, int(octet))
|
||||
}
|
||||
|
||||
// Decode the octet.
|
||||
value = (value << 6) + rune(octet&0x3F)
|
||||
}
|
||||
|
||||
// Check the length of the sequence against the value.
|
||||
switch {
|
||||
case width == 1:
|
||||
case width == 2 && value >= 0x80:
|
||||
case width == 3 && value >= 0x800:
|
||||
case width == 4 && value >= 0x10000:
|
||||
default:
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"invalid length of a UTF-8 sequence",
|
||||
parser.offset, -1)
|
||||
}
|
||||
|
||||
// Check the range of the value.
|
||||
if value >= 0xD800 && value <= 0xDFFF || value > 0x10FFFF {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"invalid Unicode character",
|
||||
parser.offset, int(value))
|
||||
}
|
||||
|
||||
case yaml_UTF16LE_ENCODING, yaml_UTF16BE_ENCODING:
|
||||
var low, high int
|
||||
if parser.encoding == yaml_UTF16LE_ENCODING {
|
||||
low, high = 0, 1
|
||||
} else {
|
||||
low, high = 1, 0
|
||||
}
|
||||
|
||||
// The UTF-16 encoding is not as simple as one might
|
||||
// naively think. Check RFC 2781
|
||||
// (http://www.ietf.org/rfc/rfc2781.txt).
|
||||
//
|
||||
// Normally, two subsequent bytes describe a Unicode
|
||||
// character. However a special technique (called a
|
||||
// surrogate pair) is used for specifying character
|
||||
// values larger than 0xFFFF.
|
||||
//
|
||||
// A surrogate pair consists of two pseudo-characters:
|
||||
// high surrogate area (0xD800-0xDBFF)
|
||||
// low surrogate area (0xDC00-0xDFFF)
|
||||
//
|
||||
// The following formulas are used for decoding
|
||||
// and encoding characters using surrogate pairs:
|
||||
//
|
||||
// U = U' + 0x10000 (0x01 00 00 <= U <= 0x10 FF FF)
|
||||
// U' = yyyyyyyyyyxxxxxxxxxx (0 <= U' <= 0x0F FF FF)
|
||||
// W1 = 110110yyyyyyyyyy
|
||||
// W2 = 110111xxxxxxxxxx
|
||||
//
|
||||
// where U is the character value, W1 is the high surrogate
|
||||
// area, W2 is the low surrogate area.
|
||||
|
||||
// Check for incomplete UTF-16 character.
|
||||
if raw_unread < 2 {
|
||||
if parser.eof {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"incomplete UTF-16 character",
|
||||
parser.offset, -1)
|
||||
}
|
||||
break inner
|
||||
}
|
||||
|
||||
// Get the character.
|
||||
value = rune(parser.raw_buffer[parser.raw_buffer_pos+low]) +
|
||||
(rune(parser.raw_buffer[parser.raw_buffer_pos+high]) << 8)
|
||||
|
||||
// Check for unexpected low surrogate area.
|
||||
if value&0xFC00 == 0xDC00 {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"unexpected low surrogate area",
|
||||
parser.offset, int(value))
|
||||
}
|
||||
|
||||
// Check for a high surrogate area.
|
||||
if value&0xFC00 == 0xD800 {
|
||||
width = 4
|
||||
|
||||
// Check for incomplete surrogate pair.
|
||||
if raw_unread < 4 {
|
||||
if parser.eof {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"incomplete UTF-16 surrogate pair",
|
||||
parser.offset, -1)
|
||||
}
|
||||
break inner
|
||||
}
|
||||
|
||||
// Get the next character.
|
||||
value2 := rune(parser.raw_buffer[parser.raw_buffer_pos+low+2]) +
|
||||
(rune(parser.raw_buffer[parser.raw_buffer_pos+high+2]) << 8)
|
||||
|
||||
// Check for a low surrogate area.
|
||||
if value2&0xFC00 != 0xDC00 {
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"expected low surrogate area",
|
||||
parser.offset+2, int(value2))
|
||||
}
|
||||
|
||||
// Generate the value of the surrogate pair.
|
||||
value = 0x10000 + ((value & 0x3FF) << 10) + (value2 & 0x3FF)
|
||||
} else {
|
||||
width = 2
|
||||
}
|
||||
|
||||
default:
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
// Check if the character is in the allowed range:
|
||||
// #x9 | #xA | #xD | [#x20-#x7E] (8 bit)
|
||||
// | #x85 | [#xA0-#xD7FF] | [#xE000-#xFFFD] (16 bit)
|
||||
// | [#x10000-#x10FFFF] (32 bit)
|
||||
switch {
|
||||
case value == 0x09:
|
||||
case value == 0x0A:
|
||||
case value == 0x0D:
|
||||
case value >= 0x20 && value <= 0x7E:
|
||||
case value == 0x85:
|
||||
case value >= 0xA0 && value <= 0xD7FF:
|
||||
case value >= 0xE000 && value <= 0xFFFD:
|
||||
case value >= 0x10000 && value <= 0x10FFFF:
|
||||
default:
|
||||
return yaml_parser_set_reader_error(parser,
|
||||
"control characters are not allowed",
|
||||
parser.offset, int(value))
|
||||
}
|
||||
|
||||
// Move the raw pointers.
|
||||
parser.raw_buffer_pos += width
|
||||
parser.offset += width
|
||||
|
||||
// Finally put the character into the buffer.
|
||||
if value <= 0x7F {
|
||||
// 0000 0000-0000 007F . 0xxxxxxx
|
||||
parser.buffer[buffer_len+0] = byte(value)
|
||||
buffer_len += 1
|
||||
} else if value <= 0x7FF {
|
||||
// 0000 0080-0000 07FF . 110xxxxx 10xxxxxx
|
||||
parser.buffer[buffer_len+0] = byte(0xC0 + (value >> 6))
|
||||
parser.buffer[buffer_len+1] = byte(0x80 + (value & 0x3F))
|
||||
buffer_len += 2
|
||||
} else if value <= 0xFFFF {
|
||||
// 0000 0800-0000 FFFF . 1110xxxx 10xxxxxx 10xxxxxx
|
||||
parser.buffer[buffer_len+0] = byte(0xE0 + (value >> 12))
|
||||
parser.buffer[buffer_len+1] = byte(0x80 + ((value >> 6) & 0x3F))
|
||||
parser.buffer[buffer_len+2] = byte(0x80 + (value & 0x3F))
|
||||
buffer_len += 3
|
||||
} else {
|
||||
// 0001 0000-0010 FFFF . 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
parser.buffer[buffer_len+0] = byte(0xF0 + (value >> 18))
|
||||
parser.buffer[buffer_len+1] = byte(0x80 + ((value >> 12) & 0x3F))
|
||||
parser.buffer[buffer_len+2] = byte(0x80 + ((value >> 6) & 0x3F))
|
||||
parser.buffer[buffer_len+3] = byte(0x80 + (value & 0x3F))
|
||||
buffer_len += 4
|
||||
}
|
||||
|
||||
parser.unread++
|
||||
}
|
||||
|
||||
// On EOF, put NUL into the buffer and return.
|
||||
if parser.eof {
|
||||
parser.buffer[buffer_len] = 0
|
||||
buffer_len++
|
||||
parser.unread++
|
||||
break
|
||||
}
|
||||
}
|
||||
// [Go] Read the documentation of this function above. To return true,
|
||||
// we need to have the given length in the buffer. Not doing that means
|
||||
// every single check that calls this function to make sure the buffer
|
||||
// has a given length is Go) panicking; or C) accessing invalid memory.
|
||||
// This happens here due to the EOF above breaking early.
|
||||
for buffer_len < length {
|
||||
parser.buffer[buffer_len] = 0
|
||||
buffer_len++
|
||||
}
|
||||
parser.buffer = parser.buffer[:buffer_len]
|
||||
return true
|
||||
}
|
||||
-326
@@ -1,326 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type resolveMapItem struct {
|
||||
value interface{}
|
||||
tag string
|
||||
}
|
||||
|
||||
var resolveTable = make([]byte, 256)
|
||||
var resolveMap = make(map[string]resolveMapItem)
|
||||
|
||||
func init() {
|
||||
t := resolveTable
|
||||
t[int('+')] = 'S' // Sign
|
||||
t[int('-')] = 'S'
|
||||
for _, c := range "0123456789" {
|
||||
t[int(c)] = 'D' // Digit
|
||||
}
|
||||
for _, c := range "yYnNtTfFoO~" {
|
||||
t[int(c)] = 'M' // In map
|
||||
}
|
||||
t[int('.')] = '.' // Float (potentially in map)
|
||||
|
||||
var resolveMapList = []struct {
|
||||
v interface{}
|
||||
tag string
|
||||
l []string
|
||||
}{
|
||||
{true, boolTag, []string{"true", "True", "TRUE"}},
|
||||
{false, boolTag, []string{"false", "False", "FALSE"}},
|
||||
{nil, nullTag, []string{"", "~", "null", "Null", "NULL"}},
|
||||
{math.NaN(), floatTag, []string{".nan", ".NaN", ".NAN"}},
|
||||
{math.Inf(+1), floatTag, []string{".inf", ".Inf", ".INF"}},
|
||||
{math.Inf(+1), floatTag, []string{"+.inf", "+.Inf", "+.INF"}},
|
||||
{math.Inf(-1), floatTag, []string{"-.inf", "-.Inf", "-.INF"}},
|
||||
{"<<", mergeTag, []string{"<<"}},
|
||||
}
|
||||
|
||||
m := resolveMap
|
||||
for _, item := range resolveMapList {
|
||||
for _, s := range item.l {
|
||||
m[s] = resolveMapItem{item.v, item.tag}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
nullTag = "!!null"
|
||||
boolTag = "!!bool"
|
||||
strTag = "!!str"
|
||||
intTag = "!!int"
|
||||
floatTag = "!!float"
|
||||
timestampTag = "!!timestamp"
|
||||
seqTag = "!!seq"
|
||||
mapTag = "!!map"
|
||||
binaryTag = "!!binary"
|
||||
mergeTag = "!!merge"
|
||||
)
|
||||
|
||||
var longTags = make(map[string]string)
|
||||
var shortTags = make(map[string]string)
|
||||
|
||||
func init() {
|
||||
for _, stag := range []string{nullTag, boolTag, strTag, intTag, floatTag, timestampTag, seqTag, mapTag, binaryTag, mergeTag} {
|
||||
ltag := longTag(stag)
|
||||
longTags[stag] = ltag
|
||||
shortTags[ltag] = stag
|
||||
}
|
||||
}
|
||||
|
||||
const longTagPrefix = "tag:yaml.org,2002:"
|
||||
|
||||
func shortTag(tag string) string {
|
||||
if strings.HasPrefix(tag, longTagPrefix) {
|
||||
if stag, ok := shortTags[tag]; ok {
|
||||
return stag
|
||||
}
|
||||
return "!!" + tag[len(longTagPrefix):]
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
func longTag(tag string) string {
|
||||
if strings.HasPrefix(tag, "!!") {
|
||||
if ltag, ok := longTags[tag]; ok {
|
||||
return ltag
|
||||
}
|
||||
return longTagPrefix + tag[2:]
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
func resolvableTag(tag string) bool {
|
||||
switch tag {
|
||||
case "", strTag, boolTag, intTag, floatTag, nullTag, timestampTag:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var yamlStyleFloat = regexp.MustCompile(`^[-+]?(\.[0-9]+|[0-9]+(\.[0-9]*)?)([eE][-+]?[0-9]+)?$`)
|
||||
|
||||
func resolve(tag string, in string) (rtag string, out interface{}) {
|
||||
tag = shortTag(tag)
|
||||
if !resolvableTag(tag) {
|
||||
return tag, in
|
||||
}
|
||||
|
||||
defer func() {
|
||||
switch tag {
|
||||
case "", rtag, strTag, binaryTag:
|
||||
return
|
||||
case floatTag:
|
||||
if rtag == intTag {
|
||||
switch v := out.(type) {
|
||||
case int64:
|
||||
rtag = floatTag
|
||||
out = float64(v)
|
||||
return
|
||||
case int:
|
||||
rtag = floatTag
|
||||
out = float64(v)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
failf("cannot decode %s `%s` as a %s", shortTag(rtag), in, shortTag(tag))
|
||||
}()
|
||||
|
||||
// Any data is accepted as a !!str or !!binary.
|
||||
// Otherwise, the prefix is enough of a hint about what it might be.
|
||||
hint := byte('N')
|
||||
if in != "" {
|
||||
hint = resolveTable[in[0]]
|
||||
}
|
||||
if hint != 0 && tag != strTag && tag != binaryTag {
|
||||
// Handle things we can lookup in a map.
|
||||
if item, ok := resolveMap[in]; ok {
|
||||
return item.tag, item.value
|
||||
}
|
||||
|
||||
// Base 60 floats are a bad idea, were dropped in YAML 1.2, and
|
||||
// are purposefully unsupported here. They're still quoted on
|
||||
// the way out for compatibility with other parser, though.
|
||||
|
||||
switch hint {
|
||||
case 'M':
|
||||
// We've already checked the map above.
|
||||
|
||||
case '.':
|
||||
// Not in the map, so maybe a normal float.
|
||||
floatv, err := strconv.ParseFloat(in, 64)
|
||||
if err == nil {
|
||||
return floatTag, floatv
|
||||
}
|
||||
|
||||
case 'D', 'S':
|
||||
// Int, float, or timestamp.
|
||||
// Only try values as a timestamp if the value is unquoted or there's an explicit
|
||||
// !!timestamp tag.
|
||||
if tag == "" || tag == timestampTag {
|
||||
t, ok := parseTimestamp(in)
|
||||
if ok {
|
||||
return timestampTag, t
|
||||
}
|
||||
}
|
||||
|
||||
plain := strings.Replace(in, "_", "", -1)
|
||||
intv, err := strconv.ParseInt(plain, 0, 64)
|
||||
if err == nil {
|
||||
if intv == int64(int(intv)) {
|
||||
return intTag, int(intv)
|
||||
} else {
|
||||
return intTag, intv
|
||||
}
|
||||
}
|
||||
uintv, err := strconv.ParseUint(plain, 0, 64)
|
||||
if err == nil {
|
||||
return intTag, uintv
|
||||
}
|
||||
if yamlStyleFloat.MatchString(plain) {
|
||||
floatv, err := strconv.ParseFloat(plain, 64)
|
||||
if err == nil {
|
||||
return floatTag, floatv
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(plain, "0b") {
|
||||
intv, err := strconv.ParseInt(plain[2:], 2, 64)
|
||||
if err == nil {
|
||||
if intv == int64(int(intv)) {
|
||||
return intTag, int(intv)
|
||||
} else {
|
||||
return intTag, intv
|
||||
}
|
||||
}
|
||||
uintv, err := strconv.ParseUint(plain[2:], 2, 64)
|
||||
if err == nil {
|
||||
return intTag, uintv
|
||||
}
|
||||
} else if strings.HasPrefix(plain, "-0b") {
|
||||
intv, err := strconv.ParseInt("-"+plain[3:], 2, 64)
|
||||
if err == nil {
|
||||
if true || intv == int64(int(intv)) {
|
||||
return intTag, int(intv)
|
||||
} else {
|
||||
return intTag, intv
|
||||
}
|
||||
}
|
||||
}
|
||||
// Octals as introduced in version 1.2 of the spec.
|
||||
// Octals from the 1.1 spec, spelled as 0777, are still
|
||||
// decoded by default in v3 as well for compatibility.
|
||||
// May be dropped in v4 depending on how usage evolves.
|
||||
if strings.HasPrefix(plain, "0o") {
|
||||
intv, err := strconv.ParseInt(plain[2:], 8, 64)
|
||||
if err == nil {
|
||||
if intv == int64(int(intv)) {
|
||||
return intTag, int(intv)
|
||||
} else {
|
||||
return intTag, intv
|
||||
}
|
||||
}
|
||||
uintv, err := strconv.ParseUint(plain[2:], 8, 64)
|
||||
if err == nil {
|
||||
return intTag, uintv
|
||||
}
|
||||
} else if strings.HasPrefix(plain, "-0o") {
|
||||
intv, err := strconv.ParseInt("-"+plain[3:], 8, 64)
|
||||
if err == nil {
|
||||
if true || intv == int64(int(intv)) {
|
||||
return intTag, int(intv)
|
||||
} else {
|
||||
return intTag, intv
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic("internal error: missing handler for resolver table: " + string(rune(hint)) + " (with " + in + ")")
|
||||
}
|
||||
}
|
||||
return strTag, in
|
||||
}
|
||||
|
||||
// encodeBase64 encodes s as base64 that is broken up into multiple lines
|
||||
// as appropriate for the resulting length.
|
||||
func encodeBase64(s string) string {
|
||||
const lineLen = 70
|
||||
encLen := base64.StdEncoding.EncodedLen(len(s))
|
||||
lines := encLen/lineLen + 1
|
||||
buf := make([]byte, encLen*2+lines)
|
||||
in := buf[0:encLen]
|
||||
out := buf[encLen:]
|
||||
base64.StdEncoding.Encode(in, []byte(s))
|
||||
k := 0
|
||||
for i := 0; i < len(in); i += lineLen {
|
||||
j := i + lineLen
|
||||
if j > len(in) {
|
||||
j = len(in)
|
||||
}
|
||||
k += copy(out[k:], in[i:j])
|
||||
if lines > 1 {
|
||||
out[k] = '\n'
|
||||
k++
|
||||
}
|
||||
}
|
||||
return string(out[:k])
|
||||
}
|
||||
|
||||
// This is a subset of the formats allowed by the regular expression
|
||||
// defined at http://yaml.org/type/timestamp.html.
|
||||
var allowedTimestampFormats = []string{
|
||||
"2006-1-2T15:4:5.999999999Z07:00", // RCF3339Nano with short date fields.
|
||||
"2006-1-2t15:4:5.999999999Z07:00", // RFC3339Nano with short date fields and lower-case "t".
|
||||
"2006-1-2 15:4:5.999999999", // space separated with no time zone
|
||||
"2006-1-2", // date only
|
||||
// Notable exception: time.Parse cannot handle: "2001-12-14 21:59:43.10 -5"
|
||||
// from the set of examples.
|
||||
}
|
||||
|
||||
// parseTimestamp parses s as a timestamp string and
|
||||
// returns the timestamp and reports whether it succeeded.
|
||||
// Timestamp formats are defined at http://yaml.org/type/timestamp.html
|
||||
func parseTimestamp(s string) (time.Time, bool) {
|
||||
// TODO write code to check all the formats supported by
|
||||
// http://yaml.org/type/timestamp.html instead of using time.Parse.
|
||||
|
||||
// Quick check: all date formats start with YYYY-.
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
if c := s[i]; c < '0' || c > '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i != 4 || i == len(s) || s[i] != '-' {
|
||||
return time.Time{}, false
|
||||
}
|
||||
for _, format := range allowedTimestampFormats {
|
||||
if t, err := time.Parse(format, s); err == nil {
|
||||
return t, true
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
-3038
File diff suppressed because it is too large
Load Diff
-134
@@ -1,134 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type keyList []reflect.Value
|
||||
|
||||
func (l keyList) Len() int { return len(l) }
|
||||
func (l keyList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
|
||||
func (l keyList) Less(i, j int) bool {
|
||||
a := l[i]
|
||||
b := l[j]
|
||||
ak := a.Kind()
|
||||
bk := b.Kind()
|
||||
for (ak == reflect.Interface || ak == reflect.Ptr) && !a.IsNil() {
|
||||
a = a.Elem()
|
||||
ak = a.Kind()
|
||||
}
|
||||
for (bk == reflect.Interface || bk == reflect.Ptr) && !b.IsNil() {
|
||||
b = b.Elem()
|
||||
bk = b.Kind()
|
||||
}
|
||||
af, aok := keyFloat(a)
|
||||
bf, bok := keyFloat(b)
|
||||
if aok && bok {
|
||||
if af != bf {
|
||||
return af < bf
|
||||
}
|
||||
if ak != bk {
|
||||
return ak < bk
|
||||
}
|
||||
return numLess(a, b)
|
||||
}
|
||||
if ak != reflect.String || bk != reflect.String {
|
||||
return ak < bk
|
||||
}
|
||||
ar, br := []rune(a.String()), []rune(b.String())
|
||||
digits := false
|
||||
for i := 0; i < len(ar) && i < len(br); i++ {
|
||||
if ar[i] == br[i] {
|
||||
digits = unicode.IsDigit(ar[i])
|
||||
continue
|
||||
}
|
||||
al := unicode.IsLetter(ar[i])
|
||||
bl := unicode.IsLetter(br[i])
|
||||
if al && bl {
|
||||
return ar[i] < br[i]
|
||||
}
|
||||
if al || bl {
|
||||
if digits {
|
||||
return al
|
||||
} else {
|
||||
return bl
|
||||
}
|
||||
}
|
||||
var ai, bi int
|
||||
var an, bn int64
|
||||
if ar[i] == '0' || br[i] == '0' {
|
||||
for j := i - 1; j >= 0 && unicode.IsDigit(ar[j]); j-- {
|
||||
if ar[j] != '0' {
|
||||
an = 1
|
||||
bn = 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for ai = i; ai < len(ar) && unicode.IsDigit(ar[ai]); ai++ {
|
||||
an = an*10 + int64(ar[ai]-'0')
|
||||
}
|
||||
for bi = i; bi < len(br) && unicode.IsDigit(br[bi]); bi++ {
|
||||
bn = bn*10 + int64(br[bi]-'0')
|
||||
}
|
||||
if an != bn {
|
||||
return an < bn
|
||||
}
|
||||
if ai != bi {
|
||||
return ai < bi
|
||||
}
|
||||
return ar[i] < br[i]
|
||||
}
|
||||
return len(ar) < len(br)
|
||||
}
|
||||
|
||||
// keyFloat returns a float value for v if it is a number/bool
|
||||
// and whether it is a number/bool or not.
|
||||
func keyFloat(v reflect.Value) (f float64, ok bool) {
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return float64(v.Int()), true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float(), true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return float64(v.Uint()), true
|
||||
case reflect.Bool:
|
||||
if v.Bool() {
|
||||
return 1, true
|
||||
}
|
||||
return 0, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// numLess returns whether a < b.
|
||||
// a and b must necessarily have the same kind.
|
||||
func numLess(a, b reflect.Value) bool {
|
||||
switch a.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return a.Int() < b.Int()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return a.Float() < b.Float()
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return a.Uint() < b.Uint()
|
||||
case reflect.Bool:
|
||||
return !a.Bool() && b.Bool()
|
||||
}
|
||||
panic("not a number")
|
||||
}
|
||||
-48
@@ -1,48 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
// Copyright (c) 2006-2010 Kirill Simonov
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
// of the Software, and to permit persons to whom the Software is furnished to do
|
||||
// so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
package yaml
|
||||
|
||||
// Set the writer error and return false.
|
||||
func yaml_emitter_set_writer_error(emitter *yaml_emitter_t, problem string) bool {
|
||||
emitter.error = yaml_WRITER_ERROR
|
||||
emitter.problem = problem
|
||||
return false
|
||||
}
|
||||
|
||||
// Flush the output buffer.
|
||||
func yaml_emitter_flush(emitter *yaml_emitter_t) bool {
|
||||
if emitter.write_handler == nil {
|
||||
panic("write handler not set")
|
||||
}
|
||||
|
||||
// Check if the buffer is empty.
|
||||
if emitter.buffer_pos == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if err := emitter.write_handler(emitter, emitter.buffer[:emitter.buffer_pos]); err != nil {
|
||||
return yaml_emitter_set_writer_error(emitter, "write error: "+err.Error())
|
||||
}
|
||||
emitter.buffer_pos = 0
|
||||
return true
|
||||
}
|
||||
-698
@@ -1,698 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package yaml implements YAML support for the Go language.
|
||||
//
|
||||
// Source code and other details for the project are available at GitHub:
|
||||
//
|
||||
// https://github.com/go-yaml/yaml
|
||||
//
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// The Unmarshaler interface may be implemented by types to customize their
|
||||
// behavior when being unmarshaled from a YAML document.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalYAML(value *Node) error
|
||||
}
|
||||
|
||||
type obsoleteUnmarshaler interface {
|
||||
UnmarshalYAML(unmarshal func(interface{}) error) error
|
||||
}
|
||||
|
||||
// The Marshaler interface may be implemented by types to customize their
|
||||
// behavior when being marshaled into a YAML document. The returned value
|
||||
// is marshaled in place of the original value implementing Marshaler.
|
||||
//
|
||||
// If an error is returned by MarshalYAML, the marshaling procedure stops
|
||||
// and returns with the provided error.
|
||||
type Marshaler interface {
|
||||
MarshalYAML() (interface{}, error)
|
||||
}
|
||||
|
||||
// Unmarshal decodes the first document found within the in byte slice
|
||||
// and assigns decoded values into the out value.
|
||||
//
|
||||
// Maps and pointers (to a struct, string, int, etc) are accepted as out
|
||||
// values. If an internal pointer within a struct is not initialized,
|
||||
// the yaml package will initialize it if necessary for unmarshalling
|
||||
// the provided data. The out parameter must not be nil.
|
||||
//
|
||||
// The type of the decoded values should be compatible with the respective
|
||||
// values in out. If one or more values cannot be decoded due to a type
|
||||
// mismatches, decoding continues partially until the end of the YAML
|
||||
// content, and a *yaml.TypeError is returned with details for all
|
||||
// missed values.
|
||||
//
|
||||
// Struct fields are only unmarshalled if they are exported (have an
|
||||
// upper case first letter), and are unmarshalled using the field name
|
||||
// lowercased as the default key. Custom keys may be defined via the
|
||||
// "yaml" name in the field tag: the content preceding the first comma
|
||||
// is used as the key, and the following comma-separated options are
|
||||
// used to tweak the marshalling process (see Marshal).
|
||||
// Conflicting names result in a runtime error.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// type T struct {
|
||||
// F int `yaml:"a,omitempty"`
|
||||
// B int
|
||||
// }
|
||||
// var t T
|
||||
// yaml.Unmarshal([]byte("a: 1\nb: 2"), &t)
|
||||
//
|
||||
// See the documentation of Marshal for the format of tags and a list of
|
||||
// supported tag options.
|
||||
//
|
||||
func Unmarshal(in []byte, out interface{}) (err error) {
|
||||
return unmarshal(in, out, false)
|
||||
}
|
||||
|
||||
// A Decoder reads and decodes YAML values from an input stream.
|
||||
type Decoder struct {
|
||||
parser *parser
|
||||
knownFields bool
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder that reads from r.
|
||||
//
|
||||
// The decoder introduces its own buffering and may read
|
||||
// data from r beyond the YAML values requested.
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{
|
||||
parser: newParserFromReader(r),
|
||||
}
|
||||
}
|
||||
|
||||
// KnownFields ensures that the keys in decoded mappings to
|
||||
// exist as fields in the struct being decoded into.
|
||||
func (dec *Decoder) KnownFields(enable bool) {
|
||||
dec.knownFields = enable
|
||||
}
|
||||
|
||||
// Decode reads the next YAML-encoded value from its input
|
||||
// and stores it in the value pointed to by v.
|
||||
//
|
||||
// See the documentation for Unmarshal for details about the
|
||||
// conversion of YAML into a Go value.
|
||||
func (dec *Decoder) Decode(v interface{}) (err error) {
|
||||
d := newDecoder()
|
||||
d.knownFields = dec.knownFields
|
||||
defer handleErr(&err)
|
||||
node := dec.parser.parse()
|
||||
if node == nil {
|
||||
return io.EOF
|
||||
}
|
||||
out := reflect.ValueOf(v)
|
||||
if out.Kind() == reflect.Ptr && !out.IsNil() {
|
||||
out = out.Elem()
|
||||
}
|
||||
d.unmarshal(node, out)
|
||||
if len(d.terrors) > 0 {
|
||||
return &TypeError{d.terrors}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode decodes the node and stores its data into the value pointed to by v.
|
||||
//
|
||||
// See the documentation for Unmarshal for details about the
|
||||
// conversion of YAML into a Go value.
|
||||
func (n *Node) Decode(v interface{}) (err error) {
|
||||
d := newDecoder()
|
||||
defer handleErr(&err)
|
||||
out := reflect.ValueOf(v)
|
||||
if out.Kind() == reflect.Ptr && !out.IsNil() {
|
||||
out = out.Elem()
|
||||
}
|
||||
d.unmarshal(n, out)
|
||||
if len(d.terrors) > 0 {
|
||||
return &TypeError{d.terrors}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshal(in []byte, out interface{}, strict bool) (err error) {
|
||||
defer handleErr(&err)
|
||||
d := newDecoder()
|
||||
p := newParser(in)
|
||||
defer p.destroy()
|
||||
node := p.parse()
|
||||
if node != nil {
|
||||
v := reflect.ValueOf(out)
|
||||
if v.Kind() == reflect.Ptr && !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
d.unmarshal(node, v)
|
||||
}
|
||||
if len(d.terrors) > 0 {
|
||||
return &TypeError{d.terrors}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal serializes the value provided into a YAML document. The structure
|
||||
// of the generated document will reflect the structure of the value itself.
|
||||
// Maps and pointers (to struct, string, int, etc) are accepted as the in value.
|
||||
//
|
||||
// Struct fields are only marshalled if they are exported (have an upper case
|
||||
// first letter), and are marshalled using the field name lowercased as the
|
||||
// default key. Custom keys may be defined via the "yaml" name in the field
|
||||
// tag: the content preceding the first comma is used as the key, and the
|
||||
// following comma-separated options are used to tweak the marshalling process.
|
||||
// Conflicting names result in a runtime error.
|
||||
//
|
||||
// The field tag format accepted is:
|
||||
//
|
||||
// `(...) yaml:"[<key>][,<flag1>[,<flag2>]]" (...)`
|
||||
//
|
||||
// The following flags are currently supported:
|
||||
//
|
||||
// omitempty Only include the field if it's not set to the zero
|
||||
// value for the type or to empty slices or maps.
|
||||
// Zero valued structs will be omitted if all their public
|
||||
// fields are zero, unless they implement an IsZero
|
||||
// method (see the IsZeroer interface type), in which
|
||||
// case the field will be excluded if IsZero returns true.
|
||||
//
|
||||
// flow Marshal using a flow style (useful for structs,
|
||||
// sequences and maps).
|
||||
//
|
||||
// inline Inline the field, which must be a struct or a map,
|
||||
// causing all of its fields or keys to be processed as if
|
||||
// they were part of the outer struct. For maps, keys must
|
||||
// not conflict with the yaml keys of other struct fields.
|
||||
//
|
||||
// In addition, if the key is "-", the field is ignored.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// type T struct {
|
||||
// F int `yaml:"a,omitempty"`
|
||||
// B int
|
||||
// }
|
||||
// yaml.Marshal(&T{B: 2}) // Returns "b: 2\n"
|
||||
// yaml.Marshal(&T{F: 1}} // Returns "a: 1\nb: 0\n"
|
||||
//
|
||||
func Marshal(in interface{}) (out []byte, err error) {
|
||||
defer handleErr(&err)
|
||||
e := newEncoder()
|
||||
defer e.destroy()
|
||||
e.marshalDoc("", reflect.ValueOf(in))
|
||||
e.finish()
|
||||
out = e.out
|
||||
return
|
||||
}
|
||||
|
||||
// An Encoder writes YAML values to an output stream.
|
||||
type Encoder struct {
|
||||
encoder *encoder
|
||||
}
|
||||
|
||||
// NewEncoder returns a new encoder that writes to w.
|
||||
// The Encoder should be closed after use to flush all data
|
||||
// to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{
|
||||
encoder: newEncoderWithWriter(w),
|
||||
}
|
||||
}
|
||||
|
||||
// Encode writes the YAML encoding of v to the stream.
|
||||
// If multiple items are encoded to the stream, the
|
||||
// second and subsequent document will be preceded
|
||||
// with a "---" document separator, but the first will not.
|
||||
//
|
||||
// See the documentation for Marshal for details about the conversion of Go
|
||||
// values to YAML.
|
||||
func (e *Encoder) Encode(v interface{}) (err error) {
|
||||
defer handleErr(&err)
|
||||
e.encoder.marshalDoc("", reflect.ValueOf(v))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes value v and stores its representation in n.
|
||||
//
|
||||
// See the documentation for Marshal for details about the
|
||||
// conversion of Go values into YAML.
|
||||
func (n *Node) Encode(v interface{}) (err error) {
|
||||
defer handleErr(&err)
|
||||
e := newEncoder()
|
||||
defer e.destroy()
|
||||
e.marshalDoc("", reflect.ValueOf(v))
|
||||
e.finish()
|
||||
p := newParser(e.out)
|
||||
p.textless = true
|
||||
defer p.destroy()
|
||||
doc := p.parse()
|
||||
*n = *doc.Content[0]
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetIndent changes the used indentation used when encoding.
|
||||
func (e *Encoder) SetIndent(spaces int) {
|
||||
if spaces < 0 {
|
||||
panic("yaml: cannot indent to a negative number of spaces")
|
||||
}
|
||||
e.encoder.indent = spaces
|
||||
}
|
||||
|
||||
// Close closes the encoder by writing any remaining data.
|
||||
// It does not write a stream terminating string "...".
|
||||
func (e *Encoder) Close() (err error) {
|
||||
defer handleErr(&err)
|
||||
e.encoder.finish()
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleErr(err *error) {
|
||||
if v := recover(); v != nil {
|
||||
if e, ok := v.(yamlError); ok {
|
||||
*err = e.err
|
||||
} else {
|
||||
panic(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type yamlError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func fail(err error) {
|
||||
panic(yamlError{err})
|
||||
}
|
||||
|
||||
func failf(format string, args ...interface{}) {
|
||||
panic(yamlError{fmt.Errorf("yaml: "+format, args...)})
|
||||
}
|
||||
|
||||
// A TypeError is returned by Unmarshal when one or more fields in
|
||||
// the YAML document cannot be properly decoded into the requested
|
||||
// types. When this error is returned, the value is still
|
||||
// unmarshaled partially.
|
||||
type TypeError struct {
|
||||
Errors []string
|
||||
}
|
||||
|
||||
func (e *TypeError) Error() string {
|
||||
return fmt.Sprintf("yaml: unmarshal errors:\n %s", strings.Join(e.Errors, "\n "))
|
||||
}
|
||||
|
||||
type Kind uint32
|
||||
|
||||
const (
|
||||
DocumentNode Kind = 1 << iota
|
||||
SequenceNode
|
||||
MappingNode
|
||||
ScalarNode
|
||||
AliasNode
|
||||
)
|
||||
|
||||
type Style uint32
|
||||
|
||||
const (
|
||||
TaggedStyle Style = 1 << iota
|
||||
DoubleQuotedStyle
|
||||
SingleQuotedStyle
|
||||
LiteralStyle
|
||||
FoldedStyle
|
||||
FlowStyle
|
||||
)
|
||||
|
||||
// Node represents an element in the YAML document hierarchy. While documents
|
||||
// are typically encoded and decoded into higher level types, such as structs
|
||||
// and maps, Node is an intermediate representation that allows detailed
|
||||
// control over the content being decoded or encoded.
|
||||
//
|
||||
// It's worth noting that although Node offers access into details such as
|
||||
// line numbers, colums, and comments, the content when re-encoded will not
|
||||
// have its original textual representation preserved. An effort is made to
|
||||
// render the data plesantly, and to preserve comments near the data they
|
||||
// describe, though.
|
||||
//
|
||||
// Values that make use of the Node type interact with the yaml package in the
|
||||
// same way any other type would do, by encoding and decoding yaml data
|
||||
// directly or indirectly into them.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// var person struct {
|
||||
// Name string
|
||||
// Address yaml.Node
|
||||
// }
|
||||
// err := yaml.Unmarshal(data, &person)
|
||||
//
|
||||
// Or by itself:
|
||||
//
|
||||
// var person Node
|
||||
// err := yaml.Unmarshal(data, &person)
|
||||
//
|
||||
type Node struct {
|
||||
// Kind defines whether the node is a document, a mapping, a sequence,
|
||||
// a scalar value, or an alias to another node. The specific data type of
|
||||
// scalar nodes may be obtained via the ShortTag and LongTag methods.
|
||||
Kind Kind
|
||||
|
||||
// Style allows customizing the apperance of the node in the tree.
|
||||
Style Style
|
||||
|
||||
// Tag holds the YAML tag defining the data type for the value.
|
||||
// When decoding, this field will always be set to the resolved tag,
|
||||
// even when it wasn't explicitly provided in the YAML content.
|
||||
// When encoding, if this field is unset the value type will be
|
||||
// implied from the node properties, and if it is set, it will only
|
||||
// be serialized into the representation if TaggedStyle is used or
|
||||
// the implicit tag diverges from the provided one.
|
||||
Tag string
|
||||
|
||||
// Value holds the unescaped and unquoted represenation of the value.
|
||||
Value string
|
||||
|
||||
// Anchor holds the anchor name for this node, which allows aliases to point to it.
|
||||
Anchor string
|
||||
|
||||
// Alias holds the node that this alias points to. Only valid when Kind is AliasNode.
|
||||
Alias *Node
|
||||
|
||||
// Content holds contained nodes for documents, mappings, and sequences.
|
||||
Content []*Node
|
||||
|
||||
// HeadComment holds any comments in the lines preceding the node and
|
||||
// not separated by an empty line.
|
||||
HeadComment string
|
||||
|
||||
// LineComment holds any comments at the end of the line where the node is in.
|
||||
LineComment string
|
||||
|
||||
// FootComment holds any comments following the node and before empty lines.
|
||||
FootComment string
|
||||
|
||||
// Line and Column hold the node position in the decoded YAML text.
|
||||
// These fields are not respected when encoding the node.
|
||||
Line int
|
||||
Column int
|
||||
}
|
||||
|
||||
// IsZero returns whether the node has all of its fields unset.
|
||||
func (n *Node) IsZero() bool {
|
||||
return n.Kind == 0 && n.Style == 0 && n.Tag == "" && n.Value == "" && n.Anchor == "" && n.Alias == nil && n.Content == nil &&
|
||||
n.HeadComment == "" && n.LineComment == "" && n.FootComment == "" && n.Line == 0 && n.Column == 0
|
||||
}
|
||||
|
||||
|
||||
// LongTag returns the long form of the tag that indicates the data type for
|
||||
// the node. If the Tag field isn't explicitly defined, one will be computed
|
||||
// based on the node properties.
|
||||
func (n *Node) LongTag() string {
|
||||
return longTag(n.ShortTag())
|
||||
}
|
||||
|
||||
// ShortTag returns the short form of the YAML tag that indicates data type for
|
||||
// the node. If the Tag field isn't explicitly defined, one will be computed
|
||||
// based on the node properties.
|
||||
func (n *Node) ShortTag() string {
|
||||
if n.indicatedString() {
|
||||
return strTag
|
||||
}
|
||||
if n.Tag == "" || n.Tag == "!" {
|
||||
switch n.Kind {
|
||||
case MappingNode:
|
||||
return mapTag
|
||||
case SequenceNode:
|
||||
return seqTag
|
||||
case AliasNode:
|
||||
if n.Alias != nil {
|
||||
return n.Alias.ShortTag()
|
||||
}
|
||||
case ScalarNode:
|
||||
tag, _ := resolve("", n.Value)
|
||||
return tag
|
||||
case 0:
|
||||
// Special case to make the zero value convenient.
|
||||
if n.IsZero() {
|
||||
return nullTag
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return shortTag(n.Tag)
|
||||
}
|
||||
|
||||
func (n *Node) indicatedString() bool {
|
||||
return n.Kind == ScalarNode &&
|
||||
(shortTag(n.Tag) == strTag ||
|
||||
(n.Tag == "" || n.Tag == "!") && n.Style&(SingleQuotedStyle|DoubleQuotedStyle|LiteralStyle|FoldedStyle) != 0)
|
||||
}
|
||||
|
||||
// SetString is a convenience function that sets the node to a string value
|
||||
// and defines its style in a pleasant way depending on its content.
|
||||
func (n *Node) SetString(s string) {
|
||||
n.Kind = ScalarNode
|
||||
if utf8.ValidString(s) {
|
||||
n.Value = s
|
||||
n.Tag = strTag
|
||||
} else {
|
||||
n.Value = encodeBase64(s)
|
||||
n.Tag = binaryTag
|
||||
}
|
||||
if strings.Contains(n.Value, "\n") {
|
||||
n.Style = LiteralStyle
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Maintain a mapping of keys to structure field indexes
|
||||
|
||||
// The code in this section was copied from mgo/bson.
|
||||
|
||||
// structInfo holds details for the serialization of fields of
|
||||
// a given struct.
|
||||
type structInfo struct {
|
||||
FieldsMap map[string]fieldInfo
|
||||
FieldsList []fieldInfo
|
||||
|
||||
// InlineMap is the number of the field in the struct that
|
||||
// contains an ,inline map, or -1 if there's none.
|
||||
InlineMap int
|
||||
|
||||
// InlineUnmarshalers holds indexes to inlined fields that
|
||||
// contain unmarshaler values.
|
||||
InlineUnmarshalers [][]int
|
||||
}
|
||||
|
||||
type fieldInfo struct {
|
||||
Key string
|
||||
Num int
|
||||
OmitEmpty bool
|
||||
Flow bool
|
||||
// Id holds the unique field identifier, so we can cheaply
|
||||
// check for field duplicates without maintaining an extra map.
|
||||
Id int
|
||||
|
||||
// Inline holds the field index if the field is part of an inlined struct.
|
||||
Inline []int
|
||||
}
|
||||
|
||||
var structMap = make(map[reflect.Type]*structInfo)
|
||||
var fieldMapMutex sync.RWMutex
|
||||
var unmarshalerType reflect.Type
|
||||
|
||||
func init() {
|
||||
var v Unmarshaler
|
||||
unmarshalerType = reflect.ValueOf(&v).Elem().Type()
|
||||
}
|
||||
|
||||
func getStructInfo(st reflect.Type) (*structInfo, error) {
|
||||
fieldMapMutex.RLock()
|
||||
sinfo, found := structMap[st]
|
||||
fieldMapMutex.RUnlock()
|
||||
if found {
|
||||
return sinfo, nil
|
||||
}
|
||||
|
||||
n := st.NumField()
|
||||
fieldsMap := make(map[string]fieldInfo)
|
||||
fieldsList := make([]fieldInfo, 0, n)
|
||||
inlineMap := -1
|
||||
inlineUnmarshalers := [][]int(nil)
|
||||
for i := 0; i != n; i++ {
|
||||
field := st.Field(i)
|
||||
if field.PkgPath != "" && !field.Anonymous {
|
||||
continue // Private field
|
||||
}
|
||||
|
||||
info := fieldInfo{Num: i}
|
||||
|
||||
tag := field.Tag.Get("yaml")
|
||||
if tag == "" && strings.Index(string(field.Tag), ":") < 0 {
|
||||
tag = string(field.Tag)
|
||||
}
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
inline := false
|
||||
fields := strings.Split(tag, ",")
|
||||
if len(fields) > 1 {
|
||||
for _, flag := range fields[1:] {
|
||||
switch flag {
|
||||
case "omitempty":
|
||||
info.OmitEmpty = true
|
||||
case "flow":
|
||||
info.Flow = true
|
||||
case "inline":
|
||||
inline = true
|
||||
default:
|
||||
return nil, errors.New(fmt.Sprintf("unsupported flag %q in tag %q of type %s", flag, tag, st))
|
||||
}
|
||||
}
|
||||
tag = fields[0]
|
||||
}
|
||||
|
||||
if inline {
|
||||
switch field.Type.Kind() {
|
||||
case reflect.Map:
|
||||
if inlineMap >= 0 {
|
||||
return nil, errors.New("multiple ,inline maps in struct " + st.String())
|
||||
}
|
||||
if field.Type.Key() != reflect.TypeOf("") {
|
||||
return nil, errors.New("option ,inline needs a map with string keys in struct " + st.String())
|
||||
}
|
||||
inlineMap = info.Num
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
ftype := field.Type
|
||||
for ftype.Kind() == reflect.Ptr {
|
||||
ftype = ftype.Elem()
|
||||
}
|
||||
if ftype.Kind() != reflect.Struct {
|
||||
return nil, errors.New("option ,inline may only be used on a struct or map field")
|
||||
}
|
||||
if reflect.PtrTo(ftype).Implements(unmarshalerType) {
|
||||
inlineUnmarshalers = append(inlineUnmarshalers, []int{i})
|
||||
} else {
|
||||
sinfo, err := getStructInfo(ftype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, index := range sinfo.InlineUnmarshalers {
|
||||
inlineUnmarshalers = append(inlineUnmarshalers, append([]int{i}, index...))
|
||||
}
|
||||
for _, finfo := range sinfo.FieldsList {
|
||||
if _, found := fieldsMap[finfo.Key]; found {
|
||||
msg := "duplicated key '" + finfo.Key + "' in struct " + st.String()
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
if finfo.Inline == nil {
|
||||
finfo.Inline = []int{i, finfo.Num}
|
||||
} else {
|
||||
finfo.Inline = append([]int{i}, finfo.Inline...)
|
||||
}
|
||||
finfo.Id = len(fieldsList)
|
||||
fieldsMap[finfo.Key] = finfo
|
||||
fieldsList = append(fieldsList, finfo)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("option ,inline may only be used on a struct or map field")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if tag != "" {
|
||||
info.Key = tag
|
||||
} else {
|
||||
info.Key = strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
if _, found = fieldsMap[info.Key]; found {
|
||||
msg := "duplicated key '" + info.Key + "' in struct " + st.String()
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
info.Id = len(fieldsList)
|
||||
fieldsList = append(fieldsList, info)
|
||||
fieldsMap[info.Key] = info
|
||||
}
|
||||
|
||||
sinfo = &structInfo{
|
||||
FieldsMap: fieldsMap,
|
||||
FieldsList: fieldsList,
|
||||
InlineMap: inlineMap,
|
||||
InlineUnmarshalers: inlineUnmarshalers,
|
||||
}
|
||||
|
||||
fieldMapMutex.Lock()
|
||||
structMap[st] = sinfo
|
||||
fieldMapMutex.Unlock()
|
||||
return sinfo, nil
|
||||
}
|
||||
|
||||
// IsZeroer is used to check whether an object is zero to
|
||||
// determine whether it should be omitted when marshaling
|
||||
// with the omitempty flag. One notable implementation
|
||||
// is time.Time.
|
||||
type IsZeroer interface {
|
||||
IsZero() bool
|
||||
}
|
||||
|
||||
func isZero(v reflect.Value) bool {
|
||||
kind := v.Kind()
|
||||
if z, ok := v.Interface().(IsZeroer); ok {
|
||||
if (kind == reflect.Ptr || kind == reflect.Interface) && v.IsNil() {
|
||||
return true
|
||||
}
|
||||
return z.IsZero()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.String:
|
||||
return len(v.String()) == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return v.IsNil()
|
||||
case reflect.Slice:
|
||||
return v.Len() == 0
|
||||
case reflect.Map:
|
||||
return v.Len() == 0
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return v.Uint() == 0
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Struct:
|
||||
vt := v.Type()
|
||||
for i := v.NumField() - 1; i >= 0; i-- {
|
||||
if vt.Field(i).PkgPath != "" {
|
||||
continue // Private field
|
||||
}
|
||||
if !isZero(v.Field(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
-807
@@ -1,807 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
// Copyright (c) 2006-2010 Kirill Simonov
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
// of the Software, and to permit persons to whom the Software is furnished to do
|
||||
// so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
package yaml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// The version directive data.
|
||||
type yaml_version_directive_t struct {
|
||||
major int8 // The major version number.
|
||||
minor int8 // The minor version number.
|
||||
}
|
||||
|
||||
// The tag directive data.
|
||||
type yaml_tag_directive_t struct {
|
||||
handle []byte // The tag handle.
|
||||
prefix []byte // The tag prefix.
|
||||
}
|
||||
|
||||
type yaml_encoding_t int
|
||||
|
||||
// The stream encoding.
|
||||
const (
|
||||
// Let the parser choose the encoding.
|
||||
yaml_ANY_ENCODING yaml_encoding_t = iota
|
||||
|
||||
yaml_UTF8_ENCODING // The default UTF-8 encoding.
|
||||
yaml_UTF16LE_ENCODING // The UTF-16-LE encoding with BOM.
|
||||
yaml_UTF16BE_ENCODING // The UTF-16-BE encoding with BOM.
|
||||
)
|
||||
|
||||
type yaml_break_t int
|
||||
|
||||
// Line break types.
|
||||
const (
|
||||
// Let the parser choose the break type.
|
||||
yaml_ANY_BREAK yaml_break_t = iota
|
||||
|
||||
yaml_CR_BREAK // Use CR for line breaks (Mac style).
|
||||
yaml_LN_BREAK // Use LN for line breaks (Unix style).
|
||||
yaml_CRLN_BREAK // Use CR LN for line breaks (DOS style).
|
||||
)
|
||||
|
||||
type yaml_error_type_t int
|
||||
|
||||
// Many bad things could happen with the parser and emitter.
|
||||
const (
|
||||
// No error is produced.
|
||||
yaml_NO_ERROR yaml_error_type_t = iota
|
||||
|
||||
yaml_MEMORY_ERROR // Cannot allocate or reallocate a block of memory.
|
||||
yaml_READER_ERROR // Cannot read or decode the input stream.
|
||||
yaml_SCANNER_ERROR // Cannot scan the input stream.
|
||||
yaml_PARSER_ERROR // Cannot parse the input stream.
|
||||
yaml_COMPOSER_ERROR // Cannot compose a YAML document.
|
||||
yaml_WRITER_ERROR // Cannot write to the output stream.
|
||||
yaml_EMITTER_ERROR // Cannot emit a YAML stream.
|
||||
)
|
||||
|
||||
// The pointer position.
|
||||
type yaml_mark_t struct {
|
||||
index int // The position index.
|
||||
line int // The position line.
|
||||
column int // The position column.
|
||||
}
|
||||
|
||||
// Node Styles
|
||||
|
||||
type yaml_style_t int8
|
||||
|
||||
type yaml_scalar_style_t yaml_style_t
|
||||
|
||||
// Scalar styles.
|
||||
const (
|
||||
// Let the emitter choose the style.
|
||||
yaml_ANY_SCALAR_STYLE yaml_scalar_style_t = 0
|
||||
|
||||
yaml_PLAIN_SCALAR_STYLE yaml_scalar_style_t = 1 << iota // The plain scalar style.
|
||||
yaml_SINGLE_QUOTED_SCALAR_STYLE // The single-quoted scalar style.
|
||||
yaml_DOUBLE_QUOTED_SCALAR_STYLE // The double-quoted scalar style.
|
||||
yaml_LITERAL_SCALAR_STYLE // The literal scalar style.
|
||||
yaml_FOLDED_SCALAR_STYLE // The folded scalar style.
|
||||
)
|
||||
|
||||
type yaml_sequence_style_t yaml_style_t
|
||||
|
||||
// Sequence styles.
|
||||
const (
|
||||
// Let the emitter choose the style.
|
||||
yaml_ANY_SEQUENCE_STYLE yaml_sequence_style_t = iota
|
||||
|
||||
yaml_BLOCK_SEQUENCE_STYLE // The block sequence style.
|
||||
yaml_FLOW_SEQUENCE_STYLE // The flow sequence style.
|
||||
)
|
||||
|
||||
type yaml_mapping_style_t yaml_style_t
|
||||
|
||||
// Mapping styles.
|
||||
const (
|
||||
// Let the emitter choose the style.
|
||||
yaml_ANY_MAPPING_STYLE yaml_mapping_style_t = iota
|
||||
|
||||
yaml_BLOCK_MAPPING_STYLE // The block mapping style.
|
||||
yaml_FLOW_MAPPING_STYLE // The flow mapping style.
|
||||
)
|
||||
|
||||
// Tokens
|
||||
|
||||
type yaml_token_type_t int
|
||||
|
||||
// Token types.
|
||||
const (
|
||||
// An empty token.
|
||||
yaml_NO_TOKEN yaml_token_type_t = iota
|
||||
|
||||
yaml_STREAM_START_TOKEN // A STREAM-START token.
|
||||
yaml_STREAM_END_TOKEN // A STREAM-END token.
|
||||
|
||||
yaml_VERSION_DIRECTIVE_TOKEN // A VERSION-DIRECTIVE token.
|
||||
yaml_TAG_DIRECTIVE_TOKEN // A TAG-DIRECTIVE token.
|
||||
yaml_DOCUMENT_START_TOKEN // A DOCUMENT-START token.
|
||||
yaml_DOCUMENT_END_TOKEN // A DOCUMENT-END token.
|
||||
|
||||
yaml_BLOCK_SEQUENCE_START_TOKEN // A BLOCK-SEQUENCE-START token.
|
||||
yaml_BLOCK_MAPPING_START_TOKEN // A BLOCK-SEQUENCE-END token.
|
||||
yaml_BLOCK_END_TOKEN // A BLOCK-END token.
|
||||
|
||||
yaml_FLOW_SEQUENCE_START_TOKEN // A FLOW-SEQUENCE-START token.
|
||||
yaml_FLOW_SEQUENCE_END_TOKEN // A FLOW-SEQUENCE-END token.
|
||||
yaml_FLOW_MAPPING_START_TOKEN // A FLOW-MAPPING-START token.
|
||||
yaml_FLOW_MAPPING_END_TOKEN // A FLOW-MAPPING-END token.
|
||||
|
||||
yaml_BLOCK_ENTRY_TOKEN // A BLOCK-ENTRY token.
|
||||
yaml_FLOW_ENTRY_TOKEN // A FLOW-ENTRY token.
|
||||
yaml_KEY_TOKEN // A KEY token.
|
||||
yaml_VALUE_TOKEN // A VALUE token.
|
||||
|
||||
yaml_ALIAS_TOKEN // An ALIAS token.
|
||||
yaml_ANCHOR_TOKEN // An ANCHOR token.
|
||||
yaml_TAG_TOKEN // A TAG token.
|
||||
yaml_SCALAR_TOKEN // A SCALAR token.
|
||||
)
|
||||
|
||||
func (tt yaml_token_type_t) String() string {
|
||||
switch tt {
|
||||
case yaml_NO_TOKEN:
|
||||
return "yaml_NO_TOKEN"
|
||||
case yaml_STREAM_START_TOKEN:
|
||||
return "yaml_STREAM_START_TOKEN"
|
||||
case yaml_STREAM_END_TOKEN:
|
||||
return "yaml_STREAM_END_TOKEN"
|
||||
case yaml_VERSION_DIRECTIVE_TOKEN:
|
||||
return "yaml_VERSION_DIRECTIVE_TOKEN"
|
||||
case yaml_TAG_DIRECTIVE_TOKEN:
|
||||
return "yaml_TAG_DIRECTIVE_TOKEN"
|
||||
case yaml_DOCUMENT_START_TOKEN:
|
||||
return "yaml_DOCUMENT_START_TOKEN"
|
||||
case yaml_DOCUMENT_END_TOKEN:
|
||||
return "yaml_DOCUMENT_END_TOKEN"
|
||||
case yaml_BLOCK_SEQUENCE_START_TOKEN:
|
||||
return "yaml_BLOCK_SEQUENCE_START_TOKEN"
|
||||
case yaml_BLOCK_MAPPING_START_TOKEN:
|
||||
return "yaml_BLOCK_MAPPING_START_TOKEN"
|
||||
case yaml_BLOCK_END_TOKEN:
|
||||
return "yaml_BLOCK_END_TOKEN"
|
||||
case yaml_FLOW_SEQUENCE_START_TOKEN:
|
||||
return "yaml_FLOW_SEQUENCE_START_TOKEN"
|
||||
case yaml_FLOW_SEQUENCE_END_TOKEN:
|
||||
return "yaml_FLOW_SEQUENCE_END_TOKEN"
|
||||
case yaml_FLOW_MAPPING_START_TOKEN:
|
||||
return "yaml_FLOW_MAPPING_START_TOKEN"
|
||||
case yaml_FLOW_MAPPING_END_TOKEN:
|
||||
return "yaml_FLOW_MAPPING_END_TOKEN"
|
||||
case yaml_BLOCK_ENTRY_TOKEN:
|
||||
return "yaml_BLOCK_ENTRY_TOKEN"
|
||||
case yaml_FLOW_ENTRY_TOKEN:
|
||||
return "yaml_FLOW_ENTRY_TOKEN"
|
||||
case yaml_KEY_TOKEN:
|
||||
return "yaml_KEY_TOKEN"
|
||||
case yaml_VALUE_TOKEN:
|
||||
return "yaml_VALUE_TOKEN"
|
||||
case yaml_ALIAS_TOKEN:
|
||||
return "yaml_ALIAS_TOKEN"
|
||||
case yaml_ANCHOR_TOKEN:
|
||||
return "yaml_ANCHOR_TOKEN"
|
||||
case yaml_TAG_TOKEN:
|
||||
return "yaml_TAG_TOKEN"
|
||||
case yaml_SCALAR_TOKEN:
|
||||
return "yaml_SCALAR_TOKEN"
|
||||
}
|
||||
return "<unknown token>"
|
||||
}
|
||||
|
||||
// The token structure.
|
||||
type yaml_token_t struct {
|
||||
// The token type.
|
||||
typ yaml_token_type_t
|
||||
|
||||
// The start/end of the token.
|
||||
start_mark, end_mark yaml_mark_t
|
||||
|
||||
// The stream encoding (for yaml_STREAM_START_TOKEN).
|
||||
encoding yaml_encoding_t
|
||||
|
||||
// The alias/anchor/scalar value or tag/tag directive handle
|
||||
// (for yaml_ALIAS_TOKEN, yaml_ANCHOR_TOKEN, yaml_SCALAR_TOKEN, yaml_TAG_TOKEN, yaml_TAG_DIRECTIVE_TOKEN).
|
||||
value []byte
|
||||
|
||||
// The tag suffix (for yaml_TAG_TOKEN).
|
||||
suffix []byte
|
||||
|
||||
// The tag directive prefix (for yaml_TAG_DIRECTIVE_TOKEN).
|
||||
prefix []byte
|
||||
|
||||
// The scalar style (for yaml_SCALAR_TOKEN).
|
||||
style yaml_scalar_style_t
|
||||
|
||||
// The version directive major/minor (for yaml_VERSION_DIRECTIVE_TOKEN).
|
||||
major, minor int8
|
||||
}
|
||||
|
||||
// Events
|
||||
|
||||
type yaml_event_type_t int8
|
||||
|
||||
// Event types.
|
||||
const (
|
||||
// An empty event.
|
||||
yaml_NO_EVENT yaml_event_type_t = iota
|
||||
|
||||
yaml_STREAM_START_EVENT // A STREAM-START event.
|
||||
yaml_STREAM_END_EVENT // A STREAM-END event.
|
||||
yaml_DOCUMENT_START_EVENT // A DOCUMENT-START event.
|
||||
yaml_DOCUMENT_END_EVENT // A DOCUMENT-END event.
|
||||
yaml_ALIAS_EVENT // An ALIAS event.
|
||||
yaml_SCALAR_EVENT // A SCALAR event.
|
||||
yaml_SEQUENCE_START_EVENT // A SEQUENCE-START event.
|
||||
yaml_SEQUENCE_END_EVENT // A SEQUENCE-END event.
|
||||
yaml_MAPPING_START_EVENT // A MAPPING-START event.
|
||||
yaml_MAPPING_END_EVENT // A MAPPING-END event.
|
||||
yaml_TAIL_COMMENT_EVENT
|
||||
)
|
||||
|
||||
var eventStrings = []string{
|
||||
yaml_NO_EVENT: "none",
|
||||
yaml_STREAM_START_EVENT: "stream start",
|
||||
yaml_STREAM_END_EVENT: "stream end",
|
||||
yaml_DOCUMENT_START_EVENT: "document start",
|
||||
yaml_DOCUMENT_END_EVENT: "document end",
|
||||
yaml_ALIAS_EVENT: "alias",
|
||||
yaml_SCALAR_EVENT: "scalar",
|
||||
yaml_SEQUENCE_START_EVENT: "sequence start",
|
||||
yaml_SEQUENCE_END_EVENT: "sequence end",
|
||||
yaml_MAPPING_START_EVENT: "mapping start",
|
||||
yaml_MAPPING_END_EVENT: "mapping end",
|
||||
yaml_TAIL_COMMENT_EVENT: "tail comment",
|
||||
}
|
||||
|
||||
func (e yaml_event_type_t) String() string {
|
||||
if e < 0 || int(e) >= len(eventStrings) {
|
||||
return fmt.Sprintf("unknown event %d", e)
|
||||
}
|
||||
return eventStrings[e]
|
||||
}
|
||||
|
||||
// The event structure.
|
||||
type yaml_event_t struct {
|
||||
|
||||
// The event type.
|
||||
typ yaml_event_type_t
|
||||
|
||||
// The start and end of the event.
|
||||
start_mark, end_mark yaml_mark_t
|
||||
|
||||
// The document encoding (for yaml_STREAM_START_EVENT).
|
||||
encoding yaml_encoding_t
|
||||
|
||||
// The version directive (for yaml_DOCUMENT_START_EVENT).
|
||||
version_directive *yaml_version_directive_t
|
||||
|
||||
// The list of tag directives (for yaml_DOCUMENT_START_EVENT).
|
||||
tag_directives []yaml_tag_directive_t
|
||||
|
||||
// The comments
|
||||
head_comment []byte
|
||||
line_comment []byte
|
||||
foot_comment []byte
|
||||
tail_comment []byte
|
||||
|
||||
// The anchor (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT, yaml_ALIAS_EVENT).
|
||||
anchor []byte
|
||||
|
||||
// The tag (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT).
|
||||
tag []byte
|
||||
|
||||
// The scalar value (for yaml_SCALAR_EVENT).
|
||||
value []byte
|
||||
|
||||
// Is the document start/end indicator implicit, or the tag optional?
|
||||
// (for yaml_DOCUMENT_START_EVENT, yaml_DOCUMENT_END_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT, yaml_SCALAR_EVENT).
|
||||
implicit bool
|
||||
|
||||
// Is the tag optional for any non-plain style? (for yaml_SCALAR_EVENT).
|
||||
quoted_implicit bool
|
||||
|
||||
// The style (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT).
|
||||
style yaml_style_t
|
||||
}
|
||||
|
||||
func (e *yaml_event_t) scalar_style() yaml_scalar_style_t { return yaml_scalar_style_t(e.style) }
|
||||
func (e *yaml_event_t) sequence_style() yaml_sequence_style_t { return yaml_sequence_style_t(e.style) }
|
||||
func (e *yaml_event_t) mapping_style() yaml_mapping_style_t { return yaml_mapping_style_t(e.style) }
|
||||
|
||||
// Nodes
|
||||
|
||||
const (
|
||||
yaml_NULL_TAG = "tag:yaml.org,2002:null" // The tag !!null with the only possible value: null.
|
||||
yaml_BOOL_TAG = "tag:yaml.org,2002:bool" // The tag !!bool with the values: true and false.
|
||||
yaml_STR_TAG = "tag:yaml.org,2002:str" // The tag !!str for string values.
|
||||
yaml_INT_TAG = "tag:yaml.org,2002:int" // The tag !!int for integer values.
|
||||
yaml_FLOAT_TAG = "tag:yaml.org,2002:float" // The tag !!float for float values.
|
||||
yaml_TIMESTAMP_TAG = "tag:yaml.org,2002:timestamp" // The tag !!timestamp for date and time values.
|
||||
|
||||
yaml_SEQ_TAG = "tag:yaml.org,2002:seq" // The tag !!seq is used to denote sequences.
|
||||
yaml_MAP_TAG = "tag:yaml.org,2002:map" // The tag !!map is used to denote mapping.
|
||||
|
||||
// Not in original libyaml.
|
||||
yaml_BINARY_TAG = "tag:yaml.org,2002:binary"
|
||||
yaml_MERGE_TAG = "tag:yaml.org,2002:merge"
|
||||
|
||||
yaml_DEFAULT_SCALAR_TAG = yaml_STR_TAG // The default scalar tag is !!str.
|
||||
yaml_DEFAULT_SEQUENCE_TAG = yaml_SEQ_TAG // The default sequence tag is !!seq.
|
||||
yaml_DEFAULT_MAPPING_TAG = yaml_MAP_TAG // The default mapping tag is !!map.
|
||||
)
|
||||
|
||||
type yaml_node_type_t int
|
||||
|
||||
// Node types.
|
||||
const (
|
||||
// An empty node.
|
||||
yaml_NO_NODE yaml_node_type_t = iota
|
||||
|
||||
yaml_SCALAR_NODE // A scalar node.
|
||||
yaml_SEQUENCE_NODE // A sequence node.
|
||||
yaml_MAPPING_NODE // A mapping node.
|
||||
)
|
||||
|
||||
// An element of a sequence node.
|
||||
type yaml_node_item_t int
|
||||
|
||||
// An element of a mapping node.
|
||||
type yaml_node_pair_t struct {
|
||||
key int // The key of the element.
|
||||
value int // The value of the element.
|
||||
}
|
||||
|
||||
// The node structure.
|
||||
type yaml_node_t struct {
|
||||
typ yaml_node_type_t // The node type.
|
||||
tag []byte // The node tag.
|
||||
|
||||
// The node data.
|
||||
|
||||
// The scalar parameters (for yaml_SCALAR_NODE).
|
||||
scalar struct {
|
||||
value []byte // The scalar value.
|
||||
length int // The length of the scalar value.
|
||||
style yaml_scalar_style_t // The scalar style.
|
||||
}
|
||||
|
||||
// The sequence parameters (for YAML_SEQUENCE_NODE).
|
||||
sequence struct {
|
||||
items_data []yaml_node_item_t // The stack of sequence items.
|
||||
style yaml_sequence_style_t // The sequence style.
|
||||
}
|
||||
|
||||
// The mapping parameters (for yaml_MAPPING_NODE).
|
||||
mapping struct {
|
||||
pairs_data []yaml_node_pair_t // The stack of mapping pairs (key, value).
|
||||
pairs_start *yaml_node_pair_t // The beginning of the stack.
|
||||
pairs_end *yaml_node_pair_t // The end of the stack.
|
||||
pairs_top *yaml_node_pair_t // The top of the stack.
|
||||
style yaml_mapping_style_t // The mapping style.
|
||||
}
|
||||
|
||||
start_mark yaml_mark_t // The beginning of the node.
|
||||
end_mark yaml_mark_t // The end of the node.
|
||||
|
||||
}
|
||||
|
||||
// The document structure.
|
||||
type yaml_document_t struct {
|
||||
|
||||
// The document nodes.
|
||||
nodes []yaml_node_t
|
||||
|
||||
// The version directive.
|
||||
version_directive *yaml_version_directive_t
|
||||
|
||||
// The list of tag directives.
|
||||
tag_directives_data []yaml_tag_directive_t
|
||||
tag_directives_start int // The beginning of the tag directives list.
|
||||
tag_directives_end int // The end of the tag directives list.
|
||||
|
||||
start_implicit int // Is the document start indicator implicit?
|
||||
end_implicit int // Is the document end indicator implicit?
|
||||
|
||||
// The start/end of the document.
|
||||
start_mark, end_mark yaml_mark_t
|
||||
}
|
||||
|
||||
// The prototype of a read handler.
|
||||
//
|
||||
// The read handler is called when the parser needs to read more bytes from the
|
||||
// source. The handler should write not more than size bytes to the buffer.
|
||||
// The number of written bytes should be set to the size_read variable.
|
||||
//
|
||||
// [in,out] data A pointer to an application data specified by
|
||||
// yaml_parser_set_input().
|
||||
// [out] buffer The buffer to write the data from the source.
|
||||
// [in] size The size of the buffer.
|
||||
// [out] size_read The actual number of bytes read from the source.
|
||||
//
|
||||
// On success, the handler should return 1. If the handler failed,
|
||||
// the returned value should be 0. On EOF, the handler should set the
|
||||
// size_read to 0 and return 1.
|
||||
type yaml_read_handler_t func(parser *yaml_parser_t, buffer []byte) (n int, err error)
|
||||
|
||||
// This structure holds information about a potential simple key.
|
||||
type yaml_simple_key_t struct {
|
||||
possible bool // Is a simple key possible?
|
||||
required bool // Is a simple key required?
|
||||
token_number int // The number of the token.
|
||||
mark yaml_mark_t // The position mark.
|
||||
}
|
||||
|
||||
// The states of the parser.
|
||||
type yaml_parser_state_t int
|
||||
|
||||
const (
|
||||
yaml_PARSE_STREAM_START_STATE yaml_parser_state_t = iota
|
||||
|
||||
yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE // Expect the beginning of an implicit document.
|
||||
yaml_PARSE_DOCUMENT_START_STATE // Expect DOCUMENT-START.
|
||||
yaml_PARSE_DOCUMENT_CONTENT_STATE // Expect the content of a document.
|
||||
yaml_PARSE_DOCUMENT_END_STATE // Expect DOCUMENT-END.
|
||||
yaml_PARSE_BLOCK_NODE_STATE // Expect a block node.
|
||||
yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE // Expect a block node or indentless sequence.
|
||||
yaml_PARSE_FLOW_NODE_STATE // Expect a flow node.
|
||||
yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE // Expect the first entry of a block sequence.
|
||||
yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE // Expect an entry of a block sequence.
|
||||
yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE // Expect an entry of an indentless sequence.
|
||||
yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE // Expect the first key of a block mapping.
|
||||
yaml_PARSE_BLOCK_MAPPING_KEY_STATE // Expect a block mapping key.
|
||||
yaml_PARSE_BLOCK_MAPPING_VALUE_STATE // Expect a block mapping value.
|
||||
yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE // Expect the first entry of a flow sequence.
|
||||
yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE // Expect an entry of a flow sequence.
|
||||
yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE // Expect a key of an ordered mapping.
|
||||
yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE // Expect a value of an ordered mapping.
|
||||
yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE // Expect the and of an ordered mapping entry.
|
||||
yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE // Expect the first key of a flow mapping.
|
||||
yaml_PARSE_FLOW_MAPPING_KEY_STATE // Expect a key of a flow mapping.
|
||||
yaml_PARSE_FLOW_MAPPING_VALUE_STATE // Expect a value of a flow mapping.
|
||||
yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE // Expect an empty value of a flow mapping.
|
||||
yaml_PARSE_END_STATE // Expect nothing.
|
||||
)
|
||||
|
||||
func (ps yaml_parser_state_t) String() string {
|
||||
switch ps {
|
||||
case yaml_PARSE_STREAM_START_STATE:
|
||||
return "yaml_PARSE_STREAM_START_STATE"
|
||||
case yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE:
|
||||
return "yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE"
|
||||
case yaml_PARSE_DOCUMENT_START_STATE:
|
||||
return "yaml_PARSE_DOCUMENT_START_STATE"
|
||||
case yaml_PARSE_DOCUMENT_CONTENT_STATE:
|
||||
return "yaml_PARSE_DOCUMENT_CONTENT_STATE"
|
||||
case yaml_PARSE_DOCUMENT_END_STATE:
|
||||
return "yaml_PARSE_DOCUMENT_END_STATE"
|
||||
case yaml_PARSE_BLOCK_NODE_STATE:
|
||||
return "yaml_PARSE_BLOCK_NODE_STATE"
|
||||
case yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE:
|
||||
return "yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE"
|
||||
case yaml_PARSE_FLOW_NODE_STATE:
|
||||
return "yaml_PARSE_FLOW_NODE_STATE"
|
||||
case yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE:
|
||||
return "yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE"
|
||||
case yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE:
|
||||
return "yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE"
|
||||
case yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE:
|
||||
return "yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE"
|
||||
case yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE:
|
||||
return "yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE"
|
||||
case yaml_PARSE_BLOCK_MAPPING_KEY_STATE:
|
||||
return "yaml_PARSE_BLOCK_MAPPING_KEY_STATE"
|
||||
case yaml_PARSE_BLOCK_MAPPING_VALUE_STATE:
|
||||
return "yaml_PARSE_BLOCK_MAPPING_VALUE_STATE"
|
||||
case yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE:
|
||||
return "yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE"
|
||||
case yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE:
|
||||
return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE"
|
||||
case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE:
|
||||
return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE"
|
||||
case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE:
|
||||
return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE"
|
||||
case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE:
|
||||
return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE"
|
||||
case yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE:
|
||||
return "yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE"
|
||||
case yaml_PARSE_FLOW_MAPPING_KEY_STATE:
|
||||
return "yaml_PARSE_FLOW_MAPPING_KEY_STATE"
|
||||
case yaml_PARSE_FLOW_MAPPING_VALUE_STATE:
|
||||
return "yaml_PARSE_FLOW_MAPPING_VALUE_STATE"
|
||||
case yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE:
|
||||
return "yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE"
|
||||
case yaml_PARSE_END_STATE:
|
||||
return "yaml_PARSE_END_STATE"
|
||||
}
|
||||
return "<unknown parser state>"
|
||||
}
|
||||
|
||||
// This structure holds aliases data.
|
||||
type yaml_alias_data_t struct {
|
||||
anchor []byte // The anchor.
|
||||
index int // The node id.
|
||||
mark yaml_mark_t // The anchor mark.
|
||||
}
|
||||
|
||||
// The parser structure.
|
||||
//
|
||||
// All members are internal. Manage the structure using the
|
||||
// yaml_parser_ family of functions.
|
||||
type yaml_parser_t struct {
|
||||
|
||||
// Error handling
|
||||
|
||||
error yaml_error_type_t // Error type.
|
||||
|
||||
problem string // Error description.
|
||||
|
||||
// The byte about which the problem occurred.
|
||||
problem_offset int
|
||||
problem_value int
|
||||
problem_mark yaml_mark_t
|
||||
|
||||
// The error context.
|
||||
context string
|
||||
context_mark yaml_mark_t
|
||||
|
||||
// Reader stuff
|
||||
|
||||
read_handler yaml_read_handler_t // Read handler.
|
||||
|
||||
input_reader io.Reader // File input data.
|
||||
input []byte // String input data.
|
||||
input_pos int
|
||||
|
||||
eof bool // EOF flag
|
||||
|
||||
buffer []byte // The working buffer.
|
||||
buffer_pos int // The current position of the buffer.
|
||||
|
||||
unread int // The number of unread characters in the buffer.
|
||||
|
||||
newlines int // The number of line breaks since last non-break/non-blank character
|
||||
|
||||
raw_buffer []byte // The raw buffer.
|
||||
raw_buffer_pos int // The current position of the buffer.
|
||||
|
||||
encoding yaml_encoding_t // The input encoding.
|
||||
|
||||
offset int // The offset of the current position (in bytes).
|
||||
mark yaml_mark_t // The mark of the current position.
|
||||
|
||||
// Comments
|
||||
|
||||
head_comment []byte // The current head comments
|
||||
line_comment []byte // The current line comments
|
||||
foot_comment []byte // The current foot comments
|
||||
tail_comment []byte // Foot comment that happens at the end of a block.
|
||||
stem_comment []byte // Comment in item preceding a nested structure (list inside list item, etc)
|
||||
|
||||
comments []yaml_comment_t // The folded comments for all parsed tokens
|
||||
comments_head int
|
||||
|
||||
// Scanner stuff
|
||||
|
||||
stream_start_produced bool // Have we started to scan the input stream?
|
||||
stream_end_produced bool // Have we reached the end of the input stream?
|
||||
|
||||
flow_level int // The number of unclosed '[' and '{' indicators.
|
||||
|
||||
tokens []yaml_token_t // The tokens queue.
|
||||
tokens_head int // The head of the tokens queue.
|
||||
tokens_parsed int // The number of tokens fetched from the queue.
|
||||
token_available bool // Does the tokens queue contain a token ready for dequeueing.
|
||||
|
||||
indent int // The current indentation level.
|
||||
indents []int // The indentation levels stack.
|
||||
|
||||
simple_key_allowed bool // May a simple key occur at the current position?
|
||||
simple_keys []yaml_simple_key_t // The stack of simple keys.
|
||||
simple_keys_by_tok map[int]int // possible simple_key indexes indexed by token_number
|
||||
|
||||
// Parser stuff
|
||||
|
||||
state yaml_parser_state_t // The current parser state.
|
||||
states []yaml_parser_state_t // The parser states stack.
|
||||
marks []yaml_mark_t // The stack of marks.
|
||||
tag_directives []yaml_tag_directive_t // The list of TAG directives.
|
||||
|
||||
// Dumper stuff
|
||||
|
||||
aliases []yaml_alias_data_t // The alias data.
|
||||
|
||||
document *yaml_document_t // The currently parsed document.
|
||||
}
|
||||
|
||||
type yaml_comment_t struct {
|
||||
|
||||
scan_mark yaml_mark_t // Position where scanning for comments started
|
||||
token_mark yaml_mark_t // Position after which tokens will be associated with this comment
|
||||
start_mark yaml_mark_t // Position of '#' comment mark
|
||||
end_mark yaml_mark_t // Position where comment terminated
|
||||
|
||||
head []byte
|
||||
line []byte
|
||||
foot []byte
|
||||
}
|
||||
|
||||
// Emitter Definitions
|
||||
|
||||
// The prototype of a write handler.
|
||||
//
|
||||
// The write handler is called when the emitter needs to flush the accumulated
|
||||
// characters to the output. The handler should write @a size bytes of the
|
||||
// @a buffer to the output.
|
||||
//
|
||||
// @param[in,out] data A pointer to an application data specified by
|
||||
// yaml_emitter_set_output().
|
||||
// @param[in] buffer The buffer with bytes to be written.
|
||||
// @param[in] size The size of the buffer.
|
||||
//
|
||||
// @returns On success, the handler should return @c 1. If the handler failed,
|
||||
// the returned value should be @c 0.
|
||||
//
|
||||
type yaml_write_handler_t func(emitter *yaml_emitter_t, buffer []byte) error
|
||||
|
||||
type yaml_emitter_state_t int
|
||||
|
||||
// The emitter states.
|
||||
const (
|
||||
// Expect STREAM-START.
|
||||
yaml_EMIT_STREAM_START_STATE yaml_emitter_state_t = iota
|
||||
|
||||
yaml_EMIT_FIRST_DOCUMENT_START_STATE // Expect the first DOCUMENT-START or STREAM-END.
|
||||
yaml_EMIT_DOCUMENT_START_STATE // Expect DOCUMENT-START or STREAM-END.
|
||||
yaml_EMIT_DOCUMENT_CONTENT_STATE // Expect the content of a document.
|
||||
yaml_EMIT_DOCUMENT_END_STATE // Expect DOCUMENT-END.
|
||||
yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE // Expect the first item of a flow sequence.
|
||||
yaml_EMIT_FLOW_SEQUENCE_TRAIL_ITEM_STATE // Expect the next item of a flow sequence, with the comma already written out
|
||||
yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE // Expect an item of a flow sequence.
|
||||
yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE // Expect the first key of a flow mapping.
|
||||
yaml_EMIT_FLOW_MAPPING_TRAIL_KEY_STATE // Expect the next key of a flow mapping, with the comma already written out
|
||||
yaml_EMIT_FLOW_MAPPING_KEY_STATE // Expect a key of a flow mapping.
|
||||
yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE // Expect a value for a simple key of a flow mapping.
|
||||
yaml_EMIT_FLOW_MAPPING_VALUE_STATE // Expect a value of a flow mapping.
|
||||
yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE // Expect the first item of a block sequence.
|
||||
yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE // Expect an item of a block sequence.
|
||||
yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE // Expect the first key of a block mapping.
|
||||
yaml_EMIT_BLOCK_MAPPING_KEY_STATE // Expect the key of a block mapping.
|
||||
yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE // Expect a value for a simple key of a block mapping.
|
||||
yaml_EMIT_BLOCK_MAPPING_VALUE_STATE // Expect a value of a block mapping.
|
||||
yaml_EMIT_END_STATE // Expect nothing.
|
||||
)
|
||||
|
||||
// The emitter structure.
|
||||
//
|
||||
// All members are internal. Manage the structure using the @c yaml_emitter_
|
||||
// family of functions.
|
||||
type yaml_emitter_t struct {
|
||||
|
||||
// Error handling
|
||||
|
||||
error yaml_error_type_t // Error type.
|
||||
problem string // Error description.
|
||||
|
||||
// Writer stuff
|
||||
|
||||
write_handler yaml_write_handler_t // Write handler.
|
||||
|
||||
output_buffer *[]byte // String output data.
|
||||
output_writer io.Writer // File output data.
|
||||
|
||||
buffer []byte // The working buffer.
|
||||
buffer_pos int // The current position of the buffer.
|
||||
|
||||
raw_buffer []byte // The raw buffer.
|
||||
raw_buffer_pos int // The current position of the buffer.
|
||||
|
||||
encoding yaml_encoding_t // The stream encoding.
|
||||
|
||||
// Emitter stuff
|
||||
|
||||
canonical bool // If the output is in the canonical style?
|
||||
best_indent int // The number of indentation spaces.
|
||||
best_width int // The preferred width of the output lines.
|
||||
unicode bool // Allow unescaped non-ASCII characters?
|
||||
line_break yaml_break_t // The preferred line break.
|
||||
|
||||
state yaml_emitter_state_t // The current emitter state.
|
||||
states []yaml_emitter_state_t // The stack of states.
|
||||
|
||||
events []yaml_event_t // The event queue.
|
||||
events_head int // The head of the event queue.
|
||||
|
||||
indents []int // The stack of indentation levels.
|
||||
|
||||
tag_directives []yaml_tag_directive_t // The list of tag directives.
|
||||
|
||||
indent int // The current indentation level.
|
||||
|
||||
flow_level int // The current flow level.
|
||||
|
||||
root_context bool // Is it the document root context?
|
||||
sequence_context bool // Is it a sequence context?
|
||||
mapping_context bool // Is it a mapping context?
|
||||
simple_key_context bool // Is it a simple mapping key context?
|
||||
|
||||
line int // The current line.
|
||||
column int // The current column.
|
||||
whitespace bool // If the last character was a whitespace?
|
||||
indention bool // If the last character was an indentation character (' ', '-', '?', ':')?
|
||||
open_ended bool // If an explicit document end is required?
|
||||
|
||||
space_above bool // Is there's an empty line above?
|
||||
foot_indent int // The indent used to write the foot comment above, or -1 if none.
|
||||
|
||||
// Anchor analysis.
|
||||
anchor_data struct {
|
||||
anchor []byte // The anchor value.
|
||||
alias bool // Is it an alias?
|
||||
}
|
||||
|
||||
// Tag analysis.
|
||||
tag_data struct {
|
||||
handle []byte // The tag handle.
|
||||
suffix []byte // The tag suffix.
|
||||
}
|
||||
|
||||
// Scalar analysis.
|
||||
scalar_data struct {
|
||||
value []byte // The scalar value.
|
||||
multiline bool // Does the scalar contain line breaks?
|
||||
flow_plain_allowed bool // Can the scalar be expessed in the flow plain style?
|
||||
block_plain_allowed bool // Can the scalar be expressed in the block plain style?
|
||||
single_quoted_allowed bool // Can the scalar be expressed in the single quoted style?
|
||||
block_allowed bool // Can the scalar be expressed in the literal or folded styles?
|
||||
style yaml_scalar_style_t // The output style.
|
||||
}
|
||||
|
||||
// Comments
|
||||
head_comment []byte
|
||||
line_comment []byte
|
||||
foot_comment []byte
|
||||
tail_comment []byte
|
||||
|
||||
key_line_comment []byte
|
||||
|
||||
// Dumper stuff
|
||||
|
||||
opened bool // If the stream was already opened?
|
||||
closed bool // If the stream was already closed?
|
||||
|
||||
// The information associated with the document nodes.
|
||||
anchors *struct {
|
||||
references int // The number of references.
|
||||
anchor int // The anchor id.
|
||||
serialized bool // If the node has been emitted?
|
||||
}
|
||||
|
||||
last_anchor_id int // The last assigned anchor id.
|
||||
|
||||
document *yaml_document_t // The currently emitted document.
|
||||
}
|
||||
-198
@@ -1,198 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2011-2019 Canonical Ltd
|
||||
// Copyright (c) 2006-2010 Kirill Simonov
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
// of the Software, and to permit persons to whom the Software is furnished to do
|
||||
// so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
package yaml
|
||||
|
||||
const (
|
||||
// The size of the input raw buffer.
|
||||
input_raw_buffer_size = 512
|
||||
|
||||
// The size of the input buffer.
|
||||
// It should be possible to decode the whole raw buffer.
|
||||
input_buffer_size = input_raw_buffer_size * 3
|
||||
|
||||
// The size of the output buffer.
|
||||
output_buffer_size = 128
|
||||
|
||||
// The size of the output raw buffer.
|
||||
// It should be possible to encode the whole output buffer.
|
||||
output_raw_buffer_size = (output_buffer_size*2 + 2)
|
||||
|
||||
// The size of other stacks and queues.
|
||||
initial_stack_size = 16
|
||||
initial_queue_size = 16
|
||||
initial_string_size = 16
|
||||
)
|
||||
|
||||
// Check if the character at the specified position is an alphabetical
|
||||
// character, a digit, '_', or '-'.
|
||||
func is_alpha(b []byte, i int) bool {
|
||||
return b[i] >= '0' && b[i] <= '9' || b[i] >= 'A' && b[i] <= 'Z' || b[i] >= 'a' && b[i] <= 'z' || b[i] == '_' || b[i] == '-'
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is a digit.
|
||||
func is_digit(b []byte, i int) bool {
|
||||
return b[i] >= '0' && b[i] <= '9'
|
||||
}
|
||||
|
||||
// Get the value of a digit.
|
||||
func as_digit(b []byte, i int) int {
|
||||
return int(b[i]) - '0'
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is a hex-digit.
|
||||
func is_hex(b []byte, i int) bool {
|
||||
return b[i] >= '0' && b[i] <= '9' || b[i] >= 'A' && b[i] <= 'F' || b[i] >= 'a' && b[i] <= 'f'
|
||||
}
|
||||
|
||||
// Get the value of a hex-digit.
|
||||
func as_hex(b []byte, i int) int {
|
||||
bi := b[i]
|
||||
if bi >= 'A' && bi <= 'F' {
|
||||
return int(bi) - 'A' + 10
|
||||
}
|
||||
if bi >= 'a' && bi <= 'f' {
|
||||
return int(bi) - 'a' + 10
|
||||
}
|
||||
return int(bi) - '0'
|
||||
}
|
||||
|
||||
// Check if the character is ASCII.
|
||||
func is_ascii(b []byte, i int) bool {
|
||||
return b[i] <= 0x7F
|
||||
}
|
||||
|
||||
// Check if the character at the start of the buffer can be printed unescaped.
|
||||
func is_printable(b []byte, i int) bool {
|
||||
return ((b[i] == 0x0A) || // . == #x0A
|
||||
(b[i] >= 0x20 && b[i] <= 0x7E) || // #x20 <= . <= #x7E
|
||||
(b[i] == 0xC2 && b[i+1] >= 0xA0) || // #0xA0 <= . <= #xD7FF
|
||||
(b[i] > 0xC2 && b[i] < 0xED) ||
|
||||
(b[i] == 0xED && b[i+1] < 0xA0) ||
|
||||
(b[i] == 0xEE) ||
|
||||
(b[i] == 0xEF && // #xE000 <= . <= #xFFFD
|
||||
!(b[i+1] == 0xBB && b[i+2] == 0xBF) && // && . != #xFEFF
|
||||
!(b[i+1] == 0xBF && (b[i+2] == 0xBE || b[i+2] == 0xBF))))
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is NUL.
|
||||
func is_z(b []byte, i int) bool {
|
||||
return b[i] == 0x00
|
||||
}
|
||||
|
||||
// Check if the beginning of the buffer is a BOM.
|
||||
func is_bom(b []byte, i int) bool {
|
||||
return b[0] == 0xEF && b[1] == 0xBB && b[2] == 0xBF
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is space.
|
||||
func is_space(b []byte, i int) bool {
|
||||
return b[i] == ' '
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is tab.
|
||||
func is_tab(b []byte, i int) bool {
|
||||
return b[i] == '\t'
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is blank (space or tab).
|
||||
func is_blank(b []byte, i int) bool {
|
||||
//return is_space(b, i) || is_tab(b, i)
|
||||
return b[i] == ' ' || b[i] == '\t'
|
||||
}
|
||||
|
||||
// Check if the character at the specified position is a line break.
|
||||
func is_break(b []byte, i int) bool {
|
||||
return (b[i] == '\r' || // CR (#xD)
|
||||
b[i] == '\n' || // LF (#xA)
|
||||
b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9) // PS (#x2029)
|
||||
}
|
||||
|
||||
func is_crlf(b []byte, i int) bool {
|
||||
return b[i] == '\r' && b[i+1] == '\n'
|
||||
}
|
||||
|
||||
// Check if the character is a line break or NUL.
|
||||
func is_breakz(b []byte, i int) bool {
|
||||
//return is_break(b, i) || is_z(b, i)
|
||||
return (
|
||||
// is_break:
|
||||
b[i] == '\r' || // CR (#xD)
|
||||
b[i] == '\n' || // LF (#xA)
|
||||
b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029)
|
||||
// is_z:
|
||||
b[i] == 0)
|
||||
}
|
||||
|
||||
// Check if the character is a line break, space, or NUL.
|
||||
func is_spacez(b []byte, i int) bool {
|
||||
//return is_space(b, i) || is_breakz(b, i)
|
||||
return (
|
||||
// is_space:
|
||||
b[i] == ' ' ||
|
||||
// is_breakz:
|
||||
b[i] == '\r' || // CR (#xD)
|
||||
b[i] == '\n' || // LF (#xA)
|
||||
b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029)
|
||||
b[i] == 0)
|
||||
}
|
||||
|
||||
// Check if the character is a line break, space, tab, or NUL.
|
||||
func is_blankz(b []byte, i int) bool {
|
||||
//return is_blank(b, i) || is_breakz(b, i)
|
||||
return (
|
||||
// is_blank:
|
||||
b[i] == ' ' || b[i] == '\t' ||
|
||||
// is_breakz:
|
||||
b[i] == '\r' || // CR (#xD)
|
||||
b[i] == '\n' || // LF (#xA)
|
||||
b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028)
|
||||
b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029)
|
||||
b[i] == 0)
|
||||
}
|
||||
|
||||
// Determine the width of the character.
|
||||
func width(b byte) int {
|
||||
// Don't replace these by a switch without first
|
||||
// confirming that it is being inlined.
|
||||
if b&0x80 == 0x00 {
|
||||
return 1
|
||||
}
|
||||
if b&0xE0 == 0xC0 {
|
||||
return 2
|
||||
}
|
||||
if b&0xF0 == 0xE0 {
|
||||
return 3
|
||||
}
|
||||
if b&0xF8 == 0xF0 {
|
||||
return 4
|
||||
}
|
||||
return 0
|
||||
|
||||
}
|
||||
Vendored
+1
-22
@@ -1,6 +1,3 @@
|
||||
# github.com/davecgh/go-spew v1.1.1
|
||||
## explicit
|
||||
github.com/davecgh/go-spew/spew
|
||||
# github.com/google/uuid v1.6.0
|
||||
## explicit
|
||||
github.com/google/uuid
|
||||
@@ -10,24 +7,6 @@ github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/sessions
|
||||
# github.com/pmezard/go-difflib v1.0.0
|
||||
## explicit
|
||||
github.com/pmezard/go-difflib/difflib
|
||||
# github.com/stretchr/objx v0.5.2
|
||||
## explicit; go 1.20
|
||||
github.com/stretchr/objx
|
||||
# github.com/stretchr/testify v1.9.0
|
||||
## explicit; go 1.17
|
||||
github.com/stretchr/testify/assert
|
||||
github.com/stretchr/testify/mock
|
||||
github.com/stretchr/testify/require
|
||||
github.com/stretchr/testify/suite
|
||||
# golang.org/x/sync v0.7.0
|
||||
## explicit; go 1.18
|
||||
golang.org/x/sync/errgroup
|
||||
# golang.org/x/time v0.5.0
|
||||
# golang.org/x/time v0.7.0
|
||||
## explicit; go 1.18
|
||||
golang.org/x/time/rate
|
||||
# gopkg.in/yaml.v3 v3.0.1
|
||||
## explicit
|
||||
gopkg.in/yaml.v3
|
||||
|
||||
Reference in New Issue
Block a user