From 8d8443d77cb496b6f8f016e8b60bd4ce99c7fde0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 10 Jan 2020 16:41:16 +0100 Subject: [PATCH] Forward context in fasthttpadaptor (#720) * forward context in fasthttpadaptor * run go fmt --- fasthttpadaptor/adaptor.go | 2 +- fasthttpadaptor/adaptor_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go index 296ef7c..7b33245 100644 --- a/fasthttpadaptor/adaptor.go +++ b/fasthttpadaptor/adaptor.go @@ -82,7 +82,7 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { r.URL = rURL var w netHTTPResponseWriter - h.ServeHTTP(&w, &r) + h.ServeHTTP(&w, r.WithContext(ctx)) ctx.SetStatusCode(w.StatusCode()) for k, vv := range w.Header() { diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go index ba15534..698b204 100644 --- a/fasthttpadaptor/adaptor_test.go +++ b/fasthttpadaptor/adaptor_test.go @@ -32,6 +32,8 @@ func TestNewFastHTTPHandler(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } + expectedContextKey := "contextKey" + expectedContextValue := "contextValue" callsCount := 0 nethttpH := func(w http.ResponseWriter, r *http.Request) { @@ -74,6 +76,9 @@ func TestNewFastHTTPHandler(t *testing.T) { if !reflect.DeepEqual(r.URL, expectedURL) { t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL) } + if r.Context().Value(expectedContextKey) != expectedContextValue { + t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue) + } for k, expectedV := range expectedHeader { v := r.Header.Get(k) @@ -88,6 +93,7 @@ func TestNewFastHTTPHandler(t *testing.T) { fmt.Fprintf(w, "request body is %q", body) } fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH)) + fasthttpH = setContextValueMiddleware(fasthttpH, expectedContextKey, expectedContextValue) var ctx fasthttp.RequestCtx var req fasthttp.Request @@ -128,3 +134,10 @@ func TestNewFastHTTPHandler(t *testing.T) { t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody) } } + +func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value interface{}) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(key, value) + next(ctx) + } +}