diff --git a/http.go b/http.go index fa1bd4e..1b2b92c 100644 --- a/http.go +++ b/http.go @@ -102,7 +102,8 @@ var ErrReadTimeout = errors.New("read timeout") // the given timeout. // // If request couldn't be read during the given timeout, -// it returns ErrReadTimeout. +// ErrReadTimeout is returned. +// Request can no longer be used after ErrReadTimeout error. func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error { if timeout <= 0 { return req.Read(r) @@ -136,7 +137,8 @@ func (req *Request) ReadTimeout(r *bufio.Reader, timeout time.Duration) error { // the given timeout. // // If response couldn't be read during the given timeout, -// it returns ErrReadTimeout. +// ErrReadTimeout is returned. +// Request can no longer be used after ErrReadTimeout error. func (resp *Response) ReadTimeout(r *bufio.Reader, timeout time.Duration) error { if timeout <= 0 { return resp.Read(r) diff --git a/http_test.go b/http_test.go index d52e599..6740ec2 100644 --- a/http_test.go +++ b/http_test.go @@ -11,7 +11,7 @@ import ( ) func TestResponseReadTimeout(t *testing.T) { - var resp Response + resp := &Response{} for i := 0; i < 5; i++ { testResponseReadTimeoutError(t, &resp) @@ -31,7 +31,7 @@ func TestResponseReadTimeout(t *testing.T) { } func TestRequestReadTimeout(t *testing.T) { - var req Request + req := &Request{} for i := 0; i < 5; i++ { testRequestReadTimeoutError(t, &req) @@ -50,28 +50,30 @@ func TestRequestReadTimeout(t *testing.T) { } } -func testResponseReadTimeoutError(t *testing.T, resp *Response) { +func testResponseReadTimeoutError(t *testing.T, resp **Response) { r, _ := io.Pipe() rb := bufio.NewReader(r) - err := resp.ReadTimeout(rb, 5*time.Millisecond) + err := (*resp).ReadTimeout(rb, 5*time.Millisecond) if err == nil { t.Fatalf("Expecting error") } if err != ErrReadTimeout { t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout) } + *resp = &Response{} } -func testRequestReadTimeoutError(t *testing.T, req *Request) { +func testRequestReadTimeoutError(t *testing.T, req **Request) { r, _ := io.Pipe() rb := bufio.NewReader(r) - err := req.ReadTimeout(rb, 5*time.Millisecond) + err := (*req).ReadTimeout(rb, 5*time.Millisecond) if err == nil { t.Fatalf("Expecting error") } if err != ErrReadTimeout { t.Fatalf("Unexpected error: %s. Expecting %s", err, ErrReadTimeout) } + *req = &Request{} } func TestRequestReadChunked(t *testing.T) { diff --git a/server.go b/server.go index 6a6904c..71cd7a9 100644 --- a/server.go +++ b/server.go @@ -13,7 +13,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" ) // RequestHandler must process incoming requests. @@ -124,7 +123,9 @@ type RequestCtx struct { c io.ReadWriter r *bufio.Reader w *bufio.Writer - shadow unsafe.Pointer + + // shadow is set by TimeoutError(). + shadow *RequestCtx timeoutCh chan struct{} timeoutTimer *time.Timer @@ -227,18 +228,9 @@ func (ctx *RequestCtx) Logger() Logger { // TimeoutError MUST be called before returning from RequestHandler if there are // references to ctx and/or its members in other goroutines. func (ctx *RequestCtx) TimeoutError(msg string) { - var shadow RequestCtx - shadow.Request = Request{} - shadow.Response = Response{} - shadow.logger.ctx = &shadow - shadow.v = &shadow - - shadow.s = ctx.s - shadow.c = ctx.c - shadow.r = ctx.r - shadow.w = ctx.w - - if atomic.CompareAndSwapPointer(&ctx.shadow, nil, unsafe.Pointer(&shadow)) { + shadow := makeShadow(ctx) + if ctx.shadow == nil { + ctx.shadow = shadow shadow.Error(msg, StatusRequestTimeout) } } @@ -405,7 +397,11 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error { } err = ctx.Request.Read(ctx.r) } else { - err = ctx.Request.ReadTimeout(ctx.r, readTimeout) + if err = ctx.Request.ReadTimeout(ctx.r, readTimeout); err == ErrReadTimeout { + // ctx.Requests cannot be used after ErrReadTimeout, so create ctx shadow. + *ctxP = makeShadow(ctx) + break + } } if err != nil { if err == io.EOF { @@ -416,9 +412,9 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error { ctx.ID++ ctx.Time = time.Now() s.Handler(ctx) - shadow := atomic.LoadPointer(&ctx.shadow) + shadow := ctx.shadow if shadow != nil { - ctx = (*RequestCtx)(shadow) + ctx = shadow *ctxP = ctx } @@ -451,8 +447,22 @@ func (s *Server) serveConn(c io.ReadWriter, ctxP **RequestCtx) error { return err } +func makeShadow(ctx *RequestCtx) *RequestCtx { + var shadow RequestCtx + shadow.Request = Request{} + shadow.Response = Response{} + shadow.logger.ctx = &shadow + shadow.v = &shadow + + shadow.s = ctx.s + shadow.c = ctx.c + shadow.r = ctx.r + shadow.w = ctx.w + return &shadow +} + func writeResponse(ctx *RequestCtx) error { - if atomic.LoadPointer(&ctx.shadow) != nil { + if ctx.shadow != nil { panic("BUG: cannot write response with shadow") } h := &ctx.Response.Header @@ -517,7 +527,7 @@ func (s *Server) acquireCtx() *RequestCtx { } func (s *Server) releaseCtx(ctx *RequestCtx) { - if atomic.LoadPointer(&ctx.shadow) != nil { + if ctx.shadow != nil { panic("BUG: cannot release RequestCtx with shadow") } ctx.c = nil diff --git a/server_test.go b/server_test.go index 0960790..68994eb 100644 --- a/server_test.go +++ b/server_test.go @@ -75,7 +75,6 @@ func TestServerTimeoutError(t *testing.T) { Handler: func(ctx *RequestCtx) { go func() { ctx.Success("aaa/bbb", []byte("xxxyyy")) - ctx.TimeoutError("ignore this") }() ctx.TimeoutError("stolen ctx") ctx.TimeoutError("should be ignored")