implement a caching reader needed for crypto streams

This commit is contained in:
Lucas Clemente 2016-04-18 10:59:10 +02:00
parent bc736feada
commit c430fbd5d4
3 changed files with 84 additions and 0 deletions

35
utils/caching_reader.go Normal file
View file

@ -0,0 +1,35 @@
package utils
import "bytes"
// CachingReader wraps a reader and saves all data it reads
type CachingReader struct {
buf bytes.Buffer
r ReadStream
}
// NewCachingReader returns a new CachingReader
func NewCachingReader(r ReadStream) *CachingReader {
return &CachingReader{r: r}
}
// Read implements io.Reader
func (r *CachingReader) Read(p []byte) (int, error) {
n, err := r.r.Read(p)
r.buf.Write(p[:n])
return n, err
}
// ReadByte implements io.ByteReader
func (r *CachingReader) ReadByte() (byte, error) {
b, err := r.r.ReadByte()
if err == nil {
r.buf.WriteByte(b)
}
return b, err
}
// Get the data cached
func (r *CachingReader) Get() []byte {
return r.buf.Bytes()
}

View file

@ -0,0 +1,36 @@
package utils
import (
"bytes"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Caching reader", func() {
It("caches Read()", func() {
r := bytes.NewReader([]byte("foobar"))
cr := NewCachingReader(r)
p := make([]byte, 3)
n, err := cr.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(p).To(Equal([]byte("foo")))
Expect(cr.Get()).To(Equal([]byte("foo")))
})
It("caches ReadByte()", func() {
r := bytes.NewReader([]byte("foobar"))
cr := NewCachingReader(r)
b, err := cr.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal(byte('f')))
b, err = cr.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal(byte('o')))
b, err = cr.ReadByte()
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal(byte('o')))
Expect(cr.Get()).To(Equal([]byte("foo")))
})
})

View file

@ -6,6 +6,19 @@ import (
"io"
)
// ReadStream is the read part of a QUIC stream
type ReadStream interface {
io.Reader
io.ByteReader
}
// Stream is the interface for QUIC streams
type Stream interface {
io.Reader
io.ByteReader
io.Writer
}
// ReadUintN reads N bytes
func ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64