diff --git a/server.go b/server.go index 6b48a74..98818c4 100644 --- a/server.go +++ b/server.go @@ -1350,6 +1350,7 @@ func nextConnID() uint64 { } func (s *Server) serveConn(c net.Conn) error { + serverName := s.getServerName() connRequestNum := uint64(0) connID := nextConnID() currentTime := time.Now() @@ -1445,6 +1446,7 @@ func (s *Server) serveConn(c net.Conn) error { connectionClose = s.DisableKeepalive || ctx.Request.Header.connectionCloseFast() isHTTP11 = ctx.Request.Header.IsHTTP11() + ctx.Response.Header.SetServerBytes(serverName) ctx.connID = connID ctx.connRequestNum = connRequestNum ctx.connTime = connTime @@ -1667,15 +1669,7 @@ func writeResponse(ctx *RequestCtx, w *bufio.Writer) error { if ctx.timeoutResponse != nil { panic("BUG: cannot write timed out response") } - h := &ctx.Response.Header - serverOld := h.Server() - if len(serverOld) == 0 { - h.server = ctx.s.getServerName() - } err := ctx.Response.Write(w) - if len(serverOld) == 0 { - h.server = serverOld - } ctx.Response.Reset() return err } diff --git a/server_test.go b/server_test.go index ca1cd36..ae87d2e 100644 --- a/server_test.go +++ b/server_test.go @@ -17,6 +17,75 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestServerResponseServerHeader(t *testing.T) { + serverName := "foobar serv" + + s := &Server{ + Handler: func(ctx *RequestCtx) { + name := ctx.Response.Header.Server() + if string(name) != serverName { + fmt.Fprintf(ctx, "unexpected server name: %q. Expecting %q", name, serverName) + } else { + ctx.WriteString("OK") + } + }, + Name: serverName, + } + + ln := fasthttputil.NewInmemoryListener() + + serverCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Fatalf("unexpected error: %s", err) + } + close(serverCh) + }() + + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + br := bufio.NewReader(c) + var resp Response + if err = resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if string(resp.Body()) != "OK" { + t.Fatalf("unexpected body: %q. Expecting %q", resp.Body(), "OK") + } + if err = c.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + select { + case <-serverCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + func TestServerResponseBodyStream(t *testing.T) { ln := fasthttputil.NewInmemoryListener()