diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go index 7153c6d..48cd0e3 100644 --- a/fasthttpadaptor/adaptor.go +++ b/fasthttpadaptor/adaptor.go @@ -56,45 +56,277 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } - w := netHTTPResponseWriter{ - w: ctx.Response.BodyWriter(), - ctx: ctx, - } - h.ServeHTTP(&w, r.WithContext(ctx)) - ctx.SetStatusCode(w.StatusCode()) - haveContentType := false - for k, vv := range w.Header() { - if k == fasthttp.HeaderContentType { - haveContentType = true + w := acquireNetHTTPResponseWriter(ctx) + + // Concurrently serve the net/http handler. + go func() { + h.ServeHTTP(w, r.WithContext(ctx)) + select { + case w.modeCh <- modeDone: + default: + } + _ = w.Close() + }() + + mode := <-w.modeCh + + switch mode { + case modeDone: + // No flush occurred before the handler returned. + // Send the data as one chunk. + ctx.SetStatusCode(w.StatusCode()) + haveContentType := false + for k, vv := range w.Header() { + if k == fasthttp.HeaderContentType { + haveContentType = true + } + + for _, v := range vv { + ctx.Response.Header.Add(k, v) + } } - for _, v := range vv { - ctx.Response.Header.Add(k, v) + if !haveContentType { + // From net/http.ResponseWriter.Write: + // If the Header does not contain a Content-Type line, Write adds a Content-Type set + // to the result of passing the initial 512 bytes of written data to DetectContentType. + l := 512 + b := *w.responseBody + if len(b) < 512 { + l = len(b) + } + ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l])) } - } - if !haveContentType { - // From net/http.ResponseWriter.Write: - // If the Header does not contain a Content-Type line, Write adds a Content-Type set - // to the result of passing the initial 512 bytes of written data to DetectContentType. - l := 512 - b := ctx.Response.Body() - if len(b) < 512 { - l = len(b) + + w.responseMutex.Lock() + if len(*w.responseBody) > 0 { + ctx.Response.SetBody(*w.responseBody) } - ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l])) + w.responseMutex.Unlock() + + // Release after sending response. + releaseNetHTTPResponseWriter(w) + + case modeFlushed: + // Flush occurred before handler returned. + // Send the first 512 bytes and start streaming + // the rest of the first chunk and new data as it arrives. + ctx.SetStatusCode(w.StatusCode()) + haveContentType := false + for k, vv := range w.Header() { + // Don't copy Content-Length header when + // streaming. + if k == fasthttp.HeaderContentLength { + continue + } + + if k == fasthttp.HeaderContentType { + haveContentType = true + } + + for _, v := range vv { + ctx.Response.Header.Add(k, v) + } + } + + // Lock the current response body until + // it is sent in the StreamWriter function. + w.responseMutex.Lock() + if !haveContentType { + // From net/http.ResponseWriter.Write: + // If the Header does not contain a Content-Type line, Write adds a Content-Type set + // to the result of passing the initial 512 bytes of written data to DetectContentType. + l := 512 + b := *w.responseBody + if len(b) < 512 { + l = len(b) + } + ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l])) + } + + // Start streaming mode on return. + ctx.SetBodyStreamWriter(func(bw *bufio.Writer) { + // Stream the first chunk. + if len(*w.responseBody) > 0 { + _, _ = bw.Write(*w.responseBody) + _ = bw.Flush() + } + // The current response body is no longer used + // past this point. + w.responseMutex.Unlock() + + // Stream the rest of the data that is read + // from the net/http handler in 32 KiB chunks. + // + // Note: Data must be manually copied in chunks + // as data comes in. + chunk := acquireBuffer() + *chunk = (*chunk)[:minBufferSize] + for { + // Read net/http handler chunk. + n, err := w.r.Read(*chunk) + if err != nil { + // Handler ended due to an io.EOF + // or an error occurred. + // + // Release the response writer for reuse. + releaseBuffer(chunk) + releaseNetHTTPResponseWriter(w) + return + } + + // Copy chunk to fasthttp response + if n > 0 { + _, err = bw.Write((*chunk)[:n]) + if err != nil { + // Handler ended due to an io.ErrPipeClosed + // or an error occurred. + // + // Release the response writer for reuse. + releaseBuffer(chunk) + releaseNetHTTPResponseWriter(w) + return + } + + err = bw.Flush() + if err != nil { + // Handler ended due to an io.ErrPipeClosed + // or an error occurred. + // + // Release the response writer for reuse. + releaseBuffer(chunk) + releaseNetHTTPResponseWriter(w) + return + } + } + } + }) + // Activate streaming mode for consequent `w.Flush()` + // by net/http handler. + w.streamCond.L.Lock() + w.isStreaming = true + w.streamCond.Signal() + w.streamCond.L.Unlock() + + case modeHijacked: + // The net/http handler called w.Hijack(). + // Copy data bidirectionally between the + // net/http and fasthttp connections. + var wg sync.WaitGroup + wg.Add(2) + + // Note: It is safe to assume that net.Conn automatically + // flushes data while copying. + go func() { + defer wg.Done() + _, _ = io.Copy(ctx.Conn(), w.handlerConn) + + // Close the fasthttp connection when + // the net/http connection closes. + _ = ctx.Conn().Close() + }() + go func() { + defer wg.Done() + _, _ = io.Copy(w.handlerConn, ctx.Conn()) + // Note: Only the net/http handler + // should close the connection. + }() + + // Wait for the net/http handler to finish + // writing to the hijacked connection prior to releasing + // the writer into the writer pool. + wg.Wait() + releaseNetHTTPResponseWriter(w) } } } +// Use a minimum buffer size of 32 KiB. +const minBufferSize = 32 * 1024 + +var bufferPool = &sync.Pool{ + New: func() any { + b := make([]byte, minBufferSize) + return &b + }, +} + +var writerPool = &sync.Pool{ + New: func() any { + pr, pw := io.Pipe() + return &netHTTPResponseWriter{ + h: make(http.Header), + r: pr, + w: pw, + modeCh: make(chan ModeType), + responseBody: acquireBuffer(), + streamCond: sync.NewCond(&sync.Mutex{}), + } + }, +} + +type ModeType int + +const ( + modeUnknown ModeType = iota + modeDone + modeFlushed + modeHijacked +) + type netHTTPResponseWriter struct { - w io.Writer - h http.Header - ctx *fasthttp.RequestCtx - statusCode int + handlerConn net.Conn + ctx *fasthttp.RequestCtx + h http.Header + r *io.PipeReader + w *io.PipeWriter + modeCh chan ModeType + responseBody *[]byte + streamCond *sync.Cond + statusCode int + once sync.Once + statusMutex sync.Mutex + responseMutex sync.Mutex + connMutex sync.Mutex + isStreaming bool +} + +func acquireNetHTTPResponseWriter(ctx *fasthttp.RequestCtx) *netHTTPResponseWriter { + w, ok := writerPool.Get().(*netHTTPResponseWriter) + if !ok { + panic("fasthttpadaptor: cannot get *netHTTPResponseWriter from writerPool") + } + w.reset() + + w.ctx = ctx + return w +} + +func releaseNetHTTPResponseWriter(w *netHTTPResponseWriter) { + releaseBuffer(w.responseBody) + w.Close() + writerPool.Put(w) +} + +func acquireBuffer() *[]byte { + buf, ok := bufferPool.Get().(*[]byte) + if !ok { + panic("fasthttpadaptor: cannot get *[]byte from bufferPool") + } + + *buf = (*buf)[:0] + return buf +} + +func releaseBuffer(buf *[]byte) { + bufferPool.Put(buf) } func (w *netHTTPResponseWriter) StatusCode() int { + w.statusMutex.Lock() + defer w.statusMutex.Unlock() + if w.statusCode == 0 { return http.StatusOK } @@ -102,35 +334,46 @@ func (w *netHTTPResponseWriter) StatusCode() int { } func (w *netHTTPResponseWriter) Header() http.Header { - if w.h == nil { - w.h = make(http.Header) - } return w.h } func (w *netHTTPResponseWriter) WriteHeader(statusCode int) { + w.statusMutex.Lock() + defer w.statusMutex.Unlock() + w.statusCode = statusCode } func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { - return w.w.Write(p) + w.streamCond.L.Lock() + defer w.streamCond.L.Unlock() + + if w.isStreaming { + // Streaming mode is on. + // Stream directly to the conn writer. + return w.w.Write(p) + } + + // Streaming mode is off. + // Write to the first chunk for flushing later. + w.responseMutex.Lock() + *w.responseBody = append(*w.responseBody, p...) + w.responseMutex.Unlock() + return len(p), nil } -func (w *netHTTPResponseWriter) Flush() {} - -type wrappedConn struct { - net.Conn - - wg sync.WaitGroup - once sync.Once -} - -func (c *wrappedConn) Close() (err error) { - c.once.Do(func() { - err = c.Conn.Close() - c.wg.Done() +func (w *netHTTPResponseWriter) Flush() { + // Trigger streaming mode setup. + w.once.Do(func() { + w.modeCh <- modeFlushed }) - return + + // Wait for streaming mode. + w.streamCond.L.Lock() + defer w.streamCond.L.Unlock() + for !w.isStreaming { + w.streamCond.Wait() + } } func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -138,22 +381,62 @@ func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // doing anything else with it. w.ctx.HijackSetNoResponse(true) - conn := &wrappedConn{Conn: w.ctx.Conn()} - conn.wg.Add(1) - w.ctx.Hijack(func(net.Conn) { - conn.wg.Wait() + netHTTPConn, fasthttpConn := net.Pipe() + w.handlerConn = fasthttpConn + + // Trigger hijacked mode. + w.once.Do(func() { + w.modeCh <- modeHijacked }) - bufW := bufio.NewWriter(conn) + bufRW := bufio.NewReadWriter(bufio.NewReader(netHTTPConn), bufio.NewWriter(netHTTPConn)) // Write any unflushed body to the hijacked connection buffer. - unflushedBody := w.ctx.Response.Body() - if len(unflushedBody) > 0 { - if _, err := bufW.Write(unflushedBody); err != nil { - conn.Close() - return nil, nil, err - } + w.responseMutex.Lock() + if len(*w.responseBody) > 0 { + _, _ = bufRW.Write(*w.responseBody) + _ = bufRW.Flush() + } + w.responseMutex.Unlock() + return netHTTPConn, bufRW, nil +} + +func (w *netHTTPResponseWriter) Close() error { + _ = w.w.Close() + _ = w.r.Close() + + w.connMutex.Lock() + if w.handlerConn != nil { + _ = w.handlerConn.Close() + } + w.connMutex.Unlock() + return nil +} + +func (w *netHTTPResponseWriter) reset() { + // Note: reset() must only run after a fasthttp handler finishes + // proxying the full net/http handler response to ensure no data races. + w.ctx = nil + w.connMutex.Lock() + w.handlerConn = nil + w.connMutex.Unlock() + w.statusCode = 0 + + // Open new bidirectional pipes + pr, pw := io.Pipe() + w.r = pr + w.w = pw + + // Clear the http Header + for key := range w.h { + delete(w.h, key) } - return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil + // Get a new buffer for the response body + w.responseBody = acquireBuffer() + + w.once = sync.Once{} + w.streamCond.L.Lock() + w.isStreaming = false + w.streamCond.L.Unlock() } diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go index 94ecf9f..e5b6f8c 100644 --- a/fasthttpadaptor/adaptor_test.go +++ b/fasthttpadaptor/adaptor_test.go @@ -1,6 +1,7 @@ package fasthttpadaptor import ( + "bufio" "io" "net" "net/http" @@ -135,7 +136,7 @@ func TestNewFastHTTPHandler(t *testing.T) { t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedBody) } if string(resp.Header.Peek("Content-Type")) != expectedContentType { - t.Fatalf("unexpected response content-type %q. Expecting %q", string(resp.Header.Peek("Content-Type")), expectedBody) + t.Fatalf("unexpected response content-type %q. Expecting %q", string(resp.Header.Peek("Content-Type")), expectedContentType) } } @@ -263,3 +264,169 @@ func TestHijack(t *testing.T) { t.Fatal("timeout") } } + +func TestFlushHandler(t *testing.T) { + t.Parallel() + + nethttpH := func(w http.ResponseWriter, r *http.Request) { + if f, ok := w.(http.Flusher); !ok { + t.Errorf("expected http.ResponseWriter to implement http.Flusher") + } else { + if _, err := w.Write([]byte("foo")); err != nil { + t.Error(err) + } + + f.Flush() + + time.Sleep(time.Second) + + if _, err := w.Write([]byte("bar")); err != nil { + t.Error(err) + } + + f.Flush() + } + } + + s := &fasthttp.Server{ + Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)), + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %v", err) + } + + time.AfterFunc(500*time.Millisecond, func() { + c.Close() + }) + resp, err := http.ReadResponse(bufio.NewReader(c), nil) + if err != nil { + t.Errorf("unexpected error reading response: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK) + } + + if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { + t.Errorf("unexpected Content-Type header: %q. Expecting %q", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil && err != io.ErrUnexpectedEOF { + t.Errorf("unexpected error reading body: %v", err) + } + + if string(body) != "foo" { + t.Errorf("unexpected response body: %q. Expecting %q", body, "foo") + } + + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestHijackFlush(t *testing.T) { + t.Parallel() + + nethttpH := func(w http.ResponseWriter, r *http.Request) { + if f, ok := w.(http.Hijacker); !ok { + t.Errorf("expected http.ResponseWriter to implement http.Hijacker") + } else { + if _, err := w.Write([]byte("foo")); err != nil { + t.Error(err) + } + + if c, rw, err := f.Hijack(); err != nil { + t.Error(err) + } else { + if _, err := rw.WriteString("bar"); err != nil { + t.Error(err) + } + + if err := rw.Flush(); err != nil { + t.Error(err) + } + + time.Sleep(time.Second) + + if _, err := rw.WriteString("bazz"); err != nil { + t.Error(err) + } + + if err := rw.Flush(); err != nil { + t.Error(err) + } + + if err := c.Close(); err != nil { + t.Error(err) + } + } + } + } + + s := &fasthttp.Server{ + Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)), + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %v", err) + } + + time.AfterFunc(500*time.Millisecond, func() { + c.Close() + }) + buf, err := io.ReadAll(c) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if string(buf) != "foobar" { + t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar") + } + + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +}