Files
fasthttp/fasthttpadaptor/adaptor_test.go
T
gilwo aefd080674 adaptor ResponseWriter - adding Hijack method and pass proper fields (#1525)
* adding hijack method and pass proper fields

* adding hijack method and pass proper fields - adding tests

* improve hijack handling, use proper test for hijacking

* extend hijackhandler propogation to NewFastHTTPHandlerFunc

* align hijacking of fasthttp adaptor net request with fasthttp request, safe conn handling for proper release of resources and custom hijack handler for more controlled by hijacking implementation

* Implement actual behaviour of net/http Hijacker

---------

Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
2024-02-17 14:51:38 +08:00

219 lines
6.0 KiB
Go

package fasthttpadaptor
import (
"io"
"net"
"net/http"
"net/url"
"reflect"
"testing"
"time"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestNewFastHTTPHandler(t *testing.T) {
t.Parallel()
expectedMethod := fasthttp.MethodPost
expectedProto := "HTTP/1.1"
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "<!doctype html><html>"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expectedContextKey := "contextKey"
expectedContextValue := "contextValue"
expectedContentType := "text/html; charset=utf-8"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
callsCount++
if r.Method != expectedMethod {
t.Fatalf("unexpected method %q. Expecting %q", r.Method, expectedMethod)
}
if r.Proto != expectedProto {
t.Fatalf("unexpected proto %q. Expecting %q", r.Proto, expectedProto)
}
if r.ProtoMajor != expectedProtoMajor {
t.Fatalf("unexpected protoMajor %d. Expecting %d", r.ProtoMajor, expectedProtoMajor)
}
if r.ProtoMinor != expectedProtoMinor {
t.Fatalf("unexpected protoMinor %d. Expecting %d", r.ProtoMinor, expectedProtoMinor)
}
if r.RequestURI != expectedRequestURI {
t.Fatalf("unexpected requestURI %q. Expecting %q", r.RequestURI, expectedRequestURI)
}
if r.ContentLength != int64(expectedContentLength) {
t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength)
}
if len(r.TransferEncoding) != 0 {
t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding)
}
if r.Host != expectedHost {
t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost)
}
if r.RemoteAddr != expectedRemoteAddr {
t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr)
}
body, err := io.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Fatalf("unexpected error when reading request body: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
if !reflect.DeepEqual(r.URL, expectedURL) {
t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL)
}
if r.Context().Value(expectedContextKey) != expectedContextValue {
t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue)
}
for k, expectedV := range expectedHeader {
v := r.Header.Get(k)
if v != expectedV {
t.Fatalf("unexpected header value %q for key %q. Expecting %q", v, k, expectedV)
}
}
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
w.Write(body) //nolint:errcheck
}
fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH))
fasthttpH = setContextValueMiddleware(fasthttpH, expectedContextKey, expectedContextValue)
var ctx fasthttp.RequestCtx
var req fasthttp.Request
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck
for k, v := range expectedHeader {
req.Header.Set(k, v)
}
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ctx.Init(&req, remoteAddr, nil)
fasthttpH(&ctx)
if callsCount != 1 {
t.Fatalf("unexpected callsCount: %d. Expecting 1", callsCount)
}
resp := &ctx.Response
if resp.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("unexpected statusCode: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusBadRequest)
}
if string(resp.Header.Peek("Header1")) != "value1" {
t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header1"), "value1")
}
if string(resp.Header.Peek("Header2")) != "value2" {
t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2")
}
if string(resp.Body()) != expectedBody {
t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedBody)
}
if string(resp.Header.Peek("Content-Type")) != expectedContentType {
t.Fatalf("unexpected response content-type %q. Expecting %q", string(resp.Header.Peek("Content-Type")), expectedBody)
}
}
func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value any) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
ctx.SetUserValue(key, value)
next(ctx)
}
}
func TestHijack(t *testing.T) {
t.Parallel()
nethttpH := func(w http.ResponseWriter, r *http.Request) {
if f, ok := w.(http.Hijacker); !ok {
t.Errorf("expected http.ResponseWriter to implement http.Hijacker")
} else {
if _, err := w.Write([]byte("foo")); err != nil {
t.Error(err)
}
if c, rw, err := f.Hijack(); err != nil {
t.Error(err)
} else {
if _, err := rw.Write([]byte("bar")); err != nil {
t.Error(err)
}
if err := rw.Flush(); err != nil {
t.Error(err)
}
if err := c.Close(); err != nil {
t.Error(err)
}
}
}
}
s := &fasthttp.Server{
Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)),
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
buf, err := io.ReadAll(c)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(buf) != "foobar" {
t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}