diff --git a/http3/client.go b/http3/client.go index 8db038f8..861eaf0a 100644 --- a/http3/client.go +++ b/http3/client.go @@ -148,7 +148,7 @@ func (c *client) handleUnidirectionalStreams() { } go func() { - streamType, err := quicvarint.Read(&byteReaderImpl{str}) + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) return diff --git a/http3/frames.go b/http3/frames.go index 3f6c54a7..679f66c1 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -10,33 +10,15 @@ import ( "github.com/lucas-clemente/quic-go/quicvarint" ) -type byteReader interface { - io.ByteReader - io.Reader -} - -type byteReaderImpl struct{ io.Reader } - -func (br *byteReaderImpl) ReadByte() (byte, error) { - b := make([]byte, 1) - if _, err := br.Reader.Read(b); err != nil { - return 0, err - } - return b[0], nil -} - type frame interface{} -func parseNextFrame(b io.Reader) (frame, error) { - br, ok := b.(byteReader) - if !ok { - br = &byteReaderImpl{b} - } - t, err := quicvarint.Read(br) +func parseNextFrame(r io.Reader) (frame, error) { + qr := quicvarint.NewReader(r) + t, err := quicvarint.Read(qr) if err != nil { return nil, err } - l, err := quicvarint.Read(br) + l, err := quicvarint.Read(qr) if err != nil { return nil, err } @@ -47,7 +29,7 @@ func parseNextFrame(b io.Reader) (frame, error) { case 0x1: return &headersFrame{Length: l}, nil case 0x4: - return parseSettingsFrame(br, l) + return parseSettingsFrame(r, l) case 0x3: // CANCEL_PUSH fallthrough case 0x5: // PUSH_PROMISE @@ -60,10 +42,10 @@ func parseNextFrame(b io.Reader) (frame, error) { fallthrough default: // skip over unknown frames - if _, err := io.CopyN(ioutil.Discard, br, int64(l)); err != nil { + if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { return nil, err } - return parseNextFrame(b) + return parseNextFrame(qr) } } diff --git a/http3/server.go b/http3/server.go index ad3cdaf3..b798abd4 100644 --- a/http3/server.go +++ b/http3/server.go @@ -281,7 +281,7 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) { } go func(str quic.ReceiveStream) { - streamType, err := quicvarint.Read(&byteReaderImpl{str}) + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) return diff --git a/quicvarint/io.go b/quicvarint/io.go new file mode 100644 index 00000000..c4d976b5 --- /dev/null +++ b/quicvarint/io.go @@ -0,0 +1,65 @@ +package quicvarint + +import ( + "bytes" + "io" +) + +// Reader implements both the io.ByteReader and io.Reader interfaces. +type Reader interface { + io.ByteReader + io.Reader +} + +var _ Reader = &bytes.Reader{} + +type byteReader struct { + io.Reader +} + +var _ Reader = &byteReader{} + +// NewReader returns a Reader for r. +// If r already implements both io.ByteReader and io.Reader, NewReader returns r. +// Otherwise, r is wrapped to add the missing interfaces. +func NewReader(r io.Reader) Reader { + if r, ok := r.(Reader); ok { + return r + } + return &byteReader{r} +} + +func (r *byteReader) ReadByte() (byte, error) { + var b [1]byte + _, err := r.Reader.Read(b[:]) + return b[0], err +} + +// Writer implements both the io.ByteWriter and io.Writer interfaces. +type Writer interface { + io.ByteWriter + io.Writer +} + +var _ Writer = &bytes.Buffer{} + +type byteWriter struct { + io.Writer +} + +var _ Writer = &byteWriter{} + +// NewWriter returns a Writer for w. +// If r already implements both io.ByteWriter and io.Writer, NewWriter returns w. +// Otherwise, w is wrapped to add the missing interfaces. +func NewWriter(w io.Writer) Writer { + if w, ok := w.(Writer); ok { + return w + } + return &byteWriter{w} +} + +func (w *byteWriter) WriteByte(c byte) error { + _, err := w.Writer.Write([]byte{c}) + return err +} diff --git a/quicvarint/io_test.go b/quicvarint/io_test.go new file mode 100644 index 00000000..cc28cd90 --- /dev/null +++ b/quicvarint/io_test.go @@ -0,0 +1,72 @@ +package quicvarint + +import ( + "bytes" + "io" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type nopReader struct{} + +func (r *nopReader) Read(_ []byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +var _ io.Reader = &nopReader{} + +type nopWriter struct{} + +func (r *nopWriter) Write(_ []byte) (int, error) { + return 0, io.ErrShortBuffer +} + +var _ io.Writer = &nopWriter{} + +var _ = Describe("Varint I/O", func() { + Context("Reader", func() { + Context("NewReader", func() { + It("passes through a Reader unchanged", func() { + b := bytes.NewReader([]byte{0}) + r := NewReader(b) + Expect(r).To(Equal(b)) + }) + + It("wraps an io.Reader", func() { + n := &nopReader{} + r := NewReader(n) + Expect(r).ToNot(Equal(n)) + }) + }) + + It("returns an error when reading from an underlying io.Reader fails", func() { + r := NewReader(&nopReader{}) + val, err := r.ReadByte() + Expect(err).To(Equal(io.ErrUnexpectedEOF)) + Expect(val).To(Equal(byte(0))) + }) + }) + + Context("Writer", func() { + Context("NewWriter", func() { + It("passes through a Writer unchanged", func() { + b := &bytes.Buffer{} + w := NewWriter(b) + Expect(w).To(Equal(b)) + }) + + It("wraps an io.Writer", func() { + n := &nopWriter{} + w := NewWriter(n) + Expect(w).ToNot(Equal(n)) + }) + }) + + It("returns an error when writing to an underlying io.Writer fails", func() { + w := NewWriter(&nopWriter{}) + err := w.WriteByte(0) + Expect(err).To(Equal(io.ErrShortBuffer)) + }) + }) +}) diff --git a/quicvarint/varint.go b/quicvarint/varint.go index fd3ae11d..723c57c4 100644 --- a/quicvarint/varint.go +++ b/quicvarint/varint.go @@ -1,7 +1,6 @@ package quicvarint import ( - "bytes" "fmt" "io" @@ -16,9 +15,9 @@ const ( maxVarInt8 = 4611686018427387903 ) -// Read reads a number in the QUIC varint format -func Read(b io.ByteReader) (uint64, error) { - firstByte, err := b.ReadByte() +// Read reads a number in the QUIC varint format from r. +func Read(r io.ByteReader) (uint64, error) { + firstByte, err := r.ReadByte() if err != nil { return 0, err } @@ -28,53 +27,53 @@ func Read(b io.ByteReader) (uint64, error) { if len == 1 { return uint64(b1), nil } - b2, err := b.ReadByte() + b2, err := r.ReadByte() if err != nil { return 0, err } if len == 2 { return uint64(b2) + uint64(b1)<<8, nil } - b3, err := b.ReadByte() + b3, err := r.ReadByte() if err != nil { return 0, err } - b4, err := b.ReadByte() + b4, err := r.ReadByte() if err != nil { return 0, err } if len == 4 { return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil } - b5, err := b.ReadByte() + b5, err := r.ReadByte() if err != nil { return 0, err } - b6, err := b.ReadByte() + b6, err := r.ReadByte() if err != nil { return 0, err } - b7, err := b.ReadByte() + b7, err := r.ReadByte() if err != nil { return 0, err } - b8, err := b.ReadByte() + b8, err := r.ReadByte() if err != nil { return 0, err } return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } -// Write writes a number in the QUIC varint format -func Write(b *bytes.Buffer, i uint64) { +// Write writes i in the QUIC varint format to w. +func Write(w Writer, i uint64) { if i <= maxVarInt1 { - b.WriteByte(uint8(i)) + w.WriteByte(uint8(i)) } else if i <= maxVarInt2 { - b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) + w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) } else if i <= maxVarInt4 { - b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) } else if i <= maxVarInt8 { - b.Write([]byte{ + w.Write([]byte{ uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), }) @@ -83,35 +82,35 @@ func Write(b *bytes.Buffer, i uint64) { } } -// WriteWithLen writes a number in the QUIC varint format, with the desired length. -func WriteWithLen(b *bytes.Buffer, i uint64, length protocol.ByteCount) { +// WriteWithLen writes i in the QUIC varint format with the desired length to w. +func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { if length != 1 && length != 2 && length != 4 && length != 8 { panic("invalid varint length") } l := Len(i) if l == length { - Write(b, i) + Write(w, i) return } if l > length { panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) } if length == 2 { - b.WriteByte(0b01000000) + w.WriteByte(0b01000000) } else if length == 4 { - b.WriteByte(0b10000000) + w.WriteByte(0b10000000) } else if length == 8 { - b.WriteByte(0b11000000) + w.WriteByte(0b11000000) } for j := protocol.ByteCount(1); j < length-l; j++ { - b.WriteByte(0) + w.WriteByte(0) } for j := protocol.ByteCount(0); j < l; j++ { - b.WriteByte(uint8(i >> (8 * (l - 1 - j)))) + w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) } } -// Len determines the number of bytes that will be needed to write a number +// Len determines the number of bytes that will be needed to write the number i. func Len(i uint64) protocol.ByteCount { if i <= maxVarInt1 { return 1