mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
Add option for middleware to set custom remote address (#1009)
* Add option for middleware to set custom remote address * Update Init2 to clear custom context remoteAddr
This commit is contained in:
@@ -572,6 +572,7 @@ type RequestCtx struct {
|
||||
connID uint64
|
||||
connRequestNum uint64
|
||||
connTime time.Time
|
||||
remoteAddr net.Addr
|
||||
|
||||
time time.Time
|
||||
|
||||
@@ -1091,6 +1092,9 @@ func (ctx *RequestCtx) IsHead() bool {
|
||||
//
|
||||
// Always returns non-nil result.
|
||||
func (ctx *RequestCtx) RemoteAddr() net.Addr {
|
||||
if ctx.remoteAddr != nil {
|
||||
return ctx.remoteAddr
|
||||
}
|
||||
if ctx.c == nil {
|
||||
return zeroTCPAddr
|
||||
}
|
||||
@@ -1101,6 +1105,14 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr {
|
||||
return addr
|
||||
}
|
||||
|
||||
// SetRemoteAddr sets remote address to the given value.
|
||||
//
|
||||
// Set nil value to resore default behaviour for using
|
||||
// connection remote address.
|
||||
func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) {
|
||||
ctx.remoteAddr = remoteAddr
|
||||
}
|
||||
|
||||
// LocalAddr returns server address for the given request.
|
||||
//
|
||||
// Always returns non-nil result.
|
||||
@@ -2524,6 +2536,7 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
|
||||
// See https://github.com/valyala/httpteleport for details.
|
||||
func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) {
|
||||
ctx.c = conn
|
||||
ctx.remoteAddr = nil
|
||||
ctx.logger.logger = logger
|
||||
ctx.connID = nextConnID()
|
||||
ctx.s = fakeServer
|
||||
@@ -2636,6 +2649,7 @@ func (s *Server) releaseCtx(ctx *RequestCtx) {
|
||||
panic("BUG: cannot release timed out RequestCtx")
|
||||
}
|
||||
ctx.c = nil
|
||||
ctx.remoteAddr = nil
|
||||
ctx.fbr.c = nil
|
||||
s.ctxPool.Put(ctx)
|
||||
}
|
||||
|
||||
@@ -2966,6 +2966,56 @@ func TestServerRemoteAddr(t *testing.T) {
|
||||
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4")
|
||||
}
|
||||
|
||||
func TestServerCustomRemoteAddr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
customRemoteAddrHandler := func(h RequestHandler) RequestHandler {
|
||||
return func(ctx *RequestCtx) {
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{
|
||||
IP: []byte{1, 2, 3, 5},
|
||||
Port: 0,
|
||||
})
|
||||
h(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
Handler: customRemoteAddrHandler(func(ctx *RequestCtx) {
|
||||
h := &ctx.Request.Header
|
||||
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
|
||||
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
|
||||
}),
|
||||
}
|
||||
|
||||
rw := &readWriter{}
|
||||
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
|
||||
|
||||
rwx := &readWriterRemoteAddr{
|
||||
rw: rw,
|
||||
addr: &net.TCPAddr{
|
||||
IP: []byte{1, 2, 3, 4},
|
||||
Port: 8765,
|
||||
},
|
||||
}
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- s.ServeConn(rwx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error from serveConn: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
br := bufio.NewReader(&rw.w)
|
||||
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5")
|
||||
}
|
||||
|
||||
type readWriterRemoteAddr struct {
|
||||
net.Conn
|
||||
rw io.ReadWriteCloser
|
||||
|
||||
Reference in New Issue
Block a user