Files
fasthttp/server_test.go
T
newacorn 3aa972e2fc Fix issues with tests interfering with each other in certain situations. (#1842)
In some cases, the goroutines started by one test do not terminate smoothly before the next round of tests begins, causing interference between tests.

Performance Impact: This results in test completion times not increasing linearly with the count value.

Correctness Impact: It affects the accuracy of memory allocation test cases.
2024-08-31 14:27:05 +02:00

4417 lines
107 KiB
Go

package fasthttp
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"os"
"reflect"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
// Make sure RequestCtx implements context.Context.
var _ context.Context = &RequestCtx{}
type closerWithRequestCtx struct {
ctx *RequestCtx
closeFunc func(ctx *RequestCtx) error
}
func (c *closerWithRequestCtx) Close() error {
return c.closeFunc(c.ctx)
}
func TestServerCRNLAfterPost_Pipeline(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: go.dev\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it
"GET / HTTP/1.1\r\nHost: go.dev\r\n\r\n")); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
func TestServerCRNLAfterPost(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
Logger: &testLogger{},
ReadTimeout: time.Millisecond * 100,
}
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: go.dev\r\nContent-Length: 3\r\n\r\nABC" +
"\r\n\r\n", // <-- this stuff is bogus, but we'll ignore it
)); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err == nil {
t.Fatal("expected error") // We didn't send a request so we should get an error here.
}
}
func TestServerPipelineFlush(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatal(err)
}
// Write a partial request.
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: ")); err != nil {
t.Fatal(err)
}
go func() {
// Wait for 200ms to finish the request
time.Sleep(time.Millisecond * 200)
if _, err = c.Write([]byte("google.com\r\n\r\n")); err != nil {
t.Error(err)
}
}()
start := time.Now()
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
// Since the second request takes 200ms to finish we expect the first one to be flushed earlier.
d := time.Since(start)
if d >= time.Millisecond*200 {
t.Fatalf("had to wait for %v", d)
}
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
func TestServerInvalidHeader(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil {
t.Error("expected Foo header")
}
},
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
t.Fatal(err)
}
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
}
c, err = ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil {
t.Fatal(err)
}
br = bufio.NewReader(c)
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
}
if err := c.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestServerConnState(t *testing.T) {
t.Parallel()
states := make([]string, 0)
s := &Server{
Handler: func(ctx *RequestCtx) {},
ConnState: func(_ net.Conn, state ConnState) {
states = append(states, state.String())
},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
// Send 2 requests on the same connection.
for i := 0; i < 2; i++ {
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
var resp Response
if err := resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
if err := c.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
// Give the server a little bit of time to transition the connection to the close state.
time.Sleep(time.Millisecond * 100)
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// 2 requests so we go to active and idle twice.
expected := []string{"new", "active", "idle", "active", "idle", "closed"}
if !reflect.DeepEqual(expected, states) {
t.Fatalf("wrong state, expected %q, got %q", expected, states)
}
}
func TestSaveMultipartFile(t *testing.T) {
t.Parallel()
filea := "This is a test file."
fileb := strings.Repeat("test", 64)
mr := multipart.NewReader(strings.NewReader(""+
"--foo\r\n"+
"Content-Disposition: form-data; name=\"filea\"; filename=\"filea.txt\"\r\n"+
"Content-Type: text/plain\r\n"+
"\r\n"+
filea+"\r\n"+
"--foo\r\n"+
"Content-Disposition: form-data; name=\"fileb\"; filename=\"fileb.txt\"\r\n"+
"Content-Type: text/plain\r\n"+
"\r\n"+
fileb+"\r\n"+
"--foo--\r\n",
), "foo")
f, err := mr.ReadForm(64)
if err != nil {
t.Fatal(err)
}
if err := SaveMultipartFile(f.File["filea"][0], "filea.txt"); err != nil {
t.Fatal(err)
}
defer os.Remove("filea.txt")
if c, err := os.ReadFile("filea.txt"); err != nil {
t.Fatal(err)
} else if string(c) != filea {
t.Fatalf("filea changed expected %q got %q", filea, c)
}
// Make sure fileb was saved to a file.
if ff, err := f.File["fileb"][0].Open(); err != nil {
t.Fatal("expected FileHeader.Open to work")
} else if _, ok := ff.(*os.File); !ok {
t.Fatal("expected fileb to be an os.File")
} else {
ff.Close()
}
if err := SaveMultipartFile(f.File["fileb"][0], "fileb.txt"); err != nil {
t.Fatal(err)
}
defer os.Remove("fileb.txt")
if c, err := os.ReadFile("fileb.txt"); err != nil {
t.Fatal(err)
} else if string(c) != fileb {
t.Fatalf("fileb changed expected %q got %q", fileb, c)
}
}
func TestServerName(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
getResponse := func() []byte {
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
resp, err := io.ReadAll(&rw.w)
if err != nil {
t.Fatalf("Unexpected error from ReadAll: %v", err)
}
return resp
}
resp := getResponse()
if !bytes.Contains(resp, []byte("\r\nServer: "+defaultServerName+"\r\n")) {
t.Fatalf("Unexpected response %q expected Server: "+defaultServerName, resp)
}
// We can't just overwrite s.Name as fasthttp caches the name in an atomic.Value
s = &Server{
Handler: func(ctx *RequestCtx) {
},
Name: "foobar",
}
resp = getResponse()
if !bytes.Contains(resp, []byte("\r\nServer: foobar\r\n")) {
t.Fatalf("Unexpected response %q expected Server: foobar", resp)
}
s = &Server{
Handler: func(ctx *RequestCtx) {
},
NoDefaultServerHeader: true,
NoDefaultContentType: true,
NoDefaultDate: true,
}
resp = getResponse()
if bytes.Contains(resp, []byte("\r\nServer: ")) {
t.Fatalf("Unexpected response %q expected no Server header", resp)
}
if bytes.Contains(resp, []byte("\r\nContent-Type: ")) {
t.Fatalf("Unexpected response %q expected no Content-Type header", resp)
}
if bytes.Contains(resp, []byte("\r\nDate: ")) {
t.Fatalf("Unexpected response %q expected no Date header", resp)
}
}
func TestRequestCtxString(t *testing.T) {
t.Parallel()
var ctx RequestCtx
s := ctx.String()
expectedS := "#0000000000000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:///"
if s != expectedS {
t.Fatalf("unexpected ctx.String: %q. Expecting %q", s, expectedS)
}
ctx.Request.SetRequestURI("https://foobar.com/aaa?bb=c")
s = ctx.String()
expectedS = "#0000000000000000 - 0.0.0.0:0<->0.0.0.0:0 - GET https://foobar.com/aaa?bb=c"
if s != expectedS {
t.Fatalf("unexpected ctx.String: %q. Expecting %q", s, expectedS)
}
}
func TestServerErrSmallBuffer(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("shouldn't be never called") //nolint:errcheck
},
ReadBufferSize: 20,
}
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: aabb.com\r\nVERY-long-Header: sdfdfsd dsf dsaf dsf df fsd\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
var serverErr error
select {
case serverErr = <-ch:
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout")
}
if serverErr == nil {
t.Fatal("expected error")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
statusCode := resp.StatusCode()
if statusCode != StatusRequestHeaderFieldsTooLarge {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusRequestHeaderFieldsTooLarge)
}
if !resp.ConnectionClose() {
t.Fatal("missing 'Connection: close' response header")
}
expectedErr := errSmallBuffer.Error()
if !strings.Contains(serverErr.Error(), expectedErr) {
t.Fatalf("unexpected log output: %v. Expecting %q", serverErr, expectedErr)
}
}
func TestRequestCtxIsTLS(t *testing.T) {
t.Parallel()
var ctx RequestCtx
// tls.Conn
ctx.c = &tls.Conn{}
if !ctx.IsTLS() {
t.Fatal("IsTLS must return true")
}
// non-tls.Conn
ctx.c = &readWriter{}
if ctx.IsTLS() {
t.Fatal("IsTLS must return false")
}
// overridden tls.Conn
ctx.c = &struct {
*tls.Conn
fooBar bool
}{}
if !ctx.IsTLS() {
t.Fatal("IsTLS must return true")
}
ctx.c = &perIPConn{Conn: &tls.Conn{}}
if !ctx.IsTLS() {
t.Fatal("IsTLS must return true")
}
}
func TestRequestCtxRedirectHTTPSSchemeless(t *testing.T) {
t.Parallel()
var ctx RequestCtx
s := "GET /foo/bar?baz HTTP/1.1\nHost: aaa.com\n\n"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := ctx.Request.Read(br); err != nil {
t.Fatalf("cannot read request: %v", err)
}
ctx.Request.isTLS = true
ctx.Redirect("//foobar.com/aa/bbb", StatusFound)
location := ctx.Response.Header.Peek(HeaderLocation)
expectedLocation := "https://foobar.com/aa/bbb"
if string(location) != expectedLocation {
t.Fatalf("Unexpected location: %q. Expecting %q", location, expectedLocation)
}
}
func TestRequestCtxRedirect(t *testing.T) {
t.Parallel()
testRequestCtxRedirect(t, "http://qqq/", "", "http://qqq/")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "", "http://qqq/foo/bar?baz=111")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "#aaa", "http://qqq/foo/bar?baz=111#aaa")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f", "http://qqq/foo/bar?abc=de&f")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f#sf", "http://qqq/foo/bar?abc=de&f#sf")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html", "http://qqq/foo/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?a=1", "http://qqq/foo/x.html?a=1")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html#aaa=bbb&cc=ddd", "http://qqq/foo/x.html#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?b=1#aaa=bbb&cc=ddd", "http://qqq/foo/x.html?b=1#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html#aaa=bbb&cc=ddd", "http://qqq/x.html#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz")
testRequestCtxRedirect(t, "https://foo.com/bar?aaa", "//google.com/aaa?bb", "https://google.com/aaa?bb")
if runtime.GOOS != "windows" {
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html")
}
}
func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) {
var ctx RequestCtx
var req Request
req.SetRequestURI(origURL)
ctx.Init(&req, nil, nil)
ctx.Redirect(redirectURL, StatusFound)
loc := ctx.Response.Header.Peek(HeaderLocation)
if string(loc) != expectedURL {
t.Fatalf("unexpected redirect url %q. Expecting %q. origURL=%q, redirectURL=%q", loc, expectedURL, origURL, redirectURL)
}
}
func TestServerResponseServerHeader(t *testing.T) {
t.Parallel()
serverName := "foobar serv"
s := &Server{
Handler: func(ctx *RequestCtx) {
name := ctx.Response.Header.Server()
if string(name) != serverName {
fmt.Fprintf(ctx, "unexpected server name: %q. Expecting %q", name, serverName)
} else {
ctx.WriteString("OK") //nolint:errcheck
}
// make sure the server name is sent to the client after ctx.Response.Reset()
ctx.NotFound()
},
Name: serverName,
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusNotFound {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotFound)
}
if string(resp.Body()) != "404 Page not found" {
t.Errorf("unexpected body: %q. Expecting %q", resp.Body(), "404 Page not found")
}
if string(resp.Header.Server()) != serverName {
t.Errorf("unexpected server header: %q. Expecting %q", resp.Header.Server(), serverName)
}
if err = c.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerResponseBodyStream(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
readyCh := make(chan struct{})
h := func(ctx *RequestCtx) {
ctx.SetConnectionClose()
if ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return false")
}
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "first")
if err := w.Flush(); err != nil {
return
}
<-readyCh
fmt.Fprintf(w, "second")
// there is no need to flush w here, since it will
// be flushed automatically after returning from StreamWriter.
})
if !ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return true")
}
}
serverCh := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if respH.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", respH.StatusCode(), StatusOK)
}
buf := make([]byte, 1024)
n, err := br.Read(buf)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
b := buf[:n]
if string(b) != "5\r\nfirst\r\n" {
t.Errorf("unexpected result %q. Expecting %q", b, "5\r\nfirst\r\n")
}
close(readyCh)
tail, err := io.ReadAll(br)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(tail) != "6\r\nsecond\r\n0\r\n\r\n" {
t.Errorf("unexpected tail %q. Expecting %q", tail, "6\r\nsecond\r\n0\r\n\r\n")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerDisableKeepalive(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
DisableKeepalive: true,
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if !resp.ConnectionClose() {
t.Error("expecting 'Connection: close' response header")
}
if string(resp.Body()) != "OK" {
t.Errorf("unexpected body: %q. Expecting %q", resp.Body(), "OK")
}
// make sure the connection is closed
data, err := io.ReadAll(br)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(data) > 0 {
t.Errorf("unexpected data read from the connection: %q. Expecting empty data", data)
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerMaxConnsPerIPLimit(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
MaxConnsPerIP: 1,
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
fakeLN := &fakeIPListener{
Listener: ln,
}
if err := s.Serve(fakeLN); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c1, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
c2, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c2)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusTooManyRequests {
t.Errorf("unexpected status code for the second connection: %d. Expecting %d",
resp.StatusCode(), StatusTooManyRequests)
}
if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error when writing to the first connection: %v", err)
}
br = bufio.NewReader(c1)
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code for the first connection: %d. Expecting %d",
resp.StatusCode(), StatusOK)
}
if string(resp.Body()) != "OK" {
t.Errorf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
type fakeIPListener struct {
net.Listener
}
func (ln *fakeIPListener) Accept() (net.Conn, error) {
conn, err := ln.Listener.Accept()
if err != nil {
return nil, err
}
return &fakeIPConn{
Conn: conn,
}, nil
}
type fakeIPConn struct {
net.Conn
}
func (conn *fakeIPConn) RemoteAddr() net.Addr {
addr, err := net.ResolveTCPAddr("tcp4", "1.2.3.4:5789")
if err != nil {
panic(fmt.Sprintf("BUG: unexpected error: %v", err))
}
return addr
}
func TestServerConcurrencyLimit(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
Concurrency: 1,
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
go func() {
c1, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
c2, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c2)
var resp Response
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusServiceUnavailable {
t.Errorf("unexpected status code for the second connection: %d. Expecting %d",
resp.StatusCode(), StatusServiceUnavailable)
}
if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error when writing to the first connection: %v", err)
}
br = bufio.NewReader(c1)
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code for the first connection: %d. Expecting %d",
resp.StatusCode(), StatusOK)
}
if string(resp.Body()) != "OK" {
t.Errorf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestRejectedRequestsCount(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
Concurrency: 1,
Logger: &testLogger{},
}
ln := fasthttputil.NewInmemoryListener()
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
clientCh := make(chan struct{})
expectedCount := 5
go func() {
for i := 0; i < expectedCount+1; i++ {
_, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
// The server's worker pool is a separate goroutine, give it
// a little bit of time to process the failed connection,
// otherwise the test may fail from time to time.
time.Sleep(time.Millisecond * 10)
if cnt := s.GetRejectedConnectionsCount(); cnt != uint32(expectedCount) {
t.Errorf("unexpected rejected connections count: %d. Expecting %d",
cnt, expectedCount)
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerWriteFastError(t *testing.T) {
t.Parallel()
s := &Server{
Name: "foobar",
}
var buf bytes.Buffer
expectedBody := "access denied"
s.writeFastError(&buf, StatusForbidden, expectedBody)
br := bufio.NewReader(&buf)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusForbidden {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusForbidden)
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody)
}
server := string(resp.Header.Server())
if server != s.Name {
t.Fatalf("unexpected server: %q. Expecting %q", server, s.Name)
}
contentType := string(resp.Header.ContentType())
if contentType != "text/plain" {
t.Fatalf("unexpected content-type: %q. Expecting %q", contentType, "text/plain")
}
if !resp.Header.ConnectionClose() {
t.Fatal("expecting 'Connection: close' response header")
}
}
func TestServerTLS(t *testing.T) {
t.Parallel()
text := []byte("Make fasthttp great again")
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Write(text) //nolint:errcheck
},
}
certData, keyData, err := GenerateTestCertificate("localhost")
if err != nil {
t.Fatal(err)
}
err = s.AppendCertEmbed(certData, keyData)
if err != nil {
t.Fatal(err)
}
go func() {
err = s.ServeTLS(ln, "", "")
if err != nil {
t.Error(err)
}
}()
c := &Client{
ReadTimeout: time.Second * 2,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
req, res := AcquireRequest(), AcquireResponse()
req.SetRequestURI("https://some.url")
err = c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(text, res.Body()) {
t.Fatal("error transmitting information")
}
}
func TestServerTLSReadTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
ReadTimeout: time.Millisecond * 500,
Logger: &testLogger{}, // Ignore log output.
Handler: func(ctx *RequestCtx) {
},
}
certData, keyData, err := GenerateTestCertificate("localhost")
if err != nil {
t.Fatal(err)
}
err = s.AppendCertEmbed(certData, keyData)
if err != nil {
t.Fatal(err)
}
go func() {
err = s.ServeTLS(ln, "", "")
if err != nil {
t.Error(err)
}
}()
c, err := ln.Dial()
if err != nil {
t.Error(err)
}
r := make(chan error)
go func() {
b := make([]byte, 1)
_, err := c.Read(b)
c.Close()
r <- err
}()
select {
case err = <-r:
case <-time.After(time.Second * 2):
}
if err == nil {
t.Error("server didn't close connection after timeout")
}
}
func TestServerServeTLSEmbed(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
certData, keyData, err := GenerateTestCertificate("localhost")
if err != nil {
t.Fatal(err)
}
// start the server
ch := make(chan struct{})
go func() {
err := ServeTLSEmbed(ln, certData, keyData, func(ctx *RequestCtx) {
if !ctx.IsTLS() {
ctx.Error("expecting tls", StatusBadRequest)
return
}
if !ctx.URI().isHTTPS() {
ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", ctx.URI().Scheme(), "https"), StatusBadRequest)
return
}
ctx.WriteString("success") //nolint:errcheck
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
// establish connection to the server
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
})
// send request
if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
// read response
respCh := make(chan struct{})
go func() {
br := bufio.NewReader(tlsConn)
var resp Response
if err := resp.Read(br); err != nil {
t.Error("unexpected error")
}
body := resp.Body()
if string(body) != "success" {
t.Errorf("unexpected response body %q. Expecting %q", body, "success")
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// close the server
if err = ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerMultipartFormDataRequest(t *testing.T) {
t.Parallel()
for _, test := range []struct {
StreamRequestBody bool
DisablePreParseMultipartForm bool
}{
{StreamRequestBody: false, DisablePreParseMultipartForm: false},
{StreamRequestBody: false, DisablePreParseMultipartForm: true},
{StreamRequestBody: true, DisablePreParseMultipartForm: false},
{StreamRequestBody: true, DisablePreParseMultipartForm: true},
} {
reqS := `POST /upload HTTP/1.1
Host: qwerty.com
Content-Length: 521
Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg
------WebKitFormBoundaryJwfATyF8tmxSJnLg
Content-Disposition: form-data; name="f1"
value1
------WebKitFormBoundaryJwfATyF8tmxSJnLg
Content-Disposition: form-data; name="fileaaa"; filename="TODO"
Content-Type: application/octet-stream
- SessionClient with referer and cookies support.
- Client with requests' pipelining support.
- ProxyHandler similar to FSHandler.
- WebSockets. See https://tools.ietf.org/html/rfc6455 .
- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 .
------WebKitFormBoundaryJwfATyF8tmxSJnLg--
GET / HTTP/1.1
Host: asbd
Connection: close
`
ln := fasthttputil.NewInmemoryListener()
s := &Server{
StreamRequestBody: test.StreamRequestBody,
DisablePreParseMultipartForm: test.DisablePreParseMultipartForm,
Handler: func(ctx *RequestCtx) {
switch string(ctx.Path()) {
case "/upload":
f, err := ctx.MultipartForm()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(f.Value) != 1 {
t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
}
if len(f.File) != 1 {
t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1)
}
fv := ctx.FormValue("f1")
if string(fv) != "value1" {
t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1")
}
ctx.Redirect("/", StatusSeeOther)
default:
ctx.WriteString("non-upload") //nolint:errcheck
}
},
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
t.Fatalf("unexpected error: %v", err)
}
var resp Response
br := bufio.NewReader(conn)
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
t.Errorf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusSeeOther {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
}
loc := resp.Header.Peek(HeaderLocation)
if string(loc) != "http://qwerty.com/" {
t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/")
}
if err := resp.Read(br); err != nil {
t.Errorf("error when reading the second response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "non-upload" {
t.Errorf("unexpected body %q. Expecting %q", body, "non-upload")
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %v", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
}
func TestServerGetWithContent(t *testing.T) {
t.Parallel()
h := func(ctx *RequestCtx) {
ctx.Success("foo/bar", []byte("success"))
}
s := &Server{
Handler: h,
}
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: mm.com\r\nContent-Length: 5\r\n\r\nabcde")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
resp := rw.w.String()
if !strings.HasSuffix(resp, "success") {
t.Fatalf("unexpected response %q.", resp)
}
}
func TestServerDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
headerName := "CASE-senSITive-HEAder-NAME"
headerNameLower := strings.ToLower(headerName)
headerValue := "foobar baz"
s := &Server{
Handler: func(ctx *RequestCtx) {
hv := ctx.Request.Header.Peek(headerName)
if string(hv) != headerValue {
t.Errorf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue)
}
hv = ctx.Request.Header.Peek(headerNameLower)
if len(hv) > 0 {
t.Errorf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv)
}
ctx.Response.Header.Set(headerName, headerValue)
ctx.WriteString("ok") //nolint:errcheck
ctx.SetContentType("aaa")
},
DisableHeaderNamesNormalizing: true,
}
rw := &readWriter{}
rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue))
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.Header.DisableNormalizing()
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
hv := resp.Header.Peek(headerName)
if string(hv) != headerValue {
t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue)
}
hv = resp.Header.Peek(headerNameLower)
if len(hv) > 0 {
t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv)
}
}
func TestServerReduceMemoryUsageSerial(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {},
ReduceMemoryUsage: true,
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
testServerRequests(t, ln)
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %v", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestServerReduceMemoryUsageConcurrent(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {},
ReduceMemoryUsage: true,
}
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
gCh := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
testServerRequests(t, ln)
gCh <- struct{}{}
}()
}
for i := 0; i < 10; i++ {
select {
case <-gCh:
case <-time.After(time.Second):
t.Fatalf("timeout on goroutine %d", i)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("error when closing listener: %v", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) {
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
var resp Response
for i := 0; i < 10; i++ {
if _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: aaa\r\n\r\n"); err != nil {
t.Fatalf("unexpected error on iteration %d: %v", i, err)
}
respCh := make(chan struct{})
go func() {
if err = resp.Read(br); err != nil {
t.Errorf("unexpected error when reading response on iteration %d: %v", i, err)
}
close(respCh)
}()
select {
case <-respCh:
case <-time.After(time.Second):
t.Fatalf("timeout on iteration %d", i)
}
}
if err = conn.Close(); err != nil {
t.Fatalf("error when closing the connection: %v", err)
}
}
func TestServerHTTP10ConnectionKeepAlive(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
err := Serve(ln, func(ctx *RequestCtx) {
if string(ctx.Path()) == "/close" {
ctx.SetConnectionClose()
}
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
if err != nil {
t.Fatalf("error when writing request: %v", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
if err != nil {
t.Fatalf("error when writing request: %v", err)
}
br := bufio.NewReader(conn)
var resp Response
if err = resp.Read(br); err != nil {
t.Fatalf("error when reading response: %v", err)
}
if resp.ConnectionClose() {
t.Fatal("response mustn't have 'Connection: close' header")
}
if err = resp.Read(br); err != nil {
t.Fatalf("error when reading response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("response must have 'Connection: close' header")
}
tailCh := make(chan struct{})
go func() {
tail, err := io.ReadAll(br)
if err != nil {
t.Errorf("error when reading tail: %v", err)
}
if len(tail) > 0 {
t.Errorf("unexpected non-zero tail %q", tail)
}
close(tailCh)
}()
select {
case <-tailCh:
case <-time.After(time.Second):
t.Fatal("timeout when reading tail")
}
if err = conn.Close(); err != nil {
t.Fatalf("error when closing the connection: %v", err)
}
if err = ln.Close(); err != nil {
t.Fatalf("error when closing listener: %v", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestServerHTTP10ConnectionClose(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
ch := make(chan struct{})
go func() {
err := Serve(ln, func(ctx *RequestCtx) {
// The server must close the connection irregardless
// of request and response state set inside request
// handler, since the HTTP/1.0 request
// had no 'Connection: keep-alive' header.
ctx.Request.Header.ResetConnectionClose()
ctx.Request.Header.Set(HeaderConnection, "keep-alive")
ctx.Response.Header.ResetConnectionClose()
ctx.Response.Header.Set(HeaderConnection, "keep-alive")
})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\n\r\n")
if err != nil {
t.Fatalf("error when writing request: %v", err)
}
br := bufio.NewReader(conn)
var resp Response
if err = resp.Read(br); err != nil {
t.Fatalf("error when reading response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("HTTP1.0 response must have 'Connection: close' header")
}
tailCh := make(chan struct{})
go func() {
tail, err := io.ReadAll(br)
if err != nil {
t.Errorf("error when reading tail: %v", err)
}
if len(tail) > 0 {
t.Errorf("unexpected non-zero tail %q", tail)
}
close(tailCh)
}()
select {
case <-tailCh:
case <-time.After(time.Second):
t.Fatal("timeout when reading tail")
}
if err = conn.Close(); err != nil {
t.Fatalf("error when closing the connection: %v", err)
}
if err = ln.Close(); err != nil {
t.Fatalf("error when closing listener: %v", err)
}
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout when waiting for the server to stop")
}
}
func TestRequestCtxFormValue(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.SetRequestURI("/foo/bar?baz=123&aaa=bbb")
req.SetBodyString("qqq=port&mmm=sddd")
req.Header.SetContentType("application/x-www-form-urlencoded")
ctx.Init(&req, nil, nil)
v := ctx.FormValue("baz")
if string(v) != "123" {
t.Fatalf("unexpected value %q. Expecting %q", v, "123")
}
v = ctx.FormValue("mmm")
if string(v) != "sddd" {
t.Fatalf("unexpected value %q. Expecting %q", v, "sddd")
}
v = ctx.FormValue("aaaasdfsdf")
if len(v) > 0 {
t.Fatalf("unexpected value for unknown key %q", v)
}
}
func TestSetStandardFormValueFunc(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
req.SetRequestURI("/foo/bar?aaa=bbb")
req.SetBodyString("aaa=port")
req.Header.SetContentType("application/x-www-form-urlencoded")
ctx.Init(&req, nil, nil)
ctx.formValueFunc = NetHttpFormValueFunc
v := ctx.FormValue("aaa")
if string(v) != "port" {
t.Fatalf("unexpected value %q. Expecting %q", v, "port")
}
}
func TestRequestCtxUserValue(t *testing.T) {
t.Parallel()
var ctx RequestCtx
for i := 0; i < 5; i++ {
k := fmt.Sprintf("key-%d", i)
ctx.SetUserValue(k, i)
}
for i := 5; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
ctx.SetUserValueBytes([]byte(k), i)
}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
v := ctx.UserValue(k)
n, ok := v.(int)
if !ok || n != i {
t.Fatalf("unexpected value obtained for key %q: %v. Expecting %d", k, v, i)
}
}
vlen := 0
ctx.VisitUserValues(func(key []byte, value any) {
vlen++
v := ctx.UserValue(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
})
if len(ctx.userValues) != vlen {
t.Fatalf("the length of user values returned from VisitUserValues is not equal to the length of the userValues, expecting: %d but got: %d", len(ctx.userValues), vlen)
}
ctx.ResetUserValues()
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
v := ctx.UserValue(k)
if v != nil {
t.Fatalf("unexpected value obtained for key %q: %v. Expecting nil", k, v)
}
}
}
func TestServerHeadRequest(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
fmt.Fprintf(ctx, "Request method is %q", ctx.Method())
ctx.SetContentType("aaa/bbb")
},
}
rw := &readWriter{}
rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.SkipBody = true
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %v", err)
}
if resp.Header.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK)
}
if len(resp.Body()) > 0 {
t.Fatalf("Unexpected non-zero body %q", resp.Body())
}
if resp.Header.ContentLength() != 24 {
t.Fatalf("unexpected content-length %d. Expecting %d", resp.Header.ContentLength(), 24)
}
if string(resp.Header.ContentType()) != "aaa/bbb" {
t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb")
}
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
func TestServerExpect100Continue(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar") //nolint:errcheck
},
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar")
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
func TestServerContinueHandler(t *testing.T) {
t.Parallel()
acceptContentLength := 5
s := &Server{
ContinueHandler: func(headers *RequestHeader) bool {
if !headers.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", headers.Method())
}
ct := headers.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpected content-type: %q. Expecting %q", ct, "a/b")
}
// Pass on any request that isn't the accepted content length
return headers.contentLength == acceptContentLength
},
Handler: func(ctx *RequestCtx) {
if ctx.Request.Header.contentLength != acceptContentLength {
t.Errorf("all requests with content-length: other than %d, should be denied", acceptContentLength)
}
if !ctx.IsPost() {
t.Errorf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Errorf("unexpected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar") //nolint:errcheck
},
}
sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) {
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse)
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
// The same server should not fail when handling the three different types of requests
// Regular requests
// Expect 100 continue accepted
// Expect 100 continue denied
rw := &readWriter{}
for i := 0; i < 25; i++ {
// Regular requests without Expect 100 continue header
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")
// Regular Expect 100 continue requests that are accepted
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
sendRequest(rw, StatusOK, "foobar")
// Requests being denied
rw.r.Reset()
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456")
sendRequest(rw, StatusExpectationFailed, "")
}
}
func TestCompressHandler(t *testing.T) {
t.Parallel()
expectedBody := string(createFixedBody(2e4))
h := CompressHandler(func(ctx *RequestCtx) {
ctx.WriteString(expectedBody) //nolint:errcheck
})
var ctx RequestCtx
var resp Response
// verify uncompressed response
h(&ctx)
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce := resp.Header.ContentEncoding()
if len(ce) != 0 {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify gzip-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// an attempt to compress already compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
hh := CompressHandler(h)
hh(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err = resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify deflate-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderAcceptEncoding, "foobar, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "deflate" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate")
}
body, err = resp.BodyInflate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
}
func TestCompressHandlerVary(t *testing.T) {
t.Parallel()
expectedBody := string(createFixedBody(2e4))
h := CompressHandlerBrotliLevel(func(ctx *RequestCtx) {
ctx.WriteString(expectedBody) //nolint:errcheck
}, CompressBrotliBestSpeed, CompressBestSpeed)
var ctx RequestCtx
var resp Response
// verify uncompressed response
h(&ctx)
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce := resp.Header.ContentEncoding()
if len(ce) != 0 {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
vary := resp.Header.Peek("Vary")
if len(vary) != 0 {
t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "")
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify gzip-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
vary = resp.Header.Peek("Vary")
if string(vary) != "Accept-Encoding" {
t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding")
}
body, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// an attempt to compress already compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
hh := CompressHandler(h)
hh(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
vary = resp.Header.Peek("Vary")
if string(vary) != "Accept-Encoding" {
t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding")
}
body, err = resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify deflate-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderAcceptEncoding, "foobar, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "deflate" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate")
}
vary = resp.Header.Peek("Vary")
if string(vary) != "Accept-Encoding" {
t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding")
}
body, err = resp.BodyInflate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify br-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip, deflate, br")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ce = resp.Header.ContentEncoding()
if string(ce) != "br" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "br")
}
vary = resp.Header.Peek("Vary")
if string(vary) != "Accept-Encoding" {
t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding")
}
body, err = resp.BodyUnbrotli()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
}
func TestRequestCtxWriteString(t *testing.T) {
t.Parallel()
var ctx RequestCtx
n, err := ctx.WriteString("foo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if n != 3 {
t.Fatalf("unexpected n %d. Expecting 3", n)
}
n, err = ctx.WriteString("привет")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if n != 12 {
t.Fatalf("unexpected n=%d. Expecting 12", n)
}
s := ctx.Response.Body()
if string(s) != "fooпривет" {
t.Fatalf("unexpected response body %q. Expecting %q", s, "fooпривет")
}
}
func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) {
t.Parallel()
reqStr := "POST /foo HTTP/1.0\r\nHost: google.com\r\nContent-Type: application/octet-stream\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n"
respRegex := regexp.MustCompile("HTTP/1.1 308 Permanent Redirect\r\nServer: fasthttp\r\nDate: (.*)\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n")
rw := &readWriter{}
rw.r.WriteString(reqStr)
var resultReqStr, resultRespStr string
ch := make(chan struct{})
go func() {
err := ServeConn(rw, func(ctx *RequestCtx) {
ctx.Response.SetStatusCode(StatusPermanentRedirect)
ctx.SetUserValue("myKey", &closerWithRequestCtx{
ctx: ctx,
closeFunc: func(closerCtx *RequestCtx) error {
resultReqStr = closerCtx.Request.String()
resultRespStr = closerCtx.Response.String()
return nil
},
})
})
if err != nil {
t.Errorf("unexpected error in ServeConn: %v", err)
}
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if resultReqStr != reqStr {
t.Errorf("Request == %q, want %q", resultReqStr, reqStr)
}
if !respRegex.MatchString(resultRespStr) {
t.Errorf("Response == %q, want regex %q", resultRespStr, respRegex)
}
}
// TestServerErrorHandler tests unexpected cases the for loop will break
// before request/response reset call. in such cases, call it before
// release to fix #548.
func TestServerErrorHandler(t *testing.T) {
t.Parallel()
var resultReqStr, resultRespStr string
s := &Server{
Handler: func(ctx *RequestCtx) {},
ErrorHandler: func(ctx *RequestCtx, _ error) {
resultReqStr = ctx.Request.String()
resultRespStr = ctx.Response.String()
},
MaxRequestBodySize: 10,
}
reqStrTpl := "POST %s HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/octet-stream\r\nContent-Length: %d\r\nConnection: keep-alive\r\n\r\n"
respRegex := regexp.MustCompile("HTTP/1.1 200 OK\r\nDate: (.*)\r\nContent-Length: 0\r\n\r\n")
rw := &readWriter{}
for i := 0; i < 100; i++ {
body := strings.Repeat("@", s.MaxRequestBodySize+1)
path := fmt.Sprintf("/%d", i)
reqStr := fmt.Sprintf(reqStrTpl, path, len(body))
expectedReqStr := fmt.Sprintf(reqStrTpl, path, 0)
rw.r.WriteString(reqStr)
rw.r.WriteString(body)
ch := make(chan struct{})
go func() {
err := s.ServeConn(rw)
if err != nil && !errors.Is(err, ErrBodyTooLarge) {
t.Errorf("unexpected error in ServeConn: %v", err)
}
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if resultReqStr != expectedReqStr {
t.Errorf("[iter: %d] Request == %q, want %s", i, resultReqStr, reqStr)
}
if !respRegex.MatchString(resultRespStr) {
t.Errorf("[iter: %d] Response == %q, want regex %q", i, resultRespStr, respRegex)
}
}
}
func TestServeConnHijackResetUserValues(t *testing.T) {
t.Parallel()
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("")
ch := make(chan struct{})
go func() {
err := ServeConn(rw, func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {})
ctx.SetUserValue("myKey", &closerWithRequestCtx{
closeFunc: func(_ *RequestCtx) error {
close(ch)
return nil
},
},
)
})
if err != nil {
t.Errorf("unexpected error in ServeConn: %v", err)
}
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Errorf("Timeout: UserValues should be reset")
}
}
func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
t.Parallel()
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.0\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /must/be/ignored HTTP/1.0\r\nHost: google.com\r\n\r\n")
requestsServed := 0
ch := make(chan struct{})
go func() {
err := ServeConn(rw, func(ctx *RequestCtx) {
requestsServed++
ctx.SuccessString("aaa/bbb", "foobar")
})
if err != nil {
t.Errorf("unexpected error in ServeConn: %v", err)
}
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
// verify the first response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %v", err)
}
if string(resp.Header.Peek(HeaderConnection)) != "keep-alive" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "keep-alive")
}
if resp.Header.ConnectionClose() {
t.Fatal("unexpected Connection: close")
}
// verify the second response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %v", err)
}
if string(resp.Header.Peek(HeaderConnection)) != "close" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "close")
}
if !resp.Header.ConnectionClose() {
t.Fatal("expecting Connection: close")
}
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after responses %q", data)
}
if requestsServed != 2 {
t.Fatalf("unexpected number of requests served: %d. Expecting 2", requestsServed)
}
}
func TestRequestCtxSetBodyStreamWriter(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
if ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return false")
}
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "body writer line 1\n")
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %v", err)
}
fmt.Fprintf(w, "body writer line 2\n")
})
if !ctx.IsBodyStream() {
t.Fatal("IsBodyStream must return true")
}
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Error when reading response: %v", err)
}
body := string(resp.Body())
expectedBody := "body writer line 1\nbody writer line 2\n"
if body != expectedBody {
t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody)
}
}
func TestRequestCtxIfModifiedSince(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
lastModified := time.Now().Add(-time.Hour)
if !ctx.IfModifiedSince(lastModified) {
t.Fatal("IfModifiedSince must return true for non-existing If-Modified-Since header")
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
if ctx.IfModifiedSince(lastModified) {
t.Fatal("If-Modified-Since current time must return false")
}
past := lastModified.Add(-time.Hour)
if ctx.IfModifiedSince(past) {
t.Fatal("If-Modified-Since past time must return false")
}
future := lastModified.Add(time.Hour)
if !ctx.IfModifiedSince(future) {
t.Fatal("If-Modified-Since future time must return true")
}
}
func TestRequestCtxSendFileNotModified(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
lastModified, err := FileLastModified(filePath)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
ctx.SendFile(filePath)
s := ctx.Response.String()
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusNotModified {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified)
}
if len(resp.Body()) > 0 {
t.Fatalf("unexpected non-zero response body: %q", resp.Body())
}
}
func TestRequestCtxSendFileModified(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
lastModified, err := FileLastModified(filePath)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
lastModified = lastModified.Add(-time.Hour)
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
ctx.SendFile(filePath)
s := ctx.Response.String()
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
f, err := os.Open(filePath)
if err != nil {
t.Fatalf("cannot open file: %v", err)
}
body, err := io.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("error when reading file: %v", err)
}
if !bytes.Equal(resp.Body(), body) {
t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body)
}
}
func TestRequestCtxSendFile(t *testing.T) {
t.Parallel()
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
ctx.SendFile(filePath)
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := ctx.Response.Write(bw); err != nil {
t.Fatalf("error when writing response: %v", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("error when flushing response: %v", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
f, err := os.Open(filePath)
if err != nil {
t.Fatalf("cannot open file: %v", err)
}
body, err := io.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("error when reading file: %v", err)
}
if !bytes.Equal(resp.Body(), body) {
t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body)
}
}
func testRequestCtxHijack(t *testing.T, s *Server) {
t.Helper()
type hijackSignal struct {
rw *readWriter
id int
}
wg := sync.WaitGroup{}
totalConns := 100
hijackStartCh := make(chan *hijackSignal, totalConns)
hijackStopCh := make(chan *hijackSignal, totalConns)
s.Handler = func(ctx *RequestCtx) {
if ctx.Hijacked() {
t.Error("connection mustn't be hijacked")
}
ctx.Hijack(func(c net.Conn) {
signal := <-hijackStartCh
defer func() {
hijackStopCh <- signal
wg.Done()
}()
b := make([]byte, 1)
stop := false
// ping-pong echo via hijacked conn
for !stop {
n, err := c.Read(b)
if err != nil {
if errors.Is(err, io.EOF) {
stop = true
continue
}
t.Errorf("unexpected read error: %v", err)
} else if n != 1 {
t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
}
if _, err = c.Write(b); err != nil {
t.Errorf("unexpected error when writing data: %v", err)
}
}
})
if !ctx.Hijacked() {
t.Error("connection must be hijacked")
}
ctx.Success("foo/bar", []byte("hijack it!"))
}
hijackedString := "foobar baz hijacked!!!"
for i := 0; i < totalConns; i++ {
wg.Add(1)
go func(t *testing.T, id int) {
t.Helper()
rw := new(readWriter)
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString(hijackedString)
if err := s.ServeConn(rw); err != nil {
t.Errorf("[iter: %d] Unexpected error from serveConn: %v", id, err)
}
hijackStartCh <- &hijackSignal{id: id, rw: rw}
}(t, i)
}
wg.Wait()
count := 0
for count != totalConns {
select {
case signal := <-hijackStopCh:
count++
id := signal.id
rw := signal.rw
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
data, err := io.ReadAll(br)
if err != nil {
t.Errorf("[iter: %d] Unexpected error when reading remaining data: %v", id, err)
return
}
if string(data) != hijackedString {
t.Errorf(
"[iter: %d] Unexpected response %q. Expecting %q",
id, data, hijackedString,
)
return
}
case <-time.After(200 * time.Millisecond):
t.Errorf("timeout")
}
}
close(hijackStartCh)
close(hijackStopCh)
}
func TestRequestCtxHijack(t *testing.T) {
t.Parallel()
testRequestCtxHijack(t, &Server{})
}
func TestRequestCtxHijackReduceMemoryUsage(t *testing.T) {
t.Parallel()
testRequestCtxHijack(t, &Server{
ReduceMemoryUsage: true,
})
}
func TestRequestCtxHijackNoResponse(t *testing.T) {
t.Parallel()
hijackDone := make(chan error)
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {
_, err := c.Write([]byte("test"))
hijackDone <- err
})
ctx.HijackSetNoResponse(true)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
select {
case err := <-hijackDone:
if err != nil {
t.Fatalf("Unexpected error from hijack: %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
if got := rw.w.String(); got != "test" {
t.Errorf(`expected "test", got %q`, got)
}
}
func TestRequestCtxNoHijackNoResponse(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("test") //nolint:errcheck
ctx.HijackSetNoResponse(true)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
bf := bufio.NewReader(
strings.NewReader(rw.w.String()),
)
resp := AcquireResponse()
resp.Read(bf) //nolint:errcheck
if got := string(resp.Body()); got != "test" {
t.Errorf(`expected "test", got %q`, got)
}
}
func TestRequestCtxInit(t *testing.T) {
// This test can't run parallel as it modifies globalConnID.
var ctx RequestCtx
var logger testLogger
globalConnID = 0x123456
ctx.Init(&ctx.Request, zeroTCPAddr, &logger)
ip := ctx.RemoteIP()
if !ip.IsUnspecified() {
t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip)
}
ctx.Logger().Printf("foo bar %d", 10)
expectedLog := "#0012345700000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:/// - foo bar 10\n"
if logger.out != expectedLog {
t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog)
}
}
func TestTimeoutHandlerSuccess(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
h := func(ctx *RequestCtx) {
if string(ctx.Path()) == "/" {
ctx.Success("aaa/bbb", []byte("real response"))
}
}
s := &Server{
Handler: TimeoutHandler(h, 10*time.Second, "timeout!!!"),
}
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
concurrency := 20
clientCh := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
clientCh <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestTimeoutHandlerTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
readyCh := make(chan struct{})
doneCh := make(chan struct{})
h := func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
<-readyCh
doneCh <- struct{}{}
}
s := &Server{
Handler: TimeoutHandler(h, 20*time.Millisecond, "timeout!!!"),
}
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
concurrency := 20
clientCh := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
clientCh <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
close(readyCh)
for i := 0; i < concurrency; i++ {
select {
case <-doneCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestTimeoutHandlerTimeoutReuse(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
h := func(ctx *RequestCtx) {
if string(ctx.Path()) == "/timeout" {
time.Sleep(time.Second)
}
ctx.SetBodyString("ok")
}
s := &Server{
Handler: TimeoutHandler(h, 500*time.Millisecond, "timeout!!!"),
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
if _, err = conn.Write([]byte("GET /timeout HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
verifyResponse(t, br, StatusOK, string(defaultContentType), "ok")
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestServerGetOnly(t *testing.T) {
t.Parallel()
h := func(ctx *RequestCtx) {
if !ctx.IsGet() {
t.Errorf("non-get request: %q", ctx.Method())
}
ctx.Success("foo/bar", []byte("success"))
}
s := &Server{
Handler: h,
GetOnly: true,
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err == nil {
t.Fatal("expecting error")
}
if err != ErrGetOnly {
t.Fatalf("Unexpected error from serveConn: %v. Expecting %v", err, ErrGetOnly)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %v", err)
}
statusCode := resp.StatusCode()
if statusCode != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusBadRequest)
}
if !resp.ConnectionClose() {
t.Fatal("missing 'Connection: close' response header")
}
}
func TestServerTimeoutErrorWithResponse(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
var resp Response
resp.SetStatusCode(123)
resp.SetBodyString("foobar. Should be ignored")
ctx.TimeoutErrorWithResponse(&resp)
resp.SetStatusCode(456)
resp.ResetBody()
fmt.Fprintf(resp.BodyWriter(), "path=%s", ctx.Path())
resp.Header.SetContentType("foo/bar")
ctx.TimeoutErrorWithResponse(&resp)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 456, "foo/bar", "path=/foo")
verifyResponse(t, br, 456, "foo/bar", "path=/bar")
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerTimeoutErrorWithCode(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
ctx.TimeoutErrorWithCode("should be ignored", 234)
ctx.TimeoutErrorWithCode("stolen ctx", StatusBadRequest)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerTimeoutError(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
ctx.TimeoutError("should be ignored")
ctx.TimeoutError("stolen ctx")
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerMaxRequestsPerConn(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {},
MaxRequestsPerConn: 1,
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("Response must have 'connection: close' header")
}
verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType), "")
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerConnectionClose(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.SetConnectionClose()
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("expecting Connection: close header")
}
data, err := io.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerRequestNumAndTime(t *testing.T) {
t.Parallel()
n := uint64(0)
var connT time.Time
s := &Server{
Handler: func(ctx *RequestCtx) {
n++
if ctx.ConnRequestNum() != n {
t.Errorf("unexpected request number: %d. Expecting %d", ctx.ConnRequestNum(), n)
}
if connT.IsZero() {
connT = ctx.ConnTime()
}
if ctx.ConnTime() != connT {
t.Errorf("unexpected serve conn time: %q. Expecting %q", ctx.ConnTime(), connT)
}
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
if n != 3 {
t.Fatalf("unexpected number of requests served: %d. Expecting %d", n, 3)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, string(defaultContentType), "")
}
func TestServerEmptyResponse(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
// do nothing :)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, string(defaultContentType), "")
}
func TestServerLogger(t *testing.T) {
// This test can't run parallel as it modifies globalConnID.
cl := &testLogger{}
s := &Server{
Handler: func(ctx *RequestCtx) {
logger := ctx.Logger()
h := &ctx.Request.Header
logger.Printf("begin")
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, body=%q, remoteAddr=%s",
h.RequestURI(), ctx.Request.Body(), ctx.RemoteAddr())))
logger.Printf("end")
},
Logger: cl,
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 5\r\nContent-Type: aa\r\n\r\nabcde")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
globalConnID = 0
if err := s.ServeConn(rwx); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, body=\"\", remoteAddr=1.2.3.4:8765")
verifyResponse(t, br, 200, "text/html", "requestURI=/foo2, body=\"abcde\", remoteAddr=1.2.3.4:8765")
expectedLogOut := `#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - begin
#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - end
#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - begin
#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - end
`
if cl.out != expectedLogOut {
t.Fatalf("Unexpected logger output: %q. Expected %q", cl.out, expectedLogOut)
}
}
func TestServerRemoteAddr(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
if err := s.ServeConn(rwx); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4")
}
func TestServerCustomRemoteAddr(t *testing.T) {
t.Parallel()
customRemoteAddrHandler := func(h RequestHandler) RequestHandler {
return func(ctx *RequestCtx) {
ctx.SetRemoteAddr(&net.TCPAddr{
IP: []byte{1, 2, 3, 5},
Port: 0,
})
h(ctx)
}
}
s := &Server{
Handler: customRemoteAddrHandler(func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
}),
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
if err := s.ServeConn(rwx); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5")
}
type readWriterRemoteAddr struct {
net.Conn
rw io.ReadWriteCloser
addr net.Addr
}
func (rw *readWriterRemoteAddr) Close() error {
return rw.rw.Close()
}
func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) {
return rw.rw.Read(b)
}
func (rw *readWriterRemoteAddr) Write(b []byte) (int, error) {
return rw.rw.Write(b)
}
func (rw *readWriterRemoteAddr) RemoteAddr() net.Addr {
return rw.addr
}
func (rw *readWriterRemoteAddr) LocalAddr() net.Addr {
return rw.addr
}
func TestServerConnError(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Error("foobar", 423)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.Header.StatusCode() != 423 {
t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423)
}
if resp.Header.ContentLength() != 6 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), 6)
}
if !bytes.Equal(resp.Header.Peek(HeaderContentType), defaultContentType) {
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Peek(HeaderContentType), defaultContentType)
}
if !bytes.Equal(resp.Body(), []byte("foobar")) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}
}
func TestServeConnSingleRequest(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek(HeaderHost))))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
}
func TestServerSetFormValueFunc(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa", ctx.FormValue("aaa"))
},
FormValueFunc: func(ctx *RequestCtx, s string) []byte {
return []byte(s)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "aaa")
}
func TestServeConnMultiRequests(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek(HeaderHost))))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
verifyResponse(t, br, 200, "aaa", "requestURI=/abc, host=foobar.com")
}
func TestShutdown(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(time.Millisecond * 500)
ctx.Success("aaa/bbb", []byte("real response"))
},
}
serveCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
_, err := ln.Dial()
if err == nil {
t.Error("server is still listening")
}
serveCh <- struct{}{}
}()
clientCh := make(chan struct{})
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
verifyResponseHeaderConnection(t, &resp.Header, "")
clientCh <- struct{}{}
}()
time.Sleep(time.Millisecond * 100)
shutdownCh := make(chan struct{})
go func() {
if err := s.Shutdown(); err != nil {
t.Errorf("unexpected error: %v", err)
}
shutdownCh <- struct{}{}
}()
done := 0
for {
select {
case <-time.After(time.Second * 2):
t.Fatal("shutdown took too long")
case <-serveCh:
done++
case <-clientCh:
done++
case <-shutdownCh:
done++
}
if done == 3 {
return
}
}
}
func TestCloseOnShutdown(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(time.Millisecond * 500)
ctx.Success("aaa/bbb", []byte("real response"))
},
CloseOnShutdown: true,
}
serveCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
_, err := ln.Dial()
if err == nil {
t.Error("server is still listening")
}
serveCh <- struct{}{}
}()
clientCh := make(chan struct{})
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
verifyResponseHeaderConnection(t, &resp.Header, "close")
clientCh <- struct{}{}
}()
time.Sleep(time.Millisecond * 100)
shutdownCh := make(chan struct{})
go func() {
if err := s.Shutdown(); err != nil {
t.Errorf("unexpected error: %v", err)
}
shutdownCh <- struct{}{}
}()
done := 0
for {
select {
case <-time.After(time.Second):
t.Fatal("shutdown took too long")
case <-serveCh:
done++
case <-clientCh:
done++
case <-shutdownCh:
done++
}
if done == 3 {
return
}
}
}
func TestShutdownReuse(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
ReadTimeout: time.Millisecond * 100,
Logger: &testLogger{}, // Ignore log output.
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
if err := s.Shutdown(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ln = fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err = ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
if err := s.Shutdown(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestShutdownDone(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
<-ctx.Done()
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
go func() {
// Shutdown won't return if the connection doesn't close,
// which doesn't happen until we read the response.
if err := s.Shutdown(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
// We can only reach this point and get a valid response
// if reading from ctx.Done() returned.
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestShutdownErr(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
// This will panic, but I was not able to intercept with recover()
c, cancel := context.WithCancel(ctx)
defer cancel()
<-c.Done()
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
go func() {
// Shutdown won't return if the connection doesn't close,
// which doesn't happen until we read the response.
if err := s.Shutdown(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
// We can only reach this point and get a valid response
// if reading from ctx.Done() returned.
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestShutdownCloseIdleConns(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
shutdownErr := make(chan error)
go func() {
shutdownErr <- s.Shutdown()
}()
timer := time.NewTimer(time.Second)
select {
case <-timer.C:
t.Fatal("idle connections not closed on shutdown")
case err = <-shutdownErr:
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
}
func TestShutdownWithContext(t *testing.T) {
t.Parallel()
done := make(chan struct{})
defer close(done)
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
<-done
ctx.Success("aaa/bbb", []byte("real response"))
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
time.Sleep(1 * time.Millisecond * 500)
go func() {
conn, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}()
time.Sleep(1 * time.Millisecond * 500)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
shutdownErr := make(chan error)
go func() {
shutdownErr <- s.ShutdownWithContext(ctx)
}()
timer := time.NewTimer(time.Second)
select {
case <-timer.C:
t.Fatal("idle connections not closed on shutdown")
case err := <-shutdownErr:
if err == nil || err != context.DeadlineExceeded {
t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
}
}
if o := atomic.LoadInt32(&s.open); o != 1 {
t.Fatalf("unexpected open connection num: %#v. Expecting %#v", o, 1)
}
}
func TestMultipleServe(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
},
}
ln1 := fasthttputil.NewInmemoryListener()
ln2 := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln1); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
go func() {
if err := s.Serve(ln2); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln1.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
conn, err = ln2.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestMaxBodySizePerRequest(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
// do nothing :)
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
MaxRequestBodySize: 5 << 10,
}
},
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
MaxRequestBodySize: 1 << 20,
}
rw := &readWriter{}
rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1)))
if err := s.ServeConn(rw); err != ErrBodyTooLarge {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
}
func TestStreamRequestBody(t *testing.T) {
t.Parallel()
part1 := strings.Repeat("1", 1<<15)
part2 := strings.Repeat("2", 1<<16)
contentLength := len(part1) + len(part2)
next := make(chan struct{})
s := &Server{
Handler: func(ctx *RequestCtx) {
checkReader(t, ctx.RequestBodyStream(), part1)
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
StreamRequestBody: true,
Logger: &testLogger{},
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
// write headers and part1 body
if _, err := fmt.Fprintf(cc, "POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength); err != nil {
t.Fatal(err)
}
if _, err := cc.Write([]byte(part1)); err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case <-next:
case <-time.After(500 * time.Millisecond):
t.Fatal("part1 timeout")
}
if _, err := cc.Write([]byte(part2)); err != nil {
t.Fatal(err)
}
if err := sc.Close(); err != nil {
t.Fatal(err)
}
select {
case err := <-ch:
if err != nil && err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
}
}
func TestStreamRequestBodyExceedMaxSize(t *testing.T) {
part1 := strings.Repeat("1", 1<<18)
part2 := strings.Repeat("2", 1<<20-1<<18)
contentLength := len(part1) + len(part2)
next := make(chan struct{})
s := &Server{
Handler: func(ctx *RequestCtx) {
checkReader(t, ctx.RequestBodyStream(), part1)
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
DisableKeepalive: true,
StreamRequestBody: true,
MaxRequestBodySize: 1,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
// write headers and part1 body
if _, err := fmt.Fprintf(cc, "POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1); err != nil {
t.Error(err)
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case <-next:
case <-time.After(500 * time.Millisecond):
t.Fatal("part1 timeout")
}
if _, err := cc.Write([]byte(part2)); err != nil {
t.Error(err)
}
select {
case err := <-ch:
if err != nil {
t.Error(err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
}
}
func TestStreamBodyRequestContentLength(t *testing.T) {
content := strings.Repeat("1", 1<<15) // 32K
contentLength := len(content)
s := &Server{
Handler: func(ctx *RequestCtx) {
realContentLength := ctx.Request.Header.ContentLength()
if realContentLength != contentLength {
t.Fatal("incorrect content length")
}
},
MaxRequestBodySize: 1 * 1024 * 1024, // 1M
StreamRequestBody: true,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
if _, err := fmt.Fprintf(cc, "POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, content); err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
if err := sc.Close(); err != nil {
t.Fatal(err)
}
select {
case err := <-ch:
if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
}
}
func checkReader(t *testing.T, r io.Reader, expected string) {
b := make([]byte, len(expected))
if _, err := io.ReadFull(r, b); err != nil {
t.Fatalf("Unexpected error from reader: %v", err)
}
if string(b) != expected {
t.Fatal("incorrect request body")
}
}
func TestMaxReadTimeoutPerRequest(t *testing.T) {
t.Parallel()
headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
s := &Server{
Handler: func(_ *RequestCtx) {
t.Error("shouldn't reach handler")
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
ReadTimeout: time.Millisecond,
}
},
ReadBufferSize: len(headers),
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
go func() {
// write headers
_, err := cc.Write(headers)
if err != nil {
t.Error(err)
}
// write body
for i := 0; i < 5*1024; i++ {
time.Sleep(time.Millisecond)
_, err = cc.Write([]byte{'a'})
if err != nil {
return
}
}
}()
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case err := <-ch:
if err == nil || !strings.EqualFold(err.Error(), "timeout") {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
}
}
func TestMaxWriteTimeoutPerRequest(t *testing.T) {
t.Parallel()
headers := []byte("GET /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: aa\r\n\r\n")
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
var buf [192]byte
for {
_, err := w.Write(buf[:])
if err != nil {
return
}
}
})
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
return RequestConfig{
WriteTimeout: time.Millisecond,
}
},
ReadBufferSize: 192,
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
var resp Response
go func() {
// write headers
_, err := cc.Write(headers)
if err != nil {
t.Error(err)
}
br := bufio.NewReaderSize(cc, 192)
err = resp.Header.Read(br)
if err != nil {
t.Error(err)
}
var chunk [192]byte
for {
time.Sleep(time.Millisecond)
_, err = br.Read(chunk[:])
if err != nil {
return
}
}
}()
ch := make(chan error)
go func() {
ch <- s.ServeConn(sc)
}()
select {
case err := <-ch:
if err == nil || !strings.EqualFold(err.Error(), "timeout") {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
}
}
func TestIncompleteBodyReturnsUnexpectedEOF(t *testing.T) {
t.Parallel()
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\n\r\n123")
s := &Server{
Handler: func(ctx *RequestCtx) {},
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
if err := <-ch; err == nil || err.Error() != "unexpected EOF" {
t.Fatal(err)
}
}
func TestServerChunkedResponse(t *testing.T) {
t.Parallel()
trailer := map[string]string{
"AtEnd1": "1111",
"AtEnd2": "2222",
"AtEnd3": "3333",
}
h := func(ctx *RequestCtx) {
ctx.Response.Header.DisableNormalizing()
ctx.Response.Header.Set("Transfer-Encoding", "chunked")
for k := range trailer {
err := ctx.Response.Header.AddTrailer(k)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) {
for i := 0; i < 3; i++ {
fmt.Fprintf(w, "message %d", i)
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %v", err)
}
time.Sleep(time.Millisecond * 100)
}
})
for k, v := range trailer {
ctx.Response.Header.Set(k, v)
}
}
s := &Server{
Handler: h,
}
rw := &readWriter{}
rw.r.WriteString("GET / HTTP/1.1\r\nHost: test.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.Header.ContentLength() != -1 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), -1)
}
if !bytes.Equal(resp.Body(), []byte("message 0"+"message 1"+"message 2")) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}
for k, v := range trailer {
h := resp.Header.Peek(k)
if !bytes.Equal(resp.Header.Peek(k), []byte(v)) {
t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, h)
}
}
}
func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response {
var resp Response
if err := resp.Read(r); err != nil {
t.Fatalf("Unexpected error when parsing response: %v", err)
}
if !bytes.Equal(resp.Body(), []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType, "")
return &resp
}
type readWriter struct {
net.Conn
r bytes.Buffer
w bytes.Buffer
}
func (rw *readWriter) Close() error {
return nil
}
func (rw *readWriter) Read(b []byte) (int, error) {
return rw.r.Read(b)
}
func (rw *readWriter) Write(b []byte) (int, error) {
return rw.w.Write(b)
}
func (rw *readWriter) RemoteAddr() net.Addr {
return zeroTCPAddr
}
func (rw *readWriter) LocalAddr() net.Addr {
return zeroTCPAddr
}
func (rw *readWriter) SetDeadline(t time.Time) error {
return nil
}
func (rw *readWriter) SetReadDeadline(t time.Time) error {
return nil
}
func (rw *readWriter) SetWriteDeadline(t time.Time) error {
return nil
}
type testLogger struct {
out string
lock sync.Mutex
}
func (cl *testLogger) Printf(format string, args ...any) {
cl.lock.Lock()
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
cl.lock.Unlock()
}
func TestRequestBodyStreamReadIssue1816(t *testing.T) {
pcs := fasthttputil.NewPipeConns()
cliCon, serverCon := pcs.Conn1(), pcs.Conn2()
go func() {
req := AcquireRequest()
defer ReleaseRequest(req)
req.Header.SetContentLength(10)
req.Header.SetMethod("POST")
req.SetRequestURI("http://localhsot:8080")
req.SetBodyRaw(bytes.Repeat([]byte{'1'}, 10))
var pipelineReqBody []byte
reqBody := req.String()
pipelineReqBody = append(pipelineReqBody, reqBody...)
pipelineReqBody = append(pipelineReqBody, reqBody...)
_, err := cliCon.Write(pipelineReqBody)
if err != nil {
t.Error(err)
}
resp := AcquireResponse()
err = resp.Read(bufio.NewReader(cliCon))
if err != nil {
t.Error(err)
}
err = cliCon.Close()
if err != nil {
t.Error(err)
}
}()
server := Server{StreamRequestBody: true, MaxRequestBodySize: 5, Handler: func(ctx *RequestCtx) {
r := ctx.RequestBodyStream()
p := make([]byte, 1300)
for {
_, err := r.Read(p)
if err != nil {
if err != io.EOF {
t.Fatal(err)
}
break
}
}
}}
err := server.serveConn(serverCon)
if err != nil {
t.Fatal(err)
}
}