Files
fasthttp/server_test.go
T

1290 lines
32 KiB
Go

package fasthttp
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"testing"
"time"
)
func TestRequestCtxFormValue(t *testing.T) {
var ctx RequestCtx
var req Request
req.SetRequestURI("/foo/bar?baz=123&aaa=bbb")
req.SetBodyString("qqq=port&mmm=sddd")
req.Header.SetContentType("application/x-www-form-urlencoded")
ctx.Init(&req, nil, nil)
v := ctx.FormValue("baz")
if string(v) != "123" {
t.Fatalf("unexpected value %q. Expecting %q", v, "123")
}
v = ctx.FormValue("mmm")
if string(v) != "sddd" {
t.Fatalf("unexpected value %q. Expecting %q", v, "sddd")
}
v = ctx.FormValue("aaaasdfsdf")
if len(v) > 0 {
t.Fatalf("unexpected value for unknown key %q", v)
}
}
func TestRequestCtxUserValue(t *testing.T) {
var ctx RequestCtx
for i := 0; i < 5; i++ {
k := fmt.Sprintf("key-%d", i)
ctx.SetUserValue(k, i)
}
for i := 5; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
ctx.SetUserValueBytes([]byte(k), i)
}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
v := ctx.UserValue(k)
n, ok := v.(int)
if !ok || n != i {
t.Fatalf("unexpected value obtained for key %q: %v. Expecting %d", k, v, i)
}
}
}
func TestServerHeadRequest(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
fmt.Fprintf(ctx, "Request method is %q", ctx.Method())
ctx.SetContentType("aaa/bbb")
},
}
rw := &readWriter{}
rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.SkipBody = true
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if resp.Header.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK)
}
if len(resp.Body()) > 0 {
t.Fatalf("Unexpected non-zero body %q", resp.Body())
}
if resp.Header.ContentLength() != 24 {
t.Fatalf("unexpected content-length %d. Expecting %d", resp.Header.ContentLength(), 24)
}
if string(resp.Header.ContentType()) != "aaa/bbb" {
t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
func TestServerExpect100Continue(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsPost() {
t.Fatalf("unexpected method %q. Expecting POST", ctx.Method())
}
if string(ctx.Path()) != "/foo" {
t.Fatalf("unexpected path %q. Expecting %q", ctx.Path(), "/foo")
}
ct := ctx.Request.Header.ContentType()
if string(ct) != "a/b" {
t.Fatalf("unexpectected content-type: %q. Expecting %q", ct, "a/b")
}
if string(ctx.PostBody()) != "12345" {
t.Fatalf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345")
}
ctx.WriteString("foobar")
},
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
}
}
func TestCompressHandler(t *testing.T) {
expectedBody := "foo/bar/baz"
h := CompressHandler(func(ctx *RequestCtx) {
ctx.Write([]byte(expectedBody))
})
var ctx RequestCtx
var resp Response
// verify uncompressed response
h(&ctx)
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce := resp.Header.Peek("Content-Encoding")
if string(ce) != "" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
body := resp.Body()
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify gzip-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek("Content-Encoding")
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
// verify deflate-compressed response
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set("Accept-Encoding", "foobar, deflate, sdhc")
h(&ctx)
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("unexpected error: %s", err)
}
ce = resp.Header.Peek("Content-Encoding")
if string(ce) != "deflate" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate")
}
body, err = resp.BodyInflate()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
}
}
func TestRequestCtxWriteString(t *testing.T) {
var ctx RequestCtx
n, err := ctx.WriteString("foo")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != 3 {
t.Fatalf("unexpected n %d. Expecting 3", n)
}
n, err = ctx.WriteString("привет")
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != 12 {
t.Fatalf("unexpected n=%d. Expecting 12", n)
}
s := ctx.Response.Body()
if string(s) != "fooпривет" {
t.Fatalf("unexpected response body %q. Expecting %q", s, "fooпривет")
}
}
func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.0\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /must/be/ignored HTTP/1.0\r\nHost: google.com\r\n\r\n")
requestsServed := 0
ch := make(chan struct{})
go func() {
err := ServeConn(rw, func(ctx *RequestCtx) {
requestsServed++
ctx.SuccessString("aaa/bbb", "foobar")
})
if err != nil {
t.Fatalf("unexpected error in ServeConn: %s", err)
}
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
// verify the first response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if string(resp.Header.Peek("Connection")) != "keep-alive" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek("Connection"), "keep-alive")
}
if resp.Header.ConnectionClose() {
t.Fatalf("unexpected Connection: close")
}
// verify the second response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if string(resp.Header.Peek("Connection")) != "close" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek("Connection"), "close")
}
if !resp.Header.ConnectionClose() {
t.Fatalf("expecting Connection: close")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after responses %q", data)
}
if requestsServed != 2 {
t.Fatalf("unexpected number of requests served: %d. Expecting 2", requestsServed)
}
}
func TestRequestCtxSetBodyStreamWriter(t *testing.T) {
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "body writer line 1\n")
if err := w.Flush(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
fmt.Fprintf(w, "body writer line 2\n")
})
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Error when reading response: %s", err)
}
body := string(resp.Body())
expectedBody := "body writer line 1\nbody writer line 2\n"
if body != expectedBody {
t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody)
}
}
func TestRequestCtxIfModifiedSince(t *testing.T) {
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
lastModified := time.Now().Add(-time.Hour)
if !ctx.IfModifiedSince(lastModified) {
t.Fatalf("IfModifiedSince must return true for non-existing If-Modified-Since header")
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
if ctx.IfModifiedSince(lastModified) {
t.Fatalf("If-Modified-Since current time must return false")
}
past := lastModified.Add(-time.Hour)
if ctx.IfModifiedSince(past) {
t.Fatalf("If-Modified-Since past time must return false")
}
future := lastModified.Add(time.Hour)
if !ctx.IfModifiedSince(future) {
t.Fatalf("If-Modified-Since future time must return true")
}
}
func TestRequestCtxSendFileNotModified(t *testing.T) {
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
lastModified, err := fsLastModified(filePath)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
if err := ctx.SendFile(filePath); err != nil {
t.Fatalf("error in SendFile: %s", err)
}
s := ctx.Response.String()
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusNotModified {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified)
}
if len(resp.Body()) > 0 {
t.Fatalf("unexpected non-zero response body: %q", resp.Body())
}
}
func TestRequestCtxSendFileModified(t *testing.T) {
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
lastModified, err := fsLastModified(filePath)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
lastModified = lastModified.Add(-time.Hour)
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
if err := ctx.SendFile(filePath); err != nil {
t.Fatalf("error in SendFile: %s", err)
}
s := ctx.Response.String()
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
f, err := os.Open(filePath)
if err != nil {
t.Fatalf("cannot open file: %s", err)
}
body, err := ioutil.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("error when reading file: %s", err)
}
if !bytes.Equal(resp.Body(), body) {
t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body)
}
}
func TestRequestCtxSendFile(t *testing.T) {
var ctx RequestCtx
var req Request
ctx.Init(&req, nil, defaultLogger)
filePath := "./server_test.go"
if err := ctx.SendFile(filePath); err != nil {
t.Fatalf("error in SendFile: %s", err)
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := ctx.Response.Write(bw); err != nil {
t.Fatalf("error when writing response: %s", err)
}
if err := bw.Flush(); err != nil {
t.Fatalf("error when flushing response: %s", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
t.Fatalf("error when reading response: %s", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
f, err := os.Open(filePath)
if err != nil {
t.Fatalf("cannot open file: %s", err)
}
body, err := ioutil.ReadAll(f)
f.Close()
if err != nil {
t.Fatalf("error when reading file: %s", err)
}
if !bytes.Equal(resp.Body(), body) {
t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body)
}
}
func TestRequestCtxHijack(t *testing.T) {
hijackStartCh := make(chan struct{})
hijackStopCh := make(chan struct{})
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Hijack(func(c net.Conn) {
<-hijackStartCh
b := make([]byte, 1)
// ping-pong echo via hijacked conn
for {
n, err := c.Read(b)
if n != 1 {
if err == io.EOF {
close(hijackStopCh)
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
t.Fatalf("unexpected number of bytes read: %d. Expecting 1", n)
}
if _, err = c.Write(b); err != nil {
t.Fatalf("unexpected error when writing data: %s", err)
}
}
})
ctx.Success("foo/bar", []byte("hijack it!"))
},
}
hijackedString := "foobar baz hijacked!!!"
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString(hijackedString)
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
close(hijackStartCh)
select {
case <-hijackStopCh:
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if string(data) != hijackedString {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString)
}
}
func TestRequestCtxInit(t *testing.T) {
var ctx RequestCtx
var logger customLogger
globalCtxID = 0x123456
ctx.Init(&ctx.Request, zeroTCPAddr, &logger)
ip := ctx.RemoteIP()
if !ip.IsUnspecified() {
t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip)
}
ctx.Logger().Printf("foo bar %d", 10)
expectedLog := "#0012345700000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:/// - foo bar 10\n"
if logger.out != expectedLog {
t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog)
}
}
func TestTimeoutHandlerSuccess(t *testing.T) {
h := func(ctx *RequestCtx) {
ctx.Success("aaa/bbb", []byte("real response"))
}
s := &Server{
Handler: TimeoutHandler(h, 100*time.Millisecond, "timeout!!!"),
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
func TestServerGetOnly(t *testing.T) {
h := func(ctx *RequestCtx) {
if !ctx.IsGet() {
t.Fatalf("non-get request: %q", ctx.Method())
}
ctx.Success("foo/bar", []byte("success"))
}
s := &Server{
Handler: h,
GetOnly: true,
}
rw := &readWriter{}
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err == nil {
t.Fatalf("expecting error")
}
if err != errGetOnly {
t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, errGetOnly)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
resp := rw.w.Bytes()
if len(resp) > 0 {
t.Fatalf("unexpected response %q. Expecting zero", resp)
}
}
func TestTimeoutHandlerTimeout(t *testing.T) {
h := func(ctx *RequestCtx) {
time.Sleep(time.Second)
ctx.Success("aaa/bbb", []byte("this shouldn't pass to client because of timeout"))
}
s := &Server{
Handler: TimeoutHandler(h, 10*time.Millisecond, "timeout!!!"),
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
}
func TestServerTimeoutErrorWithResponse(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
var resp Response
resp.SetStatusCode(123)
resp.SetBodyString("foobar. Should be ignored")
ctx.TimeoutErrorWithResponse(&resp)
resp.SetStatusCode(456)
resp.ResetBody()
fmt.Fprintf(resp.BodyWriter(), "path=%s", ctx.Path())
resp.Header.SetContentType("foo/bar")
ctx.TimeoutErrorWithResponse(&resp)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 456, "foo/bar", "path=/foo")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerTimeoutErrorWithCode(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
ctx.TimeoutErrorWithCode("should be ignored", 234)
ctx.TimeoutErrorWithCode("stolen ctx", StatusBadRequest)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerTimeoutError(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
go func() {
ctx.Success("aaa/bbb", []byte("xxxyyy"))
}()
ctx.TimeoutError("should be ignored")
ctx.TimeoutError("stolen ctx")
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerMaxKeepaliveDuration(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
time.Sleep(20 * time.Millisecond)
},
MaxKeepaliveDuration: 10 * time.Millisecond,
}
rw := &readWriter{}
rw.r.WriteString("GET /aaa HTTP/1.1\r\nHost: aa.com\r\n\r\n")
rw.r.WriteString("GET /bbbb HTTP/1.1\r\nHost: bbb.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatalf("Response must have 'connection: close' header")
}
verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType))
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerMaxRequestsPerConn(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {},
MaxRequestsPerConn: 1,
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatalf("Response must have 'connection: close' header")
}
verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType))
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerConnectionClose(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.SetConnectionClose()
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !resp.ConnectionClose() {
t.Fatalf("expecting Connection: close header")
}
data, err := ioutil.ReadAll(br)
if err != nil {
t.Fatalf("Unexpected error when reading remaining data: %s", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
}
}
func TestServerRequestNumAndTime(t *testing.T) {
n := uint64(0)
var connT time.Time
s := &Server{
Handler: func(ctx *RequestCtx) {
n++
if ctx.ConnRequestNum() != n {
t.Fatalf("unexpected request number: %d. Expecting %d", ctx.ConnRequestNum(), n)
}
if connT.IsZero() {
connT = ctx.ConnTime()
}
if ctx.ConnTime() != connT {
t.Fatalf("unexpected serve conn time: %s. Expecting %s", ctx.ConnTime(), connT)
}
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
if n != 3 {
t.Fatalf("unexpected number of requests served: %d. Expecting %d", n, 3)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, string(defaultContentType), "")
}
func TestServerEmptyResponse(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
// do nothing :)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, string(defaultContentType), "")
}
type customLogger struct {
out string
}
func (cl *customLogger) Printf(format string, args ...interface{}) {
cl.out += fmt.Sprintf(format, args...)[6:] + "\n"
}
func TestServerLogger(t *testing.T) {
cl := &customLogger{}
s := &Server{
Handler: func(ctx *RequestCtx) {
logger := ctx.Logger()
h := &ctx.Request.Header
logger.Printf("begin")
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, body=%q, remoteAddr=%s",
h.RequestURI(), ctx.Request.Body(), ctx.RemoteAddr())))
logger.Printf("end")
},
Logger: cl,
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rw.r.WriteString("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 5\r\nContent-Type: aa\r\n\r\nabcde")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
globalCtxID = 0
ch := make(chan error)
go func() {
ch <- s.ServeConn(rwx)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, body=\"\", remoteAddr=1.2.3.4:8765")
verifyResponse(t, br, 200, "text/html", "requestURI=/foo2, body=\"abcde\", remoteAddr=1.2.3.4:8765")
expectedLogOut := `#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - begin
#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - end
#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - begin
#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - end
`
if cl.out != expectedLogOut {
t.Fatalf("Unexpected logger output: %q. Expected %q", cl.out, expectedLogOut)
}
}
func TestServerRemoteAddr(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(rwx)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4")
}
type readWriterRemoteAddr struct {
net.Conn
rw io.ReadWriteCloser
addr net.Addr
}
func (rw *readWriterRemoteAddr) Close() error {
return rw.rw.Close()
}
func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) {
return rw.rw.Read(b)
}
func (rw *readWriterRemoteAddr) Write(b []byte) (int, error) {
return rw.rw.Write(b)
}
func (rw *readWriterRemoteAddr) RemoteAddr() net.Addr {
return rw.addr
}
func (rw *readWriterRemoteAddr) LocalAddr() net.Addr {
return rw.addr
}
func TestServerConnError(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Error("foobar", 423)
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %s", err)
}
if resp.Header.StatusCode() != 423 {
t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423)
}
if resp.Header.ContentLength() != 6 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), 6)
}
if !bytes.Equal(resp.Header.Peek("Content-Type"), defaultContentType) {
t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Peek("Content-Type"), defaultContentType)
}
if !bytes.Equal(resp.Body(), []byte("foobar")) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}
}
func TestServeConnSingleRequest(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek("Host"))))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
}
func TestServeConnMultiRequests(t *testing.T) {
s := &Server{
Handler: func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek("Host"))))
},
}
rw := &readWriter{}
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n")
ch := make(chan error)
go func() {
ch <- s.ServeConn(rw)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
verifyResponse(t, br, 200, "aaa", "requestURI=/abc, host=foobar.com")
}
func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) {
var resp Response
if err := resp.Read(r); err != nil {
t.Fatalf("Unexpected error when parsing response: %s", err)
}
if !bytes.Equal(resp.Body(), []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType)
}
type readWriter struct {
net.Conn
r bytes.Buffer
w bytes.Buffer
}
func (rw *readWriter) Close() error {
return nil
}
func (rw *readWriter) Read(b []byte) (int, error) {
return rw.r.Read(b)
}
func (rw *readWriter) Write(b []byte) (int, error) {
return rw.w.Write(b)
}
func (rw *readWriter) RemoteAddr() net.Addr {
return zeroTCPAddr
}
func (rw *readWriter) LocalAddr() net.Addr {
return zeroTCPAddr
}
func (rw *readWriter) SetReadDeadline(t time.Time) error {
return nil
}
func (rw *readWriter) SetWriteDeadline(t time.Time) error {
return nil
}