diff --git a/header.go b/header.go index c2ab760..d155403 100644 --- a/header.go +++ b/header.go @@ -43,9 +43,7 @@ var ( type ResponseHeader struct { StatusCode int - ContentType []byte ContentLength int - Server []byte ConnectionClose bool h []argsKV @@ -55,10 +53,6 @@ type ResponseHeader struct { type RequestHeader struct { Method []byte RequestURI []byte - Host []byte - UserAgent []byte - Referer []byte - ContentType []byte ContentLength int h []argsKV @@ -80,8 +74,6 @@ func (h *RequestHeader) IsMethodHead() bool { func (h *ResponseHeader) Clear() { h.StatusCode = 0 h.ContentLength = 0 - h.ContentType = h.ContentType[:0] - h.Server = h.Server[:0] h.ConnectionClose = false h.h = h.h[:0] @@ -90,107 +82,79 @@ func (h *ResponseHeader) Clear() { func (h *RequestHeader) Clear() { h.Method = h.Method[:0] h.RequestURI = h.RequestURI[:0] - h.Host = h.Host[:0] - h.UserAgent = h.UserAgent[:0] - h.Referer = h.Referer[:0] - h.ContentType = h.ContentType[:0] h.ContentLength = 0 h.h = h.h[:0] } func (h *ResponseHeader) Set(key, value string) { - k := getKeyBytes(&h.bufKV, key) + initHeaderKV(&h.bufKV, key, value) + h.set(h.bufKV.key, h.bufKV.value) +} +func (h *ResponseHeader) set(key, value []byte) { switch { - case bytes.Equal(strContentLength, k): + case bytes.Equal(strContentLength, key): // skip Conent-Length setting, since it will be set automatically. - return - case bytes.Equal(strContentType, k): - h.ContentType = AppendBytesStr(h.ContentType[:0], value) - return - case bytes.Equal(strServer, k): - h.Server = AppendBytesStr(h.Server[:0], value) - return - case bytes.Equal(strConnection, k): - if EqualBytesStr(strClose, value) { + case bytes.Equal(strConnection, key): + if bytes.Equal(strClose, value) { h.ConnectionClose = true } // skip other 'Connection' shit :) - return - case bytes.Equal(strTransferEncoding, k): + case bytes.Equal(strTransferEncoding, key): // Transfer-Encoding is managed automatically. - return + case bytes.Equal(strDate, key): + // Date is managed automatically. + default: + h.h = setKV(h.h, key, value) } +} +func (h *ResponseHeader) setStr(key []byte, value string) { h.bufKV.value = AppendBytesStr(h.bufKV.value[:0], value) - h.h = setKV(h.h, k, h.bufKV.value) + h.set(key, h.bufKV.value) } func (h *RequestHeader) Set(key, value string) { - k := getKeyBytes(&h.bufKV, key) + initHeaderKV(&h.bufKV, key, value) + h.set(h.bufKV.key, h.bufKV.value) +} +func (h *RequestHeader) set(key, value []byte) { switch { - case bytes.Equal(strHost, k): - h.Host = AppendBytesStr(h.Host[:0], value) - return - case bytes.Equal(strUserAgent, k): - h.UserAgent = AppendBytesStr(h.UserAgent[:0], value) - return - case bytes.Equal(strReferer, k): - h.Referer = AppendBytesStr(h.Referer[:0], value) - return - case bytes.Equal(strContentType, k): - h.ContentType = AppendBytesStr(h.ContentType[:0], value) - return - case bytes.Equal(strContentLength, k): + case bytes.Equal(strContentLength, key): // Content-Length is managed automatically. - return - case bytes.Equal(strTransferEncoding, k): + case bytes.Equal(strTransferEncoding, key): // Transfer-Encoding is managed automatically. - return - case bytes.Equal(strConnection, k): + case bytes.Equal(strConnection, key): // Connection is managed automatically. - return + default: + h.h = setKV(h.h, key, value) } - - h.bufKV.value = AppendBytesStr(h.bufKV.value[:0], value) - h.h = setKV(h.h, k, h.bufKV.value) } func (h *ResponseHeader) Peek(key string) []byte { - k := getKeyBytes(&h.bufKV, key) + k := getHeaderKeyBytes(&h.bufKV, key) + return h.peek(k) +} - switch { - case bytes.Equal(strContentType, k): - return h.ContentType - case bytes.Equal(strServer, k): - return h.Server - case bytes.Equal(strConnection, k): +func (h *RequestHeader) Peek(key string) []byte { + k := getHeaderKeyBytes(&h.bufKV, key) + return h.peek(k) +} + +func (h *ResponseHeader) peek(key []byte) []byte { + if bytes.Equal(strConnection, key) { if h.ConnectionClose { return strClose } return nil } - - return peekKV(h.h, k) + return peekKV(h.h, key) } -func (h *RequestHeader) Peek(key string) []byte { - k := getKeyBytes(&h.bufKV, key) - - switch { - case bytes.Equal(strHost, k): - return h.Host - case bytes.Equal(strUserAgent, k): - return h.UserAgent - case bytes.Equal(strReferer, k): - return h.Referer - case bytes.Equal(strContentType, k): - return h.ContentType - } - - return peekKV(h.h, k) +func (h *RequestHeader) peek(key []byte) []byte { + return peekKV(h.h, key) } func (h *ResponseHeader) Get(key string) string { @@ -201,12 +165,6 @@ func (h *RequestHeader) Get(key string) string { return string(h.Peek(key)) } -func getKeyBytes(kv *argsKV, key string) []byte { - kv.key = AppendBytesStr(kv.key[:0], key) - normalizeHeaderKey(kv.key) - return kv.key -} - func (h *ResponseHeader) Read(r *bufio.Reader) error { n := 1 for { @@ -325,14 +283,14 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error { } w.Write(statusLine(statusCode)) - server := h.Server + server := h.peek(strServer) if len(server) == 0 { server = defaultServerName } writeHeaderLine(w, strServer, server) writeHeaderLine(w, strDate, serverDate.Load().([]byte)) - contentType := h.ContentType + contentType := h.peek(strContentType) if len(contentType) == 0 { contentType = defaultContentType } @@ -349,7 +307,9 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error { for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] - writeHeaderLine(w, kv.key, kv.value) + if !bytes.Equal(strServer, kv.key) && !bytes.Equal(strContentType, kv.key) { + writeHeaderLine(w, kv.key, kv.value) + } } _, err := w.Write(strCRLF) @@ -404,23 +364,18 @@ func (h *RequestHeader) Write(w *bufio.Writer) error { w.Write(strHTTP11) w.Write(strCRLF) - if len(h.UserAgent) > 0 { - writeHeaderLine(w, strUserAgent, h.UserAgent) - } - if len(h.Referer) > 0 { - writeHeaderLine(w, strReferer, h.Referer) - } - - if len(h.Host) == 0 { + host := h.peek(strHost) + if len(host) == 0 { return fmt.Errorf("missing required Host header") } - writeHeaderLine(w, strHost, h.Host) + writeHeaderLine(w, strHost, host) if h.IsMethodPost() { - if len(h.ContentType) == 0 { + contentType := h.peek(strContentType) + if len(contentType) == 0 { return fmt.Errorf("missing required Content-Type header for POST request") } - writeHeaderLine(w, strContentType, h.ContentType) + writeHeaderLine(w, strContentType, contentType) if h.ContentLength < 0 { return fmt.Errorf("missing required Content-Length header for POST request") } @@ -429,7 +384,9 @@ func (h *RequestHeader) Write(w *bufio.Writer) error { for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] - writeHeaderLine(w, kv.key, kv.value) + if !bytes.Equal(strHost, kv.key) && !bytes.Equal(strContentType, kv.key) { + writeHeaderLine(w, kv.key, kv.value) + } } _, err := w.Write(strCRLF) @@ -542,39 +499,34 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) { p.init(buf) var err error for p.next() { - if bytes.Equal(p.key, strContentType) { - h.ContentType = append(h.ContentType[:0], p.value...) - continue - } - if bytes.Equal(p.key, strContentLength) && h.ContentLength != -1 { - h.ContentLength, err = parseContentLength(p.value) - if err != nil { - if isNeedMoreError(err) { - return nil, err + switch { + case bytes.Equal(p.key, strContentLength): + if h.ContentLength != -1 { + h.ContentLength, err = parseContentLength(p.value) + if err != nil { + if isNeedMoreError(err) { + return nil, err + } + return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf) } - return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf) } - continue + case bytes.Equal(p.key, strTransferEncoding): + if bytes.Equal(p.value, strChunked) { + h.ContentLength = -1 + } + case bytes.Equal(p.key, strConnection): + if bytes.Equal(p.value, strClose) { + h.ConnectionClose = true + } + default: + h.h = setKV(h.h, p.key, p.value) } - if bytes.Equal(p.key, strTransferEncoding) && bytes.Equal(p.value, strChunked) { - h.ContentLength = -1 - continue - } - if bytes.Equal(p.key, strServer) { - h.Server = append(h.Server[:0], p.value...) - continue - } - if bytes.Equal(p.key, strConnection) && bytes.Equal(p.value, strClose) { - h.ConnectionClose = true - continue - } - h.h = setKV(h.h, p.key, p.value) } if p.err != nil { return nil, p.err } - if len(h.ContentType) == 0 { + if len(h.peek(strContentType)) == 0 { return nil, fmt.Errorf("missing required Content-Type header in %q", buf) } if h.ContentLength == -2 { @@ -590,47 +542,34 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) { p.init(buf) var err error for p.next() { - if bytes.Equal(p.key, strHost) { - h.Host = append(h.Host[:0], p.value...) - continue - } - if bytes.Equal(p.key, strUserAgent) { - h.UserAgent = append(h.UserAgent[:0], p.value...) - continue - } - if bytes.Equal(p.key, strReferer) { - h.Referer = append(h.Referer[:0], p.value...) - continue - } - if bytes.Equal(p.key, strContentType) { - h.ContentType = append(h.ContentType[:0], p.value...) - continue - } - if bytes.Equal(p.key, strContentLength) && h.ContentLength != -1 { - h.ContentLength, err = parseContentLength(p.value) - if err != nil { - if isNeedMoreError(err) { - return nil, err + switch { + case bytes.Equal(p.key, strContentLength): + if h.ContentLength != -1 { + h.ContentLength, err = parseContentLength(p.value) + if err != nil { + if isNeedMoreError(err) { + return nil, err + } + return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf) } - return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf) } - continue + case bytes.Equal(p.key, strTransferEncoding): + if bytes.Equal(p.value, strChunked) { + h.ContentLength = -1 + } + default: + h.h = setKV(h.h, p.key, p.value) } - if bytes.Equal(p.key, strTransferEncoding) && bytes.Equal(p.value, strChunked) { - h.ContentLength = -1 - continue - } - h.h = setKV(h.h, p.key, p.value) } if p.err != nil { return nil, p.err } - if len(h.Host) == 0 { + if len(h.peek(strHost)) == 0 { return nil, fmt.Errorf("missing required Host header in %q", buf) } if h.IsMethodPost() { - if len(h.ContentType) == 0 { + if len(h.peek(strContentType)) == 0 { return nil, fmt.Errorf("missing Content-Type for POST header in %q", buf) } if h.ContentLength == -2 { @@ -708,6 +647,17 @@ func nextLine(b []byte) ([]byte, []byte, error) { return b[:n], b[nNext+1:], nil } +func initHeaderKV(kv *argsKV, key, value string) { + kv.key = getHeaderKeyBytes(kv, key) + kv.value = AppendBytesStr(kv.value[:0], value) +} + +func getHeaderKeyBytes(kv *argsKV, key string) []byte { + kv.key = AppendBytesStr(kv.key[:0], key) + normalizeHeaderKey(kv.key) + return kv.key +} + func normalizeHeaderKey(b []byte) { n := len(b) up := true diff --git a/header_test.go b/header_test.go index 5f0fa05..c8d92a3 100644 --- a/header_test.go +++ b/header_test.go @@ -33,18 +33,6 @@ func TestRequestHeaderSetGet(t *testing.T) { expectRequestHeaderGet(t, h, "baz", "xxxxx") expectRequestHeaderGet(t, h, "Transfer-Encoding", "") - if !bytes.Equal(h.Host, []byte("12345")) { - t.Fatalf("Unexpected host %q. Expected %q", h.Host, "12345") - } - if !bytes.Equal(h.ContentType, []byte("aaa/bbb")) { - t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, "aaa/bbb") - } - if !bytes.Equal(h.UserAgent, []byte("aaabbb")) { - t.Fatalf("Unepxected Server %q. Expected %q", h.UserAgent, "aaabbb") - } - if !bytes.Equal(h.Referer, []byte("axcv")) { - t.Fatalf("Unexpected referer %q. Expected %q", h.Referer, "axcv") - } if h.ContentLength != 0 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0) } @@ -65,18 +53,6 @@ func TestRequestHeaderSetGet(t *testing.T) { t.Fatalf("Unexpected error when reading request header: %s", err) } - if !bytes.Equal(h1.Host, h.Host) { - t.Fatalf("Unexpected host %q. Expected %q", h1.Host, h.Host) - } - if !bytes.Equal(h1.ContentType, h.ContentType) { - t.Fatalf("Unexpected content-type %q. Expected %q", h1.ContentType, h.ContentType) - } - if !bytes.Equal(h1.UserAgent, h.UserAgent) { - t.Fatalf("Unexpected user-agent %q. Expected %q", h1.UserAgent, h.UserAgent) - } - if !bytes.Equal(h1.Referer, h.Referer) { - t.Fatalf("Unepxected referer %q. Expected %q", h1.Referer, h.Referer) - } if h1.ContentLength != h.ContentLength { t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength, h.ContentLength) } @@ -109,12 +85,6 @@ func TestResponseHeaderSetGet(t *testing.T) { expectResponseHeaderGet(t, h, "baz", "xxxxx") expectResponseHeaderGet(t, h, "Transfer-Encoding", "") - if !bytes.Equal(h.ContentType, []byte("aaa/bbb")) { - t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, "aaa/bbb") - } - if !bytes.Equal(h.Server, []byte("aaaa")) { - t.Fatalf("Unepxected Server %q. Expected %q", h.Server, "aaaa") - } if h.ContentLength != 0 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0) } @@ -138,12 +108,6 @@ func TestResponseHeaderSetGet(t *testing.T) { t.Fatalf("Unexpected error when reading response header: %s", err) } - if !bytes.Equal(h1.ContentType, h.ContentType) { - t.Fatalf("Unexpected content-type %q. Expected %q", h1.ContentType, h.ContentType) - } - if !bytes.Equal(h1.Server, h.Server) { - t.Fatalf("Unepxected Server %q. Expected %q", h1.Server, h.Server) - } if h1.ContentLength != h.ContentLength { t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength, h.ContentLength) } @@ -559,8 +523,8 @@ func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, e if h.ContentLength != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength, expectedContentLength) } - if !bytes.Equal(h.ContentType, []byte(expectedContentType)) { - t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType, expectedContentType) + if h.Get("Content-Type") != expectedContentType { + t.Fatalf("Unexpected content type %q. Expected %q", h.Get("Content-Type"), expectedContentType) } } @@ -572,14 +536,14 @@ func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength i if !bytes.Equal(h.RequestURI, []byte(expectedRequestURI)) { t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI, expectedRequestURI) } - if !bytes.Equal(h.Host, []byte(expectedHost)) { - t.Fatalf("Unexpected host %q. Expected %q", h.Host, expectedHost) + if h.Get("Host") != expectedHost { + t.Fatalf("Unexpected host %q. Expected %q", h.Get("Host"), expectedHost) } - if !bytes.Equal(h.Referer, []byte(expectedReferer)) { - t.Fatalf("Unexpected referer %q. Expected %q", h.Referer, expectedReferer) + if h.Get("Referer") != expectedReferer { + t.Fatalf("Unexpected referer %q. Expected %q", h.Get("Referer"), expectedReferer) } - if !bytes.Equal(h.ContentType, []byte(expectedContentType)) { - t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, expectedContentType) + if h.Get("Content-Type") != expectedContentType { + t.Fatalf("Unexpected content-type %q. Expected %q", h.Get("Content-Type"), expectedContentType) } } diff --git a/http.go b/http.go index 610acde..08da00f 100644 --- a/http.go +++ b/http.go @@ -33,7 +33,7 @@ func (req *Request) ParseURI() { if req.parsedURI { return } - req.URI.Parse(req.Header.Host, req.Header.RequestURI) + req.URI.Parse(req.Header.peek(strHost), req.Header.RequestURI) req.parsedURI = true } @@ -45,9 +45,9 @@ func (req *Request) ParsePostArgs() error { if !req.Header.IsMethodPost() { return fmt.Errorf("Cannot parse POST args for %q request", req.Header.Method) } - if !bytes.Equal(req.Header.ContentType, strPostArgsContentType) { + if !bytes.Equal(req.Header.peek(strContentType), strPostArgsContentType) { return fmt.Errorf("Cannot parse POST args for %q Content-Type. Required %q Content-Type", - req.Header.ContentType, strPostArgsContentType) + req.Header.peek(strContentType), strPostArgsContentType) } req.PostArgs.ParseBytes(req.Body) req.parsedPostArgs = true diff --git a/http_test.go b/http_test.go index 684bf26..5bbb5c5 100644 --- a/http_test.go +++ b/http_test.go @@ -102,8 +102,8 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, expectedStatusCode int, expectedContentType, expectedServerName string) { var resp Response resp.Header.StatusCode = statusCode - resp.Header.ContentType = []byte(contentType) - resp.Header.Server = []byte(serverName) + resp.Header.Set("Content-Type", contentType) + resp.Header.Set("Server", serverName) resp.Body = []byte(body) w := &bytes.Buffer{} @@ -127,11 +127,11 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, if resp1.Header.ContentLength != len(body) { t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength, len(body)) } - if !bytes.Equal(resp1.Header.ContentType, []byte(expectedContentType)) { - t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.ContentType, expectedContentType) + if resp1.Header.Get("Content-Type") != expectedContentType { + t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Get("Content-Type"), expectedContentType) } - if !bytes.Equal(resp1.Header.Server, []byte(expectedServerName)) { - t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Server, expectedServerName) + if resp1.Header.Get("Server") != expectedServerName { + t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Get("Server"), expectedServerName) } if !bytes.Equal(resp1.Body, []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body, body) @@ -167,8 +167,8 @@ func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, bo req.Header.Method = []byte(method) req.Header.RequestURI = []byte(requestURI) - req.Header.Host = []byte(host) - req.Header.UserAgent = []byte(userAgent) + req.Header.Set("Host", host) + req.Header.Set("User-Agent", userAgent) req.Body = []byte(body) w := &bytes.Buffer{} @@ -184,13 +184,13 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body, req.Header.Method = []byte(method) req.Header.RequestURI = []byte(requestURI) - req.Header.Host = []byte(host) - req.Header.UserAgent = []byte(userAgent) + req.Header.Set("Host", host) + req.Header.Set("User-Agent", userAgent) req.Body = []byte(body) - contentType := []byte("foobar") + contentType := "foobar" if method == "POST" { - req.Header.ContentType = contentType + req.Header.Set("Content-Type", contentType) } w := &bytes.Buffer{} @@ -214,18 +214,18 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body, if !bytes.Equal(req1.Header.RequestURI, []byte(requestURI)) { t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI, requestURI) } - if !bytes.Equal(req1.Header.Host, []byte(host)) { - t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Host, host) + if req1.Header.Get("Host") != host { + t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Get("Host"), host) } - if !bytes.Equal(req1.Header.UserAgent, []byte(userAgent)) { - t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.UserAgent, userAgent) + if req1.Header.Get("User-Agent") != userAgent { + t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Get("User-Agent"), userAgent) } if !bytes.Equal(req1.Body, []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", req1.Body, body) } - if method == "POST" && !bytes.Equal(req1.Header.ContentType, contentType) { - t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.ContentType, contentType) + if method == "POST" && req1.Header.Get("Content-Type") != contentType { + t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Get("Content-Type"), contentType) } } @@ -345,7 +345,7 @@ func TestRequestParseURI(t *testing.T) { expectedHash := "1334dfds&=d" var req Request - req.Header.Host = []byte(host) + req.Header.Set("Host", host) req.Header.RequestURI = []byte(requestURI) req.ParseURI() diff --git a/server.go b/server.go index 4189748..9173515 100644 --- a/server.go +++ b/server.go @@ -39,13 +39,13 @@ type Server struct { type RequestHandler func(ctx *ServerCtx) type ServerCtx struct { - Request Request - Response Response + Request Request // Unique id of the context. // Used by ServerCtx.Logger(). ID uint64 + resp Response logger ctxLogger s *Server c remoteAddrer @@ -97,32 +97,24 @@ func (ctx *ServerCtx) RemoteIP() string { } func (ctx *ServerCtx) Error(msg string, statusCode int) { - resp := ctx.zeroResponse() + resp := ctx.Response() + resp.Clear() resp.Header.StatusCode = statusCode - resp.Header.ContentType = append(resp.Header.ContentType, defaultContentType...) + resp.Header.set(strContentType, defaultContentType) resp.Body = append(resp.Body, []byte(msg)...) } func (ctx *ServerCtx) Success(contentType string, body []byte) { - resp := ctx.zeroResponse() - resp.Header.ContentType = appendString(resp.Header.ContentType, contentType) + resp := ctx.Response() + resp.Header.setStr(strContentType, contentType) resp.Body = append(resp.Body, body...) } -func appendString(b []byte, s string) []byte { - for i, n := 0, len(s); i < n; i++ { - b = append(b, s[i]) - } - return b -} - -func (ctx *ServerCtx) zeroResponse() *Response { +func (ctx *ServerCtx) Response() *Response { if ctx.shadow != nil { ctx = ctx.shadow } - resp := &ctx.Response - resp.Clear() - return resp + return &ctx.resp } func (ctx *ServerCtx) Logger() Logger { @@ -136,7 +128,7 @@ func (ctx *ServerCtx) Steal() { shadow := *ctx shadow.Request = Request{} - shadow.Response = Response{} + shadow.resp = Response{} shadow.logger.ctx = &shadow shadow.v = &shadow ctx.shadow = &shadow @@ -144,17 +136,16 @@ func (ctx *ServerCtx) Steal() { func (ctx *ServerCtx) writeResponse() error { if ctx.shadow != nil { - panic("BUG: ServerCtx.writeResponse() shouldn't be called on shadow") + panic("BUG: ctx.shadow is not null") } - resp := &ctx.Response - h := &resp.Header - serverOld := h.Server + h := &ctx.resp.Header + serverOld := h.peek(strServer) if len(serverOld) == 0 { - h.Server = ctx.s.getServerName() + h.set(strServer, ctx.s.getServerName()) } - err := resp.Write(ctx.w) + err := ctx.resp.Write(ctx.w) if len(serverOld) == 0 { - h.Server = serverOld + h.set(strServer, serverOld) } return err } @@ -305,9 +296,9 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **ServerCtx) error { if err = ctx.writeResponse(); err != nil { break } - connectionClose := ctx.Response.Header.ConnectionClose + connectionClose := ctx.resp.Header.ConnectionClose - ctx.Response.Clear() + ctx.resp.Clear() trimBigBuffers(ctx) if ctx.r.Buffered() == 0 || connectionClose { @@ -329,8 +320,8 @@ func trimBigBuffers(ctx *ServerCtx) { if cap(ctx.Request.Body) > bigBufferLimit { ctx.Request.Body = nil } - if cap(ctx.Response.Body) > bigBufferLimit { - ctx.Response.Body = nil + if cap(ctx.resp.Body) > bigBufferLimit { + ctx.resp.Body = nil } } diff --git a/server_test.go b/server_test.go index 970daf8..bf8ccfe 100644 --- a/server_test.go +++ b/server_test.go @@ -51,7 +51,7 @@ func TestServerSteal(t *testing.T) { func TestServerConnectionClose(t *testing.T) { s := &Server{ Handler: func(ctx *ServerCtx) { - ctx.Response.Header.ConnectionClose = true + ctx.Response().Header.ConnectionClose = true }, } @@ -266,8 +266,8 @@ func TestServerConnError(t *testing.T) { if resp.Header.ContentLength != 6 { t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength, 6) } - if !bytes.Equal(resp.Header.ContentType, defaultContentType) { - t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.ContentType, defaultContentType) + if resp.Header.Get("Content-Type") != string(defaultContentType) { + t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Get("Content-Type"), defaultContentType) } if !bytes.Equal(resp.Body, []byte("foobar")) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body, "foobar") @@ -278,7 +278,7 @@ func TestServeConnSingleRequest(t *testing.T) { s := &Server{ Handler: func(ctx *ServerCtx) { h := &ctx.Request.Header - ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Host))) + ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Get("Host")))) }, } @@ -307,7 +307,7 @@ func TestServeConnMultiRequests(t *testing.T) { s := &Server{ Handler: func(ctx *ServerCtx) { h := &ctx.Request.Header - ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Host))) + ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Get("Host")))) }, }