quicvarint: add Reader and Writer interfaces (#3233)

This commit is contained in:
Randy Reddig 2021-08-05 10:49:17 -07:00 committed by GitHub
parent 79ce9740a4
commit 346bd63a60
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 171 additions and 53 deletions

View file

@ -148,7 +148,7 @@ func (c *client) handleUnidirectionalStreams() {
}
go func() {
streamType, err := quicvarint.Read(&byteReaderImpl{str})
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return

View file

@ -10,33 +10,15 @@ import (
"github.com/lucas-clemente/quic-go/quicvarint"
)
type byteReader interface {
io.ByteReader
io.Reader
}
type byteReaderImpl struct{ io.Reader }
func (br *byteReaderImpl) ReadByte() (byte, error) {
b := make([]byte, 1)
if _, err := br.Reader.Read(b); err != nil {
return 0, err
}
return b[0], nil
}
type frame interface{}
func parseNextFrame(b io.Reader) (frame, error) {
br, ok := b.(byteReader)
if !ok {
br = &byteReaderImpl{b}
}
t, err := quicvarint.Read(br)
func parseNextFrame(r io.Reader) (frame, error) {
qr := quicvarint.NewReader(r)
t, err := quicvarint.Read(qr)
if err != nil {
return nil, err
}
l, err := quicvarint.Read(br)
l, err := quicvarint.Read(qr)
if err != nil {
return nil, err
}
@ -47,7 +29,7 @@ func parseNextFrame(b io.Reader) (frame, error) {
case 0x1:
return &headersFrame{Length: l}, nil
case 0x4:
return parseSettingsFrame(br, l)
return parseSettingsFrame(r, l)
case 0x3: // CANCEL_PUSH
fallthrough
case 0x5: // PUSH_PROMISE
@ -60,10 +42,10 @@ func parseNextFrame(b io.Reader) (frame, error) {
fallthrough
default:
// skip over unknown frames
if _, err := io.CopyN(ioutil.Discard, br, int64(l)); err != nil {
if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil {
return nil, err
}
return parseNextFrame(b)
return parseNextFrame(qr)
}
}

View file

@ -281,7 +281,7 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) {
}
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(&byteReaderImpl{str})
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return

65
quicvarint/io.go Normal file
View file

@ -0,0 +1,65 @@
package quicvarint
import (
"bytes"
"io"
)
// Reader implements both the io.ByteReader and io.Reader interfaces.
type Reader interface {
io.ByteReader
io.Reader
}
var _ Reader = &bytes.Reader{}
type byteReader struct {
io.Reader
}
var _ Reader = &byteReader{}
// NewReader returns a Reader for r.
// If r already implements both io.ByteReader and io.Reader, NewReader returns r.
// Otherwise, r is wrapped to add the missing interfaces.
func NewReader(r io.Reader) Reader {
if r, ok := r.(Reader); ok {
return r
}
return &byteReader{r}
}
func (r *byteReader) ReadByte() (byte, error) {
var b [1]byte
_, err := r.Reader.Read(b[:])
return b[0], err
}
// Writer implements both the io.ByteWriter and io.Writer interfaces.
type Writer interface {
io.ByteWriter
io.Writer
}
var _ Writer = &bytes.Buffer{}
type byteWriter struct {
io.Writer
}
var _ Writer = &byteWriter{}
// NewWriter returns a Writer for w.
// If r already implements both io.ByteWriter and io.Writer, NewWriter returns w.
// Otherwise, w is wrapped to add the missing interfaces.
func NewWriter(w io.Writer) Writer {
if w, ok := w.(Writer); ok {
return w
}
return &byteWriter{w}
}
func (w *byteWriter) WriteByte(c byte) error {
_, err := w.Writer.Write([]byte{c})
return err
}

72
quicvarint/io_test.go Normal file
View file

@ -0,0 +1,72 @@
package quicvarint
import (
"bytes"
"io"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type nopReader struct{}
func (r *nopReader) Read(_ []byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
var _ io.Reader = &nopReader{}
type nopWriter struct{}
func (r *nopWriter) Write(_ []byte) (int, error) {
return 0, io.ErrShortBuffer
}
var _ io.Writer = &nopWriter{}
var _ = Describe("Varint I/O", func() {
Context("Reader", func() {
Context("NewReader", func() {
It("passes through a Reader unchanged", func() {
b := bytes.NewReader([]byte{0})
r := NewReader(b)
Expect(r).To(Equal(b))
})
It("wraps an io.Reader", func() {
n := &nopReader{}
r := NewReader(n)
Expect(r).ToNot(Equal(n))
})
})
It("returns an error when reading from an underlying io.Reader fails", func() {
r := NewReader(&nopReader{})
val, err := r.ReadByte()
Expect(err).To(Equal(io.ErrUnexpectedEOF))
Expect(val).To(Equal(byte(0)))
})
})
Context("Writer", func() {
Context("NewWriter", func() {
It("passes through a Writer unchanged", func() {
b := &bytes.Buffer{}
w := NewWriter(b)
Expect(w).To(Equal(b))
})
It("wraps an io.Writer", func() {
n := &nopWriter{}
w := NewWriter(n)
Expect(w).ToNot(Equal(n))
})
})
It("returns an error when writing to an underlying io.Writer fails", func() {
w := NewWriter(&nopWriter{})
err := w.WriteByte(0)
Expect(err).To(Equal(io.ErrShortBuffer))
})
})
})

View file

@ -1,7 +1,6 @@
package quicvarint
import (
"bytes"
"fmt"
"io"
@ -16,9 +15,9 @@ const (
maxVarInt8 = 4611686018427387903
)
// Read reads a number in the QUIC varint format
func Read(b io.ByteReader) (uint64, error) {
firstByte, err := b.ReadByte()
// Read reads a number in the QUIC varint format from r.
func Read(r io.ByteReader) (uint64, error) {
firstByte, err := r.ReadByte()
if err != nil {
return 0, err
}
@ -28,53 +27,53 @@ func Read(b io.ByteReader) (uint64, error) {
if len == 1 {
return uint64(b1), nil
}
b2, err := b.ReadByte()
b2, err := r.ReadByte()
if err != nil {
return 0, err
}
if len == 2 {
return uint64(b2) + uint64(b1)<<8, nil
}
b3, err := b.ReadByte()
b3, err := r.ReadByte()
if err != nil {
return 0, err
}
b4, err := b.ReadByte()
b4, err := r.ReadByte()
if err != nil {
return 0, err
}
if len == 4 {
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
}
b5, err := b.ReadByte()
b5, err := r.ReadByte()
if err != nil {
return 0, err
}
b6, err := b.ReadByte()
b6, err := r.ReadByte()
if err != nil {
return 0, err
}
b7, err := b.ReadByte()
b7, err := r.ReadByte()
if err != nil {
return 0, err
}
b8, err := b.ReadByte()
b8, err := r.ReadByte()
if err != nil {
return 0, err
}
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
}
// Write writes a number in the QUIC varint format
func Write(b *bytes.Buffer, i uint64) {
// Write writes i in the QUIC varint format to w.
func Write(w Writer, i uint64) {
if i <= maxVarInt1 {
b.WriteByte(uint8(i))
w.WriteByte(uint8(i))
} else if i <= maxVarInt2 {
b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)})
w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)})
} else if i <= maxVarInt4 {
b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)})
w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)})
} else if i <= maxVarInt8 {
b.Write([]byte{
w.Write([]byte{
uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
@ -83,35 +82,35 @@ func Write(b *bytes.Buffer, i uint64) {
}
}
// WriteWithLen writes a number in the QUIC varint format, with the desired length.
func WriteWithLen(b *bytes.Buffer, i uint64, length protocol.ByteCount) {
// WriteWithLen writes i in the QUIC varint format with the desired length to w.
func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) {
if length != 1 && length != 2 && length != 4 && length != 8 {
panic("invalid varint length")
}
l := Len(i)
if l == length {
Write(b, i)
Write(w, i)
return
}
if l > length {
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
}
if length == 2 {
b.WriteByte(0b01000000)
w.WriteByte(0b01000000)
} else if length == 4 {
b.WriteByte(0b10000000)
w.WriteByte(0b10000000)
} else if length == 8 {
b.WriteByte(0b11000000)
w.WriteByte(0b11000000)
}
for j := protocol.ByteCount(1); j < length-l; j++ {
b.WriteByte(0)
w.WriteByte(0)
}
for j := protocol.ByteCount(0); j < l; j++ {
b.WriteByte(uint8(i >> (8 * (l - 1 - j))))
w.WriteByte(uint8(i >> (8 * (l - 1 - j))))
}
}
// Len determines the number of bytes that will be needed to write a number
// Len determines the number of bytes that will be needed to write the number i.
func Len(i uint64) protocol.ByteCount {
if i <= maxVarInt1 {
return 1