mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
Uniformly process all headers
This commit is contained in:
@@ -43,9 +43,7 @@ var (
|
||||
|
||||
type ResponseHeader struct {
|
||||
StatusCode int
|
||||
ContentType []byte
|
||||
ContentLength int
|
||||
Server []byte
|
||||
ConnectionClose bool
|
||||
|
||||
h []argsKV
|
||||
@@ -55,10 +53,6 @@ type ResponseHeader struct {
|
||||
type RequestHeader struct {
|
||||
Method []byte
|
||||
RequestURI []byte
|
||||
Host []byte
|
||||
UserAgent []byte
|
||||
Referer []byte
|
||||
ContentType []byte
|
||||
ContentLength int
|
||||
|
||||
h []argsKV
|
||||
@@ -80,8 +74,6 @@ func (h *RequestHeader) IsMethodHead() bool {
|
||||
func (h *ResponseHeader) Clear() {
|
||||
h.StatusCode = 0
|
||||
h.ContentLength = 0
|
||||
h.ContentType = h.ContentType[:0]
|
||||
h.Server = h.Server[:0]
|
||||
h.ConnectionClose = false
|
||||
|
||||
h.h = h.h[:0]
|
||||
@@ -90,107 +82,79 @@ func (h *ResponseHeader) Clear() {
|
||||
func (h *RequestHeader) Clear() {
|
||||
h.Method = h.Method[:0]
|
||||
h.RequestURI = h.RequestURI[:0]
|
||||
h.Host = h.Host[:0]
|
||||
h.UserAgent = h.UserAgent[:0]
|
||||
h.Referer = h.Referer[:0]
|
||||
h.ContentType = h.ContentType[:0]
|
||||
h.ContentLength = 0
|
||||
|
||||
h.h = h.h[:0]
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) Set(key, value string) {
|
||||
k := getKeyBytes(&h.bufKV, key)
|
||||
initHeaderKV(&h.bufKV, key, value)
|
||||
h.set(h.bufKV.key, h.bufKV.value)
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) set(key, value []byte) {
|
||||
switch {
|
||||
case bytes.Equal(strContentLength, k):
|
||||
case bytes.Equal(strContentLength, key):
|
||||
// skip Conent-Length setting, since it will be set automatically.
|
||||
return
|
||||
case bytes.Equal(strContentType, k):
|
||||
h.ContentType = AppendBytesStr(h.ContentType[:0], value)
|
||||
return
|
||||
case bytes.Equal(strServer, k):
|
||||
h.Server = AppendBytesStr(h.Server[:0], value)
|
||||
return
|
||||
case bytes.Equal(strConnection, k):
|
||||
if EqualBytesStr(strClose, value) {
|
||||
case bytes.Equal(strConnection, key):
|
||||
if bytes.Equal(strClose, value) {
|
||||
h.ConnectionClose = true
|
||||
}
|
||||
// skip other 'Connection' shit :)
|
||||
return
|
||||
case bytes.Equal(strTransferEncoding, k):
|
||||
case bytes.Equal(strTransferEncoding, key):
|
||||
// Transfer-Encoding is managed automatically.
|
||||
return
|
||||
case bytes.Equal(strDate, key):
|
||||
// Date is managed automatically.
|
||||
default:
|
||||
h.h = setKV(h.h, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) setStr(key []byte, value string) {
|
||||
h.bufKV.value = AppendBytesStr(h.bufKV.value[:0], value)
|
||||
h.h = setKV(h.h, k, h.bufKV.value)
|
||||
h.set(key, h.bufKV.value)
|
||||
}
|
||||
|
||||
func (h *RequestHeader) Set(key, value string) {
|
||||
k := getKeyBytes(&h.bufKV, key)
|
||||
initHeaderKV(&h.bufKV, key, value)
|
||||
h.set(h.bufKV.key, h.bufKV.value)
|
||||
}
|
||||
|
||||
func (h *RequestHeader) set(key, value []byte) {
|
||||
switch {
|
||||
case bytes.Equal(strHost, k):
|
||||
h.Host = AppendBytesStr(h.Host[:0], value)
|
||||
return
|
||||
case bytes.Equal(strUserAgent, k):
|
||||
h.UserAgent = AppendBytesStr(h.UserAgent[:0], value)
|
||||
return
|
||||
case bytes.Equal(strReferer, k):
|
||||
h.Referer = AppendBytesStr(h.Referer[:0], value)
|
||||
return
|
||||
case bytes.Equal(strContentType, k):
|
||||
h.ContentType = AppendBytesStr(h.ContentType[:0], value)
|
||||
return
|
||||
case bytes.Equal(strContentLength, k):
|
||||
case bytes.Equal(strContentLength, key):
|
||||
// Content-Length is managed automatically.
|
||||
return
|
||||
case bytes.Equal(strTransferEncoding, k):
|
||||
case bytes.Equal(strTransferEncoding, key):
|
||||
// Transfer-Encoding is managed automatically.
|
||||
return
|
||||
case bytes.Equal(strConnection, k):
|
||||
case bytes.Equal(strConnection, key):
|
||||
// Connection is managed automatically.
|
||||
return
|
||||
default:
|
||||
h.h = setKV(h.h, key, value)
|
||||
}
|
||||
|
||||
h.bufKV.value = AppendBytesStr(h.bufKV.value[:0], value)
|
||||
h.h = setKV(h.h, k, h.bufKV.value)
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) Peek(key string) []byte {
|
||||
k := getKeyBytes(&h.bufKV, key)
|
||||
k := getHeaderKeyBytes(&h.bufKV, key)
|
||||
return h.peek(k)
|
||||
}
|
||||
|
||||
switch {
|
||||
case bytes.Equal(strContentType, k):
|
||||
return h.ContentType
|
||||
case bytes.Equal(strServer, k):
|
||||
return h.Server
|
||||
case bytes.Equal(strConnection, k):
|
||||
func (h *RequestHeader) Peek(key string) []byte {
|
||||
k := getHeaderKeyBytes(&h.bufKV, key)
|
||||
return h.peek(k)
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) peek(key []byte) []byte {
|
||||
if bytes.Equal(strConnection, key) {
|
||||
if h.ConnectionClose {
|
||||
return strClose
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return peekKV(h.h, k)
|
||||
return peekKV(h.h, key)
|
||||
}
|
||||
|
||||
func (h *RequestHeader) Peek(key string) []byte {
|
||||
k := getKeyBytes(&h.bufKV, key)
|
||||
|
||||
switch {
|
||||
case bytes.Equal(strHost, k):
|
||||
return h.Host
|
||||
case bytes.Equal(strUserAgent, k):
|
||||
return h.UserAgent
|
||||
case bytes.Equal(strReferer, k):
|
||||
return h.Referer
|
||||
case bytes.Equal(strContentType, k):
|
||||
return h.ContentType
|
||||
}
|
||||
|
||||
return peekKV(h.h, k)
|
||||
func (h *RequestHeader) peek(key []byte) []byte {
|
||||
return peekKV(h.h, key)
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) Get(key string) string {
|
||||
@@ -201,12 +165,6 @@ func (h *RequestHeader) Get(key string) string {
|
||||
return string(h.Peek(key))
|
||||
}
|
||||
|
||||
func getKeyBytes(kv *argsKV, key string) []byte {
|
||||
kv.key = AppendBytesStr(kv.key[:0], key)
|
||||
normalizeHeaderKey(kv.key)
|
||||
return kv.key
|
||||
}
|
||||
|
||||
func (h *ResponseHeader) Read(r *bufio.Reader) error {
|
||||
n := 1
|
||||
for {
|
||||
@@ -325,14 +283,14 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error {
|
||||
}
|
||||
w.Write(statusLine(statusCode))
|
||||
|
||||
server := h.Server
|
||||
server := h.peek(strServer)
|
||||
if len(server) == 0 {
|
||||
server = defaultServerName
|
||||
}
|
||||
writeHeaderLine(w, strServer, server)
|
||||
writeHeaderLine(w, strDate, serverDate.Load().([]byte))
|
||||
|
||||
contentType := h.ContentType
|
||||
contentType := h.peek(strContentType)
|
||||
if len(contentType) == 0 {
|
||||
contentType = defaultContentType
|
||||
}
|
||||
@@ -349,7 +307,9 @@ func (h *ResponseHeader) Write(w *bufio.Writer) error {
|
||||
|
||||
for i, n := 0, len(h.h); i < n; i++ {
|
||||
kv := &h.h[i]
|
||||
writeHeaderLine(w, kv.key, kv.value)
|
||||
if !bytes.Equal(strServer, kv.key) && !bytes.Equal(strContentType, kv.key) {
|
||||
writeHeaderLine(w, kv.key, kv.value)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := w.Write(strCRLF)
|
||||
@@ -404,23 +364,18 @@ func (h *RequestHeader) Write(w *bufio.Writer) error {
|
||||
w.Write(strHTTP11)
|
||||
w.Write(strCRLF)
|
||||
|
||||
if len(h.UserAgent) > 0 {
|
||||
writeHeaderLine(w, strUserAgent, h.UserAgent)
|
||||
}
|
||||
if len(h.Referer) > 0 {
|
||||
writeHeaderLine(w, strReferer, h.Referer)
|
||||
}
|
||||
|
||||
if len(h.Host) == 0 {
|
||||
host := h.peek(strHost)
|
||||
if len(host) == 0 {
|
||||
return fmt.Errorf("missing required Host header")
|
||||
}
|
||||
writeHeaderLine(w, strHost, h.Host)
|
||||
writeHeaderLine(w, strHost, host)
|
||||
|
||||
if h.IsMethodPost() {
|
||||
if len(h.ContentType) == 0 {
|
||||
contentType := h.peek(strContentType)
|
||||
if len(contentType) == 0 {
|
||||
return fmt.Errorf("missing required Content-Type header for POST request")
|
||||
}
|
||||
writeHeaderLine(w, strContentType, h.ContentType)
|
||||
writeHeaderLine(w, strContentType, contentType)
|
||||
if h.ContentLength < 0 {
|
||||
return fmt.Errorf("missing required Content-Length header for POST request")
|
||||
}
|
||||
@@ -429,7 +384,9 @@ func (h *RequestHeader) Write(w *bufio.Writer) error {
|
||||
|
||||
for i, n := 0, len(h.h); i < n; i++ {
|
||||
kv := &h.h[i]
|
||||
writeHeaderLine(w, kv.key, kv.value)
|
||||
if !bytes.Equal(strHost, kv.key) && !bytes.Equal(strContentType, kv.key) {
|
||||
writeHeaderLine(w, kv.key, kv.value)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := w.Write(strCRLF)
|
||||
@@ -542,39 +499,34 @@ func (h *ResponseHeader) parseHeaders(buf []byte) ([]byte, error) {
|
||||
p.init(buf)
|
||||
var err error
|
||||
for p.next() {
|
||||
if bytes.Equal(p.key, strContentType) {
|
||||
h.ContentType = append(h.ContentType[:0], p.value...)
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strContentLength) && h.ContentLength != -1 {
|
||||
h.ContentLength, err = parseContentLength(p.value)
|
||||
if err != nil {
|
||||
if isNeedMoreError(err) {
|
||||
return nil, err
|
||||
switch {
|
||||
case bytes.Equal(p.key, strContentLength):
|
||||
if h.ContentLength != -1 {
|
||||
h.ContentLength, err = parseContentLength(p.value)
|
||||
if err != nil {
|
||||
if isNeedMoreError(err) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
|
||||
}
|
||||
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
|
||||
}
|
||||
continue
|
||||
case bytes.Equal(p.key, strTransferEncoding):
|
||||
if bytes.Equal(p.value, strChunked) {
|
||||
h.ContentLength = -1
|
||||
}
|
||||
case bytes.Equal(p.key, strConnection):
|
||||
if bytes.Equal(p.value, strClose) {
|
||||
h.ConnectionClose = true
|
||||
}
|
||||
default:
|
||||
h.h = setKV(h.h, p.key, p.value)
|
||||
}
|
||||
if bytes.Equal(p.key, strTransferEncoding) && bytes.Equal(p.value, strChunked) {
|
||||
h.ContentLength = -1
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strServer) {
|
||||
h.Server = append(h.Server[:0], p.value...)
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strConnection) && bytes.Equal(p.value, strClose) {
|
||||
h.ConnectionClose = true
|
||||
continue
|
||||
}
|
||||
h.h = setKV(h.h, p.key, p.value)
|
||||
}
|
||||
if p.err != nil {
|
||||
return nil, p.err
|
||||
}
|
||||
|
||||
if len(h.ContentType) == 0 {
|
||||
if len(h.peek(strContentType)) == 0 {
|
||||
return nil, fmt.Errorf("missing required Content-Type header in %q", buf)
|
||||
}
|
||||
if h.ContentLength == -2 {
|
||||
@@ -590,47 +542,34 @@ func (h *RequestHeader) parseHeaders(buf []byte) ([]byte, error) {
|
||||
p.init(buf)
|
||||
var err error
|
||||
for p.next() {
|
||||
if bytes.Equal(p.key, strHost) {
|
||||
h.Host = append(h.Host[:0], p.value...)
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strUserAgent) {
|
||||
h.UserAgent = append(h.UserAgent[:0], p.value...)
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strReferer) {
|
||||
h.Referer = append(h.Referer[:0], p.value...)
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strContentType) {
|
||||
h.ContentType = append(h.ContentType[:0], p.value...)
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(p.key, strContentLength) && h.ContentLength != -1 {
|
||||
h.ContentLength, err = parseContentLength(p.value)
|
||||
if err != nil {
|
||||
if isNeedMoreError(err) {
|
||||
return nil, err
|
||||
switch {
|
||||
case bytes.Equal(p.key, strContentLength):
|
||||
if h.ContentLength != -1 {
|
||||
h.ContentLength, err = parseContentLength(p.value)
|
||||
if err != nil {
|
||||
if isNeedMoreError(err) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
|
||||
}
|
||||
return nil, fmt.Errorf("cannot parse Content-Length %q: %s at %q", p.value, err, buf)
|
||||
}
|
||||
continue
|
||||
case bytes.Equal(p.key, strTransferEncoding):
|
||||
if bytes.Equal(p.value, strChunked) {
|
||||
h.ContentLength = -1
|
||||
}
|
||||
default:
|
||||
h.h = setKV(h.h, p.key, p.value)
|
||||
}
|
||||
if bytes.Equal(p.key, strTransferEncoding) && bytes.Equal(p.value, strChunked) {
|
||||
h.ContentLength = -1
|
||||
continue
|
||||
}
|
||||
h.h = setKV(h.h, p.key, p.value)
|
||||
}
|
||||
if p.err != nil {
|
||||
return nil, p.err
|
||||
}
|
||||
|
||||
if len(h.Host) == 0 {
|
||||
if len(h.peek(strHost)) == 0 {
|
||||
return nil, fmt.Errorf("missing required Host header in %q", buf)
|
||||
}
|
||||
if h.IsMethodPost() {
|
||||
if len(h.ContentType) == 0 {
|
||||
if len(h.peek(strContentType)) == 0 {
|
||||
return nil, fmt.Errorf("missing Content-Type for POST header in %q", buf)
|
||||
}
|
||||
if h.ContentLength == -2 {
|
||||
@@ -708,6 +647,17 @@ func nextLine(b []byte) ([]byte, []byte, error) {
|
||||
return b[:n], b[nNext+1:], nil
|
||||
}
|
||||
|
||||
func initHeaderKV(kv *argsKV, key, value string) {
|
||||
kv.key = getHeaderKeyBytes(kv, key)
|
||||
kv.value = AppendBytesStr(kv.value[:0], value)
|
||||
}
|
||||
|
||||
func getHeaderKeyBytes(kv *argsKV, key string) []byte {
|
||||
kv.key = AppendBytesStr(kv.key[:0], key)
|
||||
normalizeHeaderKey(kv.key)
|
||||
return kv.key
|
||||
}
|
||||
|
||||
func normalizeHeaderKey(b []byte) {
|
||||
n := len(b)
|
||||
up := true
|
||||
|
||||
+8
-44
@@ -33,18 +33,6 @@ func TestRequestHeaderSetGet(t *testing.T) {
|
||||
expectRequestHeaderGet(t, h, "baz", "xxxxx")
|
||||
expectRequestHeaderGet(t, h, "Transfer-Encoding", "")
|
||||
|
||||
if !bytes.Equal(h.Host, []byte("12345")) {
|
||||
t.Fatalf("Unexpected host %q. Expected %q", h.Host, "12345")
|
||||
}
|
||||
if !bytes.Equal(h.ContentType, []byte("aaa/bbb")) {
|
||||
t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, "aaa/bbb")
|
||||
}
|
||||
if !bytes.Equal(h.UserAgent, []byte("aaabbb")) {
|
||||
t.Fatalf("Unepxected Server %q. Expected %q", h.UserAgent, "aaabbb")
|
||||
}
|
||||
if !bytes.Equal(h.Referer, []byte("axcv")) {
|
||||
t.Fatalf("Unexpected referer %q. Expected %q", h.Referer, "axcv")
|
||||
}
|
||||
if h.ContentLength != 0 {
|
||||
t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0)
|
||||
}
|
||||
@@ -65,18 +53,6 @@ func TestRequestHeaderSetGet(t *testing.T) {
|
||||
t.Fatalf("Unexpected error when reading request header: %s", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(h1.Host, h.Host) {
|
||||
t.Fatalf("Unexpected host %q. Expected %q", h1.Host, h.Host)
|
||||
}
|
||||
if !bytes.Equal(h1.ContentType, h.ContentType) {
|
||||
t.Fatalf("Unexpected content-type %q. Expected %q", h1.ContentType, h.ContentType)
|
||||
}
|
||||
if !bytes.Equal(h1.UserAgent, h.UserAgent) {
|
||||
t.Fatalf("Unexpected user-agent %q. Expected %q", h1.UserAgent, h.UserAgent)
|
||||
}
|
||||
if !bytes.Equal(h1.Referer, h.Referer) {
|
||||
t.Fatalf("Unepxected referer %q. Expected %q", h1.Referer, h.Referer)
|
||||
}
|
||||
if h1.ContentLength != h.ContentLength {
|
||||
t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength, h.ContentLength)
|
||||
}
|
||||
@@ -109,12 +85,6 @@ func TestResponseHeaderSetGet(t *testing.T) {
|
||||
expectResponseHeaderGet(t, h, "baz", "xxxxx")
|
||||
expectResponseHeaderGet(t, h, "Transfer-Encoding", "")
|
||||
|
||||
if !bytes.Equal(h.ContentType, []byte("aaa/bbb")) {
|
||||
t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, "aaa/bbb")
|
||||
}
|
||||
if !bytes.Equal(h.Server, []byte("aaaa")) {
|
||||
t.Fatalf("Unepxected Server %q. Expected %q", h.Server, "aaaa")
|
||||
}
|
||||
if h.ContentLength != 0 {
|
||||
t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength, 0)
|
||||
}
|
||||
@@ -138,12 +108,6 @@ func TestResponseHeaderSetGet(t *testing.T) {
|
||||
t.Fatalf("Unexpected error when reading response header: %s", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(h1.ContentType, h.ContentType) {
|
||||
t.Fatalf("Unexpected content-type %q. Expected %q", h1.ContentType, h.ContentType)
|
||||
}
|
||||
if !bytes.Equal(h1.Server, h.Server) {
|
||||
t.Fatalf("Unepxected Server %q. Expected %q", h1.Server, h.Server)
|
||||
}
|
||||
if h1.ContentLength != h.ContentLength {
|
||||
t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength, h.ContentLength)
|
||||
}
|
||||
@@ -559,8 +523,8 @@ func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, e
|
||||
if h.ContentLength != expectedContentLength {
|
||||
t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength, expectedContentLength)
|
||||
}
|
||||
if !bytes.Equal(h.ContentType, []byte(expectedContentType)) {
|
||||
t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType, expectedContentType)
|
||||
if h.Get("Content-Type") != expectedContentType {
|
||||
t.Fatalf("Unexpected content type %q. Expected %q", h.Get("Content-Type"), expectedContentType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,14 +536,14 @@ func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength i
|
||||
if !bytes.Equal(h.RequestURI, []byte(expectedRequestURI)) {
|
||||
t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI, expectedRequestURI)
|
||||
}
|
||||
if !bytes.Equal(h.Host, []byte(expectedHost)) {
|
||||
t.Fatalf("Unexpected host %q. Expected %q", h.Host, expectedHost)
|
||||
if h.Get("Host") != expectedHost {
|
||||
t.Fatalf("Unexpected host %q. Expected %q", h.Get("Host"), expectedHost)
|
||||
}
|
||||
if !bytes.Equal(h.Referer, []byte(expectedReferer)) {
|
||||
t.Fatalf("Unexpected referer %q. Expected %q", h.Referer, expectedReferer)
|
||||
if h.Get("Referer") != expectedReferer {
|
||||
t.Fatalf("Unexpected referer %q. Expected %q", h.Get("Referer"), expectedReferer)
|
||||
}
|
||||
if !bytes.Equal(h.ContentType, []byte(expectedContentType)) {
|
||||
t.Fatalf("Unexpected content-type %q. Expected %q", h.ContentType, expectedContentType)
|
||||
if h.Get("Content-Type") != expectedContentType {
|
||||
t.Fatalf("Unexpected content-type %q. Expected %q", h.Get("Content-Type"), expectedContentType)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ func (req *Request) ParseURI() {
|
||||
if req.parsedURI {
|
||||
return
|
||||
}
|
||||
req.URI.Parse(req.Header.Host, req.Header.RequestURI)
|
||||
req.URI.Parse(req.Header.peek(strHost), req.Header.RequestURI)
|
||||
req.parsedURI = true
|
||||
}
|
||||
|
||||
@@ -45,9 +45,9 @@ func (req *Request) ParsePostArgs() error {
|
||||
if !req.Header.IsMethodPost() {
|
||||
return fmt.Errorf("Cannot parse POST args for %q request", req.Header.Method)
|
||||
}
|
||||
if !bytes.Equal(req.Header.ContentType, strPostArgsContentType) {
|
||||
if !bytes.Equal(req.Header.peek(strContentType), strPostArgsContentType) {
|
||||
return fmt.Errorf("Cannot parse POST args for %q Content-Type. Required %q Content-Type",
|
||||
req.Header.ContentType, strPostArgsContentType)
|
||||
req.Header.peek(strContentType), strPostArgsContentType)
|
||||
}
|
||||
req.PostArgs.ParseBytes(req.Body)
|
||||
req.parsedPostArgs = true
|
||||
|
||||
+19
-19
@@ -102,8 +102,8 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName,
|
||||
expectedStatusCode int, expectedContentType, expectedServerName string) {
|
||||
var resp Response
|
||||
resp.Header.StatusCode = statusCode
|
||||
resp.Header.ContentType = []byte(contentType)
|
||||
resp.Header.Server = []byte(serverName)
|
||||
resp.Header.Set("Content-Type", contentType)
|
||||
resp.Header.Set("Server", serverName)
|
||||
resp.Body = []byte(body)
|
||||
|
||||
w := &bytes.Buffer{}
|
||||
@@ -127,11 +127,11 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName,
|
||||
if resp1.Header.ContentLength != len(body) {
|
||||
t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength, len(body))
|
||||
}
|
||||
if !bytes.Equal(resp1.Header.ContentType, []byte(expectedContentType)) {
|
||||
t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.ContentType, expectedContentType)
|
||||
if resp1.Header.Get("Content-Type") != expectedContentType {
|
||||
t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Get("Content-Type"), expectedContentType)
|
||||
}
|
||||
if !bytes.Equal(resp1.Header.Server, []byte(expectedServerName)) {
|
||||
t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Server, expectedServerName)
|
||||
if resp1.Header.Get("Server") != expectedServerName {
|
||||
t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Get("Server"), expectedServerName)
|
||||
}
|
||||
if !bytes.Equal(resp1.Body, []byte(body)) {
|
||||
t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body, body)
|
||||
@@ -167,8 +167,8 @@ func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, bo
|
||||
|
||||
req.Header.Method = []byte(method)
|
||||
req.Header.RequestURI = []byte(requestURI)
|
||||
req.Header.Host = []byte(host)
|
||||
req.Header.UserAgent = []byte(userAgent)
|
||||
req.Header.Set("Host", host)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Body = []byte(body)
|
||||
|
||||
w := &bytes.Buffer{}
|
||||
@@ -184,13 +184,13 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body,
|
||||
|
||||
req.Header.Method = []byte(method)
|
||||
req.Header.RequestURI = []byte(requestURI)
|
||||
req.Header.Host = []byte(host)
|
||||
req.Header.UserAgent = []byte(userAgent)
|
||||
req.Header.Set("Host", host)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Body = []byte(body)
|
||||
|
||||
contentType := []byte("foobar")
|
||||
contentType := "foobar"
|
||||
if method == "POST" {
|
||||
req.Header.ContentType = contentType
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
|
||||
w := &bytes.Buffer{}
|
||||
@@ -214,18 +214,18 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body,
|
||||
if !bytes.Equal(req1.Header.RequestURI, []byte(requestURI)) {
|
||||
t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI, requestURI)
|
||||
}
|
||||
if !bytes.Equal(req1.Header.Host, []byte(host)) {
|
||||
t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Host, host)
|
||||
if req1.Header.Get("Host") != host {
|
||||
t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Get("Host"), host)
|
||||
}
|
||||
if !bytes.Equal(req1.Header.UserAgent, []byte(userAgent)) {
|
||||
t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.UserAgent, userAgent)
|
||||
if req1.Header.Get("User-Agent") != userAgent {
|
||||
t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Get("User-Agent"), userAgent)
|
||||
}
|
||||
if !bytes.Equal(req1.Body, []byte(body)) {
|
||||
t.Fatalf("Unexpected body: %q. Expected %q", req1.Body, body)
|
||||
}
|
||||
|
||||
if method == "POST" && !bytes.Equal(req1.Header.ContentType, contentType) {
|
||||
t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.ContentType, contentType)
|
||||
if method == "POST" && req1.Header.Get("Content-Type") != contentType {
|
||||
t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Get("Content-Type"), contentType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@ func TestRequestParseURI(t *testing.T) {
|
||||
expectedHash := "1334dfds&=d"
|
||||
|
||||
var req Request
|
||||
req.Header.Host = []byte(host)
|
||||
req.Header.Set("Host", host)
|
||||
req.Header.RequestURI = []byte(requestURI)
|
||||
|
||||
req.ParseURI()
|
||||
|
||||
@@ -39,13 +39,13 @@ type Server struct {
|
||||
type RequestHandler func(ctx *ServerCtx)
|
||||
|
||||
type ServerCtx struct {
|
||||
Request Request
|
||||
Response Response
|
||||
Request Request
|
||||
|
||||
// Unique id of the context.
|
||||
// Used by ServerCtx.Logger().
|
||||
ID uint64
|
||||
|
||||
resp Response
|
||||
logger ctxLogger
|
||||
s *Server
|
||||
c remoteAddrer
|
||||
@@ -97,32 +97,24 @@ func (ctx *ServerCtx) RemoteIP() string {
|
||||
}
|
||||
|
||||
func (ctx *ServerCtx) Error(msg string, statusCode int) {
|
||||
resp := ctx.zeroResponse()
|
||||
resp := ctx.Response()
|
||||
resp.Clear()
|
||||
resp.Header.StatusCode = statusCode
|
||||
resp.Header.ContentType = append(resp.Header.ContentType, defaultContentType...)
|
||||
resp.Header.set(strContentType, defaultContentType)
|
||||
resp.Body = append(resp.Body, []byte(msg)...)
|
||||
}
|
||||
|
||||
func (ctx *ServerCtx) Success(contentType string, body []byte) {
|
||||
resp := ctx.zeroResponse()
|
||||
resp.Header.ContentType = appendString(resp.Header.ContentType, contentType)
|
||||
resp := ctx.Response()
|
||||
resp.Header.setStr(strContentType, contentType)
|
||||
resp.Body = append(resp.Body, body...)
|
||||
}
|
||||
|
||||
func appendString(b []byte, s string) []byte {
|
||||
for i, n := 0, len(s); i < n; i++ {
|
||||
b = append(b, s[i])
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (ctx *ServerCtx) zeroResponse() *Response {
|
||||
func (ctx *ServerCtx) Response() *Response {
|
||||
if ctx.shadow != nil {
|
||||
ctx = ctx.shadow
|
||||
}
|
||||
resp := &ctx.Response
|
||||
resp.Clear()
|
||||
return resp
|
||||
return &ctx.resp
|
||||
}
|
||||
|
||||
func (ctx *ServerCtx) Logger() Logger {
|
||||
@@ -136,7 +128,7 @@ func (ctx *ServerCtx) Steal() {
|
||||
|
||||
shadow := *ctx
|
||||
shadow.Request = Request{}
|
||||
shadow.Response = Response{}
|
||||
shadow.resp = Response{}
|
||||
shadow.logger.ctx = &shadow
|
||||
shadow.v = &shadow
|
||||
ctx.shadow = &shadow
|
||||
@@ -144,17 +136,16 @@ func (ctx *ServerCtx) Steal() {
|
||||
|
||||
func (ctx *ServerCtx) writeResponse() error {
|
||||
if ctx.shadow != nil {
|
||||
panic("BUG: ServerCtx.writeResponse() shouldn't be called on shadow")
|
||||
panic("BUG: ctx.shadow is not null")
|
||||
}
|
||||
resp := &ctx.Response
|
||||
h := &resp.Header
|
||||
serverOld := h.Server
|
||||
h := &ctx.resp.Header
|
||||
serverOld := h.peek(strServer)
|
||||
if len(serverOld) == 0 {
|
||||
h.Server = ctx.s.getServerName()
|
||||
h.set(strServer, ctx.s.getServerName())
|
||||
}
|
||||
err := resp.Write(ctx.w)
|
||||
err := ctx.resp.Write(ctx.w)
|
||||
if len(serverOld) == 0 {
|
||||
h.Server = serverOld
|
||||
h.set(strServer, serverOld)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -305,9 +296,9 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **ServerCtx) error {
|
||||
if err = ctx.writeResponse(); err != nil {
|
||||
break
|
||||
}
|
||||
connectionClose := ctx.Response.Header.ConnectionClose
|
||||
connectionClose := ctx.resp.Header.ConnectionClose
|
||||
|
||||
ctx.Response.Clear()
|
||||
ctx.resp.Clear()
|
||||
trimBigBuffers(ctx)
|
||||
|
||||
if ctx.r.Buffered() == 0 || connectionClose {
|
||||
@@ -329,8 +320,8 @@ func trimBigBuffers(ctx *ServerCtx) {
|
||||
if cap(ctx.Request.Body) > bigBufferLimit {
|
||||
ctx.Request.Body = nil
|
||||
}
|
||||
if cap(ctx.Response.Body) > bigBufferLimit {
|
||||
ctx.Response.Body = nil
|
||||
if cap(ctx.resp.Body) > bigBufferLimit {
|
||||
ctx.resp.Body = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+5
-5
@@ -51,7 +51,7 @@ func TestServerSteal(t *testing.T) {
|
||||
func TestServerConnectionClose(t *testing.T) {
|
||||
s := &Server{
|
||||
Handler: func(ctx *ServerCtx) {
|
||||
ctx.Response.Header.ConnectionClose = true
|
||||
ctx.Response().Header.ConnectionClose = true
|
||||
},
|
||||
}
|
||||
|
||||
@@ -266,8 +266,8 @@ func TestServerConnError(t *testing.T) {
|
||||
if resp.Header.ContentLength != 6 {
|
||||
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength, 6)
|
||||
}
|
||||
if !bytes.Equal(resp.Header.ContentType, defaultContentType) {
|
||||
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.ContentType, defaultContentType)
|
||||
if resp.Header.Get("Content-Type") != string(defaultContentType) {
|
||||
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Get("Content-Type"), defaultContentType)
|
||||
}
|
||||
if !bytes.Equal(resp.Body, []byte("foobar")) {
|
||||
t.Fatalf("Unexpected body %q. Expected %q", resp.Body, "foobar")
|
||||
@@ -278,7 +278,7 @@ func TestServeConnSingleRequest(t *testing.T) {
|
||||
s := &Server{
|
||||
Handler: func(ctx *ServerCtx) {
|
||||
h := &ctx.Request.Header
|
||||
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Host)))
|
||||
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Get("Host"))))
|
||||
},
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ func TestServeConnMultiRequests(t *testing.T) {
|
||||
s := &Server{
|
||||
Handler: func(ctx *ServerCtx) {
|
||||
h := &ctx.Request.Header
|
||||
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Host)))
|
||||
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI, h.Get("Host"))))
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user