Uniformly process all headers

This commit is contained in:
Aliaksandr Valialkin
2015-10-20 12:36:33 +03:00
parent 731dfe6539
commit ee62382f34
6 changed files with 154 additions and 249 deletions
+99 -149
View File
@@ -43,9 +43,7 @@ var (
type ResponseHeader struct {
StatusCode int
ContentType []byte
ContentLength int
Server []byte
ConnectionClose bool
h []argsKV
@@ -55,10 +53,6 @@ type ResponseHeader struct {
type RequestHeader struct {
Method []byte
RequestURI []byte
Host []byte
UserAgent []byte
Referer []byte
ContentType []byte
ContentLength int
h []argsKV
@@ -80,8 +74,6 @@ func (h *RequestHeader) IsMethodHead() bool {
func (h *ResponseHeader) Clear() {
h.StatusCode = 0
h.ContentLength = 0
h.ContentType = h.ContentType[:0]
h.Server = h.Server[:0]
h.ConnectionClose = false
h.h = h.h[:0]
@@ -90,107 +82,79 @@ func (h *ResponseHeader) Clear() {
func (h *RequestHeader) Clear() {
h.Method = h.Method[:0]
h.RequestURI = h.RequestURI[:0]
h.Host = h.Host[:0]
h.UserAgent = h.UserAgent[:0]
h.Referer = h.Referer[:0]
h.ContentType = h.ContentType[:0]
h.ContentLength = 0
h.h = h.h[:0]
}
func (h *ResponseHeader) Set(key, value string) {
k := getKeyBytes(&h.bufKV, key)
initHeaderKV(&h.bufKV, key, value)
h.set(h.bufKV.key, h.bufKV.value)
}
func (h *ResponseHeader) set(key, value []byte) {
switch {
case bytes.Equal(strContentLength, k):
case bytes.Equal(strContentLength, key):
// skip Conent-Length setting, since it will be set automatically.
return
case bytes.Equal(strContentType, k):
h.ContentType = AppendBytesStr(h.ContentType[:0], value)
return
case bytes.Equal(strServer, k):
h.Server = AppendBytesStr(h.Server[:0], value)
return
case bytes.Equal(strConnection, k):
if EqualBytesStr(strClose, value) {
case bytes.Equal(strConnection, key):
if bytes.Equal(strClose, value) {
h.ConnectionClose = true
}
// skip other 'Connection' shit :)
return
case bytes.Equal(strTransferEncoding, k):
case bytes.Equal(strTransferEncoding, key):
// Transfer-Encoding is managed automatically.
return
case bytes.Equal(strDate, key):
// Date is managed automatically.
default:
h.h = setKV(h.h, key, value)
}
}
func (h *ResponseHeader) setStr(key []byte, value string) {
h.bufKV.value = AppendBytesStr(h.bufKV.value[:0], value)
h.h = setKV(h.h, k, h.bufKV.value)
h.set(key, h.bufKV.value)
}
func (h *RequestHeader) Set(key, value string) {
k := getKeyBytes(&h.bufKV, key)
initHeaderKV(&h.bufKV, key, value)
h.set(h.bufKV.key, h.bufKV.value)
}
func (h *RequestHeader) set(key, value []byte) {
switch {
case bytes.Equal(strHost, k):
h.Host = AppendBytesStr(h.Host[:0], value)
return
case bytes.Equal(strUserAgent, k):
h.UserAgent = AppendBytesStr(h.UserAgent[:0], value)
return
case bytes.Equal(strReferer, k):
h.Referer = AppendBytesStr(h.Referer[:0], value)
return
case bytes.Equal(strContentType, k):
h.ContentType = AppendBytesStr(h.ContentType[:0], value)
return
case bytes.Equal(strContentLength, k):
case bytes.Equal(strContentLength, key):
// Content-Length is managed automatically.
return
case bytes.Equal(strTransferEncoding, k):
case bytes.Equal(strTransferEncoding, key):
// Transfer-Encoding is managed automatically.
return
case bytes.Equal(strConnection, k):
case bytes.Equal(strConnection, key):
// Connection is managed automatically.
return
default:
h.h = setKV(h.h, key, value)
}
h.bufKV.value = AppendBytesStr(h.bufKV.value[:0], value)
h.h = setKV(h.h, k, h.bufKV.value)
}
func (h *ResponseHeader) Peek(key string) []byte {
k := getKeyBytes(&h.bufKV, key)
k := getHeaderKeyBytes(&h.bufKV, key)
return h.peek(k)
}
switch {
case bytes.Equal(strContentType, k):
return h.ContentType
case bytes.Equal(strServer, k):
return h.Server
case bytes.Equal(strConnection, k):
func (h *RequestHeader) Peek(key string) []byte {
k := getHeaderKeyBytes(&h.bufKV, key)
return h.peek(k)
}
func (h *ResponseHeader) peek(key []byte) []byte {
if bytes.Equal(strConnection, key) {
if h.ConnectionClose {
return strClose
}
return nil
}
return peekKV(h.h, k)
return peekKV(h.h, key)
}
func (h *RequestHeader) Peek(key string) []byte {
k := getKeyBytes(&h.bufKV, key)
switch {
case bytes.Equal(strHost, k):
return h.Host
case bytes.Equal(strUserAgent, k):
return h.UserAgent
case bytes.Equal(strReferer, k):
return h.Referer
case bytes.Equal(strContentType, k):
return h.ContentType
}
return peekKV(h.h, k)
func (h *RequestHeader) peek(key []byte) []byte {
return peekKV(h.h, key)
}
func (h *ResponseHeader) Get(key string) string {
@@ -201,12 +165,6 @@ func (h *RequestHeader) Get(key string) string {
return string(h.Peek(key))
}
func getKeyBytes(kv *argsKV, key string) []byte {
kv.key = AppendBytesStr(kv.key[:0], key)
normalizeHeaderKey(kv.key)
return kv.key
}
func (h *ResponseHeader) Read(r *bufio.Reader) error {
n := 1
for {
@@ -325,14 +283,14 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error {
}
w.Write(statusLine(statusCode))
server := h.Server
server := h.peek(strServer)
if len(server) == 0 {
server = defaultServerName
}
writeHeaderLine(w, strServer, server)
writeHeaderLine(w, strDate, serverDate.Load().([]byte))
contentType := h.ContentType
contentType := h.peek(strContentType)
if len(contentType) == 0 {
contentType = defaultContentType
}
@@ -349,7 +307,9 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error {
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
writeHeaderLine(w, kv.key, kv.value)
if !bytes.Equal(strServer, kv.key) && !bytes.Equal(strContentType, kv.key) {
writeHeaderLine(w, kv.key, kv.value)
}
}
_, err := w.Write(strCRLF)
@@ -404,23 +364,18 @@ func (h *RequestHeader) Write(w *bufio.Writer) error {
w.Write(strHTTP11)
w.Write(strCRLF)
if len(h.UserAgent) > 0 {
writeHeaderLine(w, strUserAgent, h.UserAgent)
}
if len(h.Referer) > 0 {
writeHeaderLine(w, strReferer, h.Referer)
}
if len(h.Host) == 0 {
host := h.peek(strHost)
if len(host) == 0 {
return fmt.Errorf("missing required Host header")
}
writeHeaderLine(w, strHost, h.Host)
writeHeaderLine(w, strHost, host)
if h.IsMethodPost() {
if len(h.ContentType) == 0 {
contentType := h.peek(strContentType)
if len(contentType) == 0 {
return fmt.Errorf("missing required Content-Type header for POST request")
}
writeHeaderLine(w, strContentType, h.ContentType)
writeHeaderLine(w, strContentType, contentType)
if h.ContentLength < 0 {
return fmt.Errorf("missing required Content-Length header for POST request")
}
@@ -429,7 +384,9 @@ func (h *RequestHeader) Write(w *bufio.Writer) error {
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
writeHeaderLine(w, kv.key, kv.value)
if !bytes.Equal(strHost, kv.key) && !bytes.Equal(strContentType, kv.key) {
writeHeaderLine(w, kv.key, kv.value)
}
}
_, err := w.Write(strCRLF)
@@ -542,39 +499,34 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) {
p.init(buf)
var err error
for p.next() {
if bytes.Equal(p.key, strContentType) {
h.ContentType = append(h.ContentType[:0], p.value...)
continue
}
if bytes.Equal(p.key, strContentLength) && h.ContentLength != -1 {
h.ContentLength, err = parseContentLength(p.value)
if err != nil {
if isNeedMoreError(err) {
return nil, err
switch {
case bytes.Equal(p.key, strContentLength):
if h.ContentLength != -1 {
h.ContentLength, err = parseContentLength(p.value)
if err != nil {
if isNeedMoreError(err) {
return nil, err
}
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
}
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
}
continue
case bytes.Equal(p.key, strTransferEncoding):
if bytes.Equal(p.value, strChunked) {
h.ContentLength = -1
}
case bytes.Equal(p.key, strConnection):
if bytes.Equal(p.value, strClose) {
h.ConnectionClose = true
}
default:
h.h = setKV(h.h, p.key, p.value)
}
if bytes.Equal(p.key, strTransferEncoding) && bytes.Equal(p.value, strChunked) {
h.ContentLength = -1
continue
}
if bytes.Equal(p.key, strServer) {
h.Server = append(h.Server[:0], p.value...)
continue
}
if bytes.Equal(p.key, strConnection) && bytes.Equal(p.value, strClose) {
h.ConnectionClose = true
continue
}
h.h = setKV(h.h, p.key, p.value)
}
if p.err != nil {
return nil, p.err
}
if len(h.ContentType) == 0 {
if len(h.peek(strContentType)) == 0 {
return nil, fmt.Errorf("missing required Content-Type header in %q", buf)
}
if h.ContentLength == -2 {
@@ -590,47 +542,34 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) {
p.init(buf)
var err error
for p.next() {
if bytes.Equal(p.key, strHost) {
h.Host = append(h.Host[:0], p.value...)
continue
}
if bytes.Equal(p.key, strUserAgent) {
h.UserAgent = append(h.UserAgent[:0], p.value...)
continue
}
if bytes.Equal(p.key, strReferer) {
h.Referer = append(h.Referer[:0], p.value...)
continue
}
if bytes.Equal(p.key, strContentType) {
h.ContentType = append(h.ContentType[:0], p.value...)
continue
}
if bytes.Equal(p.key, strContentLength) && h.ContentLength != -1 {
h.ContentLength, err = parseContentLength(p.value)
if err != nil {
if isNeedMoreError(err) {
return nil, err
switch {
case bytes.Equal(p.key, strContentLength):
if h.ContentLength != -1 {
h.ContentLength, err = parseContentLength(p.value)
if err != nil {
if isNeedMoreError(err) {
return nil, err
}
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
}
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
}
continue
case bytes.Equal(p.key, strTransferEncoding):
if bytes.Equal(p.value, strChunked) {
h.ContentLength = -1
}
default:
h.h = setKV(h.h, p.key, p.value)
}
if bytes.Equal(p.key, strTransferEncoding) && bytes.Equal(p.value, strChunked) {
h.ContentLength = -1
continue
}
h.h = setKV(h.h, p.key, p.value)
}
if p.err != nil {
return nil, p.err
}
if len(h.Host) == 0 {
if len(h.peek(strHost)) == 0 {
return nil, fmt.Errorf("missing required Host header in %q", buf)
}
if h.IsMethodPost() {
if len(h.ContentType) == 0 {
if len(h.peek(strContentType)) == 0 {
return nil, fmt.Errorf("missing Content-Type for POST header in %q", buf)
}
if h.ContentLength == -2 {
@@ -708,6 +647,17 @@ func nextLine(b []byte) ([]byte, []byte, error) {
return b[:n], b[nNext+1:], nil
}
func initHeaderKV(kv *argsKV, key, value string) {
kv.key = getHeaderKeyBytes(kv, key)
kv.value = AppendBytesStr(kv.value[:0], value)
}
func getHeaderKeyBytes(kv *argsKV, key string) []byte {
kv.key = AppendBytesStr(kv.key[:0], key)
normalizeHeaderKey(kv.key)
return kv.key
}
func normalizeHeaderKey(b []byte) {
n := len(b)
up := true
+8 -44
View File
@@ -33,18 +33,6 @@ func TestRequestHeaderSetGet(t *testing.T) {
expectRequestHeaderGet(t, h, "baz", "xxxxx")
expectRequestHeaderGet(t, h, "Transfer-Encoding", "")
if !bytes.Equal(h.Host, []byte("12345")) {
t.Fatalf("Unexpected host %q. Expected %q", h.Host, "12345")
}
if !bytes.Equal(h.ContentType, []byte("aaa/bbb")) {
t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, "aaa/bbb")
}
if !bytes.Equal(h.UserAgent, []byte("aaabbb")) {
t.Fatalf("Unepxected Server %q. Expected %q", h.UserAgent, "aaabbb")
}
if !bytes.Equal(h.Referer, []byte("axcv")) {
t.Fatalf("Unexpected referer %q. Expected %q", h.Referer, "axcv")
}
if h.ContentLength != 0 {
t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0)
}
@@ -65,18 +53,6 @@ func TestRequestHeaderSetGet(t *testing.T) {
t.Fatalf("Unexpected error when reading request header: %s", err)
}
if !bytes.Equal(h1.Host, h.Host) {
t.Fatalf("Unexpected host %q. Expected %q", h1.Host, h.Host)
}
if !bytes.Equal(h1.ContentType, h.ContentType) {
t.Fatalf("Unexpected content-type %q. Expected %q", h1.ContentType, h.ContentType)
}
if !bytes.Equal(h1.UserAgent, h.UserAgent) {
t.Fatalf("Unexpected user-agent %q. Expected %q", h1.UserAgent, h.UserAgent)
}
if !bytes.Equal(h1.Referer, h.Referer) {
t.Fatalf("Unepxected referer %q. Expected %q", h1.Referer, h.Referer)
}
if h1.ContentLength != h.ContentLength {
t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength, h.ContentLength)
}
@@ -109,12 +85,6 @@ func TestResponseHeaderSetGet(t *testing.T) {
expectResponseHeaderGet(t, h, "baz", "xxxxx")
expectResponseHeaderGet(t, h, "Transfer-Encoding", "")
if !bytes.Equal(h.ContentType, []byte("aaa/bbb")) {
t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, "aaa/bbb")
}
if !bytes.Equal(h.Server, []byte("aaaa")) {
t.Fatalf("Unepxected Server %q. Expected %q", h.Server, "aaaa")
}
if h.ContentLength != 0 {
t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0)
}
@@ -138,12 +108,6 @@ func TestResponseHeaderSetGet(t *testing.T) {
t.Fatalf("Unexpected error when reading response header: %s", err)
}
if !bytes.Equal(h1.ContentType, h.ContentType) {
t.Fatalf("Unexpected content-type %q. Expected %q", h1.ContentType, h.ContentType)
}
if !bytes.Equal(h1.Server, h.Server) {
t.Fatalf("Unepxected Server %q. Expected %q", h1.Server, h.Server)
}
if h1.ContentLength != h.ContentLength {
t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength, h.ContentLength)
}
@@ -559,8 +523,8 @@ func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, e
if h.ContentLength != expectedContentLength {
t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength, expectedContentLength)
}
if !bytes.Equal(h.ContentType, []byte(expectedContentType)) {
t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType, expectedContentType)
if h.Get("Content-Type") != expectedContentType {
t.Fatalf("Unexpected content type %q. Expected %q", h.Get("Content-Type"), expectedContentType)
}
}
@@ -572,14 +536,14 @@ func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength i
if !bytes.Equal(h.RequestURI, []byte(expectedRequestURI)) {
t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI, expectedRequestURI)
}
if !bytes.Equal(h.Host, []byte(expectedHost)) {
t.Fatalf("Unexpected host %q. Expected %q", h.Host, expectedHost)
if h.Get("Host") != expectedHost {
t.Fatalf("Unexpected host %q. Expected %q", h.Get("Host"), expectedHost)
}
if !bytes.Equal(h.Referer, []byte(expectedReferer)) {
t.Fatalf("Unexpected referer %q. Expected %q", h.Referer, expectedReferer)
if h.Get("Referer") != expectedReferer {
t.Fatalf("Unexpected referer %q. Expected %q", h.Get("Referer"), expectedReferer)
}
if !bytes.Equal(h.ContentType, []byte(expectedContentType)) {
t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, expectedContentType)
if h.Get("Content-Type") != expectedContentType {
t.Fatalf("Unexpected content-type %q. Expected %q", h.Get("Content-Type"), expectedContentType)
}
}
+3 -3
View File
@@ -33,7 +33,7 @@ func (req *Request) ParseURI() {
if req.parsedURI {
return
}
req.URI.Parse(req.Header.Host, req.Header.RequestURI)
req.URI.Parse(req.Header.peek(strHost), req.Header.RequestURI)
req.parsedURI = true
}
@@ -45,9 +45,9 @@ func (req *Request) ParsePostArgs() error {
if !req.Header.IsMethodPost() {
return fmt.Errorf("Cannot parse POST args for %q request", req.Header.Method)
}
if !bytes.Equal(req.Header.ContentType, strPostArgsContentType) {
if !bytes.Equal(req.Header.peek(strContentType), strPostArgsContentType) {
return fmt.Errorf("Cannot parse POST args for %q Content-Type. Required %q Content-Type",
req.Header.ContentType, strPostArgsContentType)
req.Header.peek(strContentType), strPostArgsContentType)
}
req.PostArgs.ParseBytes(req.Body)
req.parsedPostArgs = true
+19 -19
View File
@@ -102,8 +102,8 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName,
expectedStatusCode int, expectedContentType, expectedServerName string) {
var resp Response
resp.Header.StatusCode = statusCode
resp.Header.ContentType = []byte(contentType)
resp.Header.Server = []byte(serverName)
resp.Header.Set("Content-Type", contentType)
resp.Header.Set("Server", serverName)
resp.Body = []byte(body)
w := &bytes.Buffer{}
@@ -127,11 +127,11 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName,
if resp1.Header.ContentLength != len(body) {
t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength, len(body))
}
if !bytes.Equal(resp1.Header.ContentType, []byte(expectedContentType)) {
t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.ContentType, expectedContentType)
if resp1.Header.Get("Content-Type") != expectedContentType {
t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Get("Content-Type"), expectedContentType)
}
if !bytes.Equal(resp1.Header.Server, []byte(expectedServerName)) {
t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Server, expectedServerName)
if resp1.Header.Get("Server") != expectedServerName {
t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Get("Server"), expectedServerName)
}
if !bytes.Equal(resp1.Body, []byte(body)) {
t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body, body)
@@ -167,8 +167,8 @@ func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, bo
req.Header.Method = []byte(method)
req.Header.RequestURI = []byte(requestURI)
req.Header.Host = []byte(host)
req.Header.UserAgent = []byte(userAgent)
req.Header.Set("Host", host)
req.Header.Set("User-Agent", userAgent)
req.Body = []byte(body)
w := &bytes.Buffer{}
@@ -184,13 +184,13 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body,
req.Header.Method = []byte(method)
req.Header.RequestURI = []byte(requestURI)
req.Header.Host = []byte(host)
req.Header.UserAgent = []byte(userAgent)
req.Header.Set("Host", host)
req.Header.Set("User-Agent", userAgent)
req.Body = []byte(body)
contentType := []byte("foobar")
contentType := "foobar"
if method == "POST" {
req.Header.ContentType = contentType
req.Header.Set("Content-Type", contentType)
}
w := &bytes.Buffer{}
@@ -214,18 +214,18 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body,
if !bytes.Equal(req1.Header.RequestURI, []byte(requestURI)) {
t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI, requestURI)
}
if !bytes.Equal(req1.Header.Host, []byte(host)) {
t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Host, host)
if req1.Header.Get("Host") != host {
t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Get("Host"), host)
}
if !bytes.Equal(req1.Header.UserAgent, []byte(userAgent)) {
t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.UserAgent, userAgent)
if req1.Header.Get("User-Agent") != userAgent {
t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Get("User-Agent"), userAgent)
}
if !bytes.Equal(req1.Body, []byte(body)) {
t.Fatalf("Unexpected body: %q. Expected %q", req1.Body, body)
}
if method == "POST" && !bytes.Equal(req1.Header.ContentType, contentType) {
t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.ContentType, contentType)
if method == "POST" && req1.Header.Get("Content-Type") != contentType {
t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Get("Content-Type"), contentType)
}
}
@@ -345,7 +345,7 @@ func TestRequestParseURI(t *testing.T) {
expectedHash := "1334dfds&=d"
var req Request
req.Header.Host = []byte(host)
req.Header.Set("Host", host)
req.Header.RequestURI = []byte(requestURI)
req.ParseURI()
+20 -29
View File
@@ -39,13 +39,13 @@ type Server struct {
type RequestHandler func(ctx *ServerCtx)
type ServerCtx struct {
Request Request
Response Response
Request Request
// Unique id of the context.
// Used by ServerCtx.Logger().
ID uint64
resp Response
logger ctxLogger
s *Server
c remoteAddrer
@@ -97,32 +97,24 @@ func (ctx *ServerCtx) RemoteIP() string {
}
func (ctx *ServerCtx) Error(msg string, statusCode int) {
resp := ctx.zeroResponse()
resp := ctx.Response()
resp.Clear()
resp.Header.StatusCode = statusCode
resp.Header.ContentType = append(resp.Header.ContentType, defaultContentType...)
resp.Header.set(strContentType, defaultContentType)
resp.Body = append(resp.Body, []byte(msg)...)
}
func (ctx *ServerCtx) Success(contentType string, body []byte) {
resp := ctx.zeroResponse()
resp.Header.ContentType = appendString(resp.Header.ContentType, contentType)
resp := ctx.Response()
resp.Header.setStr(strContentType, contentType)
resp.Body = append(resp.Body, body...)
}
func appendString(b []byte, s string) []byte {
for i, n := 0, len(s); i < n; i++ {
b = append(b, s[i])
}
return b
}
func (ctx *ServerCtx) zeroResponse() *Response {
func (ctx *ServerCtx) Response() *Response {
if ctx.shadow != nil {
ctx = ctx.shadow
}
resp := &ctx.Response
resp.Clear()
return resp
return &ctx.resp
}
func (ctx *ServerCtx) Logger() Logger {
@@ -136,7 +128,7 @@ func (ctx *ServerCtx) Steal() {
shadow := *ctx
shadow.Request = Request{}
shadow.Response = Response{}
shadow.resp = Response{}
shadow.logger.ctx = &shadow
shadow.v = &shadow
ctx.shadow = &shadow
@@ -144,17 +136,16 @@ func (ctx *ServerCtx) Steal() {
func (ctx *ServerCtx) writeResponse() error {
if ctx.shadow != nil {
panic("BUG: ServerCtx.writeResponse() shouldn't be called on shadow")
panic("BUG: ctx.shadow is not null")
}
resp := &ctx.Response
h := &resp.Header
serverOld := h.Server
h := &ctx.resp.Header
serverOld := h.peek(strServer)
if len(serverOld) == 0 {
h.Server = ctx.s.getServerName()
h.set(strServer, ctx.s.getServerName())
}
err := resp.Write(ctx.w)
err := ctx.resp.Write(ctx.w)
if len(serverOld) == 0 {
h.Server = serverOld
h.set(strServer, serverOld)
}
return err
}
@@ -305,9 +296,9 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **ServerCtx) error {
if err = ctx.writeResponse(); err != nil {
break
}
connectionClose := ctx.Response.Header.ConnectionClose
connectionClose := ctx.resp.Header.ConnectionClose
ctx.Response.Clear()
ctx.resp.Clear()
trimBigBuffers(ctx)
if ctx.r.Buffered() == 0 || connectionClose {
@@ -329,8 +320,8 @@ func trimBigBuffers(ctx *ServerCtx) {
if cap(ctx.Request.Body) > bigBufferLimit {
ctx.Request.Body = nil
}
if cap(ctx.Response.Body) > bigBufferLimit {
ctx.Response.Body = nil
if cap(ctx.resp.Body) > bigBufferLimit {
ctx.resp.Body = nil
}
}
+5 -5
View File
@@ -51,7 +51,7 @@ func TestServerSteal(t *testing.T) {
func TestServerConnectionClose(t *testing.T) {
s := &Server{
Handler: func(ctx *ServerCtx) {
ctx.Response.Header.ConnectionClose = true
ctx.Response().Header.ConnectionClose = true
},
}
@@ -266,8 +266,8 @@ func TestServerConnError(t *testing.T) {
if resp.Header.ContentLength != 6 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength, 6)
}
if !bytes.Equal(resp.Header.ContentType, defaultContentType) {
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.ContentType, defaultContentType)
if resp.Header.Get("Content-Type") != string(defaultContentType) {
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Get("Content-Type"), defaultContentType)
}
if !bytes.Equal(resp.Body, []byte("foobar")) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body, "foobar")
@@ -278,7 +278,7 @@ func TestServeConnSingleRequest(t *testing.T) {
s := &Server{
Handler: func(ctx *ServerCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Host)))
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Get("Host"))))
},
}
@@ -307,7 +307,7 @@ func TestServeConnMultiRequests(t *testing.T) {
s := &Server{
Handler: func(ctx *ServerCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Host)))
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Get("Host"))))
},
}