mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
3eab931bf1
Prevent request and response first-line setters from serializing embedded CR or LF bytes into the start line. Route SetMethod, SetRequestURI, SetProtocol, and SetStatusMessage through the existing newline sanitization used by other header-value setters. This preserves behavior for valid inputs while preventing header injection through malformed first-line values. Thanks to @vnykmshr for reporting this issue.
3839 lines
85 KiB
Go
3839 lines
85 KiB
Go
package fasthttp
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/valyala/fasthttp/fasthttputil"
|
|
)
|
|
|
|
func TestCloseIdleConnections(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
if _, _, err := c.Get(nil, "http://google.com"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
connsLen := func() int {
|
|
c.mLock.Lock()
|
|
defer c.mLock.Unlock()
|
|
|
|
if _, ok := c.m["google.com"]; !ok {
|
|
return 0
|
|
}
|
|
|
|
c.m["google.com"].connsLock.Lock()
|
|
defer c.m["google.com"].connsLock.Unlock()
|
|
|
|
return len(c.m["google.com"].conns)
|
|
}
|
|
|
|
if conns := connsLen(); conns > 1 {
|
|
t.Errorf("expected 1 conns got %d", conns)
|
|
}
|
|
|
|
c.CloseIdleConnections()
|
|
|
|
if conns := connsLen(); conns > 0 {
|
|
t.Errorf("expected 0 conns got %d", conns)
|
|
}
|
|
}
|
|
|
|
func TestPipelineClientSetUserAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientSetUserAgent(t, 0)
|
|
}
|
|
|
|
func TestPipelineClientSetUserAgentTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientSetUserAgent(t, time.Second)
|
|
}
|
|
|
|
func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
userAgent := "I'm not fasthttp"
|
|
c := &HostClient{
|
|
Name: userAgent,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
var err error
|
|
if timeout <= 0 {
|
|
err = c.Do(req, res)
|
|
} else {
|
|
err = c.DoTimeout(req, res, timeout)
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != userAgent {
|
|
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
|
|
}
|
|
}
|
|
|
|
func TestHostClientNegativeTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &HostClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
ln.Close()
|
|
}
|
|
|
|
func TestDoDeadlineRetry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var tries atomic.Int32
|
|
done := make(chan struct{})
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
go func() {
|
|
for {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
close(done)
|
|
break
|
|
}
|
|
tries.Add(1)
|
|
br := bufio.NewReader(c)
|
|
(&RequestHeader{}).Read(br) //nolint:errcheck
|
|
(&Request{}).readBodyStream(br, 0, false, false) //nolint:errcheck
|
|
if tries.Load() == 1 {
|
|
time.Sleep(time.Millisecond * 10)
|
|
} else {
|
|
time.Sleep(time.Millisecond * 200)
|
|
}
|
|
c.Close()
|
|
}
|
|
}()
|
|
c := &HostClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*200)); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
ln.Close()
|
|
<-done
|
|
if tr := tries.Load(); tr != 2 {
|
|
t.Fatalf("expected 2 tries got %d", tr)
|
|
}
|
|
}
|
|
|
|
func TestPipelineClientIssue832(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
req := AcquireRequest()
|
|
// Don't defer ReleaseRequest as we use it in a goroutine that might not be done at the end.
|
|
|
|
req.SetHost("example.com")
|
|
|
|
res := AcquireResponse()
|
|
// Don't defer ReleaseResponse as we use it in a goroutine that might not be done at the end.
|
|
|
|
client := PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ReadTimeout: time.Millisecond * 10,
|
|
Logger: &testLogger{}, // Ignore log output.
|
|
}
|
|
|
|
attempts := 10
|
|
go func() {
|
|
for range attempts {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if c != nil {
|
|
go func() {
|
|
time.Sleep(time.Millisecond * 50)
|
|
c.Close()
|
|
}()
|
|
}
|
|
}
|
|
}()
|
|
|
|
done := make(chan int)
|
|
go func() {
|
|
defer close(done)
|
|
|
|
for range attempts {
|
|
if err := client.Do(req, res); err == nil {
|
|
t.Error("error expected")
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-time.After(time.Second * 2):
|
|
t.Fatal("PipelineClient did not restart worker")
|
|
case <-done:
|
|
}
|
|
}
|
|
|
|
func TestClientInvalidURI(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
var requests atomic.Int64
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
requests.Add(1)
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n")
|
|
err := c.Do(req, res)
|
|
if err == nil && res.StatusCode() != StatusBadRequest {
|
|
t.Fatalf("expected invalid URI to be rejected, got status code %d", res.StatusCode())
|
|
}
|
|
if n := requests.Load(); n != 0 {
|
|
t.Fatalf("0 requests expected, got %d", n)
|
|
}
|
|
}
|
|
|
|
func TestClientRequestProtocolSetterSanitizesNewlines(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
var requests atomic.Int64
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
requests.Add(1)
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
|
|
req.SetRequestURI("http://example.com/")
|
|
req.Header.SetProtocol("HTTP/1.1\r\nX-Injected-Protocol: true")
|
|
|
|
if err := c.Do(req, res); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if got := res.StatusCode(); got != StatusBadRequest {
|
|
t.Fatalf("unexpected status code: %d. Expected %d", got, StatusBadRequest)
|
|
}
|
|
if n := requests.Load(); n != 0 {
|
|
t.Fatalf("expected malformed request to be rejected before reaching handler, got %d handled requests", n)
|
|
}
|
|
}
|
|
|
|
func TestClientResponseStatusMessageSetterSanitizesNewlines(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Response.Header.SetStatusCode(StatusOK)
|
|
ctx.Response.Header.SetStatusMessage([]byte("OK\r\nX-Injected-Status: true"))
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
|
|
req.SetRequestURI("http://example.com/")
|
|
|
|
if err := c.Do(req, res); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if got := string(res.Header.StatusMessage()); got != "OK X-Injected-Status: true" {
|
|
t.Fatalf("unexpected status message: %q. Expected %q", got, "OK X-Injected-Status: true")
|
|
}
|
|
if got := string(res.Header.Peek("X-Injected-Status")); got != "" {
|
|
t.Fatalf("unexpected injected response header value: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestClientGetWithBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
body := ctx.Request.Body()
|
|
ctx.Write(body) //nolint:errcheck
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
req.SetBodyString("test")
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(res.Body()) == 0 {
|
|
t.Fatal("missing request body")
|
|
}
|
|
}
|
|
|
|
func TestClientURLAuth(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := map[string]string{
|
|
"user:pass@": "Basic dXNlcjpwYXNz",
|
|
"foo:@": "Basic Zm9vOg==",
|
|
":@": "",
|
|
"@": "",
|
|
"": "",
|
|
}
|
|
|
|
ch := make(chan string, 1)
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ch <- string(ctx.Request.Header.Peek(HeaderAuthorization))
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
for up, expected := range cases {
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://" + up + "example.com/foo/bar")
|
|
if err := c.Do(req, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
val := <-ch
|
|
|
|
if val != expected {
|
|
t.Fatalf("wrong %q header: %q expected %q", HeaderAuthorization, val, expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientNilResp(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.Do(req, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DoTimeout(req, nil, time.Second); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ln.Close()
|
|
}
|
|
|
|
func TestClientNegativeTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
ln.Close()
|
|
}
|
|
|
|
func TestPipelineClientNilResp(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.Do(req, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DoTimeout(req, nil, time.Second); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestClientParseConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
network := "tcp"
|
|
ln, _ := net.Listen(network, "127.0.0.1:0")
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
host := ln.Addr().String()
|
|
c := &Client{}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
req.SetRequestURI("http://" + host + "")
|
|
if err := c.Do(req, res); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if res.RemoteAddr().Network() != network {
|
|
t.Fatalf("req RemoteAddr parse network fail: %q, hope: %q", res.RemoteAddr().Network(), network)
|
|
}
|
|
if host != res.RemoteAddr().String() {
|
|
t.Fatalf("req RemoteAddr parse addr fail: %q, hope: %q", res.RemoteAddr().String(), host)
|
|
}
|
|
|
|
if !regexp.MustCompile(`^127\.0\.0\.1:\d{4,5}$`).MatchString(res.LocalAddr().String()) {
|
|
t.Fatalf("res LocalAddr addr match fail: %q, hope match: %q", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
|
|
}
|
|
}
|
|
|
|
func TestClientPostArgs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
body := ctx.Request.Body()
|
|
if len(body) == 0 {
|
|
return
|
|
}
|
|
ctx.Write(body) //nolint:errcheck
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
args := req.PostArgs()
|
|
args.Add("addhttp2", "support")
|
|
args.Add("fast", "http")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetRequestURI("http://make.fasthttp.great?again")
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(res.Body()) == 0 {
|
|
t.Fatal("cannot set args as body")
|
|
}
|
|
}
|
|
|
|
func TestClientRedirectSameSchema(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listenHTTPS1 := testClientRedirectListener(t, true)
|
|
defer listenHTTPS1.Close()
|
|
|
|
listenHTTPS2 := testClientRedirectListener(t, true)
|
|
defer listenHTTPS2.Close()
|
|
|
|
sHTTPS1 := testClientRedirectChangingSchemaServer(t, listenHTTPS1, listenHTTPS1, true)
|
|
defer sHTTPS1.Stop()
|
|
|
|
sHTTPS2 := testClientRedirectChangingSchemaServer(t, listenHTTPS2, listenHTTPS2, false)
|
|
defer sHTTPS2.Stop()
|
|
|
|
destURL := fmt.Sprintf("https://%s/baz", listenHTTPS1.Addr().String())
|
|
|
|
urlParsed, err := url.Parse(destURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return
|
|
}
|
|
|
|
reqClient := &HostClient{
|
|
IsTLS: true,
|
|
Addr: urlParsed.Host,
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
|
|
if err != nil {
|
|
t.Fatalf("HostClient error: %v", err)
|
|
return
|
|
}
|
|
|
|
if statusCode != 200 {
|
|
t.Fatalf("HostClient error code response %d", statusCode)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listenHTTPS := testClientRedirectListener(t, true)
|
|
defer listenHTTPS.Close()
|
|
|
|
listenHTTP := testClientRedirectListener(t, false)
|
|
defer listenHTTP.Close()
|
|
|
|
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
|
|
defer sHTTPS.Stop()
|
|
|
|
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
|
|
defer sHTTP.Stop()
|
|
|
|
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
|
|
|
|
reqClient := &Client{
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
|
|
if err != nil {
|
|
t.Fatalf("HostClient error: %v", err)
|
|
return
|
|
}
|
|
|
|
if statusCode != 200 {
|
|
t.Fatalf("HostClient error code response %d", statusCode)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listenHTTPS := testClientRedirectListener(t, true)
|
|
defer listenHTTPS.Close()
|
|
|
|
listenHTTP := testClientRedirectListener(t, false)
|
|
defer listenHTTP.Close()
|
|
|
|
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
|
|
defer sHTTPS.Stop()
|
|
|
|
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
|
|
defer sHTTP.Stop()
|
|
|
|
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
|
|
|
|
urlParsed, err := url.Parse(destURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return
|
|
}
|
|
|
|
reqClient := &HostClient{
|
|
Addr: urlParsed.Host,
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
|
|
if err != ErrHostClientRedirectToDifferentScheme {
|
|
t.Fatal("expected HostClient error")
|
|
}
|
|
}
|
|
|
|
func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener {
|
|
var ln net.Listener
|
|
var err error
|
|
var tlsConfig *tls.Config
|
|
|
|
if isTLS {
|
|
certData, keyData, kerr := GenerateTestCertificate("localhost")
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
cert, kerr := tls.X509KeyPair(certData, keyData)
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
tlsConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
ln, err = tls.Listen("tcp", "localhost:0", tlsConfig)
|
|
} else {
|
|
ln, err = net.Listen("tcp", "localhost:0")
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("cannot listen isTLS %v: %v", isTLS, err)
|
|
}
|
|
|
|
return ln
|
|
}
|
|
|
|
func testClientRedirectChangingSchemaServer(t *testing.T, https, http net.Listener, isTLS bool) *testEchoServer {
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if ctx.IsTLS() {
|
|
ctx.SetStatusCode(200)
|
|
} else {
|
|
ctx.Redirect(fmt.Sprintf("https://%s/baz", https.Addr().String()), 301)
|
|
}
|
|
},
|
|
}
|
|
|
|
var ln net.Listener
|
|
if isTLS {
|
|
ln = https
|
|
} else {
|
|
ln = http
|
|
}
|
|
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
err := s.Serve(ln)
|
|
if err != nil {
|
|
t.Errorf("unexpected error returned from Serve(): %v", err)
|
|
}
|
|
close(ch)
|
|
}()
|
|
return &testEchoServer{
|
|
s: s,
|
|
ln: ln,
|
|
ch: ch,
|
|
t: t,
|
|
}
|
|
}
|
|
|
|
func TestClientHeaderCase(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
defer ln.Close()
|
|
|
|
go func() {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
c.Write([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck
|
|
"content-type: text/plain\r\n" +
|
|
"transfer-encoding: chunked\r\n\r\n" +
|
|
"24\r\nThis is the data in the first chunk \r\n" +
|
|
"1B\r\nand this is the second one \r\n" +
|
|
"0\r\n\r\n",
|
|
))
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ReadTimeout: time.Millisecond * 10,
|
|
|
|
// Even without name normalizing we should parse headers correctly.
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
code, body, err := c.Get(nil, "http://example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if code != 200 {
|
|
t.Errorf("expected status code 200 got %d", code)
|
|
}
|
|
if string(body) != "This is the data in the first chunk and this is the second one " {
|
|
t.Errorf("wrong body: %q", body)
|
|
}
|
|
}
|
|
|
|
func TestClientReadTimeout(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.SkipNow()
|
|
}
|
|
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
timeout := false
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
if timeout {
|
|
time.Sleep(time.Second)
|
|
} else {
|
|
timeout = true
|
|
}
|
|
},
|
|
Logger: &testLogger{}, // Don't print closed pipe errors.
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &HostClient{
|
|
ReadTimeout: time.Millisecond * 400,
|
|
MaxIdemponentCallAttempts: 1,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://localhost")
|
|
|
|
// Setting Connection: Close will make the connection be
|
|
// returned to the pool.
|
|
req.SetConnectionClose()
|
|
|
|
if err := c.Do(req, res); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://localhost")
|
|
req.SetConnectionClose()
|
|
|
|
if err := c.Do(req, res); err != ErrTimeout {
|
|
t.Errorf("expected ErrTimeout got %#v", err)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// This shouldn't take longer than the timeout times the number of requests it is going to try to do.
|
|
// Give it an extra second just to be sure.
|
|
case <-time.After(c.ReadTimeout*time.Duration(c.MaxIdemponentCallAttempts) + time.Second):
|
|
t.Fatal("Client.ReadTimeout didn't work")
|
|
}
|
|
}
|
|
|
|
func TestClientDefaultUserAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != defaultUserAgent {
|
|
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent)
|
|
}
|
|
}
|
|
|
|
func TestClientSetUserAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
userAgent := "I'm not fasthttp"
|
|
c := &Client{
|
|
Name: userAgent,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != userAgent {
|
|
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
|
|
}
|
|
}
|
|
|
|
func TestClientNoUserAgent(t *testing.T) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &Client{
|
|
NoDefaultUserAgentHeader: true,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != "" {
|
|
t.Fatalf("User-Agent wrong %q != %q", userAgentSeen, "")
|
|
}
|
|
}
|
|
|
|
func TestClientDoWithCustomHeaders(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// make sure that the client sends all the request headers and body.
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
uri := "/foo/bar/baz?a=b&cd=12"
|
|
headers := map[string]string{
|
|
"Foo": "bar",
|
|
"Host": "example.com",
|
|
"Content-Type": "asdfsdf",
|
|
"a-b-c-d-f": "",
|
|
}
|
|
body := "request body"
|
|
|
|
ch := make(chan error)
|
|
go func() {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
ch <- fmt.Errorf("cannot accept client connection: %w", err)
|
|
return
|
|
}
|
|
br := bufio.NewReader(conn)
|
|
|
|
var req Request
|
|
if err = req.Read(br); err != nil {
|
|
ch <- fmt.Errorf("cannot read client request: %w", err)
|
|
return
|
|
}
|
|
if string(req.Header.Method()) != MethodPost {
|
|
ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", req.Header.Method(), MethodPost)
|
|
return
|
|
}
|
|
reqURI := req.RequestURI()
|
|
if string(reqURI) != uri {
|
|
ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri)
|
|
return
|
|
}
|
|
for k, v := range headers {
|
|
hv := req.Header.Peek(k)
|
|
if string(hv) != v {
|
|
ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v)
|
|
return
|
|
}
|
|
}
|
|
cl := req.Header.ContentLength()
|
|
if cl != len(body) {
|
|
ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body))
|
|
return
|
|
}
|
|
reqBody := req.Body()
|
|
if string(reqBody) != body {
|
|
ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body)
|
|
return
|
|
}
|
|
|
|
var resp Response
|
|
bw := bufio.NewWriter(conn)
|
|
if err = resp.Write(bw); err != nil {
|
|
ch <- fmt.Errorf("cannot send response: %w", err)
|
|
return
|
|
}
|
|
if err = bw.Flush(); err != nil {
|
|
ch <- fmt.Errorf("cannot flush response: %w", err)
|
|
return
|
|
}
|
|
|
|
ch <- nil
|
|
}()
|
|
|
|
var req Request
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetRequestURI(uri)
|
|
for k, v := range headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.SetBodyString(body)
|
|
|
|
var resp Response
|
|
|
|
err := c.DoTimeout(&req, &resp, time.Second)
|
|
if err != nil {
|
|
t.Fatalf("error when doing request: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-ch:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestPipelineClientDoSerial(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 1, 0, 0)
|
|
}
|
|
|
|
func TestPipelineClientDoConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 10, 0, 1)
|
|
}
|
|
|
|
func TestPipelineClientDoBatchDelayConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 1)
|
|
}
|
|
|
|
func TestPipelineClientDoBatchDelayConcurrentMultiConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 3)
|
|
}
|
|
|
|
func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration, maxConns int) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.WriteString("OK") //nolint:errcheck
|
|
},
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: maxConns,
|
|
MaxPendingRequests: concurrency,
|
|
MaxBatchDelay: maxBatchDelay,
|
|
Logger: &testLogger{},
|
|
}
|
|
|
|
clientStopCh := make(chan struct{}, concurrency)
|
|
for range concurrency {
|
|
go func() {
|
|
testPipelineClientDo(t, c)
|
|
clientStopCh <- struct{}{}
|
|
}()
|
|
}
|
|
|
|
for range concurrency {
|
|
select {
|
|
case <-clientStopCh:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
if c.PendingRequests() != 0 {
|
|
t.Fatalf("unexpected number of pending requests: %d. Expecting zero", c.PendingRequests())
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func testPipelineClientDo(t *testing.T, c *PipelineClient) {
|
|
var err error
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
resp := AcquireResponse()
|
|
for i := range 10 {
|
|
if i&1 == 0 {
|
|
err = c.DoTimeout(req, resp, time.Second)
|
|
} else {
|
|
err = c.Do(req, resp)
|
|
}
|
|
if err != nil {
|
|
if err == ErrPipelineOverflow {
|
|
time.Sleep(10 * time.Millisecond)
|
|
continue
|
|
}
|
|
t.Errorf("unexpected error on iteration %d: %v", i, err)
|
|
}
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
body := string(resp.Body())
|
|
if body != "OK" {
|
|
t.Errorf("unexpected body: %q. Expecting %q", body, "OK")
|
|
}
|
|
|
|
// sleep for a while, so the connection to the host may expire.
|
|
if i%5 == 0 {
|
|
time.Sleep(30 * time.Millisecond)
|
|
}
|
|
}
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
func TestPipelineClientDoDisableHeaderNamesNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDisableHeaderNamesNormalizing(t, 0)
|
|
}
|
|
|
|
func TestPipelineClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDisableHeaderNamesNormalizing(t, time.Second)
|
|
}
|
|
|
|
func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.Duration) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Response.Header.Set("foo-BAR", "baz")
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
var req Request
|
|
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
|
|
var resp Response
|
|
for range 5 {
|
|
if timeout > 0 {
|
|
if err := c.DoTimeout(&req, &resp, timeout); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
} else {
|
|
if err := c.Do(&req, &resp); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
hv := resp.Header.Peek("foo-BAR")
|
|
if string(hv) != "baz" {
|
|
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
|
|
}
|
|
hv = resp.Header.Peek("Foo-Bar")
|
|
if len(hv) > 0 {
|
|
t.Fatalf("unexpected non-empty header value %q", hv)
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Response.Header.Set("foo-BAR", "baz")
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
var req Request
|
|
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
|
|
var resp Response
|
|
for range 5 {
|
|
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
hv := resp.Header.Peek("foo-BAR")
|
|
if string(hv) != "baz" {
|
|
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
|
|
}
|
|
hv = resp.Header.Peek("Foo-Bar")
|
|
if len(hv) > 0 {
|
|
t.Fatalf("unexpected non-empty header value %q", hv)
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
uri := ctx.URI()
|
|
uri.DisablePathNormalizing = true
|
|
ctx.Response.Header.Set("received-uri", string(uri.FullURI()))
|
|
},
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
DisablePathNormalizing: true,
|
|
}
|
|
|
|
urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff"
|
|
|
|
var req Request
|
|
req.SetRequestURI(urlWithEncodedPath)
|
|
var resp Response
|
|
for range 5 {
|
|
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
hv := resp.Header.Peek("received-uri")
|
|
if string(hv) != urlWithEncodedPath {
|
|
t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath)
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestHostClientPendingRequests(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const concurrency = 10
|
|
doneCh := make(chan struct{})
|
|
readyCh := make(chan struct{}, concurrency)
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
readyCh <- struct{}{}
|
|
<-doneCh
|
|
},
|
|
}
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
pendingRequests := c.PendingRequests()
|
|
if pendingRequests != 0 {
|
|
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
|
|
}
|
|
|
|
resultCh := make(chan error, concurrency)
|
|
for range concurrency {
|
|
go func() {
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.DoTimeout(req, resp, 10*time.Second); err != nil {
|
|
resultCh <- fmt.Errorf("unexpected error: %w", err)
|
|
return
|
|
}
|
|
|
|
if resp.StatusCode() != StatusOK {
|
|
resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
return
|
|
}
|
|
resultCh <- nil
|
|
}()
|
|
}
|
|
|
|
// wait while all the requests reach server
|
|
for range concurrency {
|
|
select {
|
|
case <-readyCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
pendingRequests = c.PendingRequests()
|
|
if pendingRequests != concurrency {
|
|
t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency)
|
|
}
|
|
|
|
// unblock request handlers on the server and wait until all the requests are finished.
|
|
close(doneCh)
|
|
for range concurrency {
|
|
select {
|
|
case err := <-resultCh:
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
pendingRequests = c.PendingRequests()
|
|
if pendingRequests != 0 {
|
|
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
|
|
}
|
|
|
|
// stop the server
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnsWithDeadline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
timeout = 200 * time.Millisecond
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
}
|
|
|
|
for range 5 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
for {
|
|
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
|
|
if err == ErrNoFreeConns {
|
|
time.Sleep(time.Millisecond)
|
|
continue
|
|
}
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
break
|
|
}
|
|
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnDuration(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
var connectionCloseCount atomic.Uint32
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.WriteString("abcd") //nolint:errcheck
|
|
if ctx.Request.ConnectionClose() {
|
|
connectionCloseCount.Add(1)
|
|
}
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConnDuration: 10 * time.Millisecond,
|
|
}
|
|
|
|
for range 5 {
|
|
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
if string(body) != "abcd" {
|
|
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
time.Sleep(c.MaxConnDuration)
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if connectionCloseCount.Load() == 0 {
|
|
t.Fatalf("expecting at least one 'Connection: close' request header")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMultipleAddrs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Write(ctx.Host()) //nolint:errcheck
|
|
ctx.SetConnectionClose()
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
dialsCount := make(map[string]int)
|
|
c := &HostClient{
|
|
Addr: "foo,bar,baz",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
dialsCount[addr]++
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
for range 9 {
|
|
statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
if string(body) != "foobar" {
|
|
t.Fatalf("unexpected body %q. Expecting %q", body, "foobar")
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if len(dialsCount) != 3 {
|
|
t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount))
|
|
}
|
|
for _, k := range []string{"foo", "bar", "baz"} {
|
|
if dialsCount[k] != 3 {
|
|
t.Fatalf("unexpected dialsCount for %q. Expecting 3", k)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientFollowRedirects(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
switch string(ctx.Path()) {
|
|
case "/foo":
|
|
u := ctx.URI()
|
|
u.Update("/xy?z=wer")
|
|
ctx.Redirect(u.String(), StatusFound)
|
|
case "/xy":
|
|
u := ctx.URI()
|
|
u.Update("/bar")
|
|
ctx.Redirect(u.String(), StatusFound)
|
|
case "/abc/*/123":
|
|
u := ctx.URI()
|
|
u.Update("/xyz/*/456")
|
|
ctx.Redirect(u.String(), StatusFound)
|
|
default:
|
|
ctx.Success("text/plain", ctx.Path())
|
|
}
|
|
},
|
|
}
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "xxx",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
for range 10 {
|
|
statusCode, body, err := c.GetTimeout(nil, "http://xxx/foo", time.Second)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
if string(body) != "/bar" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
|
}
|
|
}
|
|
|
|
for range 10 {
|
|
statusCode, body, err := c.Get(nil, "http://xxx/aaab/sss")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
if string(body) != "/aaab/sss" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
|
|
}
|
|
}
|
|
|
|
for range 10 {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
err := c.DoRedirects(req, resp, 16)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if statusCode := resp.StatusCode(); statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
|
|
if body := string(resp.Body()); body != "/bar" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
for range 10 {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
req.SetTimeout(time.Second)
|
|
err := c.DoRedirects(req, resp, 16)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if statusCode := resp.StatusCode(); statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
|
|
if body := string(resp.Body()); body != "/bar" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
for range 10 {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
testConn, _ := net.Dial("tcp", ln.Addr().String())
|
|
timeoutConn := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
}
|
|
|
|
req.SetTimeout(time.Millisecond)
|
|
err := timeoutConn.DoRedirects(req, resp, 16)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
for range 10 {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/abc/*/123")
|
|
req.URI().DisablePathNormalizing = true
|
|
req.DisableRedirectPathNormalizing = true
|
|
|
|
err := c.DoRedirects(req, resp, 16)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if statusCode := resp.StatusCode(); statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
|
|
if body := string(resp.Body()); body != "/xyz/*/456" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/xyz/*/456")
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
err := c.DoRedirects(req, resp, 0)
|
|
if have, want := err, ErrTooManyRedirects; have != want {
|
|
t.Fatalf("want error: %v, have %v", want, have)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
func TestShouldStripSensitiveHeadersOnRedirect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
initialURL string
|
|
redirectURL string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "same host keeps headers",
|
|
initialURL: "http://example.com/foo",
|
|
redirectURL: "http://example.com/bar",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "subdomain keeps headers",
|
|
initialURL: "http://example.com/foo",
|
|
redirectURL: "https://sub.example.com:8443/bar",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "same host different port keeps headers",
|
|
initialURL: "http://example.com/foo",
|
|
redirectURL: "http://example.com:8080/bar",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "http upgrade keeps headers",
|
|
initialURL: "http://example.com/foo",
|
|
redirectURL: "https://example.com/bar",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "https downgrade keeps headers",
|
|
initialURL: "https://example.com/foo",
|
|
redirectURL: "http://example.com/bar",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "parent domain strips when initial host is subdomain",
|
|
initialURL: "http://sub.example.com/foo",
|
|
redirectURL: "http://example.com/bar",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "unrelated host strips headers",
|
|
initialURL: "http://example.com/foo",
|
|
redirectURL: "http://example.net/bar",
|
|
want: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
initialHost := hostnameFromURLString(tc.initialURL)
|
|
|
|
var redirectURI URI
|
|
redirectURI.Update(tc.redirectURL)
|
|
|
|
if got := shouldStripSensitiveHeadersOnRedirect(initialHost, redirectURI.Host()); got != tc.want {
|
|
t.Fatalf("unexpected redirect stripping decision: got %v, want %v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientGetTimeoutSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientGetTimeoutSuccessConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
for range 10 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientDoTimeoutSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
for range 10 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientGetTimeoutError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
}
|
|
|
|
testClientGetTimeoutError(t, c, 100)
|
|
}
|
|
|
|
func TestClientGetTimeoutErrorConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
MaxConnsPerHost: 1000,
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for range 10 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGetTimeoutError(t, c, 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientDoTimeoutError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
}
|
|
|
|
testClientDoTimeoutError(t, c, 100)
|
|
testClientRequestSetTimeoutError(t, c, 100)
|
|
}
|
|
|
|
func TestClientDoTimeoutErrorConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
MaxConnsPerHost: 1000,
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for range 10 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientDoTimeoutError(t, c, 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func testClientDoTimeoutError(t *testing.T, c *Client, n int) {
|
|
var req Request
|
|
var resp Response
|
|
req.SetRequestURI("http://foobar.com/baz")
|
|
for range n {
|
|
err := c.DoTimeout(&req, &resp, time.Millisecond)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientGetTimeoutError(t *testing.T, c *Client, n int) {
|
|
buf := make([]byte, 10)
|
|
for range n {
|
|
statusCode, _, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
if statusCode != 0 {
|
|
t.Errorf("unexpected statusCode=%d. Expecting %d", statusCode, 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) {
|
|
var req Request
|
|
var resp Response
|
|
req.SetRequestURI("http://foobar.com/baz")
|
|
for range n {
|
|
req.SetTimeout(time.Millisecond)
|
|
err := c.Do(&req, &resp)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
}
|
|
}
|
|
|
|
type readTimeoutConn struct {
|
|
net.Conn
|
|
|
|
wc chan struct{}
|
|
rc chan struct{}
|
|
t time.Duration
|
|
}
|
|
|
|
func (r *readTimeoutConn) Read(p []byte) (int, error) {
|
|
<-r.rc
|
|
return 0, os.ErrDeadlineExceeded
|
|
}
|
|
|
|
func (r *readTimeoutConn) Write(p []byte) (int, error) {
|
|
<-r.wc
|
|
return 0, os.ErrDeadlineExceeded
|
|
}
|
|
|
|
func (r *readTimeoutConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) SetReadDeadline(d time.Time) error {
|
|
r.rc = make(chan struct{}, 1)
|
|
go func() {
|
|
time.Sleep(time.Until(d))
|
|
r.rc <- struct{}{}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) SetWriteDeadline(d time.Time) error {
|
|
r.wc = make(chan struct{}, 1)
|
|
go func() {
|
|
time.Sleep(time.Until(d))
|
|
r.wc <- struct{}{}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1, 2:
|
|
return &readErrorConn{}, nil
|
|
case 3:
|
|
return &singleEchoConn{
|
|
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
},
|
|
}
|
|
|
|
dialsCount = 0
|
|
|
|
req := Request{}
|
|
res := Response{}
|
|
|
|
req.SetRequestURI("http://foobar/a/b")
|
|
req.Header.SetMethod("POST")
|
|
body := bytes.NewBufferString("test")
|
|
req.SetBodyStream(body, body.Len())
|
|
|
|
err := c.Do(&req, &res)
|
|
if err == nil {
|
|
t.Fatal("expected error from being unable to retry a bodyStream")
|
|
}
|
|
}
|
|
|
|
func TestClientIdempotentRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1:
|
|
return &singleReadConn{
|
|
s: "invalid response",
|
|
}, nil
|
|
case 2:
|
|
return &writeErrorConn{}, nil
|
|
case 3:
|
|
return &readErrorConn{}, nil
|
|
case 4:
|
|
return &singleReadConn{
|
|
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
},
|
|
}
|
|
|
|
// idempotent GET must succeed.
|
|
statusCode, body, err := c.Get(nil, "http://foobar/a/b")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
|
|
}
|
|
if string(body) != "0123456" {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
|
|
}
|
|
|
|
var args Args
|
|
|
|
// non-idempotent POST must fail on incorrect singleReadConn
|
|
dialsCount = 0
|
|
_, _, err = c.Post(nil, "http://foobar/a/b", &args)
|
|
if err == nil {
|
|
t.Fatalf("expecting error")
|
|
}
|
|
|
|
// non-idempotent POST must fail on incorrect singleReadConn
|
|
dialsCount = 0
|
|
_, _, err = c.Post(nil, "http://foobar/a/b", nil)
|
|
if err == nil {
|
|
t.Fatalf("expecting error")
|
|
}
|
|
}
|
|
|
|
func TestClientRetryRequestWithCustomDecider(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1:
|
|
return &singleReadConn{
|
|
s: "invalid response",
|
|
}, nil
|
|
case 2:
|
|
return &writeErrorConn{}, nil
|
|
case 3:
|
|
return &readErrorConn{}, nil
|
|
case 4:
|
|
return &singleReadConn{
|
|
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
},
|
|
RetryIf: func(req *Request) bool {
|
|
return req.URI().String() == "http://foobar/a/b"
|
|
},
|
|
}
|
|
|
|
var args Args
|
|
|
|
// Post must succeed for http://foobar/a/b uri.
|
|
statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
|
|
}
|
|
if string(body) != "0123456" {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
|
|
}
|
|
|
|
// POST must fail for http://foobar/a/b/c uri.
|
|
dialsCount = 0
|
|
_, _, err = c.Post(nil, "http://foobar/a/b/c", &args)
|
|
if err == nil {
|
|
t.Fatalf("expecting error")
|
|
}
|
|
}
|
|
|
|
type TransportDemo struct {
|
|
br *bufio.Reader
|
|
bw *bufio.Writer
|
|
}
|
|
|
|
func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
|
|
if err = req.Write(t.bw); err != nil {
|
|
return false, err
|
|
}
|
|
if err = t.bw.Flush(); err != nil {
|
|
return false, err
|
|
}
|
|
err = res.Read(t.br)
|
|
return err != nil, err
|
|
}
|
|
|
|
func TestHostClientTransport(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.WriteString("abcd") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Transport: func() RoundTripper {
|
|
c, _ := ln.Dial()
|
|
|
|
br := bufio.NewReader(c)
|
|
bw := bufio.NewWriter(c)
|
|
|
|
return TransportDemo{br: br, bw: bw}
|
|
}(),
|
|
}
|
|
|
|
for range 5 {
|
|
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
if string(body) != "abcd" {
|
|
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
type writeErrorConn struct {
|
|
net.Conn
|
|
}
|
|
|
|
func (w *writeErrorConn) Write(p []byte) (int, error) {
|
|
return 1, errors.New("error")
|
|
}
|
|
|
|
func (w *writeErrorConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (w *writeErrorConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (w *writeErrorConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (w *writeErrorConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (w *writeErrorConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type readErrorConn struct {
|
|
net.Conn
|
|
}
|
|
|
|
func (r *readErrorConn) Read(p []byte) (int, error) {
|
|
return 0, errors.New("error")
|
|
}
|
|
|
|
func (r *readErrorConn) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (r *readErrorConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type singleReadConn struct {
|
|
net.Conn
|
|
|
|
s string
|
|
n int
|
|
}
|
|
|
|
func (r *singleReadConn) Read(p []byte) (int, error) {
|
|
if len(r.s) == r.n {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.s[r.n:])
|
|
r.n += n
|
|
return n, nil
|
|
}
|
|
|
|
func (r *singleReadConn) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (r *singleReadConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type singleEchoConn struct {
|
|
net.Conn
|
|
|
|
b []byte
|
|
n int
|
|
}
|
|
|
|
func (r *singleEchoConn) Read(p []byte) (int, error) {
|
|
if len(r.b) == r.n {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.b[r.n:])
|
|
r.n += n
|
|
return n, nil
|
|
}
|
|
|
|
func (r *singleEchoConn) Write(p []byte) (int, error) {
|
|
r.b = append(r.b, p...)
|
|
return len(p), nil
|
|
}
|
|
|
|
func (r *singleEchoConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func TestSingleEchoConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &singleEchoConn{
|
|
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
req := Request{}
|
|
res := Response{}
|
|
|
|
req.SetRequestURI("http://foobar/a/b")
|
|
req.Header.SetMethod("POST")
|
|
req.Header.Set("Content-Type", "text/plain")
|
|
body := bytes.NewBufferString("test")
|
|
req.SetBodyStream(body, body.Len())
|
|
|
|
err := c.Do(&req, &res)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if res.StatusCode() != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", res.StatusCode())
|
|
}
|
|
expected := "POST /a/b HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar\r\nContent-Type: text/plain\r\nContent-Length: 4\r\n\r\ntest"
|
|
if string(res.Body()) != expected {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", res.Body(), expected)
|
|
}
|
|
}
|
|
|
|
func TestClientHTTPSInvalidServerName(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
|
defer sHTTPS.Stop()
|
|
|
|
var c Client
|
|
|
|
for range 10 {
|
|
_, _, err := c.GetTimeout(nil, "https://"+sHTTPS.Addr(), time.Second)
|
|
if err == nil {
|
|
t.Fatalf("expecting TLS error")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientHTTPSConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer sHTTP.Stop()
|
|
|
|
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
|
defer sHTTPS.Stop()
|
|
|
|
c := &Client{
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := range 4 {
|
|
wg.Add(1)
|
|
addr := "http://" + sHTTP.Addr()
|
|
if i&1 != 0 {
|
|
addr = "https://" + sHTTPS.Addr()
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGet(t, c, addr, 20)
|
|
testClientPost(t, c, addr, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientManyServers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
addrs := make([]string, 0, 10)
|
|
for range 10 {
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
addrs = append(addrs, s.Addr())
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := range 4 {
|
|
wg.Add(1)
|
|
addr := "http://" + addrs[i]
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGet(t, &defaultClient, addr, 20)
|
|
testClientPost(t, &defaultClient, addr, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientGet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientGet(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientPost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientPost(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
addr := "http://" + s.Addr()
|
|
var wg sync.WaitGroup
|
|
for range 10 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGet(t, &defaultClient, addr, 30)
|
|
testClientPost(t, &defaultClient, addr, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func skipIfNotUnix(tb testing.TB) {
|
|
switch runtime.GOOS {
|
|
case "android", "nacl", "plan9", "windows":
|
|
tb.Skipf("%s does not support unix sockets", runtime.GOOS)
|
|
}
|
|
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
|
|
tb.Skip("iOS does not support unix, unixgram")
|
|
}
|
|
}
|
|
|
|
func TestHostClientGet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfNotUnix(t)
|
|
addr := "TestHostClientGet.unix"
|
|
s := startEchoServer(t, "unix", addr)
|
|
defer s.Stop()
|
|
c := createEchoClient("unix", addr)
|
|
|
|
testHostClientGet(t, c, 100)
|
|
}
|
|
|
|
func TestHostClientPost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfNotUnix(t)
|
|
addr := "./TestHostClientPost.unix"
|
|
s := startEchoServer(t, "unix", addr)
|
|
defer s.Stop()
|
|
c := createEchoClient("unix", addr)
|
|
|
|
testHostClientPost(t, c, 100)
|
|
}
|
|
|
|
func TestHostClientConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfNotUnix(t)
|
|
addr := "./TestHostClientConcurrent.unix"
|
|
s := startEchoServer(t, "unix", addr)
|
|
defer s.Stop()
|
|
c := createEchoClient("unix", addr)
|
|
|
|
var wg sync.WaitGroup
|
|
for range 10 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testHostClientGet(t, c, 30)
|
|
testHostClientPost(t, c, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func testClientGet(t *testing.T, c clientGetter, addr string, n int) {
|
|
var buf []byte
|
|
for i := range n {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
statusCode, body, err := c.Get(buf, uri)
|
|
buf = body
|
|
if err != nil {
|
|
t.Errorf("unexpected error when doing http request: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
resultURI := string(body)
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
|
|
var req Request
|
|
var resp Response
|
|
|
|
for i := range n {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
req.SetRequestURI(uri)
|
|
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
resultURI := string(resp.Body())
|
|
if strings.HasPrefix(uri, "https") {
|
|
resultURI = uri[:5] + resultURI[4:]
|
|
}
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientRequestSetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
|
|
var req Request
|
|
var resp Response
|
|
|
|
for i := range n {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
req.SetRequestURI(uri)
|
|
req.SetTimeout(time.Second)
|
|
if err := c.Do(&req, &resp); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
resultURI := string(resp.Body())
|
|
if strings.HasPrefix(uri, "https") {
|
|
resultURI = uri[:5] + resultURI[4:]
|
|
}
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
|
|
var buf []byte
|
|
for i := range n {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
statusCode, body, err := c.GetTimeout(buf, uri, time.Second)
|
|
buf = body
|
|
if err != nil {
|
|
t.Fatalf("unexpected error when doing http request: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
resultURI := string(body)
|
|
if strings.HasPrefix(uri, "https") {
|
|
resultURI = uri[:5] + resultURI[4:]
|
|
}
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientPost(t *testing.T, c clientPoster, addr string, n int) {
|
|
var buf []byte
|
|
var args Args
|
|
for i := range n {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
args.Set("xx", fmt.Sprintf("yy%d", i))
|
|
args.Set("zzz", fmt.Sprintf("qwe_%d", i))
|
|
argsS := args.String()
|
|
statusCode, body, err := c.Post(buf, uri, &args)
|
|
buf = body
|
|
if err != nil {
|
|
t.Errorf("unexpected error when doing http request: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
s := string(body)
|
|
if s != argsS {
|
|
t.Errorf("unexpected response %q. Expecting %q", s, argsS)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testHostClientGet(t *testing.T, c *HostClient, n int) {
|
|
testClientGet(t, c, "http://google.com", n)
|
|
}
|
|
|
|
func testHostClientPost(t *testing.T, c *HostClient, n int) {
|
|
testClientPost(t, c, "http://post-host.com", n)
|
|
}
|
|
|
|
type clientPoster interface {
|
|
Post(dst []byte, uri string, postArgs *Args) (int, []byte, error)
|
|
}
|
|
|
|
type clientGetter interface {
|
|
Get(dst []byte, uri string) (int, []byte, error)
|
|
}
|
|
|
|
func createEchoClient(network, addr string) *HostClient {
|
|
return &HostClient{
|
|
Addr: addr,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return net.Dial(network, addr)
|
|
},
|
|
}
|
|
}
|
|
|
|
type testEchoServer struct {
|
|
s *Server
|
|
ln net.Listener
|
|
ch chan struct{}
|
|
t *testing.T
|
|
}
|
|
|
|
func (s *testEchoServer) Stop() {
|
|
s.ln.Close()
|
|
select {
|
|
case <-s.ch:
|
|
case <-time.After(time.Second):
|
|
s.t.Fatalf("timeout when waiting for server close")
|
|
}
|
|
}
|
|
|
|
func (s *testEchoServer) Addr() string {
|
|
return s.ln.Addr().String()
|
|
}
|
|
|
|
func startEchoServerTLS(t *testing.T, network, addr string) *testEchoServer {
|
|
return startEchoServerExt(t, network, addr, true)
|
|
}
|
|
|
|
func startEchoServer(t *testing.T, network, addr string) *testEchoServer {
|
|
return startEchoServerExt(t, network, addr, false)
|
|
}
|
|
|
|
func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEchoServer {
|
|
if network == "unix" {
|
|
os.Remove(addr)
|
|
}
|
|
var ln net.Listener
|
|
var err error
|
|
if isTLS {
|
|
certData, keyData, kerr := GenerateTestCertificate("localhost")
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
cert, kerr := tls.X509KeyPair(certData, keyData)
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
ln, err = tls.Listen(network, addr, tlsConfig)
|
|
} else {
|
|
ln, err = net.Listen(network, addr)
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("cannot listen %q: %v", addr, err)
|
|
}
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if ctx.IsGet() {
|
|
ctx.Success("text/plain", ctx.URI().FullURI())
|
|
} else if ctx.IsPost() {
|
|
ctx.PostArgs().WriteTo(ctx) //nolint:errcheck
|
|
}
|
|
},
|
|
Logger: &testLogger{}, // Ignore log output.
|
|
}
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
err := s.Serve(ln)
|
|
if err != nil {
|
|
t.Errorf("unexpected error returned from Serve(): %v", err)
|
|
}
|
|
close(ch)
|
|
}()
|
|
return &testEchoServer{
|
|
s: s,
|
|
ln: ln,
|
|
ch: ch,
|
|
t: t,
|
|
}
|
|
}
|
|
|
|
func TestClientTLSHandshakeTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
addr := listener.Addr().String()
|
|
defer listener.Close()
|
|
|
|
complete := make(chan bool)
|
|
defer close(complete)
|
|
|
|
go func() {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
<-complete
|
|
conn.Close()
|
|
}()
|
|
|
|
client := Client{
|
|
WriteTimeout: 100 * time.Millisecond,
|
|
ReadTimeout: 100 * time.Millisecond,
|
|
}
|
|
|
|
_, _, err = client.Get(nil, "https://"+addr)
|
|
if err == nil {
|
|
t.Fatal("tlsClientHandshake completed successfully")
|
|
}
|
|
|
|
if err != ErrTLSHandshakeTimeout {
|
|
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
|
|
}
|
|
}
|
|
|
|
func TestClientConfigureClientFailed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
c := &Client{
|
|
ConfigureClient: func(hc *HostClient) error {
|
|
return errors.New("failed to configure")
|
|
},
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &singleEchoConn{
|
|
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
req := Request{}
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(&req, &Response{})
|
|
if err == nil {
|
|
t.Fatal("expected error (failed to configure)")
|
|
}
|
|
|
|
c.ConfigureClient = nil
|
|
err = c.Do(&req, &Response{})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
MaxConnWaitTimeout: time.Second * 2,
|
|
}
|
|
|
|
for range 5 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.Do(req, resp); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
if c.connsWait.len() > 0 {
|
|
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
|
|
}
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second * 5):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
MaxConnWaitTimeout: 10 * time.Millisecond,
|
|
}
|
|
|
|
var errNoFreeConnsCount atomic.Uint32
|
|
for range 5 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.Do(req, resp); err != nil {
|
|
if err != ErrNoFreeConns {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrNoFreeConns)
|
|
}
|
|
errNoFreeConnsCount.Add(1)
|
|
} else {
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
time.Sleep(time.Millisecond * 200)
|
|
|
|
// Prevent a race condition with the conns cleaner that might still be running.
|
|
c.connsLock.Lock()
|
|
defer c.connsLock.Unlock()
|
|
|
|
if c.connsWait.len() > 0 {
|
|
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
|
|
}
|
|
if count := errNoFreeConnsCount.Load(); count == 0 {
|
|
t.Errorf("unexpected errorCount: %d. Expecting > 0", count)
|
|
}
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
wg sync.WaitGroup
|
|
// make deadline reach earlier than conns wait timeout
|
|
sleep = 100 * time.Millisecond
|
|
timeout = 10 * time.Millisecond
|
|
maxConnWaitTimeout = 50 * time.Millisecond
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
time.Sleep(sleep)
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
Logger: &testLogger{}, // Don't print connection closed errors.
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
MaxConnWaitTimeout: maxConnWaitTimeout,
|
|
}
|
|
|
|
var errTimeoutCount atomic.Uint32
|
|
for range 5 {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
errTimeoutCount.Add(1)
|
|
} else {
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
c.connsLock.Lock()
|
|
for {
|
|
w := c.connsWait.popFront()
|
|
if w == nil {
|
|
break
|
|
}
|
|
w.mu.Lock()
|
|
if w.err != nil && w.err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", w.err, ErrTimeout)
|
|
}
|
|
w.mu.Unlock()
|
|
}
|
|
c.connsLock.Unlock()
|
|
if count := errTimeoutCount.Load(); count == 0 {
|
|
t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", count)
|
|
}
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
type TransportEmpty struct{}
|
|
|
|
func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
|
|
return false, nil
|
|
}
|
|
|
|
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client := HostClient{
|
|
IsTLS: true,
|
|
Transport: TransportEmpty{},
|
|
}
|
|
|
|
req := &Request{}
|
|
|
|
req.SetRequestURI("https://foo.com/bar")
|
|
|
|
_, err := client.doNonNilReqResp(req, &Response{})
|
|
if err != nil {
|
|
t.Fatal("https requests with IsTLS client must succeed")
|
|
}
|
|
}
|
|
|
|
func TestHostClientErrConnPoolStrategyNotImpl(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
server := &Server{
|
|
Handler: func(ctx *RequestCtx) {},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := server.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
client := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ConnPoolStrategy: ConnPoolStrategyType(100),
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
|
|
if err := client.Do(req, AcquireResponse()); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
|
|
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
|
|
}
|
|
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
|
|
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func Test_AddMissingPort(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type args struct {
|
|
addr string
|
|
isTLS bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
want string
|
|
args args
|
|
}{
|
|
{
|
|
args: args{addr: "127.1", isTLS: false}, // 127.1 is a short form of 127.0.0.1
|
|
want: "127.1:80",
|
|
},
|
|
{
|
|
args: args{addr: "127.0.0.1", isTLS: false},
|
|
want: "127.0.0.1:80",
|
|
},
|
|
{
|
|
args: args{addr: "127.0.0.1", isTLS: true},
|
|
want: "127.0.0.1:443",
|
|
},
|
|
{
|
|
args: args{addr: "[::1]", isTLS: false},
|
|
want: "[::1]:80",
|
|
},
|
|
{
|
|
args: args{addr: "::1", isTLS: false},
|
|
want: "::1", // keep as is
|
|
},
|
|
{
|
|
args: args{addr: "[::1]", isTLS: true},
|
|
want: "[::1]:443",
|
|
},
|
|
{
|
|
args: args{addr: "127.0.0.1:8080", isTLS: false},
|
|
want: "127.0.0.1:8080",
|
|
},
|
|
{
|
|
args: args{addr: "127.0.0.1:8443", isTLS: true},
|
|
want: "127.0.0.1:8443",
|
|
},
|
|
{
|
|
args: args{addr: "[::1]:8080", isTLS: false},
|
|
want: "[::1]:8080",
|
|
},
|
|
{
|
|
args: args{addr: "[::1]:8443", isTLS: true},
|
|
want: "[::1]:8443",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.want, func(t *testing.T) {
|
|
if got := AddMissingPort(tt.args.addr, tt.args.isTLS); got != tt.want {
|
|
t.Errorf("AddMissingPort() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type TransportWrapper struct {
|
|
base RoundTripper
|
|
count *int
|
|
t *testing.T
|
|
}
|
|
|
|
func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
|
|
req.Header.Set("trace-id", "123")
|
|
tw.assertRequestLog(req.String())
|
|
retry, err := tw.transport().RoundTrip(hc, req, resp)
|
|
resp.Header.Set("trace-id", "124")
|
|
tw.assertResponseLog(resp.String())
|
|
*tw.count++
|
|
return retry, err
|
|
}
|
|
|
|
func (tw *TransportWrapper) transport() RoundTripper {
|
|
if tw.base == nil {
|
|
return DefaultTransport
|
|
}
|
|
return tw.base
|
|
}
|
|
|
|
func (tw *TransportWrapper) assertRequestLog(reqLog string) {
|
|
if !strings.Contains(reqLog, "Trace-Id: 123") {
|
|
tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
|
|
}
|
|
}
|
|
|
|
func (tw *TransportWrapper) assertResponseLog(respLog string) {
|
|
if !strings.Contains(respLog, "Trace-Id: 124") {
|
|
tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
|
|
}
|
|
}
|
|
|
|
func TestClientTransportEx(t *testing.T) {
|
|
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer sHTTP.Stop()
|
|
|
|
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
|
defer sHTTPS.Stop()
|
|
|
|
count := 0
|
|
c := &Client{
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
ConfigureClient: func(hc *HostClient) error {
|
|
hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
|
|
return nil
|
|
},
|
|
}
|
|
// test transport
|
|
const loopCount = 4
|
|
const getCount = 20
|
|
const postCount = 10
|
|
for i := range loopCount {
|
|
addr := "http://" + sHTTP.Addr()
|
|
if i&1 != 0 {
|
|
addr = "https://" + sHTTPS.Addr()
|
|
}
|
|
// test get
|
|
testClientGet(t, c, addr, getCount)
|
|
// test post
|
|
testClientPost(t, c, addr, postCount)
|
|
}
|
|
roundTripCount := loopCount * (getCount + postCount)
|
|
if count != roundTripCount {
|
|
t.Errorf("round trip count should be: %v", roundTripCount)
|
|
}
|
|
}
|
|
|
|
func Test_getRedirectURL(t *testing.T) {
|
|
type args struct {
|
|
baseURL string
|
|
location []byte
|
|
disablePathNormalizing bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
want string
|
|
args args
|
|
}{
|
|
{
|
|
name: "Path normalizing enabled, no special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc",
|
|
location: []byte("http://bar.example.com/def"),
|
|
disablePathNormalizing: false,
|
|
},
|
|
want: "http://bar.example.com/def",
|
|
},
|
|
{
|
|
name: "Path normalizing enabled, special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc/*/def",
|
|
location: []byte("http://bar.example.com/123/*/456"),
|
|
disablePathNormalizing: false,
|
|
},
|
|
want: "http://bar.example.com/123/%2A/456",
|
|
},
|
|
{
|
|
name: "Path normalizing disabled, no special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc",
|
|
location: []byte("http://bar.example.com/def"),
|
|
disablePathNormalizing: true,
|
|
},
|
|
want: "http://bar.example.com/def",
|
|
},
|
|
{
|
|
name: "Path normalizing disabled, special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc/*/def",
|
|
location: []byte("http://bar.example.com/123/*/456"),
|
|
disablePathNormalizing: true,
|
|
},
|
|
want: "http://bar.example.com/123/*/456",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
redirectURI := AcquireURI()
|
|
got := getRedirectURL(tt.args.baseURL, tt.args.location, tt.args.disablePathNormalizing, redirectURI)
|
|
ReleaseURI(redirectURI)
|
|
if got != tt.want {
|
|
t.Errorf("getRedirectURL() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type clientDoTimeOuter interface {
|
|
DoTimeout(req *Request, resp *Response, timeout time.Duration) error
|
|
}
|
|
|
|
func TestDialTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
client clientDoTimeOuter
|
|
name string
|
|
requestTimeout time.Duration
|
|
shouldFailFast bool
|
|
}{
|
|
{
|
|
name: "Client should fail after a millisecond due to request timeout",
|
|
client: &Client{
|
|
// should be ignored due to DialTimeout
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
time.Sleep(time.Second)
|
|
return nil, errors.New("timeout")
|
|
},
|
|
// should be used
|
|
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
|
|
time.Sleep(timeout)
|
|
return nil, errors.New("timeout")
|
|
},
|
|
},
|
|
requestTimeout: time.Millisecond,
|
|
shouldFailFast: true,
|
|
},
|
|
{
|
|
name: "Client should fail after a second due to no DialTimeout set",
|
|
client: &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
time.Sleep(time.Second)
|
|
return nil, errors.New("timeout")
|
|
},
|
|
},
|
|
requestTimeout: time.Millisecond,
|
|
shouldFailFast: false,
|
|
},
|
|
{
|
|
name: "HostClient should fail after a millisecond due to request timeout",
|
|
client: &HostClient{
|
|
// should be ignored due to DialTimeout
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
time.Sleep(time.Second)
|
|
return nil, errors.New("timeout")
|
|
},
|
|
// should be used
|
|
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
|
|
time.Sleep(timeout)
|
|
return nil, errors.New("timeout")
|
|
},
|
|
},
|
|
requestTimeout: time.Millisecond,
|
|
shouldFailFast: true,
|
|
},
|
|
{
|
|
name: "HostClient should fail after a second due to no DialTimeout set",
|
|
client: &HostClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
time.Sleep(time.Second)
|
|
return nil, errors.New("timeout")
|
|
},
|
|
},
|
|
requestTimeout: time.Millisecond,
|
|
shouldFailFast: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
start := time.Now()
|
|
err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout)
|
|
if err == nil {
|
|
t.Fatal("expected error (timeout)")
|
|
}
|
|
if tt.shouldFailFast {
|
|
if time.Since(start) > time.Second {
|
|
t.Fatal("expected timeout after a millisecond")
|
|
}
|
|
} else {
|
|
if time.Since(start) < time.Second {
|
|
t.Fatal("expected timeout after a second")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientHeadWithBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
defer ln.Close()
|
|
|
|
go func() {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
c.Write([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck
|
|
"content-type: text/plain\r\n" +
|
|
"transfer-encoding: chunked\r\n\r\n" +
|
|
"24\r\nThis is the data in the first chunk \r\n" +
|
|
"1B\r\nand this is the second one \r\n" +
|
|
"0\r\n\r\n",
|
|
))
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ReadTimeout: time.Millisecond * 10,
|
|
MaxIdemponentCallAttempts: 1,
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://127.0.0.1:7070")
|
|
req.Header.SetMethod(MethodHead)
|
|
|
|
resp := AcquireResponse()
|
|
|
|
err := c.Do(req, resp)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// The second request on the same connection is going to give a timeout as it can't find
|
|
// a proper request in what is left on the connection.
|
|
err = c.Do(req, resp)
|
|
if err == nil {
|
|
t.Error("expected timeout error")
|
|
} else if err != ErrTimeout {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestRevertPull1233(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.SkipNow()
|
|
}
|
|
t.Parallel()
|
|
const expectedStatus = http.StatusTeapot
|
|
ln, err := net.Listen("tcp", "127.0.0.1:8089")
|
|
defer func() { ln.Close() }()
|
|
if err != nil {
|
|
t.Fatal(err.Error())
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
if !strings.Contains(err.Error(), "closed") {
|
|
t.Error(err)
|
|
}
|
|
return
|
|
}
|
|
_, err = conn.Write([]byte("HTTP/1.1 418 Teapot\r\n\r\n"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
err = conn.(*net.TCPConn).SetLinger(0)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
conn.Close()
|
|
}
|
|
}()
|
|
|
|
reqURL := "http://" + ln.Addr().String()
|
|
reqStrBody := "hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333 hello 2323 23323 2323 2323 232323 323 2323 2333333"
|
|
req2 := AcquireRequest()
|
|
resp2 := AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req2)
|
|
ReleaseResponse(resp2)
|
|
}()
|
|
req2.SetRequestURI(reqURL)
|
|
req2.SetBodyStream(F{strings.NewReader(reqStrBody)}, -1)
|
|
cli2 := Client{}
|
|
err = cli2.Do(req2, resp2)
|
|
if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, syscall.ECONNRESET) {
|
|
t.Errorf("expected error %v or %v, but got nil", syscall.EPIPE, syscall.ECONNRESET)
|
|
}
|
|
if expectedStatus == resp2.StatusCode() {
|
|
t.Errorf("Not Expected status code %d", resp2.StatusCode())
|
|
}
|
|
}
|
|
|
|
type F struct {
|
|
*strings.Reader
|
|
}
|
|
|
|
func (f F) Read(p []byte) (n int, err error) {
|
|
if len(p) > 10 {
|
|
p = p[:10]
|
|
}
|
|
// Ensure that subsequent segments can see the RST packet caused by sending previous
|
|
// segments to a closed connection.
|
|
time.Sleep(500 * time.Microsecond)
|
|
return f.Reader.Read(p)
|
|
}
|
|
|
|
func TestTCPDialerFlushDNSCache(t *testing.T) {
|
|
resolver := &testResolver{
|
|
lookupCountByHost: make(map[string]int),
|
|
resolver: net.DefaultResolver,
|
|
}
|
|
|
|
dialer := &TCPDialer{
|
|
DNSCacheDuration: 30 * time.Minute, // Long cache
|
|
Resolver: resolver,
|
|
}
|
|
|
|
// First dial - should trigger DNS lookup
|
|
conn1, err := dialer.DialTimeout("httpbin.org:80", 5*time.Second)
|
|
if err != nil {
|
|
t.Skip("Dial failed:", err)
|
|
}
|
|
conn1.Close()
|
|
|
|
if resolver.lookupCountByHost["httpbin.org"] != 1 {
|
|
t.Errorf("Expected 1 DNS lookup after first dial, got %d", resolver.lookupCountByHost["httpbin.org"])
|
|
}
|
|
|
|
// Second dial - should use cache (no new DNS lookup)
|
|
conn2, err := dialer.DialTimeout("httpbin.org:80", 5*time.Second)
|
|
if err != nil {
|
|
t.Skip("Second dial failed:", err)
|
|
}
|
|
conn2.Close()
|
|
|
|
if resolver.lookupCountByHost["httpbin.org"] != 1 {
|
|
t.Errorf("Expected 1 DNS lookup after cached dial, got %d", resolver.lookupCountByHost["httpbin.org"])
|
|
}
|
|
|
|
// Flush cache - should clear all entries
|
|
dialer.FlushDNSCache()
|
|
|
|
// Third dial - should trigger new DNS lookup since cache was flushed
|
|
conn3, err := dialer.DialTimeout("httpbin.org:80", 5*time.Second)
|
|
if err != nil {
|
|
t.Skip("Third dial failed:", err)
|
|
}
|
|
conn3.Close()
|
|
|
|
if resolver.lookupCountByHost["httpbin.org"] != 2 {
|
|
t.Errorf("Expected 2 DNS lookups after cache flush, got %d", resolver.lookupCountByHost["httpbin.org"])
|
|
}
|
|
}
|
|
|
|
// Simple test resolver that implements the Resolver interface.
|
|
type testResolver struct {
|
|
resolver *net.Resolver
|
|
lookupCountByHost map[string]int
|
|
}
|
|
|
|
func (r *testResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) {
|
|
r.lookupCountByHost[host]++
|
|
return r.resolver.LookupIPAddr(ctx, host)
|
|
}
|
|
|
|
type TransportMock struct {
|
|
wrapperFunc func(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
|
|
}
|
|
|
|
func (t *TransportMock) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
|
|
return t.wrapperFunc(hc, req, resp)
|
|
}
|
|
|
|
func TestClient_RetryIfErrUpstream(t *testing.T) {
|
|
t.Parallel()
|
|
upstreamErr := errors.New("upstream error")
|
|
|
|
t.Run("upstream_known", func(t *testing.T) {
|
|
retryIfErrCalled := false
|
|
c := &Client{
|
|
Transport: &TransportMock{
|
|
wrapperFunc: func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
|
|
resp.raddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
|
|
return true, upstreamErr
|
|
},
|
|
},
|
|
RetryIfErrUpstream: func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) {
|
|
retryIfErrCalled = true
|
|
if upstream != "127.0.0.1:8080" {
|
|
t.Errorf("expected upstream to be 127.0.0.1:8080, got %s", upstream)
|
|
}
|
|
|
|
return false, false
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if !errors.Is(err, upstreamErr) {
|
|
t.Fatal(err)
|
|
}
|
|
if !retryIfErrCalled {
|
|
t.Fatal("RetryIfErrUpstream should be called")
|
|
}
|
|
})
|
|
|
|
t.Run("no_upstream", func(t *testing.T) {
|
|
retryIfErrCalled := false
|
|
c := &Client{
|
|
Transport: &TransportMock{
|
|
wrapperFunc: func(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
|
|
return true, upstreamErr
|
|
},
|
|
},
|
|
RetryIfErrUpstream: func(request *Request, attempts int, err error, upstream string) (resetTimeout bool, retry bool) {
|
|
retryIfErrCalled = true
|
|
if upstream != "" {
|
|
t.Errorf("expected upstream to be empty, got %s", upstream)
|
|
}
|
|
|
|
return false, false
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if !errors.Is(err, upstreamErr) {
|
|
t.Fatal(err)
|
|
}
|
|
if !retryIfErrCalled {
|
|
t.Fatal("RetryIfErrUpstream should be called")
|
|
}
|
|
})
|
|
}
|