From a0ee798bbed2a2e2ffc5417bb8ee8b8cd44e3e1b Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Mon, 9 Nov 2015 18:41:27 +0200 Subject: [PATCH] Added RequestCtx.Init() helper for custom Server implementations --- server.go | 70 +++++++++++++++++++++++++++++++++++++++++++++----- server_test.go | 17 ++++++++++++ 2 files changed, 81 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index 80f5e91..ab43016 100644 --- a/server.go +++ b/server.go @@ -184,7 +184,8 @@ type Logger interface { var ctxLoggerLock sync.Mutex type ctxLogger struct { - ctx *RequestCtx + ctx *RequestCtx + logger Logger } func (cl *ctxLogger) Printf(format string, args ...interface{}) { @@ -193,7 +194,7 @@ func (cl *ctxLogger) Printf(format string, args ...interface{}) { ctx := cl.ctx req := &ctx.Request req.ParseURI() - ctx.s.logger().Printf("%.3f #%016X - %s - %s %s - %s", + cl.logger.Printf("%.3f #%016X - %s - %s %s - %s", time.Since(ctx.Time).Seconds(), ctx.ID, ctx.RemoteAddr(), req.Header.Method, req.URI.URI, s) ctxLoggerLock.Unlock() } @@ -213,7 +214,11 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr { // RemoteIP returns client ip for the given request. func (ctx *RequestCtx) RemoteIP() net.IP { - x, ok := ctx.RemoteAddr().(*net.TCPAddr) + addr := ctx.RemoteAddr() + if addr == nil { + return net.IPv4zero + } + x, ok := addr.(*net.TCPAddr) if !ok { return net.IPv4zero } @@ -251,6 +256,12 @@ func (ctx *RequestCtx) Success(contentType string, body []byte) { // // It is safe re-using returned logger for logging multiple messages. func (ctx *RequestCtx) Logger() Logger { + if ctx.logger.ctx == nil { + ctx.logger.ctx = ctx + } + if ctx.logger.logger == nil { + ctx.logger.logger = ctx.s.logger() + } return &ctx.logger } @@ -409,7 +420,7 @@ func wrapPerIPConn(s *Server, c net.Conn) net.Conn { } } -var defaultLogger = log.New(os.Stderr, "", log.LstdFlags) +var defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) func (s *Server) logger() Logger { if s.Logger != nil { @@ -708,17 +719,64 @@ func (s *Server) acquireCtx(c io.ReadWriteCloser) *RequestCtx { ctx = &RequestCtx{ s: s, } - ctx.logger.ctx = ctx ctx.v = ctx v = ctx } else { ctx = v.(*RequestCtx) } - ctx.ID = (atomic.AddUint64(&globalCtxID, 1)) << 32 + ctx.initID() ctx.c = c return ctx } +// Init prepares ctx for passing to RequestHandler. +// +// remoteAddr and logger are optional. They are used by RequestCtx.Logger(). +// +// This function is intended for custom Server implementations. +func (ctx *RequestCtx) Init(req *Request, remoteAddr net.Addr, logger Logger) { + if remoteAddr == nil { + remoteAddr = zeroIPAddr + } + ctx.c = &fakeAddrer{ + addr: remoteAddr, + } + if logger != nil { + ctx.logger.logger = logger + } + ctx.s = &fakeServer + ctx.initID() + req.CopyTo(&ctx.Request) + ctx.Response.Clear() + ctx.Time = time.Now() +} + +var fakeServer Server + +type fakeAddrer struct { + addr net.Addr +} + +func (fa *fakeAddrer) RemoteAddr() net.Addr { + return fa.addr +} + +func (fa *fakeAddrer) Read(p []byte) (int, error) { + panic("BUG: unexpected Read call") +} + +func (fa *fakeAddrer) Write(p []byte) (int, error) { + panic("BUG: unexpected Write call") +} + +func (fa *fakeAddrer) Close() error { + panic("BUG: unexpected Close call") +} + +func (ctx *RequestCtx) initID() { + ctx.ID = (atomic.AddUint64(&globalCtxID, 1)) << 32 +} + func (s *Server) releaseCtx(ctx *RequestCtx) { if len(ctx.timeoutErrMsg) > 0 { panic("BUG: cannot release timed out RequestCtx") diff --git a/server_test.go b/server_test.go index 62dca71..11013d3 100644 --- a/server_test.go +++ b/server_test.go @@ -11,6 +11,23 @@ import ( "time" ) +func TestRequestCtxInit(t *testing.T) { + var ctx RequestCtx + var logger customLogger + globalCtxID = 0x123456 + ctx.Init(&ctx.Request, zeroIPAddr, &logger) + ip := ctx.RemoteIP() + if !ip.IsUnspecified() { + t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip) + } + ctx.Logger().Printf("foo bar %d", 10) + + expectedLog := "0.000 #0012345700000000 - 0.0.0.0 - http:// - foo bar 10\n" + if logger.out != expectedLog { + t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog) + } +} + func TestTimeoutHandlerSuccess(t *testing.T) { h := func(ctx *RequestCtx) { ctx.Success("aaa/bbb", []byte("real response"))