Add flushing support to fasthttpadaptor (#2054)

* Add flushing support to fasthttpadaptor

* refactor(fasthttphandler): Fix comment typos

* refactor(fasthttphandler): Fix early closing of net/http handler

* refactor(fasthttphandler): Apply requested changes

* refactor(fasthttphandler): Reduce memory allocations by using sync.Pool

* refactor(fasthttphandler): Fix improper releaseNetHTTPResponseWriter

* refactor(fasthttphandler): Add buffer sync.Pool with panic assert

* refactor(fasthttphandler): Fix hijacked-related response writer race condition

* refactor(fasthttphandler): Rename bufW to bufRW

* refactor(fasthttphandler): Ensure proper responseMutex use

* refactor(fasthttphandler): Specify minBufferSize to ensure reading 32 KiB chunks in streaming mode

* refactor(fasthttphandler): Fix release logic

* refactor(fasthttphandler): Fix handlerConn race condition

* refactor(fasthttphandler): Explicitly ignore handlerConn close error

* refactor(fasthttphandler): Use sync.Once, sync.Cond, and a single channel for mode management

* refactor(fasthttphandler): Remove commented code

* refactor(fasthttphandler): Add period to respect linter

* refactor(fasthttphandler): Remove return else clauses to respect lint
This commit is contained in:
Giovanni Rivera
2025-09-11 08:36:02 -07:00
committed by GitHub
parent e9640b4d39
commit e04490f830
2 changed files with 507 additions and 57 deletions
+339 -56
View File
@@ -56,45 +56,277 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
w := netHTTPResponseWriter{
w: ctx.Response.BodyWriter(),
ctx: ctx,
}
h.ServeHTTP(&w, r.WithContext(ctx))
ctx.SetStatusCode(w.StatusCode())
haveContentType := false
for k, vv := range w.Header() {
if k == fasthttp.HeaderContentType {
haveContentType = true
w := acquireNetHTTPResponseWriter(ctx)
// Concurrently serve the net/http handler.
go func() {
h.ServeHTTP(w, r.WithContext(ctx))
select {
case w.modeCh <- modeDone:
default:
}
_ = w.Close()
}()
mode := <-w.modeCh
switch mode {
case modeDone:
// No flush occurred before the handler returned.
// Send the data as one chunk.
ctx.SetStatusCode(w.StatusCode())
haveContentType := false
for k, vv := range w.Header() {
if k == fasthttp.HeaderContentType {
haveContentType = true
}
for _, v := range vv {
ctx.Response.Header.Add(k, v)
}
}
for _, v := range vv {
ctx.Response.Header.Add(k, v)
if !haveContentType {
// From net/http.ResponseWriter.Write:
// If the Header does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to DetectContentType.
l := 512
b := *w.responseBody
if len(b) < 512 {
l = len(b)
}
ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l]))
}
}
if !haveContentType {
// From net/http.ResponseWriter.Write:
// If the Header does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to DetectContentType.
l := 512
b := ctx.Response.Body()
if len(b) < 512 {
l = len(b)
w.responseMutex.Lock()
if len(*w.responseBody) > 0 {
ctx.Response.SetBody(*w.responseBody)
}
ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l]))
w.responseMutex.Unlock()
// Release after sending response.
releaseNetHTTPResponseWriter(w)
case modeFlushed:
// Flush occurred before handler returned.
// Send the first 512 bytes and start streaming
// the rest of the first chunk and new data as it arrives.
ctx.SetStatusCode(w.StatusCode())
haveContentType := false
for k, vv := range w.Header() {
// Don't copy Content-Length header when
// streaming.
if k == fasthttp.HeaderContentLength {
continue
}
if k == fasthttp.HeaderContentType {
haveContentType = true
}
for _, v := range vv {
ctx.Response.Header.Add(k, v)
}
}
// Lock the current response body until
// it is sent in the StreamWriter function.
w.responseMutex.Lock()
if !haveContentType {
// From net/http.ResponseWriter.Write:
// If the Header does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to DetectContentType.
l := 512
b := *w.responseBody
if len(b) < 512 {
l = len(b)
}
ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l]))
}
// Start streaming mode on return.
ctx.SetBodyStreamWriter(func(bw *bufio.Writer) {
// Stream the first chunk.
if len(*w.responseBody) > 0 {
_, _ = bw.Write(*w.responseBody)
_ = bw.Flush()
}
// The current response body is no longer used
// past this point.
w.responseMutex.Unlock()
// Stream the rest of the data that is read
// from the net/http handler in 32 KiB chunks.
//
// Note: Data must be manually copied in chunks
// as data comes in.
chunk := acquireBuffer()
*chunk = (*chunk)[:minBufferSize]
for {
// Read net/http handler chunk.
n, err := w.r.Read(*chunk)
if err != nil {
// Handler ended due to an io.EOF
// or an error occurred.
//
// Release the response writer for reuse.
releaseBuffer(chunk)
releaseNetHTTPResponseWriter(w)
return
}
// Copy chunk to fasthttp response
if n > 0 {
_, err = bw.Write((*chunk)[:n])
if err != nil {
// Handler ended due to an io.ErrPipeClosed
// or an error occurred.
//
// Release the response writer for reuse.
releaseBuffer(chunk)
releaseNetHTTPResponseWriter(w)
return
}
err = bw.Flush()
if err != nil {
// Handler ended due to an io.ErrPipeClosed
// or an error occurred.
//
// Release the response writer for reuse.
releaseBuffer(chunk)
releaseNetHTTPResponseWriter(w)
return
}
}
}
})
// Activate streaming mode for consequent `w.Flush()`
// by net/http handler.
w.streamCond.L.Lock()
w.isStreaming = true
w.streamCond.Signal()
w.streamCond.L.Unlock()
case modeHijacked:
// The net/http handler called w.Hijack().
// Copy data bidirectionally between the
// net/http and fasthttp connections.
var wg sync.WaitGroup
wg.Add(2)
// Note: It is safe to assume that net.Conn automatically
// flushes data while copying.
go func() {
defer wg.Done()
_, _ = io.Copy(ctx.Conn(), w.handlerConn)
// Close the fasthttp connection when
// the net/http connection closes.
_ = ctx.Conn().Close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(w.handlerConn, ctx.Conn())
// Note: Only the net/http handler
// should close the connection.
}()
// Wait for the net/http handler to finish
// writing to the hijacked connection prior to releasing
// the writer into the writer pool.
wg.Wait()
releaseNetHTTPResponseWriter(w)
}
}
}
// Use a minimum buffer size of 32 KiB.
const minBufferSize = 32 * 1024
var bufferPool = &sync.Pool{
New: func() any {
b := make([]byte, minBufferSize)
return &b
},
}
var writerPool = &sync.Pool{
New: func() any {
pr, pw := io.Pipe()
return &netHTTPResponseWriter{
h: make(http.Header),
r: pr,
w: pw,
modeCh: make(chan ModeType),
responseBody: acquireBuffer(),
streamCond: sync.NewCond(&sync.Mutex{}),
}
},
}
type ModeType int
const (
modeUnknown ModeType = iota
modeDone
modeFlushed
modeHijacked
)
type netHTTPResponseWriter struct {
w io.Writer
h http.Header
ctx *fasthttp.RequestCtx
statusCode int
handlerConn net.Conn
ctx *fasthttp.RequestCtx
h http.Header
r *io.PipeReader
w *io.PipeWriter
modeCh chan ModeType
responseBody *[]byte
streamCond *sync.Cond
statusCode int
once sync.Once
statusMutex sync.Mutex
responseMutex sync.Mutex
connMutex sync.Mutex
isStreaming bool
}
func acquireNetHTTPResponseWriter(ctx *fasthttp.RequestCtx) *netHTTPResponseWriter {
w, ok := writerPool.Get().(*netHTTPResponseWriter)
if !ok {
panic("fasthttpadaptor: cannot get *netHTTPResponseWriter from writerPool")
}
w.reset()
w.ctx = ctx
return w
}
func releaseNetHTTPResponseWriter(w *netHTTPResponseWriter) {
releaseBuffer(w.responseBody)
w.Close()
writerPool.Put(w)
}
func acquireBuffer() *[]byte {
buf, ok := bufferPool.Get().(*[]byte)
if !ok {
panic("fasthttpadaptor: cannot get *[]byte from bufferPool")
}
*buf = (*buf)[:0]
return buf
}
func releaseBuffer(buf *[]byte) {
bufferPool.Put(buf)
}
func (w *netHTTPResponseWriter) StatusCode() int {
w.statusMutex.Lock()
defer w.statusMutex.Unlock()
if w.statusCode == 0 {
return http.StatusOK
}
@@ -102,35 +334,46 @@ func (w *netHTTPResponseWriter) StatusCode() int {
}
func (w *netHTTPResponseWriter) Header() http.Header {
if w.h == nil {
w.h = make(http.Header)
}
return w.h
}
func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
w.statusMutex.Lock()
defer w.statusMutex.Unlock()
w.statusCode = statusCode
}
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
return w.w.Write(p)
w.streamCond.L.Lock()
defer w.streamCond.L.Unlock()
if w.isStreaming {
// Streaming mode is on.
// Stream directly to the conn writer.
return w.w.Write(p)
}
// Streaming mode is off.
// Write to the first chunk for flushing later.
w.responseMutex.Lock()
*w.responseBody = append(*w.responseBody, p...)
w.responseMutex.Unlock()
return len(p), nil
}
func (w *netHTTPResponseWriter) Flush() {}
type wrappedConn struct {
net.Conn
wg sync.WaitGroup
once sync.Once
}
func (c *wrappedConn) Close() (err error) {
c.once.Do(func() {
err = c.Conn.Close()
c.wg.Done()
func (w *netHTTPResponseWriter) Flush() {
// Trigger streaming mode setup.
w.once.Do(func() {
w.modeCh <- modeFlushed
})
return
// Wait for streaming mode.
w.streamCond.L.Lock()
defer w.streamCond.L.Unlock()
for !w.isStreaming {
w.streamCond.Wait()
}
}
func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
@@ -138,22 +381,62 @@ func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// doing anything else with it.
w.ctx.HijackSetNoResponse(true)
conn := &wrappedConn{Conn: w.ctx.Conn()}
conn.wg.Add(1)
w.ctx.Hijack(func(net.Conn) {
conn.wg.Wait()
netHTTPConn, fasthttpConn := net.Pipe()
w.handlerConn = fasthttpConn
// Trigger hijacked mode.
w.once.Do(func() {
w.modeCh <- modeHijacked
})
bufW := bufio.NewWriter(conn)
bufRW := bufio.NewReadWriter(bufio.NewReader(netHTTPConn), bufio.NewWriter(netHTTPConn))
// Write any unflushed body to the hijacked connection buffer.
unflushedBody := w.ctx.Response.Body()
if len(unflushedBody) > 0 {
if _, err := bufW.Write(unflushedBody); err != nil {
conn.Close()
return nil, nil, err
}
w.responseMutex.Lock()
if len(*w.responseBody) > 0 {
_, _ = bufRW.Write(*w.responseBody)
_ = bufRW.Flush()
}
w.responseMutex.Unlock()
return netHTTPConn, bufRW, nil
}
func (w *netHTTPResponseWriter) Close() error {
_ = w.w.Close()
_ = w.r.Close()
w.connMutex.Lock()
if w.handlerConn != nil {
_ = w.handlerConn.Close()
}
w.connMutex.Unlock()
return nil
}
func (w *netHTTPResponseWriter) reset() {
// Note: reset() must only run after a fasthttp handler finishes
// proxying the full net/http handler response to ensure no data races.
w.ctx = nil
w.connMutex.Lock()
w.handlerConn = nil
w.connMutex.Unlock()
w.statusCode = 0
// Open new bidirectional pipes
pr, pw := io.Pipe()
w.r = pr
w.w = pw
// Clear the http Header
for key := range w.h {
delete(w.h, key)
}
return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil
// Get a new buffer for the response body
w.responseBody = acquireBuffer()
w.once = sync.Once{}
w.streamCond.L.Lock()
w.isStreaming = false
w.streamCond.L.Unlock()
}
+168 -1
View File
@@ -1,6 +1,7 @@
package fasthttpadaptor
import (
"bufio"
"io"
"net"
"net/http"
@@ -135,7 +136,7 @@ func TestNewFastHTTPHandler(t *testing.T) {
t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedBody)
}
if string(resp.Header.Peek("Content-Type")) != expectedContentType {
t.Fatalf("unexpected response content-type %q. Expecting %q", string(resp.Header.Peek("Content-Type")), expectedBody)
t.Fatalf("unexpected response content-type %q. Expecting %q", string(resp.Header.Peek("Content-Type")), expectedContentType)
}
}
@@ -263,3 +264,169 @@ func TestHijack(t *testing.T) {
t.Fatal("timeout")
}
}
func TestFlushHandler(t *testing.T) {
t.Parallel()
nethttpH := func(w http.ResponseWriter, r *http.Request) {
if f, ok := w.(http.Flusher); !ok {
t.Errorf("expected http.ResponseWriter to implement http.Flusher")
} else {
if _, err := w.Write([]byte("foo")); err != nil {
t.Error(err)
}
f.Flush()
time.Sleep(time.Second)
if _, err := w.Write([]byte("bar")); err != nil {
t.Error(err)
}
f.Flush()
}
}
s := &fasthttp.Server{
Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)),
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
time.AfterFunc(500*time.Millisecond, func() {
c.Close()
})
resp, err := http.ReadResponse(bufio.NewReader(c), nil)
if err != nil {
t.Errorf("unexpected error reading response: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Errorf("unexpected Content-Type header: %q. Expecting %q", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8")
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil && err != io.ErrUnexpectedEOF {
t.Errorf("unexpected error reading body: %v", err)
}
if string(body) != "foo" {
t.Errorf("unexpected response body: %q. Expecting %q", body, "foo")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestHijackFlush(t *testing.T) {
t.Parallel()
nethttpH := func(w http.ResponseWriter, r *http.Request) {
if f, ok := w.(http.Hijacker); !ok {
t.Errorf("expected http.ResponseWriter to implement http.Hijacker")
} else {
if _, err := w.Write([]byte("foo")); err != nil {
t.Error(err)
}
if c, rw, err := f.Hijack(); err != nil {
t.Error(err)
} else {
if _, err := rw.WriteString("bar"); err != nil {
t.Error(err)
}
if err := rw.Flush(); err != nil {
t.Error(err)
}
time.Sleep(time.Second)
if _, err := rw.WriteString("bazz"); err != nil {
t.Error(err)
}
if err := rw.Flush(); err != nil {
t.Error(err)
}
if err := c.Close(); err != nil {
t.Error(err)
}
}
}
}
s := &fasthttp.Server{
Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)),
}
ln := fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
clientCh := make(chan struct{})
go func() {
c, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
t.Errorf("unexpected error: %v", err)
}
time.AfterFunc(500*time.Millisecond, func() {
c.Close()
})
buf, err := io.ReadAll(c)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(buf) != "foobar" {
t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar")
}
close(clientCh)
}()
select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}