Update deps

This commit is contained in:
Frank Denis 2024-06-03 08:40:06 +02:00
parent 35d7aa0603
commit 0059194a9e
92 changed files with 19298 additions and 13340 deletions

View file

@ -1,6 +1,3 @@
run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
misspell:
ignore-words:
@ -26,6 +23,7 @@ linters:
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
- gofumpt
- gosimple
- govet
- ineffassign
- misspell
- prealloc
@ -34,10 +32,14 @@ linters:
- unconvert
- unparam
- unused
- vet
issues:
exclude-files:
- internal/handshake/cipher_suite.go
exclude-rules:
- path: internal/qtls
linters:
- depguard
- path: _test\.go
linters:
- exhaustive

View file

@ -39,6 +39,12 @@ func validateConfig(config *Config) error {
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
config.InitialPacketSize = protocol.MinInitialPacketSize
}
if config.InitialPacketSize > protocol.MaxPacketBufferSize {
config.InitialPacketSize = protocol.MaxPacketBufferSize
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
@ -94,6 +100,10 @@ func populateConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
initialPacketSize := config.InitialPacketSize
if initialPacketSize == 0 {
initialPacketSize = protocol.InitialPacketSize
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
@ -110,6 +120,7 @@ func populateConfig(config *Config) *Config {
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
InitialPacketSize: initialPacketSize,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,

View file

@ -153,7 +153,9 @@ type connection struct {
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
maxPayloadSizeEstimate atomic.Uint32
initialStream cryptoStream
handshakeStream cryptoStream
@ -276,7 +278,7 @@ var newConnection = func(
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
getMaxPacketSize(s.conn.RemoteAddr()),
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
@ -284,7 +286,7 @@ var newConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -295,6 +297,7 @@ var newConnection = func(
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
DisableActiveMigration: true,
StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID,
@ -385,7 +388,7 @@ var newClientConnection = func(
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
false, // has no effect
s.conn.capabilities().ECN,
@ -393,7 +396,7 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -404,6 +407,7 @@ var newClientConnection = func(
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
@ -781,11 +785,7 @@ func (s *connection) handleHandshakeConfirmed() error {
s.cryptoStreamHandler.SetHandshakeConfirmed()
if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize))
s.mtuDiscoverer.Start()
}
return nil
}
@ -1774,6 +1774,17 @@ func (s *connection) applyTransportParameters() {
// Retire the connection ID.
s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
}
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = params.MaxUDPPayloadSize
}
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
s.onMTUIncreased,
s.tracer,
)
}
func (s *connection) triggerSending(now time.Time) error {
@ -1862,7 +1873,7 @@ func (s *connection) sendPackets(now time.Time) error {
}
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), s.version)
if err != nil || packet == nil {
return err
}
@ -1889,7 +1900,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if _, err := s.appendOneShortHeaderPacket(buf, s.maxPacketSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
@ -1920,7 +1931,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
maxSize := s.maxPacketSize()
ecn := s.sentPacketHandler.ECNMode(true)
for {
@ -1989,7 +2000,7 @@ func (s *connection) resetPacingDeadline() {
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), s.version)
if err != nil {
return err
}
@ -2000,7 +2011,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
}
ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), s.version)
if err != nil {
if err == errNothingToPack {
return nil
@ -2022,7 +2033,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
break
}
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
@ -2033,7 +2044,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil {
s.retransmissionQueue.AddPing(encLevel)
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
@ -2112,14 +2123,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackConnectionClose(transportErr, s.maxPacketSize(), s.version)
} else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackApplicationClose(applicationErr, s.maxPacketSize(), s.version)
} else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.mtuDiscoverer.CurrentSize(), s.version)
}, s.maxPacketSize(), s.version)
}
if err != nil {
return nil, err
@ -2129,6 +2140,24 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
}
func (s *connection) maxPacketSize() protocol.ByteCount {
if s.mtuDiscoverer == nil {
// Use the configured packet size on the client side.
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
// Apparently the server still processed the (fully padded) Initial packet anyway.
if s.perspective == protocol.PerspectiveClient {
return protocol.ByteCount(s.config.InitialPacketSize)
}
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
// parameters:
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
// need a lot of bytes for that.
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
return protocol.MinInitialPacketSize
}
return s.mtuDiscoverer.CurrentSize()
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
@ -2352,13 +2381,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
}
}
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}
func (s *connection) SendDatagram(p []byte) error {
if !s.supportsDatagrams() {
return errors.New("datagram support disabled")
}
f := &wire.DatagramFrame{DataLenPresent: true}
maxDataLen := f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version)
// The payload size estimate is conservative.
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
}
@ -2391,3 +2430,10 @@ func (s *connection) NextConnection() Connection {
s.streamsMap.UseResetMaps()
return s
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
// connection ID length), and the size of the encryption tag.
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
}

View file

@ -21,13 +21,29 @@ func (r *exactReader) Read(b []byte) (int, error) {
return n, err
}
type countingByteReader struct {
io.ByteReader
Read int
}
func (r *countingByteReader) ReadByte() (byte, error) {
b, err := r.ByteReader.ReadByte()
if err == nil {
r.Read++
}
return b, err
}
// ParseCapsule parses the header of a Capsule.
// It returns an io.LimitedReader that can be used to read the Capsule value.
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
ct, err := quicvarint.Read(r)
cbr := countingByteReader{ByteReader: r}
ct, err := quicvarint.Read(&cbr)
if err != nil {
if err == io.EOF {
// If an io.EOF is returned without consuming any bytes, return it unmodified.
// Otherwise, return an io.ErrUnexpectedEOF.
if err == io.EOF && cbr.Read > 0 {
return 0, nil, io.ErrUnexpectedEOF
}
return 0, nil, err

View file

@ -121,12 +121,16 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
}
return
}
go func(str quic.Stream) {
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
fp := &frameParser{
r: str,
conn: c.hconn,
unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return c.StreamHijacker(ft, id, str, e)
})
if err == errHijacked {
},
}
go func() {
if _, err := fp.ParseNext(); err == errHijacked {
return
}
if err != nil {
@ -135,7 +139,7 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
}
}
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}(str)
}()
}
}

View file

@ -1,7 +1,6 @@
package http3
import (
"bytes"
"context"
"fmt"
"log/slog"
@ -71,27 +70,11 @@ func newConnection(
return c
}
func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) {
func (c *connection) clearStream(id quic.StreamID) {
c.streamMx.Lock()
defer c.streamMx.Unlock()
d, ok := c.streams[id]
if !ok { // should never happen
return
}
var isDone bool
//nolint:exhaustive // These are all the cases we care about.
switch state {
case streamStateReceiveClosed:
isDone = d.SetReceiveError(e)
case streamStateSendClosed:
isDone = d.SetSendError(e)
default:
return
}
if isDone {
delete(c.streams, id)
}
delete(c.streams, id)
}
func (c *connection) openRequestStream(
@ -109,7 +92,7 @@ func (c *connection) openRequestStream(
c.streamMx.Lock()
c.streams[str.StreamID()] = datagrams
c.streamMx.Unlock()
qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
qstr := newStateTrackingStream(str, c, datagrams)
hstr := newStream(qstr, c, datagrams)
return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil
}
@ -121,10 +104,11 @@ func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagramme
}
datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
if c.perspective == protocol.PerspectiveServer {
strID := str.StreamID()
c.streamMx.Lock()
c.streams[str.StreamID()] = datagrams
c.streams[strID] = datagrams
c.streamMx.Unlock()
str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
str = newStateTrackingStream(str, c, datagrams)
}
return str, datagrams, nil
}
@ -201,7 +185,8 @@ func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.Co
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
return
}
f, err := parseNextFrame(str, nil)
fp := &frameParser{conn: c.Connection, r: str}
f, err := fp.ParseNext()
if err != nil {
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
return
@ -252,9 +237,7 @@ func (c *connection) receiveDatagrams() error {
if err != nil {
return err
}
// TODO: this is quite wasteful in terms of allocations
r := bytes.NewReader(b)
quarterStreamID, err := quicvarint.Read(r)
quarterStreamID, n, err := quicvarint.Parse(b)
if err != nil {
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
return fmt.Errorf("could not read quarter stream id: %w", err)
@ -271,7 +254,7 @@ func (c *connection) receiveDatagrams() error {
return nil
}
c.streamMx.Unlock()
dg.enqueue(b[len(b)-r.Len():])
dg.enqueue(b[n:])
}
}

View file

@ -27,21 +27,19 @@ func newDatagrammer(sendDatagram func([]byte) error) *datagrammer {
}
}
func (d *datagrammer) SetReceiveError(err error) (isDone bool) {
func (d *datagrammer) SetReceiveError(err error) {
d.mx.Lock()
defer d.mx.Unlock()
d.receiveErr = err
d.signalHasData()
return d.sendErr != nil
}
func (d *datagrammer) SetSendError(err error) (isDone bool) {
func (d *datagrammer) SetSendError(err error) {
d.mx.Lock()
defer d.mx.Unlock()
d.sendErr = err
return d.receiveErr != nil
}
func (d *datagrammer) Send(b []byte) error {
@ -85,9 +83,9 @@ start:
d.mx.Unlock()
return data, nil
}
if d.receiveErr != nil {
if receiveErr := d.receiveErr; receiveErr != nil {
d.mx.Unlock()
return nil, d.receiveErr
return nil, receiveErr
}
d.mx.Unlock()

View file

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/quicvarint"
)
@ -18,13 +19,19 @@ type frame interface{}
var errHijacked = errors.New("hijacked")
func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
qr := quicvarint.NewReader(r)
type frameParser struct {
r io.Reader
conn quic.Connection
unknownFrameHandler unknownFrameHandlerFunc
}
func (p *frameParser) ParseNext() (frame, error) {
qr := quicvarint.NewReader(p.r)
for {
t, err := quicvarint.Read(qr)
if err != nil {
if unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(0, err)
if p.unknownFrameHandler != nil {
hijacked, err := p.unknownFrameHandler(0, err)
if err != nil {
return nil, err
}
@ -35,8 +42,8 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
return nil, err
}
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t), nil)
if t > 0xd && p.unknownFrameHandler != nil {
hijacked, err := p.unknownFrameHandler(FrameType(t), nil)
if err != nil {
return nil, err
}
@ -56,11 +63,14 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
case 0x1:
return &headersFrame{Length: l}, nil
case 0x4:
return parseSettingsFrame(r, l)
return parseSettingsFrame(p.r, l)
case 0x3: // CANCEL_PUSH
case 0x5: // PUSH_PROMISE
case 0x7: // GOAWAY
case 0xd: // MAX_PUSH_ID
case 0x2, 0x6, 0x8, 0x9:
p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
return nil, fmt.Errorf("http3: reserved frame type: %d", t)
}
// skip over unknown frames
if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil {

View file

@ -63,10 +63,14 @@ func newStream(str quic.Stream, conn *connection, datagrams *datagrammer) *strea
}
func (s *stream) Read(b []byte) (int, error) {
fp := &frameParser{
r: s.Stream,
conn: s.conn,
}
if s.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(s.Stream, nil)
frame, err := fp.ParseNext()
if err != nil {
return 0, err
}
@ -177,7 +181,11 @@ func (s *requestStream) SendRequestHeader(req *http.Request) error {
}
func (s *requestStream) ReadResponse() (*http.Response, error) {
frame, err := parseNextFrame(s.Stream, nil)
fp := &frameParser{
r: s.Stream,
conn: s.conn,
}
frame, err := fp.ParseNext()
if err != nil {
s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
@ -250,7 +258,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
func (s *stream) SendDatagram(b []byte) error {
// TODO: reject if datagrams are not negotiated (yet)
return s.conn.sendDatagram(s.Stream.StreamID(), b)
return s.datagrams.Send(b)
}
func (s *stream) ReceiveDatagram(ctx context.Context) ([]byte, error) {

View file

@ -477,7 +477,8 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
)
}
}
frame, err := parseNextFrame(str, ufh)
fp := &frameParser{conn: conn, r: str, unknownFrameHandler: ufh}
frame, err := fp.ParseNext()
if err != nil {
if !errors.Is(err, errHijacked) {
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
@ -665,11 +666,16 @@ func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) er
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServe listens on the given network address for both TLS/TCP and QUIC
// Deprecated: use ListenAndServeTLS instead.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
return ListenAndServeTLS(addr, certFile, keyFile, handler)
}
// ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC
// connections in parallel. It returns if one of the two returns an error.
// http.DefaultServeMux is used when handler is nil.
// The correct Alt-Svc headers for QUIC are set.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
func ListenAndServeTLS(addr, certFile, keyFile string, handler http.Handler) error {
// Load certs
var err error
certs := make([]tls.Certificate, 1)

View file

@ -1,62 +1,87 @@
package http3
import (
"context"
"errors"
"os"
"sync"
"github.com/quic-go/quic-go"
)
type streamState uint8
const (
streamStateOpen streamState = iota
streamStateReceiveClosed
streamStateSendClosed
streamStateSendAndReceiveClosed
)
var _ quic.Stream = &stateTrackingStream{}
// stateTrackingStream is an implementation of quic.Stream that delegates
// to an underlying stream
// it takes care of proxying send and receive errors onto an implementation of
// the errorSetter interface (intended to be occupied by a datagrammer)
// it is also responsible for clearing the stream based on its ID from its
// parent connection, this is done through the streamClearer interface when
// both the send and receive sides are closed
type stateTrackingStream struct {
quic.Stream
mx sync.Mutex
state streamState
mx sync.Mutex
sendErr error
recvErr error
onStateChange func(streamState, error)
clearer streamClearer
setter errorSetter
}
func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream {
return &stateTrackingStream{
Stream: s,
state: streamStateOpen,
onStateChange: onStateChange,
type streamClearer interface {
clearStream(quic.StreamID)
}
type errorSetter interface {
SetSendError(error)
SetReceiveError(error)
}
func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
t := &stateTrackingStream{
Stream: s,
clearer: clearer,
setter: setter,
}
}
var _ quic.Stream = &stateTrackingStream{}
context.AfterFunc(s.Context(), func() {
t.closeSend(context.Cause(s.Context()))
})
return t
}
func (s *stateTrackingStream) closeSend(e error) {
s.mx.Lock()
defer s.mx.Unlock()
if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed {
s.state = streamStateSendAndReceiveClosed
} else {
s.state = streamStateSendClosed
// clear the stream the first time both the send
// and receive are finished
if s.sendErr == nil {
if s.recvErr != nil {
s.clearer.clearStream(s.StreamID())
}
s.setter.SetSendError(e)
s.sendErr = e
}
s.onStateChange(s.state, e)
}
func (s *stateTrackingStream) closeReceive(e error) {
s.mx.Lock()
defer s.mx.Unlock()
if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed {
s.state = streamStateSendAndReceiveClosed
} else {
s.state = streamStateReceiveClosed
// clear the stream the first time both the send
// and receive are finished
if s.recvErr == nil {
if s.sendErr != nil {
s.clearer.clearStream(s.StreamID())
}
s.setter.SetReceiveError(e)
s.recvErr = e
}
s.onStateChange(s.state, e)
}
func (s *stateTrackingStream) Close() error {
@ -71,7 +96,7 @@ func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
func (s *stateTrackingStream) Write(b []byte) (int, error) {
n, err := s.Stream.Write(b)
if err != nil {
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
s.closeSend(err)
}
return n, err
@ -84,7 +109,7 @@ func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
func (s *stateTrackingStream) Read(b []byte) (int, error) {
n, err := s.Stream.Read(b)
if err != nil {
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
s.closeReceive(err)
}
return n, err

View file

@ -325,10 +325,16 @@ type Config struct {
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration
// InitialPacketSize is the initial size of packets sent.
// It is usually not necessary to manually set this value,
// since Path MTU discovery very quickly finds the path's MTU.
// If set too high, the path might not support packets that large, leading to a timeout of the QUIC handshake.
// Values below 1200 are invalid.
InitialPacketSize uint16
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// If unavailable or disabled, packets will be at most 1280 bytes in size.
DisablePathMTUDiscovery bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server.

View file

@ -17,11 +17,11 @@ import (
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const (
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
// TODO: when re-enabling cubic, make sure to use the actual packet size here
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
)
const defaultNumConnections = 1

View file

@ -12,7 +12,7 @@ import (
const (
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes.
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
maxBurstPackets = 3
renoBeta = 0.7 // Reno backoff factor.
minCongestionWindowPackets = 2

View file

@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -338,25 +337,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
return false
}
func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
ver, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
if ver != clientSessionStateRevision {
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rttEncoded, err := quicvarint.Read(r)
rttEncoded, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
rtt := time.Duration(rttEncoded) * time.Microsecond
if !earlyData {
return rtt, nil, nil
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return 0, nil, err
}
return rtt, &tp, nil

View file

@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
}
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
rev, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read session ticket revision")
}
b = b[l:]
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
rtt, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read RTT")
}
b = b[l:]
if using0RTT {
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
} else if len(b) > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
}
t.RTT = time.Duration(rtt) * time.Microsecond

View file

@ -8,11 +8,8 @@ const DesiredReceiveBufferSize = (1 << 20) * 7 // 7 MB
// DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use.
const DesiredSendBufferSize = (1 << 20) * 7 // 7 MB
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
const InitialPacketSizeIPv4 = 1252
// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const InitialPacketSizeIPv6 = 1232
// InitialPacketSize is the initial (before Path MTU discovery) maximum packet size used.
const InitialPacketSize = 1280
// MaxCongestionWindowPackets is the maximum congestion window in packet.
const MaxCongestionWindowPackets = 10000

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"sort"
"time"
@ -22,18 +21,21 @@ type AckFrame struct {
}
// parseAckFrame reads an ACK frame
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8, _ protocol.Version) (int, error) {
startLen := len(b)
ecn := typ == ackECNFrameType
la, err := quicvarint.Read(r)
la, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
largestAcked := protocol.PacketNumber(la)
delay, err := quicvarint.Read(r)
delay, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 {
@ -42,71 +44,78 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
}
frame.DelayTime = delayTime
numBlocks, err := quicvarint.Read(r)
numBlocks, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
// read the first ACK range
ab, err := quicvarint.Read(r)
ab, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largestAcked {
return errors.New("invalid first ACK range")
return 0, errors.New("invalid first ACK range")
}
smallest := largestAcked - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
// read all the other ACK ranges
for i := uint64(0); i < numBlocks; i++ {
g, err := quicvarint.Read(r)
g, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
gap := protocol.PacketNumber(g)
if smallest < gap+2 {
return errInvalidAckRanges
return 0, errInvalidAckRanges
}
largest := smallest - gap - 2
ab, err := quicvarint.Read(r)
ab, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largest {
return errInvalidAckRanges
return 0, errInvalidAckRanges
}
smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
}
if !frame.validateAckRanges() {
return errInvalidAckRanges
return 0, errInvalidAckRanges
}
if ecn {
ect0, err := quicvarint.Read(r)
ect0, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECT0 = ect0
ect1, err := quicvarint.Read(r)
ect1, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECT1 = ect1
ecnce, err := quicvarint.Read(r)
ecnce, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECNCE = ecnce
}
return nil
return startLen - len(b), nil
}
// Append appends an ACK frame.

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@ -16,40 +15,38 @@ type ConnectionCloseFrame struct {
ReasonPhrase string
}
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
func parseConnectionCloseFrame(b []byte, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, int, error) {
startLen := len(b)
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
ec, err := quicvarint.Read(r)
ec, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
f.ErrorCode = ec
// read the Frame Type, if this is not an application error
if !f.IsApplicationError {
ft, err := quicvarint.Read(r)
ft, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
f.FrameType = ft
}
var reasonPhraseLen uint64
reasonPhraseLen, err = quicvarint.Read(r)
reasonPhraseLen, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
// shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet
// reading the whole reason phrase would result in EOF when attempting to READ
if int(reasonPhraseLen) > r.Len() {
return nil, io.EOF
b = b[l:]
if int(reasonPhraseLen) > len(b) {
return nil, 0, io.EOF
}
reasonPhrase := make([]byte, reasonPhraseLen)
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
// this should never happen, since we already checked the reasonPhraseLen earlier
return nil, err
}
copy(reasonPhrase, b)
f.ReasonPhrase = string(reasonPhrase)
return f, nil
return f, startLen - len(b) + int(reasonPhraseLen), nil
}
// Length of a written frame

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@ -14,28 +13,28 @@ type CryptoFrame struct {
Data []byte
}
func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) {
startLen := len(b)
frame := &CryptoFrame{}
offset, err := quicvarint.Read(r)
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.Offset = protocol.ByteCount(offset)
dataLen, err := quicvarint.Read(r)
dataLen, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
if dataLen > uint64(r.Len()) {
return nil, io.EOF
b = b[l:]
if dataLen > uint64(len(b)) {
return nil, 0, io.EOF
}
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil {
// this should never happen, since we already checked the dataLen earlier
return nil, err
}
copy(frame.Data, b)
}
return frame, nil
return frame, startLen - len(b) + int(dataLen), nil
}
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@ -12,12 +10,12 @@ type DataBlockedFrame struct {
MaximumData protocol.ByteCount
}
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
offset, err := quicvarint.Read(r)
func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int, error) {
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, l, nil
}
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@ -20,29 +19,29 @@ type DatagramFrame struct {
Data []byte
}
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
func parseDatagramFrame(b []byte, typ uint64, _ protocol.Version) (*DatagramFrame, int, error) {
startLen := len(b)
f := &DatagramFrame{}
f.DataLenPresent = typ&0x1 > 0
var length uint64
if f.DataLenPresent {
var err error
len, err := quicvarint.Read(r)
var l int
length, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
if len > uint64(r.Len()) {
return nil, io.EOF
b = b[l:]
if length > uint64(len(b)) {
return nil, 0, io.EOF
}
length = len
} else {
length = uint64(r.Len())
length = uint64(len(b))
}
f.Data = make([]byte, length)
if _, err := io.ReadFull(r, f.Data); err != nil {
return nil, err
}
return f, nil
copy(f.Data, b)
return f, startLen - len(b) + int(length), nil
}
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,9 +1,9 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
"github.com/quic-go/quic-go/internal/protocol"
@ -38,8 +38,6 @@ const (
// The FrameParser parses QUIC frames, one by one.
type FrameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8
supportsDatagrams bool
@ -51,7 +49,6 @@ type FrameParser struct {
// NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams bool) *FrameParser {
return &FrameParser{
r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{},
}
@ -60,45 +57,46 @@ func NewFrameParser(supportsDatagrams bool) *FrameParser {
// ParseNext parses the next frame.
// It skips PADDING frames.
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
startLen := len(data)
p.r.Reset(data)
frame, err := p.parseNext(&p.r, encLevel, v)
n := startLen - p.r.Len()
p.r.Reset(nil)
return n, frame, err
frame, l, err := p.parseNext(data, encLevel, v)
return l, frame, err
}
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
for r.Len() != 0 {
typ, err := quicvarint.Read(r)
func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
var parsed int
for len(b) != 0 {
typ, l, err := quicvarint.Parse(b)
parsed += l
if err != nil {
return nil, &qerr.TransportError{
return nil, parsed, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
b = b[l:]
if typ == 0x0 { // skip PADDING frames
continue
}
f, err := p.parseFrame(r, typ, encLevel, v)
f, l, err := p.parseFrame(b, typ, encLevel, v)
parsed += l
if err != nil {
return nil, &qerr.TransportError{
return nil, parsed, &qerr.TransportError{
FrameType: typ,
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
return f, nil
return f, parsed, nil
}
return nil, nil
return nil, parsed, nil
}
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
var frame Frame
var err error
var l int
if typ&0xf8 == 0x8 {
frame, err = parseStreamFrame(r, typ, v)
frame, l, err = parseStreamFrame(b, typ, v)
} else {
switch typ {
case pingFrameType:
@ -109,43 +107,43 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
ackDelayExponent = protocol.DefaultAckDelayExponent
}
p.ackFrame.Reset()
err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v)
l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v)
frame = p.ackFrame
case resetStreamFrameType:
frame, err = parseResetStreamFrame(r, v)
frame, l, err = parseResetStreamFrame(b, v)
case stopSendingFrameType:
frame, err = parseStopSendingFrame(r, v)
frame, l, err = parseStopSendingFrame(b, v)
case cryptoFrameType:
frame, err = parseCryptoFrame(r, v)
frame, l, err = parseCryptoFrame(b, v)
case newTokenFrameType:
frame, err = parseNewTokenFrame(r, v)
frame, l, err = parseNewTokenFrame(b, v)
case maxDataFrameType:
frame, err = parseMaxDataFrame(r, v)
frame, l, err = parseMaxDataFrame(b, v)
case maxStreamDataFrameType:
frame, err = parseMaxStreamDataFrame(r, v)
frame, l, err = parseMaxStreamDataFrame(b, v)
case bidiMaxStreamsFrameType, uniMaxStreamsFrameType:
frame, err = parseMaxStreamsFrame(r, typ, v)
frame, l, err = parseMaxStreamsFrame(b, typ, v)
case dataBlockedFrameType:
frame, err = parseDataBlockedFrame(r, v)
frame, l, err = parseDataBlockedFrame(b, v)
case streamDataBlockedFrameType:
frame, err = parseStreamDataBlockedFrame(r, v)
frame, l, err = parseStreamDataBlockedFrame(b, v)
case bidiStreamBlockedFrameType, uniStreamBlockedFrameType:
frame, err = parseStreamsBlockedFrame(r, typ, v)
frame, l, err = parseStreamsBlockedFrame(b, typ, v)
case newConnectionIDFrameType:
frame, err = parseNewConnectionIDFrame(r, v)
frame, l, err = parseNewConnectionIDFrame(b, v)
case retireConnectionIDFrameType:
frame, err = parseRetireConnectionIDFrame(r, v)
frame, l, err = parseRetireConnectionIDFrame(b, v)
case pathChallengeFrameType:
frame, err = parsePathChallengeFrame(r, v)
frame, l, err = parsePathChallengeFrame(b, v)
case pathResponseFrameType:
frame, err = parsePathResponseFrame(r, v)
frame, l, err = parsePathResponseFrame(b, v)
case connectionCloseFrameType, applicationCloseFrameType:
frame, err = parseConnectionCloseFrame(r, typ, v)
frame, l, err = parseConnectionCloseFrame(b, typ, v)
case handshakeDoneFrameType:
frame = &HandshakeDoneFrame{}
case 0x30, 0x31:
if p.supportsDatagrams {
frame, err = parseDatagramFrame(r, typ, v)
frame, l, err = parseDatagramFrame(b, typ, v)
break
}
fallthrough
@ -154,12 +152,12 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
}
}
if err != nil {
return nil, err
return nil, 0, err
}
if !p.isAllowedAtEncLevel(frame, encLevel) {
return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
}
return frame, nil
return frame, l, nil
}
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
@ -190,3 +188,10 @@ func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp
}
func replaceUnexpectedEOF(e error) error {
if e == io.ErrUnexpectedEOF {
return io.EOF
}
return e
}

View file

@ -8,7 +8,6 @@ import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@ -139,18 +138,18 @@ type Header struct {
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParsePacket parses a packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// ParsePacket parses a long header packet.
// The packet is cut according to the length field.
// If we understand the version, the packet is parsed up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
return nil, nil, nil, errors.New("not a long header packet")
}
hdr, err := parseHeader(bytes.NewReader(data))
hdr, err := parseHeader(data)
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
if errors.Is(err, ErrUnsupportedVersion) {
return hdr, nil, nil, err
}
return nil, nil, nil, err
}
@ -161,55 +160,55 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
return hdr, data[:packetLen], data[packetLen:], nil
}
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// ParseHeader parses the header:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader) (*Header, error) {
startLen := b.Len()
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
func parseHeader(b []byte) (*Header, error) {
if len(b) == 0 {
return nil, io.EOF
}
typeByte := b[0]
h := &Header{typeByte: typeByte}
err = h.parseLongHeader(b)
h.parsedLen = protocol.ByteCount(startLen - b.Len())
l, err := h.parseLongHeader(b[1:])
h.parsedLen = protocol.ByteCount(l) + 1
return h, err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
func (h *Header) parseLongHeader(b []byte) (int, error) {
startLen := len(b)
if len(b) < 5 {
return 0, io.EOF
}
h.Version = protocol.Version(v)
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
return startLen - len(b), errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return err
destConnIDLen := int(b[4])
if destConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
if err != nil {
return err
b = b[5:]
if len(b) < destConnIDLen+1 {
return startLen - len(b), io.EOF
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return err
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
srcConnIDLen := int(b[destConnIDLen])
if srcConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
if err != nil {
return err
b = b[destConnIDLen+1:]
if len(b) < srcConnIDLen {
return startLen - len(b), io.EOF
}
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
b = b[srcConnIDLen:]
if h.Version == 0 { // version negotiation packet
return nil
return startLen - len(b), nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return ErrUnsupportedVersion
return startLen - len(b), ErrUnsupportedVersion
}
if h.Version == protocol.Version2 {
@ -237,38 +236,35 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
}
if h.Type == protocol.PacketTypeRetry {
tokenLen := b.Len() - 16
tokenLen := len(b) - 16
if tokenLen <= 0 {
return io.EOF
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
_, err := b.Seek(16, io.SeekCurrent)
return err
copy(h.Token, b[:tokenLen])
return startLen - len(b) + tokenLen + 16, nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := quicvarint.Read(b)
tokenLen, n, err := quicvarint.Parse(b)
if err != nil {
return err
return startLen - len(b), err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
b = b[n:]
if tokenLen > uint64(len(b)) {
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
copy(h.Token, b[:tokenLen])
b = b[tokenLen:]
}
pl, err := quicvarint.Read(b)
pl, n, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, err
}
h.Length = protocol.ByteCount(pl)
return nil
return startLen - len(b) + n, nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@ -13,14 +11,14 @@ type MaxDataFrame struct {
}
// parseMaxDataFrame parses a MAX_DATA frame
func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) {
frame := &MaxDataFrame{}
byteOffset, err := quicvarint.Read(r)
byteOffset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
frame.MaximumData = protocol.ByteCount(byteOffset)
return frame, nil
return frame, l, nil
}
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@ -13,20 +11,23 @@ type MaxStreamDataFrame struct {
MaximumStreamData protocol.ByteCount
}
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
sid, err := quicvarint.Read(r)
func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, int, error) {
startLen := len(b)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
offset, err := quicvarint.Read(r)
b = b[l:]
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &MaxStreamDataFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, nil
}, startLen - len(b), nil
}
func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
@ -14,7 +13,7 @@ type MaxStreamsFrame struct {
MaxStreamNum protocol.StreamNum
}
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreamsFrame, int, error) {
f := &MaxStreamsFrame{}
switch typ {
case bidiMaxStreamsFrameType:
@ -22,15 +21,15 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*Max
case uniMaxStreamsFrameType:
f.Type = protocol.StreamTypeUni
}
streamID, err := quicvarint.Read(r)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
f.MaxStreamNum = protocol.StreamNum(streamID)
if f.MaxStreamNum > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
}
return f, nil
return f, l, nil
}
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
@ -18,43 +17,47 @@ type NewConnectionIDFrame struct {
StatelessResetToken protocol.StatelessResetToken
}
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
seq, err := quicvarint.Read(r)
func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFrame, int, error) {
startLen := len(b)
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
ret, err := quicvarint.Read(r)
b = b[l:]
ret, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if ret > seq {
//nolint:stylecheck
return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
}
connIDLen, err := r.ReadByte()
if err != nil {
return nil, err
if len(b) == 0 {
return nil, 0, io.EOF
}
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 {
return nil, errors.New("invalid zero-length connection ID")
return nil, 0, errors.New("invalid zero-length connection ID")
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return nil, err
if connIDLen > protocol.MaxConnIDLen {
return nil, 0, protocol.ErrInvalidConnectionIDLen
}
if len(b) < connIDLen {
return nil, 0, io.EOF
}
frame := &NewConnectionIDFrame{
SequenceNumber: seq,
RetirePriorTo: ret,
ConnectionID: connID,
ConnectionID: protocol.ParseConnectionID(b[:connIDLen]),
}
if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
b = b[connIDLen:]
if len(b) < len(frame.StatelessResetToken) {
return nil, 0, io.EOF
}
return frame, nil
copy(frame.StatelessResetToken[:], b)
return frame, startLen - len(b) + len(frame.StatelessResetToken), nil
}
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"io"
@ -14,22 +13,21 @@ type NewTokenFrame struct {
Token []byte
}
func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
tokenLen, err := quicvarint.Read(r)
func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, error) {
tokenLen, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
}
if uint64(r.Len()) < tokenLen {
return nil, io.EOF
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if tokenLen == 0 {
return nil, errors.New("token must not be empty")
return nil, 0, errors.New("token must not be empty")
}
if uint64(len(b)) < tokenLen {
return nil, 0, io.EOF
}
token := make([]byte, int(tokenLen))
if _, err := io.ReadFull(r, token); err != nil {
return nil, err
}
return &NewTokenFrame{Token: token}, nil
copy(token, b)
return &NewTokenFrame{Token: token}, l + int(tokenLen), nil
}
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@ -12,15 +11,13 @@ type PathChallengeFrame struct {
Data [8]byte
}
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
frame := &PathChallengeFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, int, error) {
f := &PathChallengeFrame{}
if len(b) < 8 {
return nil, 0, io.EOF
}
return frame, nil
copy(f.Data[:], b)
return f, 8, nil
}
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@ -12,15 +11,13 @@ type PathResponseFrame struct {
Data [8]byte
}
func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
frame := &PathResponseFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, int, error) {
f := &PathResponseFrame{}
if len(b) < 8 {
return nil, 0, io.EOF
}
return frame, nil
copy(f.Data[:], b)
return f, 8, nil
}
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
@ -15,21 +13,24 @@ type ResetStreamFrame struct {
FinalSize protocol.ByteCount
}
func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
func parseResetStreamFrame(b []byte, _ protocol.Version) (*ResetStreamFrame, int, error) {
startLen := len(b)
var streamID protocol.StreamID
var byteOffset protocol.ByteCount
sid, err := quicvarint.Read(r)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
streamID = protocol.StreamID(sid)
errorCode, err := quicvarint.Read(r)
errorCode, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
bo, err := quicvarint.Read(r)
b = b[l:]
bo, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
byteOffset = protocol.ByteCount(bo)
@ -37,7 +38,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFra
StreamID: streamID,
ErrorCode: qerr.StreamErrorCode(errorCode),
FinalSize: byteOffset,
}, nil
}, startLen - len(b) + l, nil
}
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@ -12,12 +10,12 @@ type RetireConnectionIDFrame struct {
SequenceNumber uint64
}
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
seq, err := quicvarint.Read(r)
func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnectionIDFrame, int, error) {
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
return &RetireConnectionIDFrame{SequenceNumber: seq}, l, nil
}
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
@ -15,20 +13,23 @@ type StopSendingFrame struct {
}
// parseStopSendingFrame parses a STOP_SENDING frame
func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
streamID, err := quicvarint.Read(r)
func parseStopSendingFrame(b []byte, _ protocol.Version) (*StopSendingFrame, int, error) {
startLen := len(b)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
errorCode, err := quicvarint.Read(r)
b = b[l:]
errorCode, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &StopSendingFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: qerr.StreamErrorCode(errorCode),
}, nil
}, startLen - len(b), nil
}
// Length of a written frame

View file

@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@ -13,20 +11,22 @@ type StreamDataBlockedFrame struct {
MaximumStreamData protocol.ByteCount
}
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
sid, err := quicvarint.Read(r)
func parseStreamDataBlockedFrame(b []byte, _ protocol.Version) (*StreamDataBlockedFrame, int, error) {
startLen := len(b)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
offset, err := quicvarint.Read(r)
b = b[l:]
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
return &StreamDataBlockedFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, nil
}, startLen - len(b) + l, nil
}
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"io"
@ -20,33 +19,41 @@ type StreamFrame struct {
fromPool bool
}
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, int, error) {
startLen := len(b)
hasOffset := typ&0b100 > 0
fin := typ&0b1 > 0
hasDataLen := typ&0b10 > 0
streamID, err := quicvarint.Read(r)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
var offset uint64
if hasOffset {
offset, err = quicvarint.Read(r)
offset, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
}
var dataLen uint64
if hasDataLen {
var err error
dataLen, err = quicvarint.Read(r)
var l int
dataLen, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if dataLen > uint64(len(b)) {
return nil, 0, io.EOF
}
} else {
// The rest of the packet is data
dataLen = uint64(r.Len())
dataLen = uint64(len(b))
}
var frame *StreamFrame
@ -57,7 +64,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
// since those StreamFrames have a buffer length of the maximum packet size.
if dataLen > uint64(cap(frame.Data)) {
return nil, io.EOF
return nil, 0, io.EOF
}
frame.Data = frame.Data[:dataLen]
}
@ -68,17 +75,14 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF
frame.DataLenPresent = hasDataLen
if dataLen != 0 {
if _, err := io.ReadFull(r, frame.Data); err != nil {
return nil, err
}
copy(frame.Data, b)
}
if frame.Offset+frame.DataLen() > protocol.MaxByteCount {
return nil, errors.New("stream data overflows maximum offset")
return nil, 0, errors.New("stream data overflows maximum offset")
}
return frame, nil
return frame, startLen - len(b) + int(dataLen), nil
}
// Write writes a STREAM frame
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if len(f.Data) == 0 && !f.Fin {
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
@ -14,7 +13,7 @@ type StreamsBlockedFrame struct {
StreamLimit protocol.StreamNum
}
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, int, error) {
f := &StreamsBlockedFrame{}
switch typ {
case bidiStreamBlockedFrameType:
@ -22,15 +21,15 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (
case uniStreamBlockedFrameType:
f.Type = protocol.StreamTypeUni
}
streamLimit, err := quicvarint.Read(r)
streamLimit, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
f.StreamLimit = protocol.StreamNum(streamLimit)
if f.StreamLimit > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
}
return f, nil
return f, l, nil
}
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View file

@ -1,19 +1,17 @@
package wire
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net/netip"
"sort"
"slices"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@ -89,7 +87,7 @@ type TransportParameters struct {
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
if err := p.unmarshal(data, sentBy, false); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
@ -98,9 +96,9 @@ func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective
return nil
}
func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
parameterIDs := make([]transportParameterID, 0, 32)
var (
readOriginalDestinationConnectionID bool
@ -112,18 +110,20 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
p.MaxAckDelay = protocol.DefaultMaxAckDelay
p.MaxDatagramFrameSize = protocol.InvalidByteCount
for r.Len() > 0 {
paramIDInt, err := quicvarint.Read(r)
for len(b) > 0 {
paramIDInt, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
paramID := transportParameterID(paramIDInt)
paramLen, err := quicvarint.Read(r)
b = b[l:]
paramLen, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if uint64(r.Len()) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
b = b[l:]
if uint64(len(b)) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen)
}
parameterIDs = append(parameterIDs, paramID)
switch paramID {
@ -141,16 +141,18 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
maxAckDelayParameterID,
maxDatagramFrameSizeParameterID,
ackDelayExponentParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case preferredAddressParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a preferred_address")
}
if err := p.readPreferredAddress(r, int(paramLen)); err != nil {
if err := p.readPreferredAddress(b, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case disableActiveMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen)
@ -164,25 +166,41 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
var token protocol.StatelessResetToken
r.Read(token[:])
if len(b) < len(token) {
return io.EOF
}
copy(token[:], b)
b = b[len(token):]
p.StatelessResetToken = &token
case originalDestinationConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_destination_connection_id")
}
p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readOriginalDestinationConnectionID = true
case initialSourceConnectionIDParameterID:
p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readInitialSourceConnectionID = true
case retrySourceConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a retry_source_connection_id")
}
connID, _ := protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
connID := protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
p.RetrySourceConnectionID = &connID
default:
r.Seek(int64(paramLen), io.SeekCurrent)
b = b[paramLen:]
}
}
@ -202,7 +220,12 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
}
// check that every transport parameter was sent at most once
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
slices.SortFunc(parameterIDs, func(a, b transportParameterID) int {
if a < b {
return -1
}
return 1
})
for i := 0; i < len(parameterIDs)-1; i++ {
if parameterIDs[i] == parameterIDs[i+1] {
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
@ -212,60 +235,47 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return nil
}
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
remainingLen := r.Len()
func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error {
remainingLen := len(b)
pa := &PreferredAddress{}
if len(b) < 4+2+16+2+1 {
return io.EOF
}
var ipv4 [4]byte
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
return err
}
port, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
copy(ipv4[:], b[:4])
port4 := binary.BigEndian.Uint16(b[4:])
b = b[4+2:]
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
var ipv6 [16]byte
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
return err
}
port, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
connIDLen, err := r.ReadByte()
if err != nil {
return err
}
copy(ipv6[:], b[:16])
port6 := binary.BigEndian.Uint16(b[16:])
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
b = b[16+2:]
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return err
if len(b) < connIDLen+len(pa.StatelessResetToken) {
return io.EOF
}
pa.ConnectionID = connID
if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil {
return err
}
if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen {
pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen])
b = b[connIDLen:]
copy(pa.StatelessResetToken[:], b)
b = b[len(pa.StatelessResetToken):]
if bytesRead := remainingLen - len(b); bytesRead != expectedLen {
return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead)
}
p.PreferredAddress = pa
return nil
}
func (p *TransportParameters) readNumericTransportParameter(
r *bytes.Reader,
paramID transportParameterID,
expectedLen int,
) error {
remainingLen := r.Len()
val, err := quicvarint.Read(r)
func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error {
val, l, err := quicvarint.Parse(b)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if remainingLen-r.Len() != expectedLen {
if l != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID)
}
//nolint:exhaustive // This only covers the numeric transport parameters.
@ -292,7 +302,7 @@ func (p *TransportParameters) readNumericTransportParameter(
p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
case maxUDPPayloadSizeParameterID:
if val < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
return fmt.Errorf("invalid value for max_udp_payload_size: %d (minimum 1200)", val)
}
p.MaxUDPPayloadSize = protocol.ByteCount(val)
case ackDelayExponentParameterID:
@ -347,8 +357,10 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// idle_timeout
b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
// max_packet_size
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize))
// max_udp_payload_size
if p.MaxUDPPayloadSize > 0 {
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(p.MaxUDPPayloadSize))
}
// max_ack_delay
// Only send it if is different from the default value.
if p.MaxAckDelay != protocol.DefaultMaxAckDelay {
@ -457,15 +469,15 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
}
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
version, err := quicvarint.Read(r)
func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error {
version, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.unmarshal(r, protocol.PerspectiveServer, true)
return p.unmarshal(b[l:], protocol.PerspectiveServer, true)
}
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.

View file

@ -24,6 +24,7 @@ type ConnectionTracer struct {
UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
AcknowledgedPacket func(EncryptionLevel, PacketNumber)
LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason)
UpdatedMTU func(mtu ByteCount, done bool)
UpdatedCongestionState func(CongestionState)
UpdatedPTOCount func(value uint32)
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
@ -168,6 +169,13 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra
}
}
},
UpdatedMTU: func(mtu ByteCount, done bool) {
for _, t := range tracers {
if t.UpdatedMTU != nil {
t.UpdatedMTU(mtu, done)
}
}
},
UpdatedCongestionState: func(state CongestionState) {
for _, t := range tracers {
if t.UpdatedCongestionState != nil {

View file

@ -1,19 +1,19 @@
package quic
import (
"net"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(maxPacketSize protocol.ByteCount)
Start()
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
@ -27,20 +27,6 @@ const (
mtuProbeDelay = 5
)
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if utils.IsIPv4(udpAddr.IP) {
maxSize = protocol.InitialPacketSizeIPv4
} else {
maxSize = protocol.InitialPacketSizeIPv6
}
}
return maxSize
}
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
@ -49,16 +35,25 @@ type mtuFinder struct {
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
current protocol.ByteCount
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
tracer *logging.ConnectionTracer
}
var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
func newMTUDiscoverer(
rttStats *utils.RTTStats,
start, max protocol.ByteCount,
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
return &mtuFinder{
inFlight: protocol.InvalidByteCount,
current: start,
max: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
}
}
@ -66,9 +61,15 @@ func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1
}
func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
func (f *mtuFinder) SetMax(max protocol.ByteCount) {
f.max = max
}
func (f *mtuFinder) Start() {
if f.max == protocol.InvalidByteCount {
panic("invalid")
}
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
f.max = maxPacketSize
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
@ -87,7 +88,7 @@ func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
f.inFlight = size
return ackhandler.Frame{
Frame: &wire.PingFrame{},
Handler: (*mtuFinderAckHandler)(f),
Handler: &mtuFinderAckHandler{f},
}, size
}
@ -95,7 +96,9 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.current
}
type mtuFinderAckHandler mtuFinder
type mtuFinderAckHandler struct {
*mtuFinder
}
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
@ -106,6 +109,9 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
}
h.inFlight = protocol.InvalidByteCount
h.current = size
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
h.mtuIncreased(size)
}

View file

@ -26,16 +26,16 @@ func Read(r io.ByteReader) (uint64, error) {
return 0, err
}
// the first two bits of the first byte encode the length
len := 1 << ((firstByte & 0xc0) >> 6)
l := 1 << ((firstByte & 0xc0) >> 6)
b1 := firstByte & (0xff - 0xc0)
if len == 1 {
if l == 1 {
return uint64(b1), nil
}
b2, err := r.ReadByte()
if err != nil {
return 0, err
}
if len == 2 {
if l == 2 {
return uint64(b2) + uint64(b1)<<8, nil
}
b3, err := r.ReadByte()
@ -46,7 +46,7 @@ func Read(r io.ByteReader) (uint64, error) {
if err != nil {
return 0, err
}
if len == 4 {
if l == 4 {
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
}
b5, err := r.ReadByte()
@ -68,6 +68,31 @@ func Read(r io.ByteReader) (uint64, error) {
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
}
// Parse reads a number in the QUIC varint format.
// It returns the number of bytes consumed.
func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) {
if len(b) == 0 {
return 0, 0, io.EOF
}
firstByte := b[0]
// the first two bits of the first byte encode the length
l := 1 << ((firstByte & 0xc0) >> 6)
if len(b) < l {
return 0, 0, io.ErrUnexpectedEOF
}
b0 := firstByte & (0xff - 0xc0)
if l == 1 {
return uint64(b0), 1, nil
}
if l == 2 {
return uint64(b[1]) + uint64(b0)<<8, 2, nil
}
if l == 4 {
return uint64(b[3]) + uint64(b[2])<<8 + uint64(b[1])<<16 + uint64(b0)<<24, 4, nil
}
return uint64(b[7]) + uint64(b[6])<<8 + uint64(b[5])<<16 + uint64(b[4])<<24 + uint64(b[3])<<32 + uint64(b[2])<<40 + uint64(b[1])<<48 + uint64(b0)<<56, 8, nil
}
// Append appends i in the QUIC varint format.
func Append(b []byte, i uint64) []byte {
if i <= maxVarInt1 {