Add option for middleware to set custom remote address (#1009)

* Add option for middleware to set custom remote address

* Update Init2 to clear custom context remoteAddr
This commit is contained in:
Lauris BH
2021-04-23 14:25:02 +03:00
committed by GitHub
parent 894272e578
commit 620f0c83ad
2 changed files with 64 additions and 0 deletions
+14
View File
@@ -572,6 +572,7 @@ type RequestCtx struct {
connID uint64
connRequestNum uint64
connTime time.Time
remoteAddr net.Addr
time time.Time
@@ -1091,6 +1092,9 @@ func (ctx *RequestCtx) IsHead() bool {
//
// Always returns non-nil result.
func (ctx *RequestCtx) RemoteAddr() net.Addr {
if ctx.remoteAddr != nil {
return ctx.remoteAddr
}
if ctx.c == nil {
return zeroTCPAddr
}
@@ -1101,6 +1105,14 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr {
return addr
}
// SetRemoteAddr sets remote address to the given value.
//
// Set nil value to resore default behaviour for using
// connection remote address.
func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) {
ctx.remoteAddr = remoteAddr
}
// LocalAddr returns server address for the given request.
//
// Always returns non-nil result.
@@ -2524,6 +2536,7 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
// See https://github.com/valyala/httpteleport for details.
func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) {
ctx.c = conn
ctx.remoteAddr = nil
ctx.logger.logger = logger
ctx.connID = nextConnID()
ctx.s = fakeServer
@@ -2636,6 +2649,7 @@ func (s *Server) releaseCtx(ctx *RequestCtx) {
panic("BUG: cannot release timed out RequestCtx")
}
ctx.c = nil
ctx.remoteAddr = nil
ctx.fbr.c = nil
s.ctxPool.Put(ctx)
}
+50
View File
@@ -2966,6 +2966,56 @@ func TestServerRemoteAddr(t *testing.T) {
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4")
}
func TestServerCustomRemoteAddr(t *testing.T) {
t.Parallel()
customRemoteAddrHandler := func(h RequestHandler) RequestHandler {
return func(ctx *RequestCtx) {
ctx.SetRemoteAddr(&net.TCPAddr{
IP: []byte{1, 2, 3, 5},
Port: 0,
})
h(ctx)
}
}
s := &Server{
Handler: customRemoteAddrHandler(func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
}),
}
rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}
ch := make(chan error)
go func() {
ch <- s.ServeConn(rwx)
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5")
}
type readWriterRemoteAddr struct {
net.Conn
rw io.ReadWriteCloser