diff --git a/.clusterfuzzlite/Dockerfile b/.clusterfuzzlite/Dockerfile index b57da8ca..d9fedb9a 100644 --- a/.clusterfuzzlite/Dockerfile +++ b/.clusterfuzzlite/Dockerfile @@ -3,7 +3,7 @@ FROM gcr.io/oss-fuzz-base/base-builder-go:v1 ARG TARGETPLATFORM RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}" -ENV GOVERSION=1.20.7 +ENV GOVERSION=1.22.0 RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ filename="go${GOVERSION}.${platform}.tar.gz" && \ diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..5ace4600 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.golangci.yml b/.golangci.yml index 1315759b..469d54cf 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,16 +2,6 @@ run: skip-files: - internal/handshake/cipher_suite.go linters-settings: - depguard: - type: blacklist - packages: - - github.com/marten-seemann/qtls - - github.com/quic-go/qtls-go1-19 - - github.com/quic-go/qtls-go1-20 - packages-with-error-message: - - github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls" - - github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls" - - github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls" misspell: ignore-words: - ect @@ -20,7 +10,6 @@ linters: disable-all: true enable: - asciicheck - - depguard - exhaustive - exportloopref - goimports diff --git a/client.go b/client.go index c70937f5..70f7143b 100644 --- a/client.go +++ b/client.go @@ -29,7 +29,7 @@ type client struct { initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool - version protocol.VersionNumber + version protocol.Version handshakeChan chan struct{} @@ -237,7 +237,7 @@ func (c *client) dial(ctx context.Context) error { select { case <-ctx.Done(): - c.conn.shutdown() + c.conn.destroy(nil) return context.Cause(ctx) case err := <-errorChan: return err diff --git a/client_test.go b/client_test.go index c55953f3..3b1da3b0 100644 --- a/client_test.go +++ b/client_test.go @@ -47,7 +47,7 @@ var _ = Describe("Client", func() { tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn ) @@ -61,7 +61,7 @@ var _ = Describe("Client", func() { Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { return tr }, - Versions: []protocol.VersionNumber{protocol.Version1}, + Versions: []protocol.Version{protocol.Version1}, } Eventually(areConnsRunning).Should(BeFalse()) packetConn = NewMockSendConn(mockCtrl) @@ -88,7 +88,7 @@ var _ = Describe("Client", func() { AfterEach(func() { if s, ok := cl.conn.(*connection); ok { - s.shutdown() + s.destroy(nil) } Eventually(areConnsRunning).Should(BeFalse()) }) @@ -126,11 +126,11 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - _ protocol.VersionNumber, + _ protocol.Version, ) quicConn { Expect(enable0RTT).To(BeFalse()) conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { close(run) }) + conn.EXPECT().run().Do(func() error { close(run); return nil }) c := make(chan struct{}) close(c) conn.EXPECT().HandshakeComplete().Return(c) @@ -163,11 +163,11 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - _ protocol.VersionNumber, + _ protocol.Version, ) quicConn { Expect(enable0RTT).To(BeTrue()) conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { close(done) }) + conn.EXPECT().run().Do(func() error { close(done); return nil }) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().earlyConnReady().Return(readyChan) return conn @@ -200,7 +200,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - _ protocol.VersionNumber, + _ protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run().Return(testErr) @@ -266,9 +266,9 @@ var _ = Describe("Client", func() { }) It("creates new connections with the right parameters", func() { - config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} + config := &Config{Versions: []protocol.Version{protocol.Version1}} c := make(chan struct{}) - var version protocol.VersionNumber + var version protocol.Version var conf *Config done := make(chan struct{}) newClientConnection = func( @@ -285,7 +285,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - versionP protocol.VersionNumber, + versionP protocol.Version, ) quicConn { version = versionP conf = configP @@ -328,7 +328,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - versionP protocol.VersionNumber, + versionP protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) @@ -352,7 +352,7 @@ var _ = Describe("Client", func() { return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.Version{protocol.Version1}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) diff --git a/closed_conn.go b/closed_conn.go index 071071c7..1dc14425 100644 --- a/closed_conn.go +++ b/closed_conn.go @@ -4,7 +4,6 @@ import ( "math/bits" "net" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" ) @@ -12,9 +11,8 @@ import ( // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. type closedLocalConn struct { - counter uint32 - perspective protocol.Perspective - logger utils.Logger + counter uint32 + logger utils.Logger sendPacket func(net.Addr, packetInfo) } @@ -22,11 +20,10 @@ type closedLocalConn struct { var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { +func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler { return &closedLocalConn{ - sendPacket: sendPacket, - perspective: pers, - logger: logger, + sendPacket: sendPacket, + logger: logger, } } @@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) { c.sendPacket(p.remoteAddr, p.info) } -func (c *closedLocalConn) shutdown() {} -func (c *closedLocalConn) destroy(error) {} -func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective } +func (c *closedLocalConn) destroy(error) {} +func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {} // A closedRemoteConn is a connection that was closed remotely. // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // We can just ignore those packets. -type closedRemoteConn struct { - perspective protocol.Perspective -} +type closedRemoteConn struct{} var _ packetHandler = &closedRemoteConn{} -func newClosedRemoteConn(pers protocol.Perspective) packetHandler { - return &closedRemoteConn{perspective: pers} +func newClosedRemoteConn() packetHandler { + return &closedRemoteConn{} } -func (s *closedRemoteConn) handlePacket(receivedPacket) {} -func (s *closedRemoteConn) shutdown() {} -func (s *closedRemoteConn) destroy(error) {} -func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } +func (c *closedRemoteConn) handlePacket(receivedPacket) {} +func (c *closedRemoteConn) destroy(error) {} +func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {} diff --git a/closed_conn_test.go b/closed_conn_test.go index 40337998..d36d664b 100644 --- a/closed_conn_test.go +++ b/closed_conn_test.go @@ -3,7 +3,6 @@ package quic import ( "net" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" . "github.com/onsi/ginkgo/v2" @@ -11,20 +10,9 @@ import ( ) var _ = Describe("Closed local connection", func() { - It("tells its perspective", func() { - conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger) - Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient)) - // stop the connection - conn.shutdown() - }) - It("repeats the packet containing the CONNECTION_CLOSE frame", func() { written := make(chan net.Addr, 1) - conn := newClosedLocalConn( - func(addr net.Addr, _ packetInfo) { written <- addr }, - protocol.PerspectiveClient, - utils.DefaultLogger, - ) + conn := newClosedLocalConn(func(addr net.Addr, _ packetInfo) { written <- addr }, utils.DefaultLogger) addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337} for i := 1; i <= 20; i++ { conn.handlePacket(receivedPacket{remoteAddr: addr}) diff --git a/codecov.yml b/codecov.yml index a24c7a15..59e4b58f 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,6 +5,8 @@ coverage: - interop/ - internal/handshake/cipher_suite.go - internal/utils/linkedlist/linkedlist.go + - internal/testdata + - testutils/ - fuzzing/ - metrics/ status: diff --git a/config.go b/config.go index 501ed1a0..41440488 100644 --- a/config.go +++ b/config.go @@ -2,7 +2,6 @@ package quic import ( "fmt" - "net" "time" "github.com/refraction-networking/uquic/internal/protocol" @@ -49,16 +48,6 @@ func validateConfig(config *Config) error { return nil } -// populateServerConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.RequireAddressValidation == nil { - config.RequireAddressValidation = func(net.Addr) bool { return false } - } - return config -} - // populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateConfig(config *Config) *Config { @@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config { Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index e0eef430..b34575c0 100644 --- a/config_test.go +++ b/config_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "reflect" "time" @@ -23,7 +22,7 @@ var _ = Describe("Config", func() { }) It("validates a config with normal values", func() { - conf := populateServerConfig(&Config{ + conf := populateConfig(&Config{ MaxIncomingStreams: 5, MaxStreamReceiveWindow: 10, }) @@ -69,7 +68,7 @@ var _ = Describe("Config", func() { case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": // Can't compare functions. case "Versions": - f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) + f.Set(reflect.ValueOf([]Version{1, 2, 3})) case "ConnectionIDLength": f.Set(reflect.ValueOf(8)) case "ConnectionIDGenerator": @@ -118,19 +117,16 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool + var calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, - RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { calledTracer = true return nil }, } c2 := c1.Clone() - c2.RequireAddressValidation(&net.UDPAddr{}) - Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) _, err := c2.GetConfigForClient(&ClientHelloInfo{}) @@ -145,29 +141,15 @@ var _ = Describe("Config", func() { }) It("returns a copy", func() { - c1 := &Config{ - MaxIncomingStreams: 100, - RequireAddressValidation: func(net.Addr) bool { return true }, - } + c1 := &Config{MaxIncomingStreams: 100} c2 := c1.Clone() c2.MaxIncomingStreams = 200 - c2.RequireAddressValidation = func(net.Addr) bool { return false } Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) }) }) Context("populating", func() { - It("populates function fields", func() { - var calledAddrValidation bool - c1 := &Config{} - c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1) - c2.RequireAddressValidation(&net.UDPAddr{}) - Expect(calledAddrValidation).To(BeTrue()) - }) - It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() Expect(populateConfig(c)).To(Equal(c)) @@ -186,10 +168,5 @@ var _ = Describe("Config", func() { Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.GetConfigForClient).To(BeNil()) }) - - It("populates empty fields with default values, for the server", func() { - c := populateServerConfig(&Config{}) - Expect(c.RequireAddressValidation).ToNot(BeNil()) - }) }) }) diff --git a/conn_id_generator.go b/conn_id_generator.go index aab6bb42..5f64b8a5 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -5,7 +5,6 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" - "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" ) @@ -20,7 +19,7 @@ type connIDGenerator struct { getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) + replaceWithClosed func([]protocol.ConnectionID, []byte) queueControlFrame func(wire.Frame) } @@ -31,7 +30,7 @@ func newConnIDGenerator( getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), + replaceWithClosed func([]protocol.ConnectionID, []byte), queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { @@ -60,7 +59,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { // transport parameter. // We currently don't send the preferred_address transport parameter, // so we can issue (limit - 1) connection IDs. - for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ { + for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ { if err := m.issueNewConnID(); err != nil { return err } @@ -127,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() { } } -func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { +func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, *m.initialClientDestConnID) @@ -135,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.replaceWithClosed(connIDs, pers, connClose) + m.replaceWithClosed(connIDs, connClose) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 36e5db65..2e8a2785 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -41,9 +41,7 @@ var _ = Describe("Connection ID Generator", func() { connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, - func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { - replacedWithClosed = append(replacedWithClosed, cs...) - }, + func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, &protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()}, ) @@ -177,7 +175,7 @@ var _ = Describe("Connection ID Generator", func() { It("replaces with a closed connection for all connection IDs", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(4)) - g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar")) + g.ReplaceWithClosed([]byte("foobar")) Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID)) Expect(replacedWithClosed).To(ContainElement(initialConnID)) diff --git a/conn_id_manager.go b/conn_id_manager.go index 49ca79ec..86f013c9 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -155,7 +155,7 @@ func (h *connIDManager) updateConnectionID() { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, }) - h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber) + h.highestRetired = max(h.highestRetired, h.activeSequenceNumber) if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } diff --git a/connection.go b/connection.go index c7e0d4a3..cdf53169 100644 --- a/connection.go +++ b/connection.go @@ -26,7 +26,7 @@ import ( ) type unpacker interface { - UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) + UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } @@ -94,7 +94,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) + ReplaceWithClosed([]protocol.ConnectionID, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } @@ -107,7 +107,7 @@ type closeError struct { type errCloseForRecreating struct { nextPacketNumber protocol.PacketNumber - nextVersion protocol.VersionNumber + nextVersion protocol.Version } func (e *errCloseForRecreating) Error() string { @@ -129,7 +129,7 @@ type connection struct { srcConnIDLen int perspective protocol.Perspective - version protocol.VersionNumber + version protocol.Version config *Config conn sendConn @@ -178,6 +178,7 @@ type connection struct { earlyConnReadyChan chan struct{} sentFirstPacket bool + droppedInitialKeys bool handshakeComplete bool handshakeConfirmed bool @@ -236,7 +237,7 @@ var newConnection = func( tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn { s := &connection{ conn: conn, @@ -308,7 +309,7 @@ var newConnection = func( RetrySourceConnectionID: retrySrcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } @@ -349,7 +350,7 @@ var newClientConnection = func( tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn { s := &connection{ conn: conn, @@ -417,7 +418,7 @@ var newClientConnection = func( InitialSourceConnectionID: srcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } @@ -457,7 +458,7 @@ func (s *connection) preSetup() { s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() - s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) + s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ByteCount(s.config.InitialConnectionReceiveWindow), @@ -524,6 +525,9 @@ func (s *connection) run() error { runLoop: for { + if s.framer.QueuedTooManyControlFrames() { + s.closeLocal(&qerr.TransportError{ErrorCode: InternalError}) + } // Close immediately if requested select { case closeErr = <-s.closeChan: @@ -633,7 +637,7 @@ runLoop: sendQueueAvailable = s.sendQueue.Available() continue } - if err := s.triggerSending(); err != nil { + if err := s.triggerSending(now); err != nil { s.closeLocal(err) } if s.sendQueue.WouldBlock() { @@ -685,7 +689,7 @@ func (s *connection) ConnectionState() ConnectionState { // Time when the connection should time out func (s *connection) nextIdleTimeoutTime() time.Time { - idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3) + idleTimeout := max(s.idleTimeout, s.rttStats.PTO(true)*3) return s.idleTimeoutStartTime().Add(idleTimeout) } @@ -695,7 +699,7 @@ func (s *connection) nextKeepAliveTime() time.Time { if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { return time.Time{} } - keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) + keepAliveInterval := max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) return s.lastPacketReceivedTime.Add(keepAliveInterval) } @@ -735,6 +739,10 @@ func (s *connection) handleHandshakeComplete() error { s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() + if s.tracer != nil && s.tracer.ChoseALPN != nil { + s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol) + } + // The server applies transport parameters right away, but the client side has to wait for handshake completion. // During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets. if s.perspective == protocol.PerspectiveClient { @@ -780,7 +788,7 @@ func (s *connection) handleHandshakeConfirmed() error { if maxPacketSize == 0 { maxPacketSize = protocol.MaxByteCount } - s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize)) + s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize)) } return nil } @@ -808,14 +816,14 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) break @@ -830,7 +838,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if err == wire.ErrUnsupportedVersion { dropReason = logging.PacketDropUnsupportedVersion } - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), dropReason) } s.logger.Debugf("error parsing packet: %s", err) break @@ -839,7 +847,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if hdr.Version != s.version { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) } s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) break @@ -898,7 +906,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) + s.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate) } return false } @@ -944,7 +952,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) + s.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) return false @@ -952,7 +960,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) + s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } return false } @@ -968,10 +976,10 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) packet.hdr.Log(s.logger) } - if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { + if pn := packet.hdr.PacketNumber; s.receivedPacketHandler.IsPotentiallyDuplicate(pn, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate) } return false } @@ -987,7 +995,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P switch err { case handshake.ErrKeysDropped: if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) case handshake.ErrKeysNotYetAvailable: @@ -1003,7 +1011,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) default: @@ -1011,7 +1019,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) } else { @@ -1026,14 +1034,14 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry.") return false } if s.receivedFirstPacket { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since we already received a packet.") return false @@ -1041,7 +1049,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti destConnID := s.connIDManager.Get() if hdr.SrcConnectionID == destConnID { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") return false @@ -1056,7 +1064,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") return false @@ -1089,7 +1097,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) } return } @@ -1097,7 +1105,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) return @@ -1106,7 +1114,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { for _, v := range supportedVersions { if v == s.version { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) } // The Version Negotiation packet contains the version that we offered. // This might be a packet sent by an attacker, or it was corrupted. @@ -1148,7 +1156,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( if !s.receivedFirstPacket { s.receivedFirstPacket = true if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { - var clientVersions, serverVersions []protocol.VersionNumber + var clientVersions, serverVersions []protocol.Version switch s.perspective { case protocol.PerspectiveClient: clientVersions = s.config.Versions @@ -1185,7 +1193,8 @@ func (s *connection) handleUnpackedLongHeaderPacket( } } - if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake { + if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake && + !s.droppedInitialKeys { // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // See Section 4.9.1 of RFC 9001. if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { @@ -1347,7 +1356,7 @@ func (s *connection) handlePacket(p receivedPacket) { case s.receivedPackets <- p: default: if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } } } @@ -1526,7 +1535,7 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr } func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { - if f.Length(s.version) > protocol.MaxDatagramFrameSize { + if f.Length(s.version) > wire.MaxDatagramSize { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "DATAGRAM frame too large", @@ -1572,13 +1581,6 @@ func (s *connection) closeRemote(e error) { }) } -// Close the connection. It sends a NO_ERROR application error. -// It waits until the run loop has stopped before returning -func (s *connection) shutdown() { - s.closeLocal(nil) - <-s.ctx.Done() -} - func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { s.closeLocal(&qerr.ApplicationError{ ErrorCode: code, @@ -1588,6 +1590,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro return nil } +func (s *connection) closeWithTransportError(code TransportErrorCode) { + s.closeLocal(&qerr.TransportError{ErrorCode: code}) + <-s.ctx.Done() +} + func (s *connection) handleCloseError(closeErr *closeError) { e := closeErr.err if e == nil { @@ -1632,7 +1639,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) + s.connIDGenerator.ReplaceWithClosed(nil) return } if closeErr.immediate { @@ -1649,7 +1656,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) + s.connIDGenerator.ReplaceWithClosed(connClosePacket) } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { @@ -1661,6 +1668,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro //nolint:exhaustive // only Initial and 0-RTT need special treatment switch encLevel { case protocol.EncryptionInitial: + s.droppedInitialKeys = true s.cryptoStreamHandler.DiscardInitialKeys() case protocol.Encryption0RTT: s.streamsMap.ResetFor0RTT() @@ -1755,7 +1763,7 @@ func (s *connection) applyTransportParameters() { params := s.peerParams // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) - s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.keepAliveInterval = min(s.config.KeepAlivePeriod, min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.streamsMap.UpdateLimits(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) @@ -1771,9 +1779,8 @@ func (s *connection) applyTransportParameters() { } } -func (s *connection) triggerSending() error { +func (s *connection) triggerSending(now time.Time) error { s.pacingDeadline = time.Time{} - now := time.Now() sendMode := s.sentPacketHandler.SendMode(now) //nolint:exhaustive // No need to handle pacing limited here. @@ -1805,7 +1812,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) case ackhandler.SendPTOHandshake: if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil { return err @@ -1814,7 +1821,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) case ackhandler.SendPTOAppData: if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil { return err @@ -1823,7 +1830,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) default: return fmt.Errorf("BUG: invalid send mode %d", sendMode) } @@ -1992,7 +1999,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) + return s.sendPackedCoalescedPacket(packet, ecn, now) } ecn := s.sentPacketHandler.ECNMode(true) @@ -2078,7 +2085,8 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot largestAcked = p.ack.LargestAcked() } s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) - if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { + if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake && + !s.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { @@ -2309,7 +2317,7 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return @@ -2347,21 +2355,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { } } -func (s *connection) SendMessage(p []byte) error { +func (s *connection) SendDatagram(p []byte) error { if !s.supportsDatagrams() { return errors.New("datagram support disabled") } f := &wire.DatagramFrame{DataLenPresent: true} if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { - return errors.New("message too large") + return &DatagramTooLargeError{ + PeerMaxDatagramFrameSize: int64(s.peerParams.MaxDatagramFrameSize), + } } f.Data = make([]byte, len(p)) copy(f.Data, p) - return s.datagramQueue.AddAndWait(f) + return s.datagramQueue.Add(f) } -func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) { +func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) { if !s.config.EnableDatagrams { return nil, errors.New("datagram support disabled") } @@ -2376,11 +2386,7 @@ func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *connection) getPerspective() protocol.Perspective { - return s.perspective -} - -func (s *connection) GetVersion() protocol.VersionNumber { +func (s *connection) GetVersion() protocol.Version { return s.version } diff --git a/connection_test.go b/connection_test.go index 87cf05c9..9947e586 100644 --- a/connection_test.go +++ b/connection_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "net/netip" "runtime/pprof" "strings" "time" @@ -21,10 +22,10 @@ import ( mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" - "github.com/refraction-networking/uquic/internal/testutils" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" + "github.com/refraction-networking/uquic/testutils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -76,7 +77,7 @@ var _ = Describe("Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ContainElement(srcConnID)) if len(connIDs) > 1 { Expect(connIDs).To(ContainElement(clientDestConnID)) @@ -84,8 +85,8 @@ var _ = Describe("Connection", func() { }) } - expectAppendPacket := func(packer *MockPacker, p shortHeaderPacket, b []byte) *gomock.Call { - return packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), Version1).DoAndReturn(func(buf *packetBuffer, _ protocol.ByteCount, _ protocol.VersionNumber) (shortHeaderPacket, error) { + expectAppendPacket := func(packer *MockPacker, p shortHeaderPacket, b []byte) *MockPackerAppendPacketCall { + return packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), Version1).DoAndReturn(func(buf *packetBuffer, _ protocol.ByteCount, _ protocol.Version) (shortHeaderPacket, error) { buf.Data = append(buf.Data, b...) return p, nil }) @@ -118,7 +119,7 @@ var _ = Describe("Connection", func() { srcConnID, &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, - populateServerConfig(&Config{DisablePathMTUDiscovery: true}), + populateConfig(&Config{DisablePathMTUDiscovery: true}), &tls.Config{}, tokenGenerator, false, @@ -346,7 +347,7 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(expectedErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) }) cryptoSetup.EXPECT().Close() @@ -375,7 +376,7 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(testErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) }) cryptoSetup.EXPECT().Close() @@ -410,7 +411,7 @@ var _ = Describe("Connection", func() { It("tells its versions", func() { conn.version = 4242 - Expect(conn.GetVersion()).To(Equal(protocol.VersionNumber(4242))) + Expect(conn.GetVersion()).To(Equal(protocol.Version(4242))) }) Context("closing", func() { @@ -450,7 +451,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("connection close")...) - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.Version) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeEquivalentTo(qerr.NoError)) Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil @@ -465,7 +466,7 @@ var _ = Describe("Connection", func() { }), tracer.EXPECT().Close(), ) - conn.shutdown() + conn.CloseWithError(0, "") Eventually(areConnsRunning).Should(BeFalse()) Expect(conn.Context().Done()).To(BeClosed()) }) @@ -479,8 +480,8 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() - conn.shutdown() + conn.CloseWithError(0, "") + conn.CloseWithError(0, "") Eventually(areConnsRunning).Should(BeFalse()) Expect(conn.Context().Done()).To(BeClosed()) }) @@ -551,29 +552,6 @@ var _ = Describe("Connection", func() { } }) - It("cancels the context when the run loop exists", func() { - runConn() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - returned := make(chan struct{}) - go func() { - defer GinkgoRecover() - ctx := conn.Context() - <-ctx.Done() - Expect(ctx.Err()).To(MatchError(context.Canceled)) - close(returned) - }() - Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(returned).Should(BeClosed()) - Expect(context.Cause(conn.Context())).To(MatchError(context.Canceled)) - }) - It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() { unpacker := NewMockUnpacker(mockCtrl) conn.handshakeConfirmed = true @@ -581,7 +559,7 @@ var _ = Describe("Connection", func() { runConn() cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne) Expect(err).ToNot(HaveOccurred()) @@ -687,7 +665,7 @@ var _ = Describe("Connection", func() { Version: conn.version, Token: []byte("foobar"), }}, make([]byte, 16) /* Retry integrity tag */) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -697,7 +675,7 @@ var _ = Describe("Connection", func() { protocol.ArbitraryLenConnectionID(destConnID.Bytes()), conn.config.Versions, ) - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(receivedPacket{ data: b, buffer: getPacketBuffer(), @@ -713,7 +691,7 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen2, }, nil) p.data[0] ^= 0x40 // unset the QUIC bit - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -725,12 +703,12 @@ var _ = Describe("Connection", func() { }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnsupportedVersion) + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("drops packets with an unsupported version", func() { - origSupportedVersions := make([]protocol.VersionNumber, len(protocol.SupportedVersions)) + origSupportedVersions := make([]protocol.Version, len(protocol.SupportedVersions)) copy(origSupportedVersions, protocol.SupportedVersions) defer func() { protocol.SupportedVersions = origSupportedVersions @@ -746,7 +724,7 @@ var _ = Describe("Connection", func() { }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedVersion) + tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -812,7 +790,7 @@ var _ = Describe("Connection", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) conn.receivedPacketHandler = rph - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) Expect(conn.handlePacketImpl(packet)).To(BeFalse()) }) @@ -838,7 +816,7 @@ var _ = Describe("Connection", func() { PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) - tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropPayloadDecryptError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) conn.handlePacket(p) Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return @@ -956,7 +934,7 @@ var _ = Describe("Connection", func() { runErr <- conn.run() }() expectReplaceWithClosed() - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, gomock.Any(), logging.PacketDropHeaderParseError) conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return @@ -964,7 +942,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1029,7 +1007,7 @@ var _ = Describe("Connection", func() { Expect(conn.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. p2 := getLongHeaderPacket(hdr2, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) + tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) Expect(conn.handlePacketImpl(p2)).To(BeFalse()) }) @@ -1088,7 +1066,7 @@ var _ = Describe("Connection", func() { It("cuts packets to the right length", func() { hdrLen, packet := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1105,7 +1083,7 @@ var _ = Describe("Connection", func() { It("handles coalesced packets", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) packet1.ecn = protocol.ECT1 - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1117,7 +1095,7 @@ var _ = Describe("Connection", func() { }, nil }) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1144,7 +1122,7 @@ var _ = Describe("Connection", func() { hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) gomock.InOrder( unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable), - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1170,7 +1148,7 @@ var _ = Describe("Connection", func() { wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) Expect(srcConnID).ToNot(Equal(wrongConnID)) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1184,7 +1162,7 @@ var _ = Describe("Connection", func() { // don't EXPECT any more calls to unpacker.UnpackLongHeader() gomock.InOrder( tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any(), gomock.Any()), - tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), ) packet1.data = append(packet1.data, packet2.data...) Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) @@ -1219,7 +1197,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(connDone).Should(BeClosed()) }) @@ -1270,7 +1248,6 @@ var _ = Describe("Connection", func() { sph.EXPECT().ECNMode(true).AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() - conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() }) @@ -1281,7 +1258,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) + packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { + close(done) + return nil, nil + }) runConn() conn.scheduleSending() Eventually(done).Should(BeClosed()) @@ -1323,7 +1303,7 @@ var _ = Describe("Connection", func() { Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { var sendMode ackhandler.SendMode - var getFrame func(protocol.ByteCount, protocol.VersionNumber) wire.Frame + var getFrame func(protocol.ByteCount, protocol.Version) wire.Frame BeforeEach(func() { //nolint:exhaustive @@ -1420,7 +1400,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1809,7 +1789,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1915,7 +1895,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) error { close(sent); return nil }) go func() { defer GinkgoRecover() @@ -1935,7 +1915,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1944,6 +1924,7 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) + tracer.EXPECT().ChoseALPN(gomock.Any()) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).AnyTimes() @@ -1952,6 +1933,7 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Retire(clientDestConnID) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() + cryptoSetup.EXPECT().ConnectionState() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) Expect(conn.handleHandshakeComplete()).To(Succeed()) @@ -1964,8 +1946,10 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Retire(clientDestConnID) conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial) tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) + tracer.EXPECT().ChoseALPN(gomock.Any()) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) + cryptoSetup.EXPECT().ConnectionState() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) @@ -2019,10 +2003,11 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ChoseALPN(gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *packetBuffer, _ protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *packetBuffer, _ protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, v) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -2039,6 +2024,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() + cryptoSetup.EXPECT().ConnectionState() mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleHandshakeComplete()).To(Succeed()) conn.run() @@ -2051,7 +2037,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2065,13 +2051,11 @@ var _ = Describe("Connection", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.destroy(nil) Eventually(done).Should(BeClosed()) Expect(context.Cause(conn.Context())).To(MatchError(context.Canceled)) }) @@ -2157,7 +2141,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2165,7 +2149,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(5 * time.Second) conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { close(sent) return nil, nil }) @@ -2178,7 +2162,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(time.Hour) conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { close(sent) return nil, nil }) @@ -2211,7 +2195,7 @@ var _ = Describe("Connection", func() { pto := conn.rttStats.PTO(true) conn.lastPacketReceivedTime = time.Now() sentPingTimeChan := make(chan time.Time) - packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { sentPingTimeChan <- time.Now() return nil, nil }) @@ -2282,7 +2266,7 @@ var _ = Describe("Connection", func() { conn.config.HandshakeIdleTimeout = 9999 * time.Second conn.config.MaxIdleTimeout = 9999 * time.Second conn.lastPacketReceivedTime = time.Now().Add(-time.Minute) - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.Version) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeZero()) return &coalescedPacket{buffer: getPacketBuffer()}, nil }) @@ -2308,7 +2292,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2349,6 +2333,7 @@ var _ = Describe("Connection", func() { ) cryptoSetup.EXPECT().Close() gomock.InOrder( + tracer.EXPECT().ChoseALPN(gomock.Any()), tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { Expect(e).To(MatchError(&IdleTimeoutError{})) @@ -2364,6 +2349,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) + cryptoSetup.EXPECT().ConnectionState() Expect(conn.handleHandshakeComplete()).To(Succeed()) err := conn.run() nerr, ok := err.(net.Error) @@ -2393,7 +2379,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2428,7 +2414,7 @@ var _ = Describe("Connection", func() { It("stores up to MaxConnUnprocessedPackets packets", func() { done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.ByteCount, logging.PacketDropReason) { + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) { close(done) }) // Nothing here should block @@ -2573,7 +2559,7 @@ var _ = Describe("Client Connection", func() { It("changes the connection ID when receiving the first packet from the server", func() { unpacker := NewMockUnpacker(mockCtrl) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { return &unpackedPacket{ encryptionLevel: protocol.Encryption1RTT, hdr: &wire.ExtendedHeader{Header: *hdr}, @@ -2582,7 +2568,10 @@ var _ = Describe("Client Connection", func() { }) conn.unpacker = unpacker done := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) { close(done) }) + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { + close(done) + return nil, nil + }) newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ @@ -2606,11 +2595,11 @@ var _ = Describe("Client Connection", func() { // make sure the go routine returns packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) time.Sleep(200 * time.Millisecond) }) @@ -2626,7 +2615,7 @@ var _ = Describe("Client Connection", func() { }) Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) // now receive a packet with the original source connection ID - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.Version) (*unpackedPacket, error) { return &unpackedPacket{ hdr: &wire.ExtendedHeader{Header: *hdr}, data: []byte{0}, @@ -2672,9 +2661,10 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() running := make(chan struct{}) - cryptoSetup.EXPECT().StartHandshake().Do(func() { + cryptoSetup.EXPECT().StartHandshake().Do(func() error { close(running) conn.closeLocal(errors.New("early error")) + return nil }) cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().Close() @@ -2705,7 +2695,7 @@ var _ = Describe("Client Connection", func() { }) Context("handling Version Negotiation", func() { - getVNP := func(versions ...protocol.VersionNumber) receivedPacket { + getVNP := func(versions ...protocol.Version) receivedPacket { b := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), protocol.ArbitraryLenConnectionID(destConnID.Bytes()), @@ -2723,7 +2713,7 @@ var _ = Describe("Client Connection", func() { conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) - conn.config.Versions = []protocol.VersionNumber{1234, 4321} + conn.config.Versions = []protocol.Version{1234, 4321} errChan := make(chan error, 1) start := make(chan struct{}) go func() { @@ -2738,8 +2728,8 @@ var _ = Describe("Client Connection", func() { connRunner.EXPECT().Remove(srcConnID) tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.VersionNumber) { Expect(versions).To(And( - ContainElement(protocol.VersionNumber(4321)), - ContainElement(protocol.VersionNumber(1337)), + ContainElement(protocol.Version(4321)), + ContainElement(protocol.Version(1337)), )) }) cryptoSetup.EXPECT().Close() @@ -2750,7 +2740,7 @@ var _ = Describe("Client Connection", func() { Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{})) recreateErr := err.(*errCloseForRecreating) - Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321))) + Expect(recreateErr.nextVersion).To(Equal(protocol.Version(4321))) Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) }) @@ -2784,14 +2774,14 @@ var _ = Describe("Client Connection", func() { It("ignores Version Negotiation packets that offer the current version", func() { p := getVNP(conn.version) - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores unparseable Version Negotiation packets", func() { p := getVNP(conn.version) p.data = p.data[:len(p.data)-2] - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) }) @@ -2840,14 +2830,14 @@ var _ = Describe("Client Connection", func() { It("ignores Retry packets after receiving a regular packet", func() { conn.receivedFirstPacket = true p := getPacket(retryHdr, getRetryTag(retryHdr)) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores Retry packets if the server didn't change the connection ID", func() { retryHdr.SrcConnectionID = destConnID p := getPacket(retryHdr, getRetryTag(retryHdr)) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -2855,7 +2845,7 @@ var _ = Describe("Client Connection", func() { tag := getRetryTag(retryHdr) tag[0]++ p := getPacket(retryHdr, tag) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropPayloadDecryptError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) }) @@ -2890,6 +2880,8 @@ var _ = Describe("Client Connection", func() { defer GinkgoRecover() Expect(conn.handleHandshakeComplete()).To(Succeed()) }) + tracer.EXPECT().ChoseALPN(gomock.Any()).MaxTimes(1) + cryptoSetup.EXPECT().ConnectionState().MaxTimes(1) errChan <- conn.run() close(errChan) }() @@ -2897,7 +2889,7 @@ var _ = Describe("Client Connection", func() { expectClose := func(applicationClose, errored bool) { if !closed && !errored { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()) if applicationClose { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } else { @@ -2914,7 +2906,7 @@ var _ = Describe("Client Connection", func() { } AfterEach(func() { - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(errChan).Should(BeClosed()) }) @@ -2924,8 +2916,8 @@ var _ = Describe("Client Connection", func() { OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, PreferredAddress: &wire.PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42), + IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13), ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }, @@ -2958,7 +2950,7 @@ var _ = Describe("Client Connection", func() { Eventually(processed).Should(BeClosed()) // close first expectClose(true, false) - conn.shutdown() + conn.CloseWithError(0, "") // then check. Avoids race condition when accessing idleTimeout Expect(conn.idleTimeout).To(Equal(18 * time.Second)) }) @@ -3153,7 +3145,7 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) }) @@ -3168,7 +3160,7 @@ var _ = Describe("Client Connection", func() { PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) - tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, p.Size(), gomock.Any()) + tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), gomock.Any()) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -3176,7 +3168,7 @@ var _ = Describe("Client Connection", func() { // the connection to immediately break down It("fails on Initial-level ACK for unsent packet", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, destConnID, []wire.Frame{ack}, protocol.PerspectiveServer, conn.version) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) @@ -3188,7 +3180,7 @@ var _ = Describe("Client Connection", func() { IsApplicationError: true, ReasonPhrase: "mitm attacker", } - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, destConnID, []wire.Frame{connCloseFrame}, protocol.PerspectiveServer, conn.version) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) }) @@ -3206,8 +3198,8 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ReceivedRetry(gomock.Any()) conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version))) - initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.version, conn.connIDManager.Get(), nil) - tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.connIDManager.Get(), nil, protocol.PerspectiveServer, conn.version) + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) }) diff --git a/crypto_stream.go b/crypto_stream.go index 4ad097ce..c0f26d43 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -6,7 +6,6 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" - "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" ) @@ -56,7 +55,7 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // could e.g. be a retransmission return nil } - s.highestOffset = utils.Max(s.highestOffset, highestOffset) + s.highestOffset = max(s.highestOffset, highestOffset) if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { return err } @@ -99,7 +98,7 @@ func (s *cryptoStreamImpl) HasData() bool { func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { f := &wire.CryptoFrame{Offset: s.writeOffset} - n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) f.Data = s.writeBuf[:n] s.writeBuf = s.writeBuf[n:] s.writeOffset += n diff --git a/datagram_queue.go b/datagram_queue.go index fbd6a251..979f3b86 100644 --- a/datagram_queue.go +++ b/datagram_queue.go @@ -4,14 +4,20 @@ import ( "context" "sync" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/internal/utils/ringbuffer" "github.com/refraction-networking/uquic/internal/wire" ) +const ( + maxDatagramSendQueueLen = 32 + maxDatagramRcvQueueLen = 128 +) + type datagramQueue struct { - sendQueue chan *wire.DatagramFrame - nextFrame *wire.DatagramFrame + sendMx sync.Mutex + sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame] + sent chan struct{} // used to notify Add that a datagram was dequeued rcvMx sync.Mutex rcvQueue [][]byte @@ -22,60 +28,65 @@ type datagramQueue struct { hasData func() - dequeued chan struct{} - logger utils.Logger } func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { return &datagramQueue{ - hasData: hasData, - sendQueue: make(chan *wire.DatagramFrame, 1), - rcvd: make(chan struct{}, 1), - dequeued: make(chan struct{}), - closed: make(chan struct{}), - logger: logger, + hasData: hasData, + rcvd: make(chan struct{}, 1), + sent: make(chan struct{}, 1), + closed: make(chan struct{}), + logger: logger, } } -// AddAndWait queues a new DATAGRAM frame for sending. -// It blocks until the frame has been dequeued. -func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { - select { - case h.sendQueue <- f: - h.hasData() - case <-h.closed: - return h.closeErr - } +// Add queues a new DATAGRAM frame for sending. +// Up to 32 DATAGRAM frames will be queued. +// Once that limit is reached, Add blocks until the queue size has reduced. +func (h *datagramQueue) Add(f *wire.DatagramFrame) error { + h.sendMx.Lock() - select { - case <-h.dequeued: - return nil - case <-h.closed: - return h.closeErr + for { + if h.sendQueue.Len() < maxDatagramSendQueueLen { + h.sendQueue.PushBack(f) + h.sendMx.Unlock() + h.hasData() + return nil + } + select { + case <-h.sent: // drain the queue so we don't loop immediately + default: + } + h.sendMx.Unlock() + select { + case <-h.closed: + return h.closeErr + case <-h.sent: + } + h.sendMx.Lock() } } // Peek gets the next DATAGRAM frame for sending. // If actually sent out, Pop needs to be called before the next call to Peek. func (h *datagramQueue) Peek() *wire.DatagramFrame { - if h.nextFrame != nil { - return h.nextFrame - } - select { - case h.nextFrame = <-h.sendQueue: - h.dequeued <- struct{}{} - default: + h.sendMx.Lock() + defer h.sendMx.Unlock() + if h.sendQueue.Empty() { return nil } - return h.nextFrame + return h.sendQueue.PeekFront() } func (h *datagramQueue) Pop() { - if h.nextFrame == nil { - panic("datagramQueue BUG: Pop called for nil frame") + h.sendMx.Lock() + defer h.sendMx.Unlock() + _ = h.sendQueue.PopFront() + select { + case h.sent <- struct{}{}: + default: } - h.nextFrame = nil } // HandleDatagramFrame handles a received DATAGRAM frame. @@ -84,7 +95,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { copy(data, f.Data) var queued bool h.rcvMx.Lock() - if len(h.rcvQueue) < protocol.DatagramRcvQueueLen { + if len(h.rcvQueue) < maxDatagramRcvQueueLen { h.rcvQueue = append(h.rcvQueue, data) queued = true select { @@ -94,7 +105,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { } h.rcvMx.Unlock() if !queued && h.logger.Debug() { - h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data)) } } diff --git a/datagram_queue_test.go b/datagram_queue_test.go index f581f429..3d8efb4e 100644 --- a/datagram_queue_test.go +++ b/datagram_queue_test.go @@ -3,6 +3,7 @@ package quic import ( "context" "errors" + "time" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" @@ -26,55 +27,65 @@ var _ = Describe("Datagram Queue", func() { }) It("queues a datagram", func() { - done := make(chan struct{}) frame := &wire.DatagramFrame{Data: []byte("foobar")} - go func() { - defer GinkgoRecover() - defer close(done) - Expect(queue.AddAndWait(frame)).To(Succeed()) - }() - - Eventually(queued).Should(HaveLen(1)) - Consistently(done).ShouldNot(BeClosed()) + Expect(queue.Add(frame)).To(Succeed()) + Expect(queued).To(HaveLen(1)) f := queue.Peek() Expect(f.Data).To(Equal([]byte("foobar"))) - Eventually(done).Should(BeClosed()) queue.Pop() Expect(queue.Peek()).To(BeNil()) }) - It("returns the same datagram multiple times, when Pop isn't called", func() { - sent := make(chan struct{}, 1) + It("blocks when the maximum number of datagrams have been queued", func() { + for i := 0; i < maxDatagramSendQueueLen; i++ { + Expect(queue.Add(&wire.DatagramFrame{Data: []byte{0}})).To(Succeed()) + } + errChan := make(chan error, 1) go func() { defer GinkgoRecover() - Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foo")})).To(Succeed()) - sent <- struct{}{} - Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("bar")})).To(Succeed()) - sent <- struct{}{} + errChan <- queue.Add(&wire.DatagramFrame{Data: []byte("foobar")}) }() + Consistently(errChan, 50*time.Millisecond).ShouldNot(Receive()) + Expect(queue.Peek()).ToNot(BeNil()) + Consistently(errChan, 50*time.Millisecond).ShouldNot(Receive()) + queue.Pop() + Eventually(errChan).Should(Receive(BeNil())) + for i := 1; i < maxDatagramSendQueueLen; i++ { + queue.Pop() + } + f := queue.Peek() + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) - Eventually(queued).Should(HaveLen(1)) + It("returns the same datagram multiple times, when Pop isn't called", func() { + Expect(queue.Add(&wire.DatagramFrame{Data: []byte("foo")})).To(Succeed()) + Expect(queue.Add(&wire.DatagramFrame{Data: []byte("bar")})).To(Succeed()) + + Eventually(queued).Should(HaveLen(2)) f := queue.Peek() Expect(f.Data).To(Equal([]byte("foo"))) - Eventually(sent).Should(Receive()) Expect(queue.Peek()).To(Equal(f)) Expect(queue.Peek()).To(Equal(f)) queue.Pop() - Eventually(func() *wire.DatagramFrame { f = queue.Peek(); return f }).ShouldNot(BeNil()) f = queue.Peek() + Expect(f).ToNot(BeNil()) Expect(f.Data).To(Equal([]byte("bar"))) }) It("closes", func() { + for i := 0; i < maxDatagramSendQueueLen; i++ { + Expect(queue.Add(&wire.DatagramFrame{Data: []byte("foo")})).To(Succeed()) + } errChan := make(chan error, 1) go func() { defer GinkgoRecover() - errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) + errChan <- queue.Add(&wire.DatagramFrame{Data: []byte("foo")}) }() - - Consistently(errChan).ShouldNot(Receive()) - queue.CloseWithError(errors.New("test error")) - Eventually(errChan).Should(Receive(MatchError("test error"))) + Consistently(errChan, 25*time.Millisecond).ShouldNot(Receive()) + testErr := errors.New("test error") + queue.CloseWithError(testErr) + Eventually(errChan).Should(Receive(MatchError(testErr))) }) }) diff --git a/errors.go b/errors.go index fba74e86..c4ffcf87 100644 --- a/errors.go +++ b/errors.go @@ -61,3 +61,15 @@ func (e *StreamError) Error() string { } return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode) } + +// DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent. +type DatagramTooLargeError struct { + PeerMaxDatagramFrameSize int64 +} + +func (e *DatagramTooLargeError) Is(target error) bool { + _, ok := target.(*DatagramTooLargeError) + return ok +} + +func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" } diff --git a/example/client/main.go b/example/client/main.go index 18a09a19..50a100f8 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -1,12 +1,9 @@ package main import ( - "bufio" "bytes" - "context" "crypto/x509" "flag" - "fmt" "io" "log" "net/http" @@ -18,29 +15,16 @@ import ( quic "github.com/refraction-networking/uquic" "github.com/refraction-networking/uquic/http3" "github.com/refraction-networking/uquic/internal/testdata" - "github.com/refraction-networking/uquic/internal/utils" - "github.com/refraction-networking/uquic/logging" "github.com/refraction-networking/uquic/qlog" ) func main() { - verbose := flag.Bool("v", false, "verbose") quiet := flag.Bool("q", false, "don't print the data") keyLogFile := flag.String("keylog", "", "key log file") insecure := flag.Bool("insecure", false, "skip certificate verification") - enableQlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") flag.Parse() urls := flag.Args() - logger := utils.DefaultLogger - - if *verbose { - logger.SetLogLevel(utils.LogLevelDebug) - } else { - logger.SetLogLevel(utils.LogLevelInfo) - } - logger.SetLogTimeFormat("") - var keyLog io.Writer if len(*keyLogFile) > 0 { f, err := os.Create(*keyLogFile) @@ -57,25 +41,15 @@ func main() { } testdata.AddRootCA(pool) - var qconf quic.Config - if *enableQlog { - qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - filename := fmt.Sprintf("client_%x.qlog", connID) - f, err := os.Create(filename) - if err != nil { - log.Fatal(err) - } - log.Printf("Creating qlog file %s.\n", filename) - return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) - } - } roundTripper := &http3.RoundTripper{ TLSClientConfig: &tls.Config{ RootCAs: pool, InsecureSkipVerify: *insecure, KeyLogWriter: keyLog, }, - QuicConfig: &qconf, + QuicConfig: &quic.Config{ + Tracer: qlog.DefaultTracer, + }, } defer roundTripper.Close() hclient := &http.Client{ @@ -85,13 +59,13 @@ func main() { var wg sync.WaitGroup wg.Add(len(urls)) for _, addr := range urls { - logger.Infof("GET %s", addr) + log.Printf("GET %s", addr) go func(addr string) { rsp, err := hclient.Get(addr) if err != nil { log.Fatal(err) } - logger.Infof("Got response for %s: %#v", addr, rsp) + log.Printf("Got response for %s: %#v", addr, rsp) body := &bytes.Buffer{} _, err = io.Copy(body, rsp.Body) @@ -99,10 +73,9 @@ func main() { log.Fatal(err) } if *quiet { - logger.Infof("Response Body: %d bytes", body.Len()) + log.Printf("Response Body: %d bytes", body.Len()) } else { - logger.Infof("Response Body:") - logger.Infof("%s", body.Bytes()) + log.Printf("Response Body (%d bytes):\n%s", body.Len(), body.Bytes()) } wg.Done() }(addr) diff --git a/example/echo/echo.go b/example/echo/echo.go index e8962a77..77603a0c 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -37,14 +37,19 @@ func echoServer() error { if err != nil { return err } + defer listener.Close() + conn, err := listener.Accept(context.Background()) if err != nil { return err } + stream, err := conn.AcceptStream(context.Background()) if err != nil { panic(err) } + defer stream.Close() + // Echo through the loggingWriter _, err = io.Copy(loggingWriter{stream}, stream) return err @@ -59,11 +64,13 @@ func clientMain() error { if err != nil { return err } + defer conn.CloseWithError(0, "") stream, err := conn.OpenStreamSync(context.Background()) if err != nil { return err } + defer stream.Close() fmt.Printf("Client: Sending '%s'\n", message) _, err = stream.Write([]byte(message)) diff --git a/example/main.go b/example/main.go index a77af339..ce3ef51f 100644 --- a/example/main.go +++ b/example/main.go @@ -1,8 +1,6 @@ package main import ( - "bufio" - "context" "crypto/md5" "errors" "flag" @@ -11,7 +9,6 @@ import ( "log" "mime/multipart" "net/http" - "os" "strconv" "strings" "sync" @@ -21,8 +18,6 @@ import ( quic "github.com/refraction-networking/uquic" "github.com/refraction-networking/uquic/http3" "github.com/refraction-networking/uquic/internal/testdata" - "github.com/refraction-networking/uquic/internal/utils" - "github.com/refraction-networking/uquic/logging" "github.com/refraction-networking/uquic/qlog" ) @@ -121,7 +116,7 @@ func setupHandler(www string) http.Handler { err = errors.New("couldn't get uploaded file size") } } - utils.DefaultLogger.Infof("Error receiving upload: %#v", err) + log.Printf("Error receiving upload: %#v", err) } io.WriteString(w, `