From 149f0f38b798501980ae3e91a75cf18cc3c69a12 Mon Sep 17 00:00:00 2001 From: Aliaksandr Valialkin Date: Fri, 25 Dec 2015 13:33:01 +0200 Subject: [PATCH] Issue #14: added CompressHandler wrapper for transparent response compression support --- header.go | 25 +++++++++++++++++ header_test.go | 25 +++++++++++++++++ http.go | 76 +++++++++++++++++++++++++++++++++++--------------- server.go | 28 +++++++++++++++++++ server_test.go | 72 +++++++++++++++++++++++++++++++++++++++++++++++ strings.go | 1 + 6 files changed, 204 insertions(+), 23 deletions(-) diff --git a/header.go b/header.go index aa2ae6e..e45d009 100644 --- a/header.go +++ b/header.go @@ -408,6 +408,31 @@ func (h *ResponseHeader) IsHTTP11() bool { return !h.noHTTP11 } +// HasAcceptEncoding returns true if the header contains +// the given Accept-Encoding value. +func (h *RequestHeader) HasAcceptEncoding(acceptEncoding string) bool { + h.bufKV.value = append(h.bufKV.value[:0], acceptEncoding...) + return h.HasAcceptEncodingBytes(h.bufKV.value) +} + +// HasAcceptEncodingBytes returns true if the header contains +// the given Accept-Encoding value. +func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool { + ae := h.peek(strAcceptEncoding) + n := bytes.Index(ae, acceptEncoding) + if n < 0 { + return false + } + b := ae[n+len(acceptEncoding):] + if len(b) > 0 && b[0] != ',' { + return false + } + if n == 0 { + return true + } + return ae[n-1] == ' ' +} + // Len returns the number of headers set, // i.e. the number of times f is called in VisitAll. func (h *ResponseHeader) Len() int { diff --git a/header_test.go b/header_test.go index d43a52a..940ea85 100644 --- a/header_test.go +++ b/header_test.go @@ -10,6 +10,31 @@ import ( "testing" ) +func TestRequestHeaderHasAcceptEncoding(t *testing.T) { + testRequestHeaderHasAcceptEncoding(t, "", "gzip", false) + testRequestHeaderHasAcceptEncoding(t, "gzip", "sdhc", false) + testRequestHeaderHasAcceptEncoding(t, "deflate", "deflate", true) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "gzi", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "dhc", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "sdh", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "zip", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "flat", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "flate", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "def", false) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "gzip", true) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "deflate", true) + testRequestHeaderHasAcceptEncoding(t, "gzip, deflate, sdhc", "sdhc", true) +} + +func testRequestHeaderHasAcceptEncoding(t *testing.T, ae, v string, resultExpected bool) { + var h RequestHeader + h.Set("Accept-Encoding", ae) + result := h.HasAcceptEncoding(v) + if result != resultExpected { + t.Fatalf("unexpected result in HasAcceptEncoding(%q, %q): %v. Expecting %v", ae, v, result, resultExpected) + } +} + func TestRequestMultipartFormBoundary(t *testing.T) { testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=foobar\r\n\r\n", "foobar") diff --git a/http.go b/http.go index 8bf27e6..0fb7fa7 100644 --- a/http.go +++ b/http.go @@ -568,17 +568,65 @@ func (req *Request) Write(w *bufio.Writer) error { // // WriteGzip doesn't flush response to w for performance reasons. func (resp *Response) WriteGzip(w *bufio.Writer) error { - return resp.WriteGzipLevel(w, gzip.DefaultCompression) + return resp.WriteGzipLevel(w, CompressDefaultCompression) } // WriteGzipLevel writes response with gzipped body to w. // -// Level is compression level. See available levels in encoding/gzip package. +// Level is the desired compression level: +// +// * CompressNoCompression +// * CompressBestSpeed +// * CompressBestCompression +// * CompressDefaultCompression // // The method sets 'Content-Encoding: gzip' header. // // WriteGzipLevel doesn't flush response to w for performance reasons. func (resp *Response) WriteGzipLevel(w *bufio.Writer, level int) error { + if err := resp.gzipBody(level); err != nil { + return err + } + return resp.Write(w) +} + +// WriteDeflate writes response with deflated body to w. +// +// The method sets 'Content-Encoding: deflate' header. +// +// WriteDeflate doesn't flush response to w for performance reasons. +func (resp *Response) WriteDeflate(w *bufio.Writer) error { + return resp.WriteDeflateLevel(w, CompressDefaultCompression) +} + +// WriteDeflateLevel writes response with deflated body to w. +// +// Level is the desired compression level: +// +// * CompressNoCompression +// * CompressBestSpeed +// * CompressBestCompression +// * CompressDefaultCompression +// +// The method sets 'Content-Encoding: deflate' header. +// +// WriteDeflateLevel doesn't flush response to w for performance reasons. +func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error { + if err := resp.deflateBody(level); err != nil { + return err + } + return resp.Write(w) +} + +// Supported compression levels. +const ( + CompressNoCompression = flate.NoCompression + CompressBestSpeed = flate.BestSpeed + CompressBestCompression = flate.BestCompression + CompressDefaultCompression = flate.DefaultCompression +) + +func (resp *Response) gzipBody(level int) error { // Do not care about memory allocations here, since gzip is slow // and allocates a lot of memory by itself. if resp.bodyStream != nil { @@ -597,28 +645,11 @@ func (resp *Response) WriteGzipLevel(w *bufio.Writer, level int) error { zw.Close() resp.body = buf.Bytes() } - resp.Header.SetCanonical(strContentEncoding, strGzip) - return resp.Write(w) + return nil } -// WriteDeflate writes response with deflated body to w. -// -// The method sets 'Content-Encoding: deflate' header. -// -// WriteDeflate doesn't flush response to w for performance reasons. -func (resp *Response) WriteDeflate(w *bufio.Writer) error { - return resp.WriteDeflateLevel(w, flate.DefaultCompression) -} - -// WriteDeflateLevel writes response with deflated body to w. -// -// Level is compression level. See available levels in encoding/flate package. -// -// The method sets 'Content-Encoding: deflate' header. -// -// WriteDeflateLevel doesn't flush response to w for performance reasons. -func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error { +func (resp *Response) deflateBody(level int) error { // Do not care about memory allocations here, since flate is slow // and allocates a lot of memory by itself. if resp.bodyStream != nil { @@ -637,9 +668,8 @@ func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error { zw.Close() resp.body = buf.Bytes() } - resp.Header.SetCanonical(strContentEncoding, strDeflate) - return resp.Write(w) + return nil } func newDeflateWriter(w io.Writer, level int) *flate.Writer { diff --git a/server.go b/server.go index 94cc1c7..1c93aac 100644 --- a/server.go +++ b/server.go @@ -231,6 +231,34 @@ func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) Request } } +// CompressHandlerLevel returns RequestHandler that transparently compresses +// response body generated by h if the request contains 'gzip' or 'deflate' +// 'Accept-Encoding' header. +func CompressHandler(h RequestHandler) RequestHandler { + return CompressHandlerLevel(h, CompressDefaultCompression) +} + +// CompressHandlerLevel returns RequestHandler that transparently compresses +// response body generated by h if the request contains 'gzip' or 'deflate' +// 'Accept-Encoding' header. +// +// Level is the desired compression level: +// +// * CompressNoCompression +// * CompressBestSpeed +// * CompressBestCompression +// * CompressDefaultCompression +func CompressHandlerLevel(h RequestHandler, level int) RequestHandler { + return func(ctx *RequestCtx) { + h(ctx) + if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + ctx.Response.gzipBody(level) + } else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) { + ctx.Response.deflateBody(level) + } + } +} + // RequestCtx contains incoming request and manages outgoing response. // // It is forbidden copying RequestCtx instances. diff --git a/server_test.go b/server_test.go index 8c0a310..43e6bf2 100644 --- a/server_test.go +++ b/server_test.go @@ -12,6 +12,78 @@ import ( "time" ) +func TestCompressHandler(t *testing.T) { + expectedBody := "foo/bar/baz" + h := CompressHandler(func(ctx *RequestCtx) { + ctx.Write([]byte(expectedBody)) + }) + + var ctx RequestCtx + var resp Response + + // verify uncompressed response + h(&ctx) + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce := resp.Header.Peek("Content-Encoding") + if string(ce) != "" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") + } + body := resp.Body() + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify gzip-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek("Content-Encoding") + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + } + body, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify deflate-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "foobar, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek("Content-Encoding") + if string(ce) != "deflate" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate") + } + body, err = resp.BodyInflate() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } +} + func TestRequestCtxWriteString(t *testing.T) { var ctx RequestCtx n, err := ctx.WriteString("foo") diff --git a/strings.go b/strings.go index 6dc42fb..02132a1 100644 --- a/strings.go +++ b/strings.go @@ -33,6 +33,7 @@ var ( strServer = []byte("Server") strTransferEncoding = []byte("Transfer-Encoding") strContentEncoding = []byte("Content-Encoding") + strAcceptEncoding = []byte("Accept-Encoding") strUserAgent = []byte("User-Agent") strCookie = []byte("Cookie") strSetCookie = []byte("Set-Cookie")