diff --git a/server.go b/server.go index 83217e4..f7e6b40 100644 --- a/server.go +++ b/server.go @@ -152,6 +152,14 @@ type Server struct { // DefaultConcurrency is used if not set. Concurrency int + // Whether to disable keep-alive connections. + // + // The server will close all the incoming connections after sending + // the first response to client if this option is set to true. + // + // By default keep-alive connections are enabled. + DisableKeepalive bool + // Per-connection buffer size for requests' reading. // This also limits the maximum header size. // @@ -1333,14 +1341,17 @@ func (s *Server) serveConn(c net.Conn) error { connRequestNum := uint64(0) ctx := s.acquireCtx(c) - var br *bufio.Reader - var bw *bufio.Writer - - var err error - var connectionClose bool - var isHTTP11 bool - var timeoutResponse *Response - var hijackHandler HijackHandler + var ( + br *bufio.Reader + bw *bufio.Writer + ) + var ( + err error + connectionClose bool + isHTTP11 bool + timeoutResponse *Response + hijackHandler HijackHandler + ) for { ctx.id++ connRequestNum++ @@ -1422,7 +1433,7 @@ func (s *Server) serveConn(c net.Conn) error { } } - connectionClose = ctx.Request.Header.connectionCloseFast() + connectionClose = s.DisableKeepalive || ctx.Request.Header.connectionCloseFast() isHTTP11 = ctx.Request.Header.IsHTTP11() ctx.connRequestNum = connRequestNum diff --git a/server_test.go b/server_test.go index f144435..a7b46ff 100644 --- a/server_test.go +++ b/server_test.go @@ -16,6 +16,77 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestServerDisableKeepalive(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("OK") + }, + DisableKeepalive: true, + } + + ln := fasthttputil.NewInmemoryListener() + + serverCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Fatalf("unexpected error: %s", err) + } + close(serverCh) + }() + + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + br := bufio.NewReader(c) + var resp Response + if err = resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if !resp.ConnectionClose() { + t.Fatalf("expecting 'Connection: close' response header") + } + if string(resp.Body()) != "OK" { + t.Fatalf("unexpected body: %q. Expecting %q", resp.Body(), "OK") + } + + // make sure the connection is closed + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if len(data) > 0 { + t.Fatalf("unexpected data read from the connection: %q. Expecting empty data", data) + } + + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + select { + case <-serverCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + func TestServerMaxConnsPerIPLimit(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) {