diff --git a/compress.go b/compress.go index c3b81cf..b6869c6 100644 --- a/compress.go +++ b/compress.go @@ -9,6 +9,7 @@ import ( "github.com/klauspost/compress/flate" "github.com/klauspost/compress/gzip" "github.com/klauspost/compress/zlib" + "github.com/valyala/fasthttp/stackless" ) // Supported compression levels. @@ -77,12 +78,15 @@ func acquireGzipWriter(w io.Writer, level int) *gzipWriter { v := p.Get() if v == nil { - zw, err := gzip.NewWriterLevel(w, level) - if err != nil { - panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err)) - } + sw := stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + zw, err := gzip.NewWriterLevel(w, level) + if err != nil { + panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err)) + } + return zw + }) return &gzipWriter{ - Writer: zw, + Writer: sw, p: p, } } @@ -97,7 +101,7 @@ func releaseGzipWriter(zw *gzipWriter) { } type gzipWriter struct { - *gzip.Writer + stackless.Writer p *sync.Pool } @@ -225,12 +229,15 @@ func acquireFlateWriter(w io.Writer, level int) *flateWriter { v := p.Get() if v == nil { - zw, err := zlib.NewWriterLevel(w, level) - if err != nil { - panic(fmt.Sprintf("BUG: unexpected error in zlib.NewWriterLevel(%d): %s", level, err)) - } + sw := stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + zw, err := zlib.NewWriterLevel(w, level) + if err != nil { + panic(fmt.Sprintf("BUG: unexpected error in zlib.NewWriterLevel(%d): %s", level, err)) + } + return zw + }) return &flateWriter{ - Writer: zw, + Writer: sw, p: p, } } @@ -245,7 +252,7 @@ func releaseFlateWriter(zw *flateWriter) { } type flateWriter struct { - *zlib.Writer + stackless.Writer p *sync.Pool } diff --git a/stackless/doc.go b/stackless/doc.go new file mode 100644 index 0000000..37591dd --- /dev/null +++ b/stackless/doc.go @@ -0,0 +1,3 @@ +// Package stackless saves stack space when using writers from compress/* +// packages. +package stackless diff --git a/stackless/writer.go b/stackless/writer.go new file mode 100644 index 0000000..18c3e62 --- /dev/null +++ b/stackless/writer.go @@ -0,0 +1,146 @@ +package stackless + +import ( + "fmt" + "github.com/valyala/bytebufferpool" + "io" + "runtime" +) + +// Writer is an interface stackless writer must conform to. +// +// The interface contains common subset for Writers from compress/* packages. +type Writer interface { + Write(p []byte) (int, error) + Flush() error + Close() error + Reset(w io.Writer) +} + +// NewWriterFunc must return new writer that will be wrapped into +// stackless writer. +type NewWriterFunc func(w io.Writer) Writer + +// NewWriter creates a stackless writer around a writer returned +// from newWriter. +// +// The returned writer writes data to dstW. +// +// Writers that use a lot of stack space may be wrapped into stackless writer, +// thus saving stack space. +func NewWriter(dstW io.Writer, newWriter NewWriterFunc) Writer { + w := &writer{ + dstW: dstW, + done: make(chan error), + } + w.zw = newWriter(&w.xw) + return w +} + +type writer struct { + dstW io.Writer + zw Writer + xw xWriter + + done chan error + n int + + p []byte + op op +} + +type op int + +const ( + opWrite op = iota + opFlush + opClose + opReset +) + +func (w *writer) Write(p []byte) (int, error) { + w.p = p + err := w.do(opWrite) + w.p = nil + return w.n, err +} + +func (w *writer) Flush() error { + return w.do(opFlush) +} + +func (w *writer) Close() error { + return w.do(opClose) +} + +func (w *writer) Reset(dstW io.Writer) { + w.xw.Reset() + w.do(opReset) + w.dstW = dstW +} + +func (w *writer) do(op op) error { + w.op = op + writerCh <- w + err := <-w.done + if err != nil { + return err + } + if w.xw.bb != nil { + _, err = w.dstW.Write(w.xw.bb.B) + } + w.xw.Reset() + + return err +} + +type xWriter struct { + bb *bytebufferpool.ByteBuffer +} + +func (w *xWriter) Write(p []byte) (int, error) { + if w.bb == nil { + w.bb = bufferPool.Get() + } + w.bb.Write(p) + return len(p), nil +} + +func (w *xWriter) Reset() { + if w.bb != nil { + bufferPool.Put(w.bb) + w.bb = nil + } +} + +var bufferPool bytebufferpool.Pool + +func init() { + n := runtime.GOMAXPROCS(-1) + writerCh = make(chan *writer, n) + for i := 0; i < n; i++ { + go worker() + } +} + +var writerCh chan *writer + +func worker() { + var err error + for w := range writerCh { + switch w.op { + case opWrite: + w.n, err = w.zw.Write(w.p) + case opFlush: + err = w.zw.Flush() + case opClose: + err = w.zw.Close() + case opReset: + w.zw.Reset(&w.xw) + err = nil + default: + panic(fmt.Sprintf("BUG: unexpected op: %d", w.op)) + } + w.done <- err + } +} diff --git a/stackless/writer_test.go b/stackless/writer_test.go new file mode 100644 index 0000000..f36f18d --- /dev/null +++ b/stackless/writer_test.go @@ -0,0 +1,122 @@ +package stackless + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "testing" + "time" +) + +func TestCompressFlateSerial(t *testing.T) { + if err := testCompressFlate(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestCompressFlateConcurrent(t *testing.T) { + if err := testConcurrent(testCompressFlate, 10); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func testCompressFlate() error { + return testWriter(func(w io.Writer) Writer { + zw, err := flate.NewWriter(w, flate.DefaultCompression) + if err != nil { + panic(fmt.Sprintf("BUG: unexpected error: %s", err)) + } + return zw + }, func(r io.Reader) io.Reader { + return flate.NewReader(r) + }) +} + +func TestCompressGzipSerial(t *testing.T) { + if err := testCompressGzip(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestCompressGzipConcurrent(t *testing.T) { + if err := testConcurrent(testCompressGzip, 10); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func testCompressGzip() error { + return testWriter(func(w io.Writer) Writer { + return gzip.NewWriter(w) + }, func(r io.Reader) io.Reader { + zr, err := gzip.NewReader(r) + if err != nil { + panic(fmt.Sprintf("BUG: cannot create gzip reader: %s", err)) + } + return zr + }) +} + +func testWriter(newWriter NewWriterFunc, newReader func(io.Reader) io.Reader) error { + dstW := &bytes.Buffer{} + w := NewWriter(dstW, newWriter) + + for i := 0; i < 5; i++ { + if err := testWriterReuse(w, dstW, newReader); err != nil { + return fmt.Errorf("unepxected error when re-using writer on iteration %d: %s", i, err) + } + dstW = &bytes.Buffer{} + w.Reset(dstW) + } + + return nil +} + +func testWriterReuse(w Writer, r io.Reader, newReader func(io.Reader) io.Reader) error { + wantW := &bytes.Buffer{} + mw := io.MultiWriter(w, wantW) + for i := 0; i < 30; i++ { + fmt.Fprintf(mw, "foobar %d\n", i) + if i%13 == 0 { + if err := w.Flush(); err != nil { + return fmt.Errorf("error on flush: %s", err) + } + } + } + w.Close() + + zr := newReader(r) + data, err := ioutil.ReadAll(zr) + if err != nil { + return fmt.Errorf("unexpected error: %s, data=%q", err, data) + } + + wantData := wantW.Bytes() + if !bytes.Equal(data, wantData) { + return fmt.Errorf("unexpected data: %q. Expecting %q", data, wantData) + } + + return nil +} + +func testConcurrent(testFunc func() error, concurrency int) error { + ch := make(chan error, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + ch <- testFunc() + }() + } + for i := 0; i < concurrency; i++ { + select { + case err := <-ch: + if err != nil { + return fmt.Errorf("unexpected error on goroutine %d: %s", i, err) + } + case <-time.After(time.Second): + return fmt.Errorf("timeout on goroutine %d", i) + } + } + return nil +}