Update deps

This commit is contained in:
Frank Denis 2024-08-07 12:53:44 +02:00
parent 56bc6e6a06
commit 7447fc4a0e
114 changed files with 1405 additions and 1003 deletions

View file

@ -4,6 +4,7 @@ main
mockgen_tmp.go
*.qtr
*.qlog
*.sqlog
*.txt
race.[0-9]*

View file

@ -16,7 +16,6 @@ import (
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/logutils"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
@ -25,15 +24,10 @@ import (
)
type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error)
UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
type streamGetter interface {
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
}
type streamManager interface {
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
@ -59,6 +53,7 @@ type cryptoStreamHandler interface {
GetSessionTicket() ([]byte, error)
NextEvent() handshake.Event
DiscardInitialKeys()
HandleMessage([]byte, protocol.EncryptionLevel) error
io.Closer
ConnectionState() handshake.ConnectionState
}
@ -144,8 +139,7 @@ type connection struct {
sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler
retransmissionQueue *retransmissionQueue
framer framer
windowUpdateQueue *windowUpdateQueue
framer *framer
connFlowController flowcontrol.ConnectionFlowController
tokenStoreKey string // only set for the client
tokenGenerator *handshake.TokenGenerator // only set for the server
@ -157,9 +151,9 @@ type connection struct {
maxPayloadSizeEstimate atomic.Uint32
initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream // only set for the server
initialStream *cryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
receivedPackets chan receivedPacket
@ -334,7 +328,7 @@ var newConnection = func(
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, s.oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, s.oneRTTStream)
return s
}
@ -438,7 +432,7 @@ var newClientConnection = func(
s.version,
)
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
@ -464,7 +458,6 @@ func (s *connection) preSetup() {
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
protocol.ByteCount(s.config.MaxConnectionReceiveWindow),
s.onHasConnectionWindowUpdate,
func(size protocol.ByteCount) bool {
if s.config.AllowConnectionWindowIncrease == nil {
return true
@ -478,12 +471,13 @@ func (s *connection) preSetup() {
s.streamsMap = newStreamsMap(
s.ctx,
s,
s.queueControlFrame,
s.newFlowController,
uint64(s.config.MaxIncomingStreams),
uint64(s.config.MaxIncomingUniStreams),
s.perspective,
)
s.framer = newFramer(s.streamsMap)
s.framer = newFramer()
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
@ -493,7 +487,6 @@ func (s *connection) preSetup() {
s.lastPacketReceivedTime = now
s.creationTime = now
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame)
s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger)
s.connState.Version = s.version
}
@ -706,10 +699,10 @@ func (s *connection) nextKeepAliveTime() time.Time {
func (s *connection) maybeResetTimer() {
var deadline time.Time
if !s.handshakeComplete {
deadline = utils.MinTime(
s.creationTime.Add(s.config.handshakeTimeout()),
s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout),
)
deadline = s.creationTime.Add(s.config.handshakeTimeout())
if t := s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout); t.Before(deadline) {
deadline = t
}
} else {
if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() {
deadline = keepAliveTime
@ -727,7 +720,11 @@ func (s *connection) maybeResetTimer() {
}
func (s *connection) idleTimeoutStartTime() time.Time {
return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime)
startTime := s.lastPacketReceivedTime
if t := s.firstAckElicitingPacketAfterIdleSentTime; t.After(startTime) {
startTime = t
}
return startTime
}
func (s *connection) handleHandshakeComplete() error {
@ -803,13 +800,11 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
data := rp.data
p := rp
for len(data) > 0 {
var destConnID protocol.ConnectionID
if counter > 0 {
p = *(p.Clone())
p.data = data
var err error
destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen)
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
@ -869,7 +864,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
if counter > 0 {
p.buffer.Split()
}
processed = s.handleShortHeaderPacket(p, destConnID)
processed = s.handleShortHeaderPacket(p)
break
}
}
@ -878,7 +873,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
return processed
}
func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protocol.ConnectionID) bool {
func (s *connection) handleShortHeaderPacket(p receivedPacket) bool {
var wasQueued bool
defer func() {
@ -888,6 +883,11 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc
}
}()
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, protocol.ByteCount(len(p.data)), logging.PacketDropHeaderParseError)
return false
}
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
@ -961,7 +961,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
return false
}
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data, s.version)
packet, err := s.unpacker.UnpackLongHeader(hdr, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false
@ -1261,7 +1261,7 @@ func (s *connection) handleFrames(
isAckEliciting = true
}
if log != nil {
frames = append(frames, logutils.ConvertFrame(frame))
frames = append(frames, toLoggingFrame(frame))
}
// An error occurred handling a previous frame.
// Don't handle the current frame.
@ -1378,6 +1378,15 @@ func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protoco
if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil {
return err
}
for {
data := s.cryptoStreamManager.GetCryptoData(encLevel)
if data == nil {
break
}
if err := s.cryptoStreamHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
return s.handleHandshakeEvents()
}
@ -1668,10 +1677,8 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro
s.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT:
s.streamsMap.ResetFor0RTT()
if err := s.connFlowController.Reset(); err != nil {
return err
}
return s.framer.Handle0RTTRejection()
s.framer.Handle0RTTRejection()
return s.connFlowController.Reset()
}
return s.cryptoStreamManager.Drop(encLevel)
}
@ -1758,7 +1765,10 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters)
func (s *connection) applyTransportParameters() {
params := s.peerParams
// Our local idle timeout will always be > 0.
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
s.idleTimeout = s.config.MaxIdleTimeout
if s.idleTimeout > 0 && params.MaxIdleTimeout < s.idleTimeout {
s.idleTimeout = params.MaxIdleTimeout
}
s.keepAliveInterval = min(s.config.KeepAlivePeriod, min(s.idleTimeout/2, protocol.MaxKeepAliveInterval))
s.streamsMap.UpdateLimits(params)
s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
@ -1866,7 +1876,9 @@ func (s *connection) sendPackets(now time.Time) error {
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset})
}
s.windowUpdateQueue.QueueAll()
if offset := s.connFlowController.GetWindowUpdate(); offset > 0 {
s.framer.QueueControlFrame(&wire.MaxDataFrame{MaximumData: offset})
}
if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
s.queueControlFrame(cf)
}
@ -2157,128 +2169,6 @@ func (s *connection) maxPacketSize() protocol.ByteCount {
return s.mtuDiscoverer.CurrentSize()
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
p.header.Log(s.logger)
if p.ack != nil {
wire.LogFrame(s.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
for _, frame := range p.streamFrames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil {
frames := make([]logging.Frame, 0, len(p.frames))
for _, f := range p.frames {
frames = append(frames, logutils.ConvertFrame(f.Frame))
}
for _, f := range p.streamFrames {
frames = append(frames, logutils.ConvertFrame(f.Frame))
}
var ack *logging.AckFrame
if p.ack != nil {
ack = logutils.ConvertAckFrame(p.ack)
}
s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
}
}
func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame,
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount,
isCoalesced bool,
) {
if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
}
// quic-go logging
if s.logger.Debug() {
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp)
if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true)
}
for _, f := range frames {
wire.LogFrame(s.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(s.logger, f.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil {
fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames {
fs = append(fs, logutils.ConvertFrame(f.Frame))
}
for _, f := range streamFrames {
fs = append(fs, logutils.ConvertFrame(f.Frame))
}
var ack *logging.AckFrame
if ackFrame != nil {
ack = logutils.ConvertAckFrame(ackFrame)
}
s.tracer.SentShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: kp,
},
size,
ecn,
ack,
fs,
)
}
}
func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
s.logShortHeaderPacket(
packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length,
false,
)
return
}
if len(packet.longHdrPackets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
} else {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel())
}
}
for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p, ecn)
}
if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
}
}
// AcceptStream returns the next stream openend by the peer
func (s *connection) AcceptStream(ctx context.Context) (Stream, error) {
return s.streamsMap.AcceptStream(ctx)
@ -2320,7 +2210,6 @@ func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamF
protocol.ByteCount(s.config.InitialStreamReceiveWindow),
protocol.ByteCount(s.config.MaxStreamReceiveWindow),
initialSendWindow,
s.onHasStreamWindowUpdate,
s.rttStats,
s.logger,
)
@ -2359,18 +2248,13 @@ func (s *connection) queueControlFrame(f wire.Frame) {
s.scheduleSending()
}
func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) {
s.windowUpdateQueue.AddStream(id)
func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) {
s.framer.AddActiveStream(id, str)
s.scheduleSending()
}
func (s *connection) onHasConnectionWindowUpdate() {
s.windowUpdateQueue.AddConnection()
s.scheduleSending()
}
func (s *connection) onHasStreamData(id protocol.StreamID) {
s.framer.AddActiveStream(id)
func (s *connection) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) {
s.framer.AddStreamWithControlFrames(id, str)
s.scheduleSending()
}
@ -2378,6 +2262,7 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.closeLocal(err)
}
s.framer.RemoveActiveStream(id)
}
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {

173
vendor/github.com/quic-go/quic-go/connection_logging.go generated vendored Normal file
View file

@ -0,0 +1,173 @@
package quic
import (
"slices"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func toLoggingFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.AckFrame:
// We use a pool for ACK frames.
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
return toLoggingAckFrame(f)
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}
func toLoggingAckFrame(f *wire.AckFrame) *logging.AckFrame {
ack := &logging.AckFrame{
AckRanges: slices.Clone(f.AckRanges),
DelayTime: f.DelayTime,
ECNCE: f.ECNCE,
ECT0: f.ECT0,
ECT1: f.ECT1,
}
return ack
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
p.header.Log(s.logger)
if p.ack != nil {
wire.LogFrame(s.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
for _, frame := range p.streamFrames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil {
frames := make([]logging.Frame, 0, len(p.frames))
for _, f := range p.frames {
frames = append(frames, toLoggingFrame(f.Frame))
}
for _, f := range p.streamFrames {
frames = append(frames, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if p.ack != nil {
ack = toLoggingAckFrame(p.ack)
}
s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
}
}
func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame,
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount,
isCoalesced bool,
) {
if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
}
// quic-go logging
if s.logger.Debug() {
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp)
if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true)
}
for _, f := range frames {
wire.LogFrame(s.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(s.logger, f.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil {
fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames {
fs = append(fs, toLoggingFrame(f.Frame))
}
for _, f := range streamFrames {
fs = append(fs, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if ackFrame != nil {
ack = toLoggingAckFrame(ackFrame)
}
s.tracer.SentShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: kp,
},
size,
ecn,
ack,
fs,
)
}
}
func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
s.logShortHeaderPacket(
packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length,
false,
)
return
}
if len(packet.longHdrPackets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
} else {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel())
}
}
for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p, ecn)
}
if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
}
}

View file

@ -2,27 +2,14 @@ package quic
import (
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoStream interface {
// for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte
Finish() error
// for sending data
io.Writer
HasData() bool
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
}
type cryptoStreamImpl struct {
queue *frameSorter
msgBuf []byte
type cryptoStream struct {
queue frameSorter
highestOffset protocol.ByteCount
finished bool
@ -31,11 +18,11 @@ type cryptoStreamImpl struct {
writeBuf []byte
}
func newCryptoStream() cryptoStream {
return &cryptoStreamImpl{queue: newFrameSorter()}
func newCryptoStream() *cryptoStream {
return &cryptoStream{queue: *newFrameSorter()}
}
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
func (s *cryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return &qerr.TransportError{
@ -56,26 +43,16 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
return nil
}
s.highestOffset = max(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
return err
}
for {
_, data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
return s.queue.Push(f.Data, f.Offset, nil)
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
b := s.msgBuf
s.msgBuf = nil
return b
func (s *cryptoStream) GetCryptoData() []byte {
_, data, _ := s.queue.Pop()
return data
}
func (s *cryptoStreamImpl) Finish() error {
func (s *cryptoStream) Finish() error {
if s.queue.HasMoreData() {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
@ -87,16 +64,16 @@ func (s *cryptoStreamImpl) Finish() error {
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
func (s *cryptoStream) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *cryptoStreamImpl) HasData() bool {
func (s *cryptoStream) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
func (s *cryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
f.Data = s.writeBuf[:n]

View file

@ -3,32 +3,22 @@ package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() handshake.Event
}
type cryptoStreamManager struct {
cryptoHandler cryptoDataHandler
initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream
initialStream *cryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream
}
func newCryptoStreamManager(
cryptoHandler cryptoDataHandler,
initialStream cryptoStream,
handshakeStream cryptoStream,
oneRTTStream cryptoStream,
initialStream *cryptoStream,
handshakeStream *cryptoStream,
oneRTTStream *cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
cryptoHandler: cryptoHandler,
initialStream: initialStream,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
@ -36,7 +26,7 @@ func newCryptoStreamManager(
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str cryptoStream
var str *cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
@ -48,18 +38,23 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return err
}
for {
data := str.GetCryptoData()
if data == nil {
return nil
}
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
return err
}
return str.HandleCryptoFrame(frame)
}
func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte {
var str *cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel))
}
return str.GetCryptoData()
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {

View file

@ -1,7 +1,7 @@
package quic
import (
"errors"
"slices"
"sync"
"github.com/quic-go/quic-go/internal/ackhandler"
@ -11,37 +11,25 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
type framer interface {
HasData() bool
QueueControlFrame(wire.Frame)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
QueuedTooManyControlFrames() bool
}
const (
maxPathResponses = 256
maxControlFrames = 16 << 10
)
type framerI struct {
// This is the largest possible size of a stream-related control frame
// (which is the RESET_STREAM frame).
const maxStreamControlFrameSize = 25
type streamControlFrameGetter interface {
getControlFrame() (_ ackhandler.Frame, ok, hasMore bool)
}
type framer struct {
mutex sync.Mutex
streamGetter streamGetter
activeStreams map[protocol.StreamID]struct{}
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
activeStreams map[protocol.StreamID]sendStreamI
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
@ -49,16 +37,14 @@ type framerI struct {
queuedTooManyControlFrames bool
}
var _ framer = &framerI{}
func newFramer(streamGetter streamGetter) framer {
return &framerI{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
func newFramer() *framer {
return &framer{
activeStreams: make(map[protocol.StreamID]sendStreamI),
streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter),
}
}
func (f *framerI) HasData() bool {
func (f *framer) HasData() bool {
f.mutex.Lock()
hasData := !f.streamQueue.Empty()
f.mutex.Unlock()
@ -67,10 +53,10 @@ func (f *framerI) HasData() bool {
}
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
return len(f.controlFrames) > 0 || len(f.pathResponses) > 0
return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0
}
func (f *framerI) QueueControlFrame(frame wire.Frame) {
func (f *framer) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
@ -92,7 +78,7 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrames = append(f.controlFrames, frame)
}
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
func (f *framer) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) {
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
@ -108,6 +94,29 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
}
}
// add stream-related control frames
for id, str := range f.streamsWithControlFrames {
start:
remainingLen := maxLen - length
if remainingLen <= maxStreamControlFrameSize {
break
}
fr, ok, hasMore := str.getControlFrame()
if !hasMore {
delete(f.streamsWithControlFrames, id)
}
if !ok {
continue
}
frames = append(frames, fr)
length += fr.Frame.Length(v)
if hasMore {
// It is rare that a stream has more than one control frame to queue.
// We don't want to spawn another loop for just to cover that case.
goto start
}
}
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(v)
@ -118,27 +127,51 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
return frames, length
}
func (f *framerI) QueuedTooManyControlFrames() bool {
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
func (f *framer) QueuedTooManyControlFrames() bool {
return f.queuedTooManyControlFrames
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
func (f *framer) AddActiveStream(id protocol.StreamID, str sendStreamI) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue.PushBack(id)
f.activeStreams[id] = struct{}{}
f.activeStreams[id] = str
}
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) {
f.controlFrameMutex.Lock()
if _, ok := f.streamsWithControlFrames[id]; !ok {
f.streamsWithControlFrames[id] = str
}
f.controlFrameMutex.Unlock()
}
// RemoveActiveStream is called when a stream completes.
func (f *framer) RemoveActiveStream(id protocol.StreamID) {
f.mutex.Lock()
delete(f.activeStreams, id)
// We don't delete the stream from the streamQueue,
// since we'd have to iterate over the ringbuffer.
// Instead, we check if the stream is still in activeStreams in AppendStreamFrames.
f.mutex.Unlock()
}
func (f *framer) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount
f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
// pop STREAM frames, until less than 128 bytes are left in the packet
numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize+length > maxLen {
@ -147,10 +180,9 @@ func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen pro
id := f.streamQueue.PopFront()
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
str, ok := f.activeStreams[id]
// The stream might have been removed after being enqueued.
if !ok {
continue
}
remainingLen := maxLen - length
@ -165,7 +197,7 @@ func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen pro
delete(f.activeStreams, id)
}
// The frame can be "nil"
// * if the receiveStream was canceled after it said it had data
// * if the stream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame
if !ok {
continue
@ -183,11 +215,12 @@ func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen pro
return frames, length
}
func (f *framerI) Handle0RTTRejection() error {
func (f *framer) Handle0RTTRejection() {
f.mutex.Lock()
defer f.mutex.Unlock()
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
f.streamQueue.Clear()
for id := range f.activeStreams {
delete(f.activeStreams, id)
@ -195,16 +228,13 @@ func (f *framerI) Handle0RTTRejection() error {
var j int
for i, frame := range f.controlFrames {
switch frame.(type) {
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame,
*wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
continue
default:
f.controlFrames[j] = f.controlFrames[i]
j++
}
}
f.controlFrames = f.controlFrames[:j]
f.controlFrameMutex.Unlock()
return nil
f.controlFrames = slices.Delete(f.controlFrames, j, len(f.controlFrames))
}

View file

@ -88,6 +88,7 @@ func (c *SingleDestinationRoundTripper) init() {
c.EnableDatagrams,
protocol.PerspectiveClient,
c.Logger,
0,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {

View file

@ -7,6 +7,7 @@ import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
@ -51,6 +52,9 @@ type connection struct {
settings *Settings
receivedSettings chan struct{}
idleTimeout time.Duration
idleTimer *time.Timer
}
func newConnection(
@ -59,17 +63,27 @@ func newConnection(
enableDatagrams bool,
perspective protocol.Perspective,
logger *slog.Logger,
idleTimeout time.Duration,
) *connection {
return &connection{
c := &connection{
ctx: ctx,
Connection: quicConn,
perspective: perspective,
logger: logger,
idleTimeout: idleTimeout,
enableDatagrams: enableDatagrams,
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
receivedSettings: make(chan struct{}),
streams: make(map[protocol.StreamID]*datagrammer),
}
if idleTimeout > 0 {
c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer)
}
return c
}
func (c *connection) onIdleTimer() {
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout")
}
func (c *connection) clearStream(id quic.StreamID) {
@ -77,6 +91,9 @@ func (c *connection) clearStream(id quic.StreamID) {
defer c.streamMx.Unlock()
delete(c.streams, id)
if c.idleTimeout > 0 && len(c.streams) == 0 {
c.idleTimer.Reset(c.idleTimeout)
}
}
func (c *connection) openRequestStream(
@ -109,12 +126,24 @@ func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagramme
strID := str.StreamID()
c.streamMx.Lock()
c.streams[strID] = datagrams
if c.idleTimeout > 0 {
if len(c.streams) == 1 {
c.idleTimer.Stop()
}
}
c.streamMx.Unlock()
str = newStateTrackingStream(str, c, datagrams)
}
return str, datagrams, nil
}
func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string) error {
if c.idleTimer != nil {
c.idleTimer.Stop()
}
return c.Connection.CloseWithError(code, msg)
}
func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
var (
rcvdControlStr atomic.Bool

View file

@ -198,6 +198,12 @@ type Server struct {
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
// IdleTimeout specifies how long until idle clients connection should be
// closed. Idle refers only to the HTTP/3 layer, activity at the QUIC layer
// like PING frames are not considered.
// If zero or negative, there is no timeout.
IdleTimeout time.Duration
// ConnContext optionally specifies a function that modifies the context used for a new connection c.
// The provided ctx has a ServerContextKey value.
ConnContext func(ctx context.Context, c quic.Connection) context.Context
@ -216,7 +222,13 @@ type Server struct {
//
// If s.Addr is blank, ":https" is used.
func (s *Server) ListenAndServe() error {
return s.serveConn(s.TLSConfig, nil)
ln, err := s.setupListenerForConn(s.TLSConfig, nil)
if err != nil {
return err
}
defer s.removeListener(&ln)
return s.serveListener(ln)
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
@ -231,17 +243,26 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
ln, err := s.setupListenerForConn(&tls.Config{Certificates: certs}, nil)
if err != nil {
return err
}
return s.serveConn(config, nil)
defer s.removeListener(&ln)
return s.serveListener(ln)
}
// Serve an existing UDP connection.
// It is possible to reuse the same connection for outgoing connections.
// Closing the server does not close the connection.
func (s *Server) Serve(conn net.PacketConn) error {
return s.serveConn(s.TLSConfig, conn)
ln, err := s.setupListenerForConn(s.TLSConfig, conn)
if err != nil {
return err
}
defer s.removeListener(&ln)
return s.serveListener(ln)
}
// ServeQUICConn serves a single QUIC connection.
@ -255,10 +276,18 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error {
// Closing the server does close the listener.
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
func (s *Server) ServeListener(ln QUICEarlyListener) error {
s.mutex.Lock()
if err := s.addListener(&ln); err != nil {
s.mutex.Unlock()
return err
}
s.mutex.Unlock()
defer s.removeListener(&ln)
return s.serveListener(ln)
}
func (s *Server) serveListener(ln QUICEarlyListener) error {
for {
conn, err := ln.Accept(context.Background())
if err == quic.ErrServerClosed {
@ -279,16 +308,9 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error {
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
func (s *Server) setupListenerForConn(tlsConf *tls.Config, conn net.PacketConn) (QUICEarlyListener, error) {
if tlsConf == nil {
return errServerWithoutTLSConfig
}
s.mutex.Lock()
closed := s.closed
s.mutex.Unlock()
if closed {
return http.ErrServerClosed
return nil, errServerWithoutTLSConfig
}
baseConf := ConfigureTLSConfig(tlsConf)
@ -302,6 +324,13 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
quicConf.EnableDatagrams = true
}
s.mutex.Lock()
defer s.mutex.Unlock()
closed := s.closed
if closed {
return nil, http.ErrServerClosed
}
var ln QUICEarlyListener
var err error
if conn == nil {
@ -314,9 +343,12 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
ln, err = quicListen(conn, baseConf, quicConf)
}
if err != nil {
return err
return nil, err
}
return s.ServeListener(ln)
if err := s.addListener(&ln); err != nil {
return nil, err
}
return ln, nil
}
func extractPort(addr string) (int, error) {
@ -392,9 +424,6 @@ func (s *Server) generateAltSvcHeader() {
// call trackListener via Serve and can track+defer untrack the same pointer to
// local variable there. We never need to compare a Listener from another caller.
func (s *Server) addListener(l *QUICEarlyListener) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return http.ErrServerClosed
}
@ -456,8 +485,10 @@ func (s *Server) handleConn(conn quic.Connection) error {
s.EnableDatagrams,
protocol.PerspectiveServer,
s.Logger,
s.IdleTimeout,
)
go hconn.HandleUnidirectionalStreams(s.UniStreamHijacker)
// Process all requests immediately.
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
for {

View file

@ -89,8 +89,8 @@ type ReceiveStream interface {
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
// If the stream was canceled by the peer, the error is a StreamError and
// Remote == true.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
io.Reader
@ -113,8 +113,8 @@ type SendStream interface {
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
// If the stream was canceled by the peer, the error is a StreamError and
// Remote == true.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
io.Writer
@ -150,7 +150,7 @@ type SendStream interface {
// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error)
// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error)
// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error)
// * StatelessResetError: when we receive a stateless reset
// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers
type Connection interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
@ -163,28 +163,29 @@ type Connection interface {
AcceptUniStream(context.Context) (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, err.Temporary() will be true.
// If the connection was closed due to a timeout, Timeout() will be true.
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// When reaching the peer's stream limit, it is not possible to open a new stream until the
// peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until a new stream can be opened.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenStreamSync(context.Context) (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, Temporary() will be true.
// If the connection was closed due to a timeout, Timeout() will be true.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// When reaching the peer's stream limit, it is not possible to open a new stream until the
// peer raises the stream limit. In that case, a StreamLimitReachedError is returned.
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
OpenUniStreamSync(context.Context) (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr

View file

@ -1,10 +1,9 @@
package ackhandler
import (
"sync"
"slices"
"github.com/quic-go/quic-go/internal/protocol"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
"github.com/quic-go/quic-go/internal/wire"
)
@ -14,25 +13,17 @@ type interval struct {
End protocol.PacketNumber
}
var intervalElementPool sync.Pool
func init() {
intervalElementPool = *list.NewPool[interval]()
}
// The receivedPacketHistory stores if a packet number has already been received.
// It generates ACK ranges which can be used to assemble an ACK frame.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *list.List[interval]
ranges []interval // maximum length: protocol.MaxNumAckRanges
deletedBelow protocol.PacketNumber
}
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: list.NewWithPool[interval](&intervalElementPool),
}
return &receivedPacketHistory{}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
@ -41,58 +32,54 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /*
if p < h.deletedBelow {
return false
}
isNew := h.addToRanges(p)
h.maybeDeleteOldRanges()
// Delete old ranges, if we're tracking too many of them.
// This is a DoS defense against a peer that sends us too many gaps.
if len(h.ranges) > protocol.MaxNumAckRanges {
h.ranges = slices.Delete(h.ranges, 0, len(h.ranges)-protocol.MaxNumAckRanges)
}
return isNew
}
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
if h.ranges.Len() == 0 {
h.ranges.PushBack(interval{Start: p, End: p})
if len(h.ranges) == 0 {
h.ranges = append(h.ranges, interval{Start: p, End: p})
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
for i := len(h.ranges) - 1; i >= 0; i-- {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
if p >= h.ranges[i].Start && p <= h.ranges[i].End {
return false
}
if el.Value.End == p-1 { // extend a range at the end
el.Value.End = p
if h.ranges[i].End == p-1 { // extend a range at the end
h.ranges[i].End = p
return true
}
if el.Value.Start == p+1 { // extend a range at the beginning
el.Value.Start = p
if h.ranges[i].Start == p+1 { // extend a range at the beginning
h.ranges[i].Start = p
prev := el.Prev()
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
if i > 0 && h.ranges[i-1].End+1 == h.ranges[i].Start { // merge two ranges
h.ranges[i-1].End = h.ranges[i].End
h.ranges = slices.Delete(h.ranges, i, i+1)
}
return true
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(interval{Start: p, End: p}, el)
// create a new range after the current one
if p > h.ranges[i].End {
h.ranges = slices.Insert(h.ranges, i+1, interval{Start: p, End: p})
return true
}
}
// create a new range at the beginning
h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front())
h.ranges = slices.Insert(h.ranges, 0, interval{Start: p, End: p})
return true
}
// Delete old ranges, if we're tracking more than 500 of them.
// This is a DoS defense against a peer that sends us too many gaps.
func (h *receivedPacketHistory) maybeDeleteOldRanges() {
for h.ranges.Len() > protocol.MaxNumAckRanges {
h.ranges.Remove(h.ranges.Front())
}
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p < h.deletedBelow {
@ -100,37 +87,39 @@ func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
}
h.deletedBelow = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if len(h.ranges) == 0 {
return
}
if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
return
idx := -1
for i := 0; i < len(h.ranges); i++ {
if h.ranges[i].End < p { // delete a whole range
idx = i
} else if p > h.ranges[i].Start && p <= h.ranges[i].End {
h.ranges[i].Start = p
break
} else { // no ranges affected. Nothing to do
return
break
}
}
if idx >= 0 {
h.ranges = slices.Delete(h.ranges, 0, idx+1)
}
}
// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
if h.ranges.Len() > 0 {
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End})
}
for i := len(h.ranges) - 1; i >= 0; i-- {
ackRanges = append(ackRanges, wire.AckRange{Smallest: h.ranges[i].Start, Largest: h.ranges[i].End})
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.Smallest = r.Start
ackRange.Largest = r.End
if len(h.ranges) > 0 {
ackRange.Smallest = h.ranges[len(h.ranges)-1].Start
ackRange.Largest = h.ranges[len(h.ranges)-1].End
}
return ackRange
}
@ -139,11 +128,12 @@ func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber)
if p < h.deletedBelow {
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
if p > el.Value.End {
// Iterating over the slices is faster than using a binary search (using slices.BinarySearchFunc).
for i := len(h.ranges) - 1; i >= 0; i-- {
if p > h.ranges[i].End {
return false
}
if p <= el.Value.End && p >= el.Value.Start {
if p <= h.ranges[i].End && p >= h.ranges[i].Start {
return true
}
}

View file

@ -28,7 +28,7 @@ const (
)
type packetNumberSpace struct {
history *sentPacketHistory
history sentPacketHistory
pns packetNumberGenerator
lossTime time.Time
@ -38,15 +38,15 @@ type packetNumberSpace struct {
largestSent protocol.PacketNumber
}
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
func newPacketNumberSpace(initialPN protocol.PacketNumber, isAppData bool) *packetNumberSpace {
var pns packetNumberGenerator
if skipPNs {
if isAppData {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
} else {
pns = newSequentialPacketNumberGenerator(initialPN)
}
return &packetNumberSpace{
history: newSentPacketHistory(),
history: *newSentPacketHistory(isAppData),
pns: pns,
largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,

View file

@ -14,11 +14,16 @@ type sentPacketHistory struct {
highestPacketNumber protocol.PacketNumber
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packets: make([]*packet, 0, 32),
func newSentPacketHistory(isAppData bool) *sentPacketHistory {
h := &sentPacketHistory{
highestPacketNumber: protocol.InvalidPacketNumber,
}
if isAppData {
h.packets = make([]*packet, 0, 32)
} else {
h.packets = make([]*packet, 0, 6)
}
return h
}
func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {

View file

@ -12,8 +12,6 @@ import (
type connectionFlowController struct {
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
@ -23,7 +21,6 @@ var _ ConnectionFlowController = &connectionFlowController{}
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
allowWindowIncrease func(size protocol.ByteCount) bool,
rttStats *utils.RTTStats,
logger utils.Logger,
@ -37,7 +34,6 @@ func NewConnectionFlowController(
allowWindowIncrease: allowWindowIncrease,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
@ -63,18 +59,14 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()

View file

@ -8,14 +8,13 @@ type flowController interface {
UpdateSendWindow(protocol.ByteCount) (updated bool)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
AddBytesRead(protocol.ByteCount) (shouldQueueWindowUpdate bool)
// UpdateHighestReceived is called when a new highest offset is received
// final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
@ -23,12 +22,15 @@ type StreamFlowController interface {
// Abandon is called when reading from the stream is aborted early,
// and there won't be any further calls to AddBytesRead.
Abandon()
IsNewlyBlocked() bool
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
AddBytesRead(protocol.ByteCount)
Reset() error
IsNewlyBlocked() (bool, protocol.ByteCount)
}
type connectionFlowControllerI interface {

View file

@ -13,8 +13,6 @@ type streamFlowController struct {
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
receivedFinalOffset bool
@ -29,14 +27,12 @@ func NewStreamFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *utils.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
@ -97,15 +93,13 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount,
return c.connection.IncrementHighestReceived(increment)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (shouldQueueWindowUpdate bool) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
shouldQueueWindowUpdate = c.shouldQueueWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
c.connection.AddBytesRead(n)
return
}
func (c *streamFlowController) Abandon() {
@ -127,6 +121,11 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) IsNewlyBlocked() bool {
blocked, _ := c.baseFlowController.IsNewlyBlocked()
return blocked
}
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
return !c.receivedFinalOffset && c.hasWindowUpdate()
}

View file

@ -253,7 +253,10 @@ func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
h.handshakeComplete()
return false, nil
default:
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
// Unknown events should be ignored.
// crypto/tls will ensure that this is safe to do.
// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
return false, nil
}
}
@ -621,8 +624,7 @@ func (h *cryptoSetup) ConnectionState() ConnectionState {
}
func wrapError(err error) error {
// alert 80 is an internal error
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) {
return qerr.NewLocalCryptoError(uint8(alertErr), err)
}
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}

View file

@ -46,7 +46,7 @@ type TokenGenerator struct {
// NewTokenGenerator initializes a new TokenGenerator
func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
return &TokenGenerator{tokenProtector: newTokenProtector(key)}
return &TokenGenerator{tokenProtector: *newTokenProtector(key)}
}
// NewRetryToken generates a new token for a Retry for a given source address

View file

@ -14,28 +14,20 @@ import (
// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
type TokenProtectorKey [32]byte
// TokenProtector is used to create and verify a token
type tokenProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const tokenNonceSize = 32
// tokenProtector is used to create and verify a token
type tokenProtectorImpl struct {
type tokenProtector struct {
key TokenProtectorKey
}
// newTokenProtector creates a source for source address tokens
func newTokenProtector(key TokenProtectorKey) tokenProtector {
return &tokenProtectorImpl{key: key}
func newTokenProtector(key TokenProtectorKey) *tokenProtector {
return &tokenProtector{key: key}
}
// NewToken encodes data into a new token.
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
func (s *tokenProtector) NewToken(data []byte) ([]byte, error) {
var nonce [tokenNonceSize]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, err
@ -48,7 +40,7 @@ func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
}
// DecodeToken decodes a token.
func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
func (s *tokenProtector) DecodeToken(p []byte) ([]byte, error) {
if len(p) < tokenNonceSize {
return nil, fmt.Errorf("token too short: %d", len(p))
}
@ -60,7 +52,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
}
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
func (s *tokenProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {

View file

@ -1,50 +0,0 @@
package logutils
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func ConvertFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.AckFrame:
// We use a pool for ACK frames.
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
return ConvertAckFrame(f)
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}
func ConvertAckFrame(f *wire.AckFrame) *logging.AckFrame {
ranges := make([]wire.AckRange, 0, len(f.AckRanges))
ranges = append(ranges, f.AckRanges...)
ack := &logging.AckFrame{
AckRanges: ranges,
DelayTime: f.DelayTime,
ECNCE: f.ECNCE,
ECT0: f.ECT0,
ECT1: f.ECT1,
}
return ack
}

View file

@ -1,21 +0,0 @@
package utils
import (
"bytes"
"io"
)
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
Uint32([]byte) uint32
Uint24([]byte) uint32
Uint16([]byte) uint16
ReadUint32(io.ByteReader) (uint32, error)
ReadUint24(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
}

View file

@ -1,103 +0,0 @@
package utils
import (
"bytes"
"encoding/binary"
"io"
)
// BigEndian is the big-endian implementation of ByteOrder.
var BigEndian ByteOrder = bigEndian{}
type bigEndian struct{}
var _ ByteOrder = &bigEndian{}
// ReadUintN reads N bytes
func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << ((length - 1 - i) * 8)
}
return res, nil
}
// ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint24 reads a uint24
func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) {
var b1, b2, b3 uint8
var err error
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil
}
// ReadUint16 reads a uint16
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
func (bigEndian) Uint32(b []byte) uint32 {
return binary.BigEndian.Uint32(b)
}
func (bigEndian) Uint24(b []byte) uint32 {
_ = b[2] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
}
func (bigEndian) Uint16(b []byte) uint16 {
return binary.BigEndian.Uint16(b)
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes a uint24
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})
}

View file

@ -1,10 +0,0 @@
package utils
import "net"
func IsIPv4(ip net.IP) bool {
// If ip is not an IPv4 address, To4 returns nil.
// Note that there might be some corner cases, where this is not correct.
// See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
return ip.To4() != nil
}

View file

@ -1,36 +0,0 @@
package utils
import (
"math"
"time"
)
// InfDuration is a duration of infinite length
const InfDuration = time.Duration(math.MaxInt64)
// MinNonZeroDuration return the minimum duration that's not zero.
func MinNonZeroDuration(a, b time.Duration) time.Duration {
if a == 0 {
return b
}
if b == 0 {
return a
}
return min(a, b)
}
// MinTime returns the earlier time
func MinTime(a, b time.Time) time.Time {
if a.After(b) {
return b
}
return a
}
// MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}

View file

@ -64,7 +64,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == InfDuration || sendDelta <= 0 {
if sendDelta <= 0 {
return
}

View file

@ -2,11 +2,11 @@ package wire
import (
"errors"
"math"
"sort"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@ -40,7 +40,7 @@ func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 {
// If the delay time overflows, set it to the maximum encode-able value.
delayTime = utils.InfDuration
delayTime = time.Duration(math.MaxInt64)
}
frame.DelayTime = delayTime

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@ -32,66 +31,23 @@ type ExtendedHeader struct {
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) {
startLen := b.Len()
func (h *ExtendedHeader) parse(data []byte) (bool /* reserved bits valid */, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return false, err
}
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
reservedBitsValid, err := h.parseLongHeader(b, v)
if err != nil {
return false, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return reservedBitsValid, err
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0xc != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.typeByte = data[0]
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
n, err := b.ReadByte()
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen2:
n, err := utils.BigEndian.ReadUint16(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen3:
n, err := utils.BigEndian.ReadUint24(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen4:
n, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
if protocol.ByteCount(len(data)) < h.Header.ParsedLen()+protocol.ByteCount(h.PacketNumberLen) {
return false, io.EOF
}
return nil
pn, err := readPacketNumber(data[h.Header.ParsedLen():], h.PacketNumberLen)
if err != nil {
return true, nil
}
h.PacketNumber = pn
reservedBitsValid := h.typeByte&0xc == 0
h.parsedLen = h.Header.ParsedLen() + protocol.ByteCount(h.PacketNumberLen)
return reservedBitsValid, err
}
// Append appends the Header.

View file

@ -1,7 +1,6 @@
package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@ -40,37 +39,27 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
// https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
// This function should only be called on Long Header packets for which we don't support the version.
func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
r := bytes.NewReader(data)
remaining := r.Len()
src, dest, err := parseArbitraryLenConnectionIDs(r)
return remaining - r.Len(), src, dest, err
}
func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) {
r.Seek(5, io.SeekStart) // skip first byte and version field
destConnIDLen, err := r.ReadByte()
if err != nil {
return nil, nil, err
startLen := len(data)
if len(data) < 6 {
return 0, nil, nil, io.EOF
}
data = data[5:] // skip first byte and version field
destConnIDLen := data[0]
data = data[1:]
destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
if _, err := io.ReadFull(r, destConnID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, nil, err
if len(data) < int(destConnIDLen)+1 {
return 0, nil, nil, io.EOF
}
srcConnIDLen, err := r.ReadByte()
if err != nil {
return nil, nil, err
copy(destConnID, data)
data = data[destConnIDLen:]
srcConnIDLen := data[0]
data = data[1:]
if len(data) < int(srcConnIDLen) {
return 0, nil, nil, io.EOF
}
srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
if _, err := io.ReadFull(r, srcConnID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, nil, err
}
return destConnID, srcConnID, nil
copy(srcConnID, data)
return startLen - len(data) + int(srcConnIDLen), destConnID, srcConnID, nil
}
func IsPotentialQUICPacket(firstByte byte) bool {
@ -274,9 +263,9 @@ func (h *Header) ParsedLen() protocol.ByteCount {
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) {
func (h *Header) ParseExtended(data []byte) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
reservedBitsValid, err := extHdr.parse(data)
if err != nil {
return nil, err
}
@ -294,3 +283,20 @@ func (h *Header) toExtendedHeader() *ExtendedHeader {
func (h *Header) PacketType() string {
return h.Type.String()
}
func readPacketNumber(data []byte, pnLen protocol.PacketNumberLen) (protocol.PacketNumber, error) {
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[0])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(binary.BigEndian.Uint16(data[:2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(uint32(data[2]) + uint32(data[1])<<8 + uint32(data[0])<<16)
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(binary.BigEndian.Uint32(data[:4]))
default:
return 0, fmt.Errorf("invalid packet number length: %d", pnLen)
}
return pn, nil
}

View file

@ -2,7 +2,6 @@ package wire
import (
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@ -28,25 +27,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
}
pos := 1 + connIDLen
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[pos])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(utils.BigEndian.Uint16(data[pos : pos+2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(utils.BigEndian.Uint24(data[pos : pos+3]))
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4]))
default:
return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen)
pn, err := readPacketNumber(data[pos:], pnLen)
if err != nil {
return 0, 0, 0, 0, err
}
kp := protocol.KeyPhaseZero
if data[0]&0b100 > 0 {
kp = protocol.KeyPhaseOne
}
var err error
if data[0]&0x18 != 0 {
err = ErrInvalidReservedBits
}

View file

@ -8,14 +8,14 @@ import (
// A ConnectionTracer records events.
type ConnectionTracer struct {
StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID)
NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber)
NegotiatedVersion func(chosen Version, clientVersions, serverVersions []Version)
ClosedConnection func(error)
SentTransportParameters func(*TransportParameters)
ReceivedTransportParameters func(*TransportParameters)
RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT
SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame)
SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame)
ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber)
ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []Version)
ReceivedRetry func(*Header)
ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame)
ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame)
@ -57,7 +57,7 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra
}
}
},
NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
NegotiatedVersion: func(chosen Version, clientVersions, serverVersions []Version) {
for _, t := range tracers {
if t.NegotiatedVersion != nil {
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
@ -106,7 +106,7 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra
}
}
},
ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []Version) {
for _, t := range tracers {
if t.ReceivedVersionNegotiationPacket != nil {
t.ReceivedVersionNegotiationPacket(dest, src, versions)

View file

@ -37,7 +37,10 @@ type (
// The StreamType is the type of the stream (unidirectional or bidirectional).
StreamType = protocol.StreamType
// The VersionNumber is the QUIC version.
// Deprecated: use Version instead.
VersionNumber = protocol.Version
// The Version is the QUIC version.
Version = protocol.Version
// The Header is the QUIC packet header, before removing header protection.
Header = wire.Header
@ -72,27 +75,27 @@ const (
const (
// KeyPhaseZero is key phase bit 0
KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero
KeyPhaseZero = protocol.KeyPhaseZero
// KeyPhaseOne is key phase bit 1
KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne
KeyPhaseOne = protocol.KeyPhaseOne
)
const (
// PerspectiveServer is used for a QUIC server
PerspectiveServer Perspective = protocol.PerspectiveServer
PerspectiveServer = protocol.PerspectiveServer
// PerspectiveClient is used for a QUIC client
PerspectiveClient Perspective = protocol.PerspectiveClient
PerspectiveClient = protocol.PerspectiveClient
)
const (
// EncryptionInitial is the Initial encryption level
EncryptionInitial EncryptionLevel = protocol.EncryptionInitial
EncryptionInitial = protocol.EncryptionInitial
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake
EncryptionHandshake = protocol.EncryptionHandshake
// Encryption1RTT is the 1-RTT encryption level
Encryption1RTT EncryptionLevel = protocol.Encryption1RTT
Encryption1RTT = protocol.Encryption1RTT
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT EncryptionLevel = protocol.Encryption0RTT
Encryption0RTT = protocol.Encryption0RTT
)
const (

View file

@ -5,7 +5,7 @@ import "net"
// A Tracer traces events.
type Tracer struct {
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []Version)
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
Debug func(name, msg string)
Close func()
@ -27,7 +27,7 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
}
}
},
SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []Version) {
for _, t := range tracers {
if t.SentVersionNegotiationPacket != nil {
t.SentVersionNegotiationPacket(remote, dest, src, versions)

View file

@ -14,23 +14,17 @@ type Sender = sender
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI"
type StreamI = streamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream"
type CryptoStream = cryptoStream
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI"
type ReceiveStreamI = receiveStreamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI"
type SendStreamI = sendStreamI
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter"
type StreamGetter = streamGetter
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender"
type StreamSender = streamSender
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler"
type CryptoDataHandler = cryptoDataHandler
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_control_frame_getter_test.go github.com/quic-go/quic-go StreamControlFrameGetter"
type StreamControlFrameGetter = streamControlFrameGetter
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource"
type FrameSource = frameSource
@ -72,5 +66,4 @@ type PacketHandlerManager = packetHandlerManager
//
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn"

View file

@ -121,8 +121,8 @@ type packetPacker struct {
perspective protocol.Perspective
cryptoSetup sealingManager
initialStream cryptoStream
handshakeStream cryptoStream
initialStream *cryptoStream
handshakeStream *cryptoStream
token []byte
@ -141,7 +141,7 @@ var _ packer = &packetPacker{}
func newPacketPacker(
srcConnID protocol.ConnectionID,
getDestConnID func() protocol.ConnectionID,
initialStream, handshakeStream cryptoStream,
initialStream, handshakeStream *cryptoStream,
packetNumberManager packetNumberManager,
retransmissionQueue *retransmissionQueue,
cryptoSetup sealingManager,
@ -482,7 +482,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
return nil, payload{}
}
var s cryptoStream
var s *cryptoStream
var handler ackhandler.FrameHandler
var hasRetransmission bool
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
@ -645,6 +645,9 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
pl.length += lengthAdded
// add handlers for the control frames that were added
for i := startLen; i < len(pl.frames); i++ {
if pl.frames[i].Handler != nil {
continue
}
switch pl.frames[i].Frame.(type) {
case *wire.PathChallengeFrame, *wire.PathResponseFrame:
// Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet.

View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"fmt"
"time"
@ -53,7 +52,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) {
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
@ -65,7 +64,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@ -75,7 +74,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@ -85,7 +84,7 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
@ -125,8 +124,8 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot
return pn, pnLen, kp, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
@ -187,17 +186,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int
}
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data, v)
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
return extHdr, err
}
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
hdrLen := hdr.ParsedLen()
if protocol.ByteCount(len(data)) < hdrLen+4+16 {
//nolint:stylecheck
@ -214,7 +211,7 @@ func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v proto
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(r, v)
extHdr, parseErr := hdr.ParseExtended(data)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}

View file

@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
@ -19,7 +20,6 @@ type receiveStreamI interface {
handleStreamFrame(*wire.StreamFrame) error
handleResetStreamFrame(*wire.ResetStreamFrame) error
closeForShutdown(error)
getWindowUpdate() protocol.ByteCount
}
type receiveStream struct {
@ -37,6 +37,9 @@ type receiveStream struct {
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream
queuedStopSending bool
queuedMaxStreamData bool
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
@ -54,8 +57,9 @@ type receiveStream struct {
}
var (
_ ReceiveStream = &receiveStream{}
_ receiveStreamI = &receiveStream{}
_ ReceiveStream = &receiveStream{}
_ receiveStreamI = &receiveStream{}
_ streamControlFrameGetter = &receiveStream{}
)
func newReceiveStream(
@ -87,13 +91,16 @@ func (s *receiveStream) Read(p []byte) (int, error) {
defer func() { <-s.readOnce }()
s.mutex.Lock()
n, err := s.readImpl(p)
queuedNewControlFrame, n, err := s.readImpl(p)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if queuedNewControlFrame {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
return n, err
}
@ -118,19 +125,20 @@ func (s *receiveStream) isNewlyCompleted() bool {
return false
}
func (s *receiveStream) readImpl(p []byte) (int, error) {
func (s *receiveStream) readImpl(p []byte) (bool, int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
return false, 0, io.EOF
}
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
return false, 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
return false, 0, s.closeForShutdownErr
}
var queuedNewControlFrame bool
var bytesRead int
var deadlineTimer *utils.Timer
for bytesRead < len(p) {
@ -138,23 +146,23 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return bytesRead, s.closeForShutdownErr
return queuedNewControlFrame, bytesRead, s.closeForShutdownErr
}
for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return bytesRead, s.closeForShutdownErr
return queuedNewControlFrame, bytesRead, s.closeForShutdownErr
}
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
return queuedNewControlFrame, 0, s.cancelErr
}
deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return bytesRead, errDeadline
return queuedNewControlFrame, bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
@ -184,10 +192,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
}
if bytesRead > len(p) {
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
return queuedNewControlFrame, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}
m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
@ -197,7 +205,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
if queueMaxStreamData := s.flowController.AddBytesRead(protocol.ByteCount(m)); queueMaxStreamData {
s.queuedMaxStreamData = true
queuedNewControlFrame = true
}
}
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
@ -206,10 +217,10 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
s.currentFrameDone()
}
s.errorRead = true
return bytesRead, io.EOF
return queuedNewControlFrame, bytesRead, io.EOF
}
}
return bytesRead, nil
return queuedNewControlFrame, bytesRead, nil
}
func (s *receiveStream) dequeueNextFrame() {
@ -225,30 +236,31 @@ func (s *receiveStream) dequeueNextFrame() {
func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
s.cancelReadImpl(errorCode)
queuedNewControlFrame := s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if queuedNewControlFrame {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNewControlFrame bool) {
if s.cancelledLocally { // duplicate call to CancelRead
return
return false
}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
return false
}
s.queuedStopSending = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
return true
}
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
@ -318,6 +330,26 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
return nil
}
func (s *receiveStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.queuedStopSending && !s.queuedMaxStreamData {
return ackhandler.Frame{}, false, false
}
if s.queuedStopSending {
s.queuedStopSending = false
return ackhandler.Frame{
Frame: &wire.StopSendingFrame{StreamID: s.streamID, ErrorCode: s.cancelErr.ErrorCode},
}, true, s.queuedMaxStreamData
}
s.queuedMaxStreamData = false
return ackhandler.Frame{
Frame: &wire.MaxStreamDataFrame{StreamID: s.streamID, MaximumStreamData: s.flowController.GetWindowUpdate()},
}, true, false
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
s.deadline = t
@ -336,10 +368,6 @@ func (s *receiveStream) closeForShutdown(err error) {
s.signalRead()
}
func (s *receiveStream) getWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// signalRead performs a non-blocking send on the readChan
func (s *receiveStream) signalRead() {
select {

View file

@ -26,7 +26,7 @@ type sendStreamI interface {
type sendStream struct {
mutex sync.Mutex
numOutstandingFrames int64
numOutstandingFrames int64 // outstanding STREAM and RESET_STREAM frames
retransmissionQueue []*wire.StreamFrame
ctx context.Context
@ -37,9 +37,12 @@ type sendStream struct {
writeOffset protocol.ByteCount
cancelWriteErr error
cancelWriteErr *StreamError
closeForShutdownErr error
queuedResetStreamFrame bool
queuedBlockedFrame bool
finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
// Set when the application knows about the cancellation.
@ -59,8 +62,9 @@ type sendStream struct {
}
var (
_ SendStream = &sendStream{}
_ sendStreamI = &sendStream{}
_ SendStream = &sendStream{}
_ sendStreamI = &sendStream{}
_ streamControlFrameGetter = &sendStream{}
)
func newSendStream(
@ -172,7 +176,7 @@ func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error)
s.mutex.Unlock()
if !notifiedSender {
s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex
s.sender.onHasStreamData(s.streamID, s) // must be called without holding the mutex
notifiedSender = true
}
if copied {
@ -215,12 +219,15 @@ func (s *sendStream) canBufferStreamFrame() bool {
// maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) {
s.mutex.Lock()
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
f, hasMoreData, queuedControlFrame := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil {
s.numOutstandingFrames++
}
s.mutex.Unlock()
if queuedControlFrame {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
if f == nil {
return ackhandler.StreamFrame{}, false, hasMoreData
}
@ -230,20 +237,20 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
}, true, hasMoreData
}
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) {
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, hasMoreData, queuedControlFrame bool) {
if s.cancelWriteErr != nil || s.closeForShutdownErr != nil {
return nil, false
return nil, false, false
}
if len(s.retransmissionQueue) > 0 {
f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
if f != nil || hasMoreRetransmissions {
if f == nil {
return nil, true
return nil, true, false
}
// We always claim that we have more data to send.
// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
return f, true
return f, true, false
}
}
@ -255,21 +262,18 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
Offset: s.writeOffset,
DataLenPresent: true,
Fin: true,
}, false
}, false, false
}
return nil, false
return nil, false, false
}
sendWindow := s.flowController.SendWindowSize()
if sendWindow == 0 {
if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{
StreamID: s.streamID,
MaximumStreamData: offset,
})
return nil, false
if s.flowController.IsNewlyBlocked() {
s.queuedBlockedFrame = true
return nil, false, true
}
return nil, true
return nil, true, false
}
f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v)
@ -281,7 +285,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun
if f.Fin {
s.finSent = true
}
return f, hasMoreData
return f, hasMoreData, false
}
func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) {
@ -367,7 +371,7 @@ func (s *sendStream) isNewlyCompleted() bool {
return false
}
// We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 {
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame {
return false
}
// The stream is completed if we sent the FIN.
@ -379,7 +383,7 @@ func (s *sendStream) isNewlyCompleted() bool {
// 1. the application called CancelWrite, or
// 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or
// * the application called CLsoe
// * the application called Close
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true
return true
@ -407,7 +411,7 @@ func (s *sendStream) Close() error {
if cancelWriteErr != nil {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
s.ctxCancel(nil)
return nil
@ -421,6 +425,17 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
s.mutex.Lock()
if !remote {
s.cancellationFlagged = true
if s.cancelWriteErr != nil {
completed := s.isNewlyCompleted()
s.mutex.Unlock()
// The user has called CancelWrite. If the previous cancellation was
// because of a STOP_SENDING, we don't need to flag the error to the
// user anymore.
if completed {
s.sender.onStreamCompleted(s.streamID)
}
return
}
}
if s.cancelWriteErr != nil {
s.mutex.Unlock()
@ -430,18 +445,11 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool
s.ctxCancel(s.cancelWriteErr)
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
newlyCompleted := s.isNewlyCompleted()
s.queuedResetStreamFrame = true
s.mutex.Unlock()
s.signalWrite()
s.sender.queueControlFrame(&wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: errorCode,
})
if newlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
s.sender.onHasStreamControlFrame(s.streamID, s)
}
func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
@ -453,7 +461,7 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) {
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
s.mutex.Unlock()
if hasStreamData {
s.sender.onHasStreamData(s.streamID)
s.sender.onHasStreamData(s.streamID, s)
}
}
@ -461,6 +469,32 @@ func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.cancelWriteImpl(frame.ErrorCode, true)
}
func (s *sendStream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.queuedBlockedFrame && !s.queuedResetStreamFrame {
return ackhandler.Frame{}, false, false
}
if s.queuedBlockedFrame {
s.queuedBlockedFrame = false
return ackhandler.Frame{
Frame: &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset},
}, true, s.queuedResetStreamFrame
}
// RESET_STREAM frame
s.queuedResetStreamFrame = false
s.numOutstandingFrames++
return ackhandler.Frame{
Frame: &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: s.cancelWriteErr.ErrorCode,
},
Handler: (*sendStreamResetStreamHandler)(s),
}, true, false
}
func (s *sendStream) Context() context.Context {
return s.ctx
}
@ -507,10 +541,10 @@ func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
newlyCompleted := (*sendStream)(s).isNewlyCompleted()
completed := (*sendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if newlyCompleted {
if completed {
s.sender.onStreamCompleted(s.streamID)
}
}
@ -530,5 +564,30 @@ func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
s.sender.onHasStreamData(s.streamID, (*sendStream)(s))
}
type sendStreamResetStreamHandler sendStream
var _ ackhandler.FrameHandler = &sendStreamResetStreamHandler{}
func (s *sendStreamResetStreamHandler) OnAcked(wire.Frame) {
s.mutex.Lock()
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
completed := (*sendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStreamResetStreamHandler) OnLost(wire.Frame) {
s.mutex.Lock()
s.queuedResetStreamFrame = true
s.mutex.Unlock()
s.sender.onHasStreamControlFrame(s.streamID, (*sendStream)(s))
}

View file

@ -18,7 +18,12 @@ import (
)
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
var ErrServerClosed = errors.New("quic: server closed")
var ErrServerClosed = errServerClosed{}
type errServerClosed struct{}
func (errServerClosed) Error() string { return "quic: server closed" }
func (errServerClosed) Unwrap() error { return net.ErrClosed }
// packetHandler handles packets
type packetHandler interface {
@ -803,7 +808,7 @@ func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) {
hdr := p.hdr
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
extHdr, err := unpackLongHeader(opener, hdr, data)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil {

View file

@ -24,8 +24,8 @@ var errDeadline net.Error = &deadlineError{}
// The streamSender is notified by the stream about various events.
type streamSender interface {
queueControlFrame(wire.Frame)
onHasStreamData(protocol.StreamID)
onHasStreamData(protocol.StreamID, sendStreamI)
onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter)
// must be called without holding the mutex that is acquired by closeForShutdown
onStreamCompleted(protocol.StreamID)
}
@ -34,19 +34,16 @@ type streamSender interface {
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
onStreamCompletedImpl func()
onHasStreamControlFrameImpl func(protocol.StreamID, streamControlFrameGetter)
}
func (s *uniStreamSender) queueControlFrame(f wire.Frame) {
s.streamSender.queueControlFrame(f)
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str sendStreamI) {
s.streamSender.onHasStreamData(id, str)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) {
s.streamSender.onHasStreamData(id)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) {
s.onStreamCompletedImpl()
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() }
func (s *uniStreamSender) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) {
s.onHasStreamControlFrameImpl(id, str)
}
var _ streamSender = &uniStreamSender{}
@ -57,7 +54,6 @@ type streamI interface {
// for receiving
handleStreamFrame(*wire.StreamFrame) error
handleResetStreamFrame(*wire.ResetStreamFrame) error
getWindowUpdate() protocol.ByteCount
// for sending
hasData() bool
handleStopSendingFrame(*wire.StopSendingFrame)
@ -83,7 +79,10 @@ type stream struct {
sendStreamCompleted bool
}
var _ Stream = &stream{}
var (
_ Stream = &stream{}
_ streamControlFrameGetter = &receiveStream{}
)
// newStream creates a new Stream
func newStream(
@ -101,6 +100,9 @@ func newStream(
s.checkIfCompleted()
s.completedMutex.Unlock()
},
onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) {
sender.onHasStreamControlFrame(streamID, s)
},
}
s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController)
senderForReceiveStream := &uniStreamSender{
@ -111,6 +113,9 @@ func newStream(
s.checkIfCompleted()
s.completedMutex.Unlock()
},
onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) {
sender.onHasStreamControlFrame(streamID, s)
},
}
s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController)
return s
@ -126,6 +131,14 @@ func (s *stream) Close() error {
return s.sendStream.Close()
}
func (s *stream) getControlFrame() (_ ackhandler.Frame, ok, hasMore bool) {
f, ok, _ := s.sendStream.getControlFrame()
if ok {
return f, true, true
}
return s.receiveStream.getControlFrame()
}
func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t) // SetReadDeadline never errors
_ = s.SetWriteDeadline(t) // SetWriteDeadline never errors

View file

@ -38,11 +38,21 @@ type streamOpenErr struct{ error }
var _ net.Error = &streamOpenErr{}
func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
func (streamOpenErr) Timeout() bool { return false }
func (streamOpenErr) Timeout() bool { return false }
func (e streamOpenErr) Unwrap() error { return e.error }
// errTooManyOpenStreams is used internally by the outgoing streams maps.
var errTooManyOpenStreams = errors.New("too many open streams")
func (e streamOpenErr) Temporary() bool {
// In older versions of quic-go, the stream limit error was documented to be a net.Error.Temporary.
// This function was since deprecated, but we keep the existing behavior.
return errors.Is(e, &StreamLimitReachedError{})
}
// StreamLimitReachedError is returned from Connection.OpenStream and Connection.OpenUniStream
// when it is not possible to open a new stream because the number of opens streams reached
// the peer's stream limit.
type StreamLimitReachedError struct{}
func (e StreamLimitReachedError) Error() string { return "too many open streams" }
type streamsMap struct {
ctx context.Context // not used for cancellations, but carries the values associated with the connection
@ -52,6 +62,7 @@ type streamsMap struct {
maxIncomingUniStreams uint64
sender streamSender
queueControlFrame func(wire.Frame)
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
mutex sync.Mutex
@ -67,14 +78,16 @@ var _ streamManager = &streamsMap{}
func newStreamsMap(
ctx context.Context,
sender streamSender,
queueControlFrame func(wire.Frame),
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingBidiStreams uint64,
maxIncomingUniStreams uint64,
perspective protocol.Perspective,
) streamManager {
) *streamsMap {
m := &streamsMap{
ctx: ctx,
perspective: perspective,
queueControlFrame: queueControlFrame,
newFlowController: newFlowController,
maxIncomingBidiStreams: maxIncomingBidiStreams,
maxIncomingUniStreams: maxIncomingUniStreams,
@ -91,7 +104,7 @@ func (m *streamsMap) initMaps() {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
m.queueControlFrame,
)
m.incomingBidiStreams = newIncomingStreamsMap(
protocol.StreamTypeBidi,
@ -100,7 +113,7 @@ func (m *streamsMap) initMaps() {
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.maxIncomingBidiStreams,
m.sender.queueControlFrame,
m.queueControlFrame,
)
m.outgoingUniStreams = newOutgoingStreamsMap(
protocol.StreamTypeUni,
@ -108,7 +121,7 @@ func (m *streamsMap) initMaps() {
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
m.queueControlFrame,
)
m.incomingUniStreams = newIncomingStreamsMap(
protocol.StreamTypeUni,
@ -117,7 +130,7 @@ func (m *streamsMap) initMaps() {
return newReceiveStream(id, m.sender, m.newFlowController(id))
},
m.maxIncomingUniStreams,
m.sender.queueControlFrame,
m.queueControlFrame,
)
}

View file

@ -60,7 +60,7 @@ func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
// if there are OpenStreamSync calls waiting, return an error here
if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
m.maybeSendBlockedFrame()
return *new(T), streamOpenErr{errTooManyOpenStreams}
return *new(T), streamOpenErr{&StreamLimitReachedError{}}
}
return m.openStream(), nil
}

View file

@ -1,71 +0,0 @@
package quic
import (
"sync"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type windowUpdateQueue struct {
mutex sync.Mutex
queue map[protocol.StreamID]struct{} // used as a set
queuedConn bool // connection-level window update
streamGetter streamGetter
connFlowController flowcontrol.ConnectionFlowController
callback func(wire.Frame)
}
func newWindowUpdateQueue(
streamGetter streamGetter,
connFC flowcontrol.ConnectionFlowController,
cb func(wire.Frame),
) *windowUpdateQueue {
return &windowUpdateQueue{
queue: make(map[protocol.StreamID]struct{}),
streamGetter: streamGetter,
connFlowController: connFC,
callback: cb,
}
}
func (q *windowUpdateQueue) AddStream(id protocol.StreamID) {
q.mutex.Lock()
q.queue[id] = struct{}{}
q.mutex.Unlock()
}
func (q *windowUpdateQueue) AddConnection() {
q.mutex.Lock()
q.queuedConn = true
q.mutex.Unlock()
}
func (q *windowUpdateQueue) QueueAll() {
q.mutex.Lock()
// queue a connection-level window update
if q.queuedConn {
q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()})
q.queuedConn = false
}
// queue all stream-level window updates
for id := range q.queue {
delete(q.queue, id)
str, err := q.streamGetter.GetOrOpenReceiveStream(id)
if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update
continue
}
offset := str.getWindowUpdate()
if offset == 0 { // can happen if we received a final offset, right after queueing the window update
continue
}
q.callback(&wire.MaxStreamDataFrame{
StreamID: id,
MaximumStreamData: offset,
})
}
q.mutex.Unlock()
}