mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
quicvarint: add Reader and Writer interfaces (#3233)
This commit is contained in:
parent
79ce9740a4
commit
346bd63a60
6 changed files with 171 additions and 53 deletions
|
@ -148,7 +148,7 @@ func (c *client) handleUnidirectionalStreams() {
|
|||
}
|
||||
|
||||
go func() {
|
||||
streamType, err := quicvarint.Read(&byteReaderImpl{str})
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
return
|
||||
|
|
|
@ -10,33 +10,15 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
type byteReader interface {
|
||||
io.ByteReader
|
||||
io.Reader
|
||||
}
|
||||
|
||||
type byteReaderImpl struct{ io.Reader }
|
||||
|
||||
func (br *byteReaderImpl) ReadByte() (byte, error) {
|
||||
b := make([]byte, 1)
|
||||
if _, err := br.Reader.Read(b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
type frame interface{}
|
||||
|
||||
func parseNextFrame(b io.Reader) (frame, error) {
|
||||
br, ok := b.(byteReader)
|
||||
if !ok {
|
||||
br = &byteReaderImpl{b}
|
||||
}
|
||||
t, err := quicvarint.Read(br)
|
||||
func parseNextFrame(r io.Reader) (frame, error) {
|
||||
qr := quicvarint.NewReader(r)
|
||||
t, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := quicvarint.Read(br)
|
||||
l, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -47,7 +29,7 @@ func parseNextFrame(b io.Reader) (frame, error) {
|
|||
case 0x1:
|
||||
return &headersFrame{Length: l}, nil
|
||||
case 0x4:
|
||||
return parseSettingsFrame(br, l)
|
||||
return parseSettingsFrame(r, l)
|
||||
case 0x3: // CANCEL_PUSH
|
||||
fallthrough
|
||||
case 0x5: // PUSH_PROMISE
|
||||
|
@ -60,10 +42,10 @@ func parseNextFrame(b io.Reader) (frame, error) {
|
|||
fallthrough
|
||||
default:
|
||||
// skip over unknown frames
|
||||
if _, err := io.CopyN(ioutil.Discard, br, int64(l)); err != nil {
|
||||
if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseNextFrame(b)
|
||||
return parseNextFrame(qr)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -281,7 +281,7 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) {
|
|||
}
|
||||
|
||||
go func(str quic.ReceiveStream) {
|
||||
streamType, err := quicvarint.Read(&byteReaderImpl{str})
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
return
|
||||
|
|
65
quicvarint/io.go
Normal file
65
quicvarint/io.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package quicvarint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Reader implements both the io.ByteReader and io.Reader interfaces.
|
||||
type Reader interface {
|
||||
io.ByteReader
|
||||
io.Reader
|
||||
}
|
||||
|
||||
var _ Reader = &bytes.Reader{}
|
||||
|
||||
type byteReader struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
var _ Reader = &byteReader{}
|
||||
|
||||
// NewReader returns a Reader for r.
|
||||
// If r already implements both io.ByteReader and io.Reader, NewReader returns r.
|
||||
// Otherwise, r is wrapped to add the missing interfaces.
|
||||
func NewReader(r io.Reader) Reader {
|
||||
if r, ok := r.(Reader); ok {
|
||||
return r
|
||||
}
|
||||
return &byteReader{r}
|
||||
}
|
||||
|
||||
func (r *byteReader) ReadByte() (byte, error) {
|
||||
var b [1]byte
|
||||
_, err := r.Reader.Read(b[:])
|
||||
return b[0], err
|
||||
}
|
||||
|
||||
// Writer implements both the io.ByteWriter and io.Writer interfaces.
|
||||
type Writer interface {
|
||||
io.ByteWriter
|
||||
io.Writer
|
||||
}
|
||||
|
||||
var _ Writer = &bytes.Buffer{}
|
||||
|
||||
type byteWriter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
var _ Writer = &byteWriter{}
|
||||
|
||||
// NewWriter returns a Writer for w.
|
||||
// If r already implements both io.ByteWriter and io.Writer, NewWriter returns w.
|
||||
// Otherwise, w is wrapped to add the missing interfaces.
|
||||
func NewWriter(w io.Writer) Writer {
|
||||
if w, ok := w.(Writer); ok {
|
||||
return w
|
||||
}
|
||||
return &byteWriter{w}
|
||||
}
|
||||
|
||||
func (w *byteWriter) WriteByte(c byte) error {
|
||||
_, err := w.Writer.Write([]byte{c})
|
||||
return err
|
||||
}
|
72
quicvarint/io_test.go
Normal file
72
quicvarint/io_test.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package quicvarint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type nopReader struct{}
|
||||
|
||||
func (r *nopReader) Read(_ []byte) (int, error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
var _ io.Reader = &nopReader{}
|
||||
|
||||
type nopWriter struct{}
|
||||
|
||||
func (r *nopWriter) Write(_ []byte) (int, error) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
var _ io.Writer = &nopWriter{}
|
||||
|
||||
var _ = Describe("Varint I/O", func() {
|
||||
Context("Reader", func() {
|
||||
Context("NewReader", func() {
|
||||
It("passes through a Reader unchanged", func() {
|
||||
b := bytes.NewReader([]byte{0})
|
||||
r := NewReader(b)
|
||||
Expect(r).To(Equal(b))
|
||||
})
|
||||
|
||||
It("wraps an io.Reader", func() {
|
||||
n := &nopReader{}
|
||||
r := NewReader(n)
|
||||
Expect(r).ToNot(Equal(n))
|
||||
})
|
||||
})
|
||||
|
||||
It("returns an error when reading from an underlying io.Reader fails", func() {
|
||||
r := NewReader(&nopReader{})
|
||||
val, err := r.ReadByte()
|
||||
Expect(err).To(Equal(io.ErrUnexpectedEOF))
|
||||
Expect(val).To(Equal(byte(0)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Writer", func() {
|
||||
Context("NewWriter", func() {
|
||||
It("passes through a Writer unchanged", func() {
|
||||
b := &bytes.Buffer{}
|
||||
w := NewWriter(b)
|
||||
Expect(w).To(Equal(b))
|
||||
})
|
||||
|
||||
It("wraps an io.Writer", func() {
|
||||
n := &nopWriter{}
|
||||
w := NewWriter(n)
|
||||
Expect(w).ToNot(Equal(n))
|
||||
})
|
||||
})
|
||||
|
||||
It("returns an error when writing to an underlying io.Writer fails", func() {
|
||||
w := NewWriter(&nopWriter{})
|
||||
err := w.WriteByte(0)
|
||||
Expect(err).To(Equal(io.ErrShortBuffer))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,7 +1,6 @@
|
|||
package quicvarint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
|
@ -16,9 +15,9 @@ const (
|
|||
maxVarInt8 = 4611686018427387903
|
||||
)
|
||||
|
||||
// Read reads a number in the QUIC varint format
|
||||
func Read(b io.ByteReader) (uint64, error) {
|
||||
firstByte, err := b.ReadByte()
|
||||
// Read reads a number in the QUIC varint format from r.
|
||||
func Read(r io.ByteReader) (uint64, error) {
|
||||
firstByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -28,53 +27,53 @@ func Read(b io.ByteReader) (uint64, error) {
|
|||
if len == 1 {
|
||||
return uint64(b1), nil
|
||||
}
|
||||
b2, err := b.ReadByte()
|
||||
b2, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len == 2 {
|
||||
return uint64(b2) + uint64(b1)<<8, nil
|
||||
}
|
||||
b3, err := b.ReadByte()
|
||||
b3, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b4, err := b.ReadByte()
|
||||
b4, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len == 4 {
|
||||
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
|
||||
}
|
||||
b5, err := b.ReadByte()
|
||||
b5, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b6, err := b.ReadByte()
|
||||
b6, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b7, err := b.ReadByte()
|
||||
b7, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b8, err := b.ReadByte()
|
||||
b8, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
|
||||
}
|
||||
|
||||
// Write writes a number in the QUIC varint format
|
||||
func Write(b *bytes.Buffer, i uint64) {
|
||||
// Write writes i in the QUIC varint format to w.
|
||||
func Write(w Writer, i uint64) {
|
||||
if i <= maxVarInt1 {
|
||||
b.WriteByte(uint8(i))
|
||||
w.WriteByte(uint8(i))
|
||||
} else if i <= maxVarInt2 {
|
||||
b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)})
|
||||
w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)})
|
||||
} else if i <= maxVarInt4 {
|
||||
b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
} else if i <= maxVarInt8 {
|
||||
b.Write([]byte{
|
||||
w.Write([]byte{
|
||||
uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
|
@ -83,35 +82,35 @@ func Write(b *bytes.Buffer, i uint64) {
|
|||
}
|
||||
}
|
||||
|
||||
// WriteWithLen writes a number in the QUIC varint format, with the desired length.
|
||||
func WriteWithLen(b *bytes.Buffer, i uint64, length protocol.ByteCount) {
|
||||
// WriteWithLen writes i in the QUIC varint format with the desired length to w.
|
||||
func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) {
|
||||
if length != 1 && length != 2 && length != 4 && length != 8 {
|
||||
panic("invalid varint length")
|
||||
}
|
||||
l := Len(i)
|
||||
if l == length {
|
||||
Write(b, i)
|
||||
Write(w, i)
|
||||
return
|
||||
}
|
||||
if l > length {
|
||||
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
|
||||
}
|
||||
if length == 2 {
|
||||
b.WriteByte(0b01000000)
|
||||
w.WriteByte(0b01000000)
|
||||
} else if length == 4 {
|
||||
b.WriteByte(0b10000000)
|
||||
w.WriteByte(0b10000000)
|
||||
} else if length == 8 {
|
||||
b.WriteByte(0b11000000)
|
||||
w.WriteByte(0b11000000)
|
||||
}
|
||||
for j := protocol.ByteCount(1); j < length-l; j++ {
|
||||
b.WriteByte(0)
|
||||
w.WriteByte(0)
|
||||
}
|
||||
for j := protocol.ByteCount(0); j < l; j++ {
|
||||
b.WriteByte(uint8(i >> (8 * (l - 1 - j))))
|
||||
w.WriteByte(uint8(i >> (8 * (l - 1 - j))))
|
||||
}
|
||||
}
|
||||
|
||||
// Len determines the number of bytes that will be needed to write a number
|
||||
// Len determines the number of bytes that will be needed to write the number i.
|
||||
func Len(i uint64) protocol.ByteCount {
|
||||
if i <= maxVarInt1 {
|
||||
return 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue