diff --git a/header.go b/header.go index 19eb914..f86fc60 100644 --- a/header.go +++ b/header.go @@ -46,6 +46,9 @@ type ResponseHeader struct { ContentLength int ConnectionClose bool + contentType []byte + server []byte + h []argsKV bufKV argsKV } @@ -55,6 +58,9 @@ type RequestHeader struct { RequestURI []byte ContentLength int + host []byte + contentType []byte + h []argsKV bufKV argsKV } @@ -76,6 +82,9 @@ func (h *ResponseHeader) Clear() { h.ContentLength = 0 h.ConnectionClose = false + h.contentType = h.contentType[:0] + h.server = h.server[:0] + h.h = h.h[:0] } @@ -84,6 +93,9 @@ func (h *RequestHeader) Clear() { h.RequestURI = h.RequestURI[:0] h.ContentLength = 0 + h.host = h.host[:0] + h.contentType = h.contentType[:0] + h.h = h.h[:0] } @@ -94,6 +106,10 @@ func (h *ResponseHeader) Set(key, value string) { func (h *ResponseHeader) set(key, value []byte) { switch { + case bytes.Equal(strContentType, key): + h.contentType = append(h.contentType[:0], value...) + case bytes.Equal(strServer, key): + h.server = append(h.server[:0], value...) case bytes.Equal(strContentLength, key): // skip Conent-Length setting, since it will be set automatically. case bytes.Equal(strConnection, key): @@ -127,6 +143,10 @@ func (h *RequestHeader) Set(key, value string) { func (h *RequestHeader) set(key, value []byte) { switch { + case bytes.Equal(strHost, key): + h.host = append(h.host[:0], value...) + case bytes.Equal(strContentType, key): + h.contentType = append(h.contentType[:0], value...) case bytes.Equal(strContentLength, key): // Content-Length is managed automatically. case bytes.Equal(strTransferEncoding, key): @@ -154,17 +174,30 @@ func (h *RequestHeader) Peek(key string) []byte { } func (h *ResponseHeader) peek(key []byte) []byte { - if bytes.Equal(strConnection, key) { + switch { + case bytes.Equal(strContentType, key): + return h.contentType + case bytes.Equal(strServer, key): + return h.server + case bytes.Equal(strConnection, key): if h.ConnectionClose { return strClose } return nil + default: + return peekKV(h.h, key) } - return peekKV(h.h, key) } func (h *RequestHeader) peek(key []byte) []byte { - return peekKV(h.h, key) + switch { + case bytes.Equal(strHost, key): + return h.host + case bytes.Equal(strContentType, key): + return h.contentType + default: + return peekKV(h.h, key) + } } func (h *ResponseHeader) Get(key string) string { @@ -293,14 +326,14 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error { } w.Write(statusLine(statusCode)) - server := h.peek(strServer) + server := h.server if len(server) == 0 { server = defaultServerName } writeHeaderLine(w, strServer, server) writeHeaderLine(w, strDate, serverDate.Load().([]byte)) - contentType := h.peek(strContentType) + contentType := h.contentType if len(contentType) == 0 { contentType = defaultContentType } @@ -341,14 +374,14 @@ func (h *RequestHeader) Write(w *bufio.Writer) error { w.Write(strHTTP11) w.Write(strCRLF) - host := h.peek(strHost) + host := h.host if len(host) == 0 { return fmt.Errorf("missing required Host header") } writeHeaderLine(w, strHost, host) if h.IsMethodPost() { - contentType := h.peek(strContentType) + contentType := h.contentType if len(contentType) == 0 { return fmt.Errorf("missing required Content-Type header for POST request") } @@ -384,20 +417,6 @@ func writeContentLength(w *bufio.Writer, contentLength int) { w.Write(strCRLF) } -func mustPeekBuffered(r *bufio.Reader) []byte { - buf, err := r.Peek(r.Buffered()) - if len(buf) == 0 || err != nil { - panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err)) - } - return buf -} - -func mustDiscard(r *bufio.Reader, n int) { - if _, err := r.Discard(n); err != nil { - panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err)) - } -} - func (h *ResponseHeader) parse(buf []byte) (b []byte, err error) { b, err = h.parseFirstLine(buf) if err != nil { @@ -477,6 +496,10 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) { var err error for p.next() { switch { + case bytes.Equal(p.key, strContentType): + h.contentType = append(h.contentType[:0], p.value...) + case bytes.Equal(p.key, strServer): + h.server = append(h.server[:0], p.value...) case bytes.Equal(p.key, strContentLength): if h.ContentLength != -1 { h.ContentLength, err = parseContentLength(p.value) @@ -503,7 +526,7 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) { return nil, p.err } - if len(h.peek(strContentType)) == 0 { + if len(h.contentType) == 0 { return nil, fmt.Errorf("missing required Content-Type header in %q", buf) } if h.ContentLength == -2 { @@ -520,6 +543,10 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) { var err error for p.next() { switch { + case bytes.Equal(p.key, strHost): + h.host = append(h.host[:0], p.value...) + case bytes.Equal(p.key, strContentType): + h.contentType = append(h.contentType[:0], p.value...) case bytes.Equal(p.key, strContentLength): if h.ContentLength != -1 { h.ContentLength, err = parseContentLength(p.value) @@ -542,11 +569,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) { return nil, p.err } - if len(h.peek(strHost)) == 0 { + if len(h.host) == 0 { return nil, fmt.Errorf("missing required Host header in %q", buf) } if h.IsMethodPost() { - if len(h.peek(strContentType)) == 0 { + if len(h.contentType) == 0 { return nil, fmt.Errorf("missing Content-Type for POST header in %q", buf) } if h.ContentLength == -2 { @@ -671,3 +698,17 @@ func isNeedMoreError(err error) bool { _, ok := err.(*errNeedMore) return ok } + +func mustPeekBuffered(r *bufio.Reader) []byte { + buf, err := r.Peek(r.Buffered()) + if len(buf) == 0 || err != nil { + panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err)) + } + return buf +} + +func mustDiscard(r *bufio.Reader, n int) { + if _, err := r.Discard(n); err != nil { + panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err)) + } +}