mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
adaptor ResponseWriter - adding Hijack method and pass proper fields (#1525)
* adding hijack method and pass proper fields * adding hijack method and pass proper fields - adding tests * improve hijack handling, use proper test for hijacking * extend hijackhandler propogation to NewFastHTTPHandlerFunc * align hijacking of fasthttp adaptor net request with fasthttp request, safe conn handling for proper release of resources and custom hijack handler for more controlled by hijacking implementation * Implement actual behaviour of net/http Hijacker --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
This commit is contained in:
@@ -3,8 +3,11 @@
|
||||
package fasthttpadaptor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
|
||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()}
|
||||
w := netHTTPResponseWriter{
|
||||
w: ctx.Response.BodyWriter(),
|
||||
ctx: ctx,
|
||||
}
|
||||
h.ServeHTTP(&w, r.WithContext(ctx))
|
||||
|
||||
ctx.SetStatusCode(w.StatusCode())
|
||||
@@ -86,6 +91,7 @@ type netHTTPResponseWriter struct {
|
||||
statusCode int
|
||||
h http.Header
|
||||
w io.Writer
|
||||
ctx *fasthttp.RequestCtx
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) StatusCode() int {
|
||||
@@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) Flush() {}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (c *wrappedConn) Close() (err error) {
|
||||
c.once.Do(func() {
|
||||
err = c.Conn.Close()
|
||||
c.wg.Done()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
// Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or
|
||||
// doing anything else with it.
|
||||
w.ctx.HijackSetNoResponse(true)
|
||||
|
||||
conn := &wrappedConn{Conn: w.ctx.Conn()}
|
||||
conn.wg.Add(1)
|
||||
w.ctx.Hijack(func(net.Conn) {
|
||||
conn.wg.Wait()
|
||||
})
|
||||
|
||||
bufW := bufio.NewWriter(conn)
|
||||
|
||||
// Write any unflushed body to the hijacked connection buffer.
|
||||
unflushedBody := w.ctx.Response.Body()
|
||||
if len(unflushedBody) > 0 {
|
||||
if _, err := bufW.Write(unflushedBody); err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil
|
||||
}
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
)
|
||||
|
||||
func TestNewFastHTTPHandler(t *testing.T) {
|
||||
@@ -143,3 +145,74 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value a
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHijack(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
nethttpH := func(w http.ResponseWriter, r *http.Request) {
|
||||
if f, ok := w.(http.Hijacker); !ok {
|
||||
t.Errorf("expected http.ResponseWriter to implement http.Hijacker")
|
||||
} else {
|
||||
if _, err := w.Write([]byte("foo")); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c, rw, err := f.Hijack(); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if _, err := rw.Write([]byte("bar")); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := rw.Flush(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s := &fasthttp.Server{
|
||||
Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)),
|
||||
}
|
||||
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(ln); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientCh := make(chan struct{})
|
||||
go func() {
|
||||
c, err := ln.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
buf, err := io.ReadAll(c)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if string(buf) != "foobar" {
|
||||
t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar")
|
||||
}
|
||||
|
||||
close(clientCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-clientCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user