diff --git a/header.go b/header.go index fb6f005..c14532c 100644 --- a/header.go +++ b/header.go @@ -33,7 +33,8 @@ type ResponseHeader struct { noDefaultDate bool statusCode int - statusLine []byte + statusMessage []byte + protocol []byte contentLength int contentLengthBytes []byte secureErrorLogMessage bool @@ -137,17 +138,27 @@ func (h *ResponseHeader) SetStatusCode(statusCode int) { h.statusCode = statusCode } -// StatusLine returns response status line. -func (h *ResponseHeader) StatusLine() []byte { - if len(h.statusLine) > 0 { - return h.statusLine - } - return statusLine(h.StatusCode()) +// StatusMessage returns response status message. +func (h *ResponseHeader) StatusMessage() []byte { + return h.statusMessage } -// SetStatusLine sets response status line bytes. -func (h *ResponseHeader) SetStatusLine(statusLine []byte) { - h.statusLine = append(h.statusLine[:0], statusLine...) +// SetStatusMessage sets response status message bytes. +func (h *ResponseHeader) SetStatusMessage(statusMessage []byte) { + h.statusMessage = append(h.statusMessage[:0], statusMessage...) +} + +// Protocol returns response protocol bytes. +func (h *ResponseHeader) Protocol() []byte { + if len(h.protocol) > 0 { + return h.protocol + } + return strHTTP11 +} + +// SetProtocol sets response protocol bytes. +func (h *ResponseHeader) SetProtocol(protocol []byte) { + h.protocol = append(h.protocol[:0], protocol...) } // SetLastModified sets 'Last-Modified' header to the given value. @@ -697,7 +708,8 @@ func (h *ResponseHeader) resetSkipNormalize() { h.connectionClose = false h.statusCode = 0 - h.statusLine = h.statusLine[:0] + h.statusMessage = h.statusMessage[:0] + h.protocol = h.protocol[:0] h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] @@ -746,7 +758,8 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { dst.noDefaultDate = h.noDefaultDate dst.statusCode = h.statusCode - dst.statusLine = append(dst.statusLine, h.statusLine...) + dst.statusMessage = append(dst.statusMessage, h.statusMessage...) + dst.protocol = append(dst.protocol, h.protocol...) dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes, h.contentLengthBytes...) dst.contentType = append(dst.contentType, h.contentType...) @@ -1648,19 +1661,20 @@ func (h *ResponseHeader) String() string { return string(h.Header()) } -// AppendBytes appends response header representation to dst and returns +// appendStatusLine appends the response status line to dst and returns // the extended dst. -func (h *ResponseHeader) AppendBytes(dst []byte) []byte { +func (h *ResponseHeader) appendStatusLine(dst []byte) []byte { statusCode := h.StatusCode() if statusCode < 0 { statusCode = StatusOK } + return formatStatusLine(dst, h.Protocol(), statusCode, h.StatusMessage()) +} - if len(h.statusLine) > 0 { - dst = append(dst, h.statusLine...) - } else { - dst = append(dst, statusLine(statusCode)...) - } +// AppendBytes appends response header representation to dst and returns +// the extended dst. +func (h *ResponseHeader) AppendBytes(dst []byte) []byte { + dst = h.appendStatusLine(dst[:0]) server := h.Server() if len(server) != 0 { @@ -1880,8 +1894,8 @@ func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) { } return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf) } - if len(b) > n+1 && !bytes.Equal(b[n+1:], statusLine(h.statusCode)) { - h.SetStatusLine(b[n+1:]) + if len(b) > n+1 { + h.SetStatusMessage(b[n+1:]) } return len(buf) - len(bNext), nil diff --git a/header_test.go b/header_test.go index 19138ad..27a292c 100644 --- a/header_test.go +++ b/header_test.go @@ -52,8 +52,29 @@ func TestResponseHeaderMultiLineValue(t *testing.T) { t.Fatalf("parse response using net/http failed, %s", err) } - if !bytes.Equal(header.StatusLine(), []byte("SuperOK")) { - t.Errorf("parse status line with non-default value failed, got: %s want: SuperOK", header.StatusLine()) + if !bytes.Equal(header.StatusMessage(), []byte("SuperOK")) { + t.Errorf("parse status line with non-default value failed, got: '%s' want: 'SuperOK'", header.StatusMessage()) + } + + header.SetProtocol([]byte("HTTP/3.3")) + if !bytes.Equal(header.Protocol(), []byte("HTTP/3.3")) { + t.Errorf("parse protocol with non-default value failed, got: '%s' want: 'HTTP/3.3'", header.Protocol()) + } + + if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/3.3 200 SuperOK\r\n")) { + t.Errorf("parse status line with non-default value failed, got: '%s' want: 'HTTP/3.3 200 SuperOK'", header.Protocol()) + } + + header.SetStatusMessage(nil) + + if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/3.3 200 OK\r\n")) { + t.Errorf("parse status line with default protocol value failed, got: '%s' want: 'HTTP/3.3 200 OK'", header.appendStatusLine(nil)) + } + + header.SetStatusMessage(s2b(StatusMessage(200))) + + if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/3.3 200 OK\r\n")) { + t.Errorf("parse status line with default protocol value failed, got: '%s' want: 'HTTP/3.3 200 OK'", header.appendStatusLine(nil)) } for name, vals := range response.Header { @@ -83,8 +104,16 @@ func TestResponseHeaderMultiLineName(t *testing.T) { t.Errorf("expected error, got %q (%v)", m, err) } - if !bytes.Equal(header.StatusLine(), []byte("OK")) { - t.Errorf("expected default status line, got: %s", header.StatusLine()) + if !bytes.Equal(header.StatusMessage(), []byte("OK")) { + t.Errorf("expected default status line, got: %s", header.StatusMessage()) + } + + if !bytes.Equal(header.Protocol(), []byte("HTTP/1.1")) { + t.Errorf("expected default protocol, got: %s", header.Protocol()) + } + + if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/1.1 200 OK\r\n")) { + t.Errorf("parse status line with non-default value failed, got: %s want: HTTP/1.1 200 OK", header.Protocol()) } } diff --git a/http_test.go b/http_test.go index 8d57bb0..00ce8a4 100644 --- a/http_test.go +++ b/http_test.go @@ -837,9 +837,9 @@ func TestResponseSkipBody(t *testing.T) { t.Fatalf("unexpected content-type in response %q", s) } - // set StatusNoContent with statusLine + // set StatusNoContent with statusMessage r.Header.SetStatusCode(StatusNoContent) - r.Header.SetStatusLine([]byte("HTTP/1.1 204 NC\r\n")) + r.Header.SetStatusMessage([]byte("NC")) r.SetBodyString("foobar") s = r.String() if strings.Contains(s, "\r\n\r\nfoobar") { diff --git a/server.go b/server.go index cb4c823..bc4a735 100644 --- a/server.go +++ b/server.go @@ -2741,7 +2741,7 @@ func (s *Server) getServerName() []byte { } func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) { - w.Write(statusLine(statusCode)) //nolint:errcheck + w.Write(formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode)))) //nolint:errcheck server := "" if !s.NoDefaultServerHeader { diff --git a/status.go b/status.go index 28d1286..c88ba11 100644 --- a/status.go +++ b/status.go @@ -1,7 +1,6 @@ package fasthttp import ( - "fmt" "strconv" ) @@ -81,7 +80,7 @@ const ( ) var ( - statusLines = make([][]byte, statusMessageMax+1) + unknownStatusCode = "Unknown Status Code" statusMessages = []string{ StatusContinue: "Continue", @@ -155,39 +154,24 @@ var ( // StatusMessage returns HTTP status message for the given status code. func StatusMessage(statusCode int) string { if statusCode < statusMessageMin || statusCode > statusMessageMax { - return "Unknown Status Code" + return unknownStatusCode } - s := statusMessages[statusCode] - if s == "" { - s = "Unknown Status Code" + if s := statusMessages[statusCode]; s != "" { + return s } - return s + return unknownStatusCode } -func init() { - // Fill all valid status lines - for i := 0; i < len(statusLines); i++ { - statusLines[i] = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", i, StatusMessage(i))) +func formatStatusLine(dst []byte, protocol []byte, statusCode int, statusText []byte) []byte { + dst = append(dst, protocol...) + dst = append(dst, ' ') + dst = strconv.AppendInt(dst, int64(statusCode), 10) + dst = append(dst, ' ') + if len(statusText) == 0 { + dst = append(dst, s2b(StatusMessage(statusCode))...) + } else { + dst = append(dst, statusText...) } -} - -func statusLine(statusCode int) []byte { - if statusCode < 0 || statusCode > statusMessageMax { - return invalidStatusLine(statusCode) - } - - return statusLines[statusCode] -} - -func invalidStatusLine(statusCode int) []byte { - statusText := StatusMessage(statusCode) - // xxx placeholder of status code - var line = make([]byte, 0, len("HTTP/1.1 xxx \r\n")+len(statusText)) - line = append(line, "HTTP/1.1 "...) - line = strconv.AppendInt(line, int64(statusCode), 10) - line = append(line, ' ') - line = append(line, statusText...) - line = append(line, "\r\n"...) - return line + return append(dst, strCRLF...) } diff --git a/status_test.go b/status_test.go index e3bfa8a..512c884 100644 --- a/status_test.go +++ b/status_test.go @@ -17,7 +17,7 @@ func TestStatusLine(t *testing.T) { } func testStatusLine(t *testing.T, statusCode int, expected []byte) { - line := statusLine(statusCode) + line := formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode))) if !bytes.Equal(expected, line) { t.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected)) } diff --git a/status_timing_test.go b/status_timing_test.go index 42b1c91..d2ec34a 100644 --- a/status_timing_test.go +++ b/status_timing_test.go @@ -20,7 +20,7 @@ func BenchmarkStatusLine512(b *testing.B) { func benchmarkStatusLine(b *testing.B, statusCode int, expected []byte) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - line := statusLine(statusCode) + line := formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode))) if !bytes.Equal(expected, line) { b.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected)) }