diff --git a/bytebuffer.go b/bytebuffer.go index 7ea3ce5..c0c5a2b 100644 --- a/bytebuffer.go +++ b/bytebuffer.go @@ -1,5 +1,7 @@ package bytebufferpool +import "io" + // ByteBuffer provides byte buffer, which can be used for minimizing // memory allocations. // @@ -14,6 +16,12 @@ type ByteBuffer struct { B []byte } +// WriteTo implements io.WriterTo +func (b *ByteBuffer) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(b.B) + return int64(n), err +} + // Bytes returns b.B, i.e. all the bytes accumulated in the buffer. // // The purpose of this function is bytes.Buffer compatibility. diff --git a/bytebuffer_test.go b/bytebuffer_test.go index 7b98815..4eeaca9 100644 --- a/bytebuffer_test.go +++ b/bytebuffer_test.go @@ -1,11 +1,37 @@ package bytebufferpool import ( + "bytes" "fmt" + "io" "testing" "time" ) +func TestByteBufferWriteTo(t *testing.T) { + expectedS := "foobarbaz" + var bb ByteBuffer + bb.WriteString(expectedS[:3]) + bb.WriteString(expectedS[3:]) + + wt := (io.WriterTo)(&bb) + var w bytes.Buffer + for i := 0; i < 10; i++ { + n, err := wt.WriteTo(&w) + if n != int64(len(expectedS)) { + t.Fatalf("unexpected n returned from WriteTo: %d. Expecting %d", n, len(expectedS)) + } + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + s := string(w.Bytes()) + if s != expectedS { + t.Fatalf("unexpected string written %q. Expecting %q", s, expectedS) + } + w.Reset() + } +} + func TestByteBufferGetPutSerial(t *testing.T) { testByteBufferGetPut(t) }