make RequestCtx's userdata accept keys that are of type: interface{} (#1387)

Co-authored-by: rocketlaunchr-cto <rocketlaunchr.cloud@gmail.com>
This commit is contained in:
pj
2022-10-07 01:25:32 +11:00
committed by GitHub
parent bcf7e8e944
commit d404f2db91
3 changed files with 40 additions and 21 deletions
+18 -8
View File
@@ -670,7 +670,7 @@ func (ctx *RequestCtx) Hijacked() bool {
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func (ctx *RequestCtx) SetUserValue(key string, value interface{}) {
func (ctx *RequestCtx) SetUserValue(key interface{}, value interface{}) {
ctx.userValues.Set(key, value)
}
@@ -688,7 +688,7 @@ func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) {
}
// UserValue returns the value stored via SetUserValue* under the given key.
func (ctx *RequestCtx) UserValue(key string) interface{} {
func (ctx *RequestCtx) UserValue(key interface{}) interface{} {
return ctx.userValues.Get(key)
}
@@ -698,11 +698,24 @@ func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} {
return ctx.userValues.GetBytes(key)
}
// VisitUserValues calls visitor for each existing userValue.
// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
if _, ok := kv.key.(string); ok {
visitor(s2b(kv.key.(string)), kv.value)
}
}
}
// VisitUserValuesAll calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValuesAll(visitor func(interface{}, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
visitor(kv.key, kv.value)
@@ -715,7 +728,7 @@ func (ctx *RequestCtx) ResetUserValues() {
}
// RemoveUserValue removes the given key and the value under it in ctx.
func (ctx *RequestCtx) RemoveUserValue(key string) {
func (ctx *RequestCtx) RemoveUserValue(key interface{}) {
ctx.userValues.Remove(key)
}
@@ -2696,10 +2709,7 @@ func (ctx *RequestCtx) Err() error {
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key)
func (ctx *RequestCtx) Value(key interface{}) interface{} {
if keyString, ok := key.(string); ok {
return ctx.UserValue(keyString)
}
return nil
return ctx.UserValue(key)
}
var fakeServer = &Server{
+1 -1
View File
@@ -1737,7 +1737,7 @@ func TestRequestCtxUserValue(t *testing.T) {
vlen := 0
ctx.VisitUserValues(func(key []byte, value interface{}) {
vlen++
v := ctx.UserValueBytes(key)
v := ctx.UserValue(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
+21 -12
View File
@@ -5,18 +5,21 @@ import (
)
type userDataKV struct {
key []byte
key interface{}
value interface{}
}
type userData []userDataKV
func (d *userData) Set(key string, value interface{}) {
func (d *userData) Set(key interface{}, value interface{}) {
if b, ok := key.([]byte); ok {
key = string(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
kv.value = value
return
}
@@ -30,28 +33,31 @@ func (d *userData) Set(key string, value interface{}) {
if c > n {
args = args[:n+1]
kv := &args[n]
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = args
return
}
kv := userDataKV{}
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = append(args, kv)
}
func (d *userData) SetBytes(key []byte, value interface{}) {
d.Set(b2s(key), value)
d.Set(key, value)
}
func (d *userData) Get(key string) interface{} {
func (d *userData) Get(key interface{}) interface{} {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
return kv.value
}
}
@@ -59,7 +65,7 @@ func (d *userData) Get(key string) interface{} {
}
func (d *userData) GetBytes(key []byte) interface{} {
return d.Get(b2s(key))
return d.Get(key)
}
func (d *userData) Reset() {
@@ -74,12 +80,15 @@ func (d *userData) Reset() {
*d = (*d)[:0]
}
func (d *userData) Remove(key string) {
func (d *userData) Remove(key interface{}) {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
n--
args[i], args[n] = args[n], args[i]
args[n].value = nil
@@ -91,5 +100,5 @@ func (d *userData) Remove(key string) {
}
func (d *userData) RemoveBytes(key []byte) {
d.Remove(b2s(key))
d.Remove(key)
}