From 8189e75be6121fdc31dc1d6085f17015e9154667 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 6 Apr 2023 18:02:51 +0800 Subject: [PATCH] implement the Transport --- client.go | 158 +++-- client_test.go | 236 ++------ integrationtests/self/cancelation_test.go | 1 + integrationtests/self/close_test.go | 2 + integrationtests/self/conn_id_test.go | 12 +- integrationtests/self/datagram_test.go | 36 +- integrationtests/self/early_data_test.go | 1 + integrationtests/self/go119_test.go | 7 +- integrationtests/self/go120_test.go | 2 +- integrationtests/self/handshake_rtt_test.go | 11 +- integrationtests/self/handshake_test.go | 8 +- integrationtests/self/http_test.go | 79 ++- integrationtests/self/multiplex_test.go | 33 +- integrationtests/self/packetization_test.go | 3 +- integrationtests/self/self_suite_test.go | 8 + integrationtests/self/stateless_reset_test.go | 28 +- integrationtests/self/stream_test.go | 3 + integrationtests/self/timeout_test.go | 1 + integrationtests/self/uni_stream_test.go | 4 + mock_multiplexer_test.go | 65 -- mock_packet_handler_manager_test.go | 53 +- mockgen.go | 3 - multiplexer.go | 69 +-- multiplexer_test.go | 69 +-- packet_handler_map.go | 266 ++------- packet_handler_map_test.go | 555 ++++++------------ quic_suite_test.go | 17 + server.go | 92 +-- server_test.go | 40 +- transport.go | 410 +++++++++++++ transport_test.go | 287 +++++++++ 31 files changed, 1309 insertions(+), 1250 deletions(-) delete mode 100644 mock_multiplexer_test.go create mode 100644 transport.go create mode 100644 transport_test.go diff --git a/client.go b/client.go index ad80d4f2..c8ea0641 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,7 @@ type client struct { use0RTT bool packetHandlers packetHandlerManager + onClose func() tlsConf *tls.Config config *Config @@ -45,32 +46,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. // It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialAddrContext(ctx, addr, tlsConf, config, false) -} - -// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { - conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) - if err != nil { - return nil, err - } - utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") - return conn, nil -} - -func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } +func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err } - return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + return dl.Dial(ctx, udpAddr, tlsConf, conf) +} + +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. +// See DialEarly for details. +func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + dl, err := setupTransport(c, tlsConf, false) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil } // Dial establishes a new QUIC connection to a server using a net.PacketConn. If @@ -78,34 +105,43 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write // packets. -// The same PacketConn can be used for multiple calls to Dial and Listen. -// QUIC connection IDs are used for demultiplexing the different connections. // The tls.Config must define an application protocol (using NextProtos). -func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialContext(ctx, pconn, addr, tlsConf, config, false, false) -} - -// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. -// The same PacketConn can be used for multiple calls to Dial and Listen, -// QUIC connection IDs are used for demultiplexing the different connections. -// The tls.Config must define an application protocol (using NextProtos). -func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { - return dialContext(ctx, pconn, addr, tlsConf, config, true, false) -} - -func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) +func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + dl, err := setupTransport(c, tlsConf, false) if err != nil { return nil, err } - c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn) + conn, err := dl.Dial(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + return &Transport{ + Conn: c, + createdConn: createdPacketConn, + isSingleUse: true, + }, nil +} + +func dial( + ctx context.Context, + conn net.PacketConn, + packetHandlers packetHandlerManager, + addr net.Addr, + tlsConf *tls.Config, + config *Config, + onClose func(), + use0RTT bool, + createdPacketConn bool, +) (quicConn, error) { + c, err := newClient(conn, addr, config, tlsConf, onClose, use0RTT, createdPacketConn) if err != nil { return nil, err } @@ -128,7 +164,7 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo return c.conn, nil } -func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) { +func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, onClose func(), use0RTT, createdPacketConn bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { @@ -149,6 +185,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon sconn: newSendPconn(pconn, remoteAddr), createdPacketConn: createdPacketConn, use0RTT: use0RTT, + onClose: onClose, tlsConf: tlsConf, config: config, version: config.Versions[0], @@ -179,13 +216,18 @@ func (c *client) dial(ctx context.Context) error { c.packetHandlers.Add(c.srcConnID, c.conn) errorChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) go func() { - err := c.conn.run() // returns as soon as the connection is closed - - if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { - c.packetHandlers.Destroy() + err := c.conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return } - errorChan <- err + if c.onClose != nil { + c.onClose() + } + errorChan <- err // returns as soon as the connection is closed }() // only set when we're using 0-RTT @@ -200,14 +242,12 @@ func (c *client) dial(ctx context.Context) error { c.conn.shutdown() return ctx.Err() case err := <-errorChan: - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - } return err + case recreateErr := <-recreateChan: + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) case <-earlyConnChan: // ready to send 0-RTT data return nil diff --git a/client_test.go b/client_test.go index dbd03cf3..ec8b83b8 100644 --- a/client_test.go +++ b/client_test.go @@ -18,13 +18,17 @@ import ( . "github.com/onsi/gomega" ) +type nullMultiplexer struct{} + +func (n nullMultiplexer) AddConn(indexableConn) {} +func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil } + var _ = Describe("Client", func() { var ( cl *client packetConn *MockPacketConn addr net.Addr connID protocol.ConnectionID - mockMultiplexer *MockMultiplexer origMultiplexer multiplexer tlsConf *tls.Config tracer *mocklogging.MockConnectionTracer @@ -53,6 +57,7 @@ var _ = Describe("Client", func() { originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl) + tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}} Eventually(areConnsRunning).Should(BeFalse()) @@ -68,10 +73,9 @@ var _ = Describe("Client", func() { logger: utils.DefaultLogger, } getMultiplexer() // make the sync.Once execute - // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer - mockMultiplexer = NewMockMultiplexer(mockCtrl) + // replace the clientMuxer. getMultiplexer will now return the nullMultiplexer origMultiplexer = connMuxer - connMuxer = mockMultiplexer + connMuxer = &nullMultiplexer{} }) AfterEach(func() { @@ -100,48 +104,14 @@ var _ = Describe("Client", func() { generateConnectionIDForInitial = origGenerateConnectionIDForInitial }) - It("resolves the address", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - remoteAddrChan := make(chan string, 1) - newClientConnection = func( - sconn sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - remoteAddrChan <- sconn.RemoteAddr().String() - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run() - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - return conn - } - _, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) - Expect(err).ToNot(HaveOccurred()) - Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) - }) - It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) run := make(chan struct{}) newClientConnection = func( _ sendConn, - runner connRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -162,18 +132,17 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := Dial(context.Background(), packetConn, addr, tlsConf, config) + cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, false, false) Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) + cl.packetHandlers = manager + Expect(cl).ToNot(BeNil()) + Expect(cl.dial(context.Background())).To(Succeed()) Eventually(run).Should(BeClosed()) }) It("returns early connections", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - readyChan := make(chan struct{}) done := make(chan struct{}) newClientConnection = func( @@ -193,29 +162,23 @@ var _ = Describe("Client", func() { ) quicConn { Expect(enable0RTT).To(BeTrue()) conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { <-done }) + conn.EXPECT().run().Do(func() { close(done) }) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().earlyConnReady().Return(readyChan) return conn } - go func() { - defer GinkgoRecover() - defer close(done) - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - }() - Consistently(done).ShouldNot(BeClosed()) - close(readyChan) + cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, nil, true, false) + Expect(err).ToNot(HaveOccurred()) + cl.packetHandlers = manager + Expect(cl).ToNot(BeNil()) + Expect(cl.dial(context.Background())).To(Succeed()) Eventually(done).Should(BeClosed()) }) It("returns an error that occurs while waiting for the handshake to complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientConnection = func( @@ -236,108 +199,16 @@ var _ = Describe("Client", func() { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run().Return(testErr) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn.EXPECT().earlyConnReady().Return(make(chan struct{})) return conn } - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) - Expect(err).To(MatchError(testErr)) - }) - - It("closes the connection when the context is canceled", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - connRunning := make(chan struct{}) - defer close(connRunning) - conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { - <-connRunning - }) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - newClientConnection = func( - _ sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - return conn - } - ctx, cancel := context.WithCancel(context.Background()) - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := Dial(ctx, packetConn, addr, tlsConf, config) - Expect(err).To(MatchError(context.Canceled)) - close(dialed) - }() - Consistently(dialed).ShouldNot(BeClosed()) - conn.EXPECT().shutdown() - cancel() - Eventually(dialed).Should(BeClosed()) - }) - - It("closes the connection when it was created by DialAddr", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - manager.EXPECT().Add(gomock.Any(), gomock.Any()) - - var sconn sendConn - run := make(chan struct{}) - connCreated := make(chan struct{}) - conn := NewMockQUICConn(mockCtrl) - newClientConnection = func( - connP sendConn, - _ connRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ protocol.PacketNumber, - _ bool, - _ bool, - _ logging.ConnectionTracer, - _ uint64, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicConn { - sconn = connP - close(connCreated) - return conn - } - conn.EXPECT().run().Do(func() { - <-run - }) - conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - - Eventually(connCreated).Should(BeClosed()) - - // check that the connection is not closed - Expect(sconn.Write([]byte("foobar"))).To(Succeed()) - - manager.EXPECT().Destroy() - close(run) - time.Sleep(50 * time.Millisecond) - - Eventually(done).Should(BeClosed()) + var closed bool + cl, err := newClient(packetConn, addr, populateClientConfig(config, true), tlsConf, func() { closed = true }, true, false) + Expect(err).ToNot(HaveOccurred()) + cl.packetHandlers = manager + Expect(cl).ToNot(BeNil()) + Expect(cl.dial(context.Background())).To(MatchError(testErr)) + Expect(closed).To(BeTrue()) }) Context("quic.Config", func() { @@ -365,12 +236,6 @@ var _ = Describe("Client", func() { Expect(c.EnableDatagrams).To(BeTrue()) }) - It("errors when the Config contains an invalid version", func() { - version := protocol.VersionNumber(0x1234) - _, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("invalid QUIC version: 0x1234")) - }) - It("disables bidirectional streams", func() { config := &Config{ MaxIncomingStreams: -1, @@ -405,15 +270,12 @@ var _ = Describe("Client", func() { }) It("creates new connections with the right parameters", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - - config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} + config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber var conf *Config + done := make(chan struct{}) newClientConnection = func( connP sendConn, _ connRunner, @@ -437,8 +299,15 @@ var _ = Describe("Client", func() { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run() conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) + conn.EXPECT().destroy(gomock.Any()) + close(done) return conn } + packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { + <-done + return 0, nil, errors.New("closed") + }) + packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) @@ -448,17 +317,12 @@ var _ = Describe("Client", func() { }) It("creates a new connections after version negotiation", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - manager.EXPECT().Add(connID, gomock.Any()).Times(2) - manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - var counter int newClientConnection = func( _ sendConn, - _ connRunner, - _ protocol.ConnectionID, + runner connRunner, _ protocol.ConnectionID, + connID protocol.ConnectionID, configP *Config, _ *tls.Config, pn protocol.PacketNumber, @@ -474,20 +338,24 @@ var _ = Describe("Client", func() { if counter == 0 { Expect(pn).To(BeZero()) Expect(hasNegotiatedVersion).To(BeFalse()) - conn.EXPECT().run().Return(&errCloseForRecreating{ - nextPacketNumber: 109, - nextVersion: 789, + conn.EXPECT().run().DoAndReturn(func() error { + runner.Remove(connID) + return &errCloseForRecreating{ + nextPacketNumber: 109, + nextVersion: 789, + } }) } else { Expect(pn).To(Equal(protocol.PacketNumber(109))) Expect(hasNegotiatedVersion).To(BeTrue()) conn.EXPECT().run() + conn.EXPECT().destroy(gomock.Any()) } counter++ return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &protocol.DefaultConnectionIDGenerator{}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -495,15 +363,3 @@ var _ = Describe("Client", func() { }) }) }) - -type mockConnIDGenerator struct { - ConnID protocol.ConnectionID -} - -func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) { - return m.ConnID, nil -} - -func (m *mockConnIDGenerator) ConnectionIDLen() int { - return m.ConnID.Len() -} diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 63c32677..5f95c0b7 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() { getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}), ) Expect(err).ToNot(HaveOccurred()) + defer server.Close() var wg sync.WaitGroup wg.Add(2 * 4 * maxIncomingStreams) diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index 31905e30..d0bcf7f0 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) + defer server.Close() var drop atomic.Bool dropped := make(chan []byte, 100) @@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") sconn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index c835a2ef..7cb1904d 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -35,7 +35,11 @@ var _ = Describe("Connection ID lengths tests", func() { randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) } runServer := func(conf *quic.Config) *quic.Listener { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) + if conf.ConnectionIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + } else { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) + } ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf) Expect(err).ToNot(HaveOccurred()) go func() { @@ -59,7 +63,11 @@ var _ = Describe("Connection ID lengths tests", func() { } runClient := func(addr net.Addr, conf *quic.Config) { - GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) + if conf.ConnectionIDGenerator != nil { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", conf.ConnectionIDGenerator.ConnectionIDLen()))) + } else { + GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) + } cl, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 552488f4..2ccbbe16 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() { const num = 100 var ( - proxy *quicproxy.QuicProxy serverConn, clientConn *net.UDPConn dropped, total int32 ) - startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) { + startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) serverConn, err = net.ListenUDP("udp", addr) @@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() { ) Expect(err).ToNot(HaveOccurred()) + accepted := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(accepted) conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) @@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() { }() serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), // drop 10% of Short Header packets sent from the server DropPacket: func(dir quicproxy.Direction, packet []byte) bool { @@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() { }, }) Expect(err).ToNot(HaveOccurred()) + return proxy.LocalPort(), func() { + Eventually(accepted).Should(BeClosed()) + proxy.Close() + ln.Close() + } } BeforeEach(func() { @@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() { Expect(err).ToNot(HaveOccurred()) }) - AfterEach(func() { - Expect(proxy.Close()).To(Succeed()) - }) - It("sends datagrams", func() { - startServerAndProxy(true, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, close := startServerAndProxy(true, true) + defer close() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( context.Background(), @@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() { for { // Close the connection if no message is received for 100 ms. timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { + fmt.Println("closing conn") conn.CloseWithError(0, "") }) if _, err := conn.ReceiveMessage(); err != nil { @@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() { BeNumerically(">", expVal*9/10), BeNumerically("<", num), )) + Eventually(conn.Context().Done).Should(BeClosed()) }) It("server can disable datagram", func() { - startServerAndProxy(false, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, close := startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( context.Background(), @@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + close() conn.CloseWithError(0, "") - <-time.After(10 * time.Millisecond) }) It("client can disable datagram", func() { - startServerAndProxy(false, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, close := startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( context.Background(), @@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.SendMessage([]byte{0})).To(HaveOccurred()) + + close() conn.CloseWithError(0, "") - <-time.After(10 * time.Millisecond) }) }) diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index 0ce09926..136c3d0b 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -24,6 +24,7 @@ var _ = Describe("early data", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/integrationtests/self/go119_test.go b/integrationtests/self/go119_test.go index cd9824dd..c676693d 100644 --- a/integrationtests/self/go119_test.go +++ b/integrationtests/self/go119_test.go @@ -8,10 +8,9 @@ import ( "time" ) -var ( - go120 = false - errNotSupported = errors.New("not supported") -) +const go120 = false + +var errNotSupported = errors.New("not supported") func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { return errNotSupported diff --git a/integrationtests/self/go120_test.go b/integrationtests/self/go120_test.go index 4ddf3c7c..88eb4a7e 100644 --- a/integrationtests/self/go120_test.go +++ b/integrationtests/self/go120_test.go @@ -7,7 +7,7 @@ import ( "time" ) -var go120 = true +const go120 = true func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { rc := http.NewResponseController(w) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index e1dff321..36ea7c78 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() - _, err = quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") expectDurationInRTTs(startTime, 2) }) @@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() - _, err = quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") expectDurationInRTTs(startTime, 1) }) @@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() - _, err = quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") expectDurationInRTTs(startTime, 2) }) @@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) @@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 3274b84c..b3a13e9d 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() { context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - nil, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptStream(context.Background()) @@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() { var ( server *quic.Listener pconn net.PacketConn + dialer *quic.Transport ) dial := func() (quic.Connection, error) { remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) raddr, err := net.ResolveUDPAddr("udp", remoteAddr) Expect(err).ToNot(HaveOccurred()) - return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil) + return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil)) } BeforeEach(func() { @@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) pconn, err = net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) + dialer = &quic.Transport{Conn: pconn} }) AfterEach(func() { Expect(server.Close()).To(Succeed()) Expect(pconn.Close()).To(Succeed()) + Expect(dialer.Close()).To(Succeed()) }) It("rejects new connection attempts if connections don't get accepted", func() { @@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() { It("uses tokens provided in NEW_TOKEN frames", func() { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) + defer server.Close() // dial the first connection and receive the token go func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 4e384d9f..918bfdf2 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() { tlsConf.NextProtos = []string{"h3"} ln, err := quic.ListenAddr("localhost:0", tlsConf, nil) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -398,57 +399,51 @@ var _ = Describe("HTTP tests", func() { Eventually(done).Should(BeClosed()) }) - It("supports read deadlines", func() { - if !go120 { - Skip("This test requires Go 1.20+") - } + if go120 { + It("supports read deadlines", func() { + mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setReadDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) - mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setReadDeadline(w, time.Now().Add(deadlineDelay)) + body, err := io.ReadAll(r.Body) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(body).To(ContainSubstring("aa")) + + w.Write([]byte("ok")) + }) + + expectedEnd := time.Now().Add(deadlineDelay) + resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(r.Body) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - Expect(body).To(ContainSubstring("aa")) - - w.Write([]byte("ok")) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(Equal("ok")) }) - expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) + It("supports write deadlines", func() { + mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(Equal("ok")) - }) + _, err = io.Copy(w, neverEnding('a')) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + }) - It("supports write deadlines", func() { - if !go120 { - Skip("This test requires Go 1.20+") - } + expectedEnd := time.Now().Add(deadlineDelay) - mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) + resp, err := client.Get("https://localhost:" + port + "/write-deadline") Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) - _, err = io.Copy(w, neverEnding('a')) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(ContainSubstring("aa")) }) - - expectedEnd := time.Now().Add(deadlineDelay) - - resp, err := client.Get("https://localhost:" + port + "/write-deadline") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(ContainSubstring("aa")) - }) + } }) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index b75d1656..dcac1b46 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() { }() } - dial := func(pconn net.PacketConn, addr net.Addr) { - conn, err := quic.Dial( + dial := func(tr *quic.Transport, addr net.Addr) { + conn, err := tr.Dial( context.Background(), - pconn, addr, getTLSClientConfig(), getQuicConfig(nil), @@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() { conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) defer conn.Close() + tr := &quic.Transport{Conn: conn} done1 := make(chan struct{}) done2 := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn, server.Addr()) + dial(tr, server.Addr()) close(done1) }() go func() { defer GinkgoRecover() - dial(conn, server.Addr()) + dial(tr, server.Addr()) close(done2) }() timeout := 30 * time.Second @@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() { conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) defer conn.Close() + tr := &quic.Transport{Conn: conn} done1 := make(chan struct{}) done2 := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn, server1.Addr()) + dial(tr, server1.Addr()) close(done1) }() go func() { defer GinkgoRecover() - dial(conn, server2.Addr()) + dial(tr, server2.Addr()) close(done2) }() timeout := 30 * time.Second @@ -135,9 +136,9 @@ var _ = Describe("Multiplexing", func() { conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) defer conn.Close() + tr := &quic.Transport{Conn: conn} - server, err := quic.Listen( - conn, + server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), ) @@ -146,7 +147,7 @@ var _ = Describe("Multiplexing", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn, server.Addr()) + dial(tr, server.Addr()) close(done) }() timeout := 30 * time.Second @@ -165,15 +166,16 @@ var _ = Describe("Multiplexing", func() { conn1, err := net.ListenUDP("udp", addr1) Expect(err).ToNot(HaveOccurred()) defer conn1.Close() + tr1 := &quic.Transport{Conn: conn1} addr2, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) conn2, err := net.ListenUDP("udp", addr2) Expect(err).ToNot(HaveOccurred()) defer conn2.Close() + tr2 := &quic.Transport{Conn: conn2} - server1, err := quic.Listen( - conn1, + server1, err := tr1.Listen( getTLSConfig(), getQuicConfig(nil), ) @@ -181,8 +183,7 @@ var _ = Describe("Multiplexing", func() { runServer(server1) defer server1.Close() - server2, err := quic.Listen( - conn2, + server2, err := tr2.Listen( getTLSConfig(), getQuicConfig(nil), ) @@ -194,12 +195,12 @@ var _ = Describe("Multiplexing", func() { done2 := make(chan struct{}) go func() { defer GinkgoRecover() - dial(conn2, server1.Addr()) + dial(tr2, server1.Addr()) close(done1) }() go func() { defer GinkgoRecover() - dial(conn1, server2.Addr()) + dial(tr1, server2.Addr()) close(done2) }() timeout := 30 * time.Second diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 1338b30c..86062bd5 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -31,8 +31,8 @@ var _ = Describe("Packetization", func() { }), ) Expect(err).ToNot(HaveOccurred()) - serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) defer server.Close() + serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: serverAddr, @@ -54,6 +54,7 @@ var _ = Describe("Packetization", func() { }), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") go func() { defer GinkgoRecover() diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 9dfafa31..966a5715 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -199,8 +199,16 @@ func areHandshakesRunning() bool { return strings.Contains(b.String(), "RunHandshake") } +func areTransportsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*Transport).listen") +} + var _ = AfterEach(func() { Expect(areHandshakesRunning()).To(BeFalse()) + Eventually(areTransportsRunning).Should(BeFalse()) + if debugLog() { logFile, err := os.Create(logFileName) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 98b21b2a..cc5afe90 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -2,7 +2,6 @@ package self_test import ( "context" - "errors" "fmt" "math/rand" "net" @@ -27,7 +26,13 @@ var _ = Describe("Stateless Resets", func() { rand.Read(statelessResetKey[:]) serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey}) - ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + c, err := net.ListenUDP("udp", nil) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{ + Conn: c, + } + defer tr.Close() + ln, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port @@ -42,7 +47,8 @@ var _ = Describe("Stateless Resets", func() { _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) <-closeServer - ln.Close() + Expect(ln.Close()).To(Succeed()) + Expect(tr.Close()).To(Succeed()) }() var drop atomic.Bool @@ -77,11 +83,14 @@ var _ = Describe("Stateless Resets", func() { close(closeServer) time.Sleep(100 * time.Millisecond) - ln2, err := quic.ListenAddr( - fmt.Sprintf("localhost:%d", serverPort), - getTLSConfig(), - serverConfig, - ) + // We need to create a new Transport here, since the old one is still sending out + // CONNECTION_CLOSE packets for (recently) closed connections). + tr2 := &quic.Transport{ + Conn: c, + StatelessResetKey: &statelessResetKey, + } + defer tr2.Close() + ln2, err := tr2.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) drop.Store(false) @@ -100,8 +109,7 @@ var _ = Describe("Stateless Resets", func() { _, serr = str.Read([]byte{0}) } Expect(serr).To(HaveOccurred()) - statelessResetErr := &quic.StatelessResetError{} - Expect(errors.As(serr, &statelessResetErr)).To(BeTrue()) + Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{})) Expect(ln2.Close()).To(Succeed()) Eventually(acceptStopped).Should(BeClosed()) }) diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 0af14b8f..332cd505 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() { ) Expect(err).ToNot(HaveOccurred()) runSendingPeer(client) + client.CloseWithError(0, "") + <-conn.Context().Done() }) It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { @@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() { runReceivingPeer(client) <-done1 <-done2 + client.CloseWithError(0, "") }) }) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 5996a534..abc05dd7 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() serverErrChan := make(chan error, 1) go func() { diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 9253b701..a2fe4e50 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() { }) It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { + done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) runSendingPeer(conn) + <-conn.Context().Done() }() client, err := quic.DialAddr( @@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() { ) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(client) + client.CloseWithError(0, "") }) It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() { diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go deleted file mode 100644 index 0383b1a4..00000000 --- a/mock_multiplexer_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: Multiplexer) - -// Package quic is a generated GoMock package. -package quic - -import ( - net "net" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - logging "github.com/quic-go/quic-go/logging" -) - -// MockMultiplexer is a mock of Multiplexer interface. -type MockMultiplexer struct { - ctrl *gomock.Controller - recorder *MockMultiplexerMockRecorder -} - -// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer. -type MockMultiplexerMockRecorder struct { - mock *MockMultiplexer -} - -// NewMockMultiplexer creates a new mock instance. -func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer { - mock := &MockMultiplexer{ctrl: ctrl} - mock.recorder = &MockMultiplexerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { - return m.recorder -} - -// AddConn mocks base method. -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(packetHandlerManager) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AddConn indicates an expected call of AddConn. -func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3) -} - -// RemoveConn mocks base method. -func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveConn", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveConn indicates an expected call of RemoveConn. -func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0) -} diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index bcd16038..25ae5420 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) } +// Close mocks base method. +func (m *MockPacketHandlerManager) Close(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close", arg0) +} + +// Close indicates an expected call of Close. +func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) +} + // CloseServer mocks base method. func (m *MockPacketHandlerManager) CloseServer() { m.ctrl.T.Helper() @@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } -// Destroy mocks base method. -func (m *MockPacketHandlerManager) Destroy() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Destroy") - ret0, _ := ret[0].(error) - return ret0 -} - -// Destroy indicates an expected call of Destroy. -func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy)) -} - // Get mocks base method. func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { m.ctrl.T.Helper() @@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) } +// GetByResetToken mocks base method. +func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByResetToken", arg0) + ret0, _ := ret[0].(packetHandler) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetByResetToken indicates an expected call of GetByResetToken. +func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) +} + // GetStatelessResetToken mocks base method. func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { m.ctrl.T.Helper() @@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) } - -// SetServer mocks base method. -func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetServer", arg0) -} - -// SetServer indicates an expected call of SetServer. -func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0) -} diff --git a/mockgen.go b/mockgen.go index 443e9c10..eb700864 100644 --- a/mockgen.go +++ b/mockgen.go @@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer" -type Multiplexer = multiplexer - // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // diff --git a/multiplexer.go b/multiplexer.go index 37d4e75c..85f7f403 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -6,7 +6,6 @@ import ( "sync" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/logging" ) var ( @@ -14,30 +13,19 @@ var ( connMuxer multiplexer ) -type indexableConn interface { - LocalAddr() net.Addr -} +type indexableConn interface{ LocalAddr() net.Addr } type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error) + AddConn(conn indexableConn) RemoveConn(indexableConn) error } -type connManager struct { - connIDLen int - statelessResetKey *StatelessResetKey - tracer logging.Tracer - manager packetHandlerManager -} - // The connMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the connection handler. type connMultiplexer struct { mutex sync.Mutex - conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests - + conns map[string] /* LocalAddr().String() */ indexableConn logger utils.Logger } @@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{} func getMultiplexer() multiplexer { connMuxerOnce.Do(func() { connMuxer = &connMultiplexer{ - conns: make(map[string]connManager), - logger: utils.DefaultLogger.WithPrefix("muxer"), - newPacketHandlerManager: newPacketHandlerMap, + conns: make(map[string]indexableConn), + logger: utils.DefaultLogger.WithPrefix("muxer"), } }) return connMuxer } -func (m *connMultiplexer) AddConn( - c net.PacketConn, - connIDLen int, - statelessResetKey *StatelessResetKey, - tracer logging.Tracer, -) (packetHandlerManager, error) { +func (m *connMultiplexer) index(addr net.Addr) string { + return addr.Network() + " " + addr.String() +} + +func (m *connMultiplexer) AddConn(c indexableConn) { m.mutex.Lock() defer m.mutex.Unlock() - addr := c.LocalAddr() - connIndex := addr.Network() + " " + addr.String() + connIndex := m.index(c.LocalAddr()) p, ok := m.conns[connIndex] - if !ok { - manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) - if err != nil { - return nil, err - } - p = connManager{ - connIDLen: connIDLen, - statelessResetKey: statelessResetKey, - manager: manager, - tracer: tracer, - } - m.conns[connIndex] = p - } else { - if p.connIDLen != connIDLen { - return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) - } - if statelessResetKey != nil && p.statelessResetKey != statelessResetKey { - return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") - } - if tracer != p.tracer { - return nil, fmt.Errorf("cannot use different tracers on the same packet conn") - } + if ok { + // Panics if we're already listening on this connection. + // This is a safeguard because we're introducing a breaking API change, see + // https://github.com/quic-go/quic-go/issues/3727 for details. + // We'll remove this at a later time, when most users of the library have made the switch. + panic("connection already exists") // TODO: write a nice message } - return p.manager, nil + m.conns[connIndex] = p } func (m *connMultiplexer) RemoveConn(c indexableConn) error { m.mutex.Lock() defer m.mutex.Unlock() - connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() + connIndex := m.index(c.LocalAddr()) if _, ok := m.conns[connIndex]; !ok { return fmt.Errorf("cannote remove connection, connection is unknown") } diff --git a/multiplexer_test.go b/multiplexer_test.go index 3730cc33..48590b28 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -3,71 +3,24 @@ package quic import ( "net" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" - - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -type testConn struct { - counter int - net.PacketConn -} - var _ = Describe("Multiplexer", func() { - It("adds a new packet conn ", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) - _, err := getMultiplexer().AddConn(conn, 8, nil, nil) - Expect(err).ToNot(HaveOccurred()) + It("adds new packet conns", func() { + conn1 := NewMockPacketConn(mockCtrl) + conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) + getMultiplexer().AddConn(conn1) + conn2 := NewMockPacketConn(mockCtrl) + conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235}) + getMultiplexer().AddConn(conn2) }) - It("recognizes when the same connection is added twice", func() { - srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'} - pconn := NewMockPacketConn(mockCtrl) - pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) - pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn := testConn{PacketConn: pconn} - tracer := mocklogging.NewMockTracer(mockCtrl) - _, err := getMultiplexer().AddConn(conn, 8, srk, tracer) - Expect(err).ToNot(HaveOccurred()) - conn.counter++ - _, err = getMultiplexer().AddConn(conn, 8, srk, tracer) - Expect(err).ToNot(HaveOccurred()) - Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) - }) - - It("errors when adding an existing conn with a different connection ID length", func() { + It("panics when the same connection is added twice", func() { conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 5, nil, nil) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 6, nil, nil) - Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) - }) - - It("errors when adding an existing conn with a different stateless rest key", func() { - srk1 := &StatelessResetKey{'f', 'o', 'o'} - srk2 := &StatelessResetKey{'b', 'a', 'r'} - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, srk1, nil) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, srk2, nil) - Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) - }) - - It("errors when adding an existing conn with different tracers", func() { - conn := NewMockPacketConn(mockCtrl) - conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) - Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) - Expect(err).To(MatchError("cannot use different tracers on the same packet conn")) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) + getMultiplexer().AddConn(conn) + Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 99e2bfb1..2a08359a 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -5,28 +5,22 @@ import ( "crypto/rand" "crypto/sha256" "errors" - "fmt" "hash" "io" - "log" "net" - "os" - "strconv" - "strings" "sync" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/logging" ) -// rawConn is a connection that allow reading of a receivedPacket. +// rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (*receivedPacket, error) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr + SetReadDeadline(time.Time) error io.Closer } @@ -36,113 +30,49 @@ type closePacket struct { info *packetInfo } -// The packetHandlerMap stores packetHandlers, identified by connection ID. -// It is used: -// * by the server to store connections -// * when multiplexing outgoing connections to store clients +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + setCloseError(error) +} + +var errListenerAlreadySet = errors.New("listener already set") + type packetHandlerMap struct { - mutex sync.Mutex - - conn rawConn - connIDLen int - - closeQueue chan closePacket - + mutex sync.Mutex handlers map[protocol.ConnectionID]packetHandler resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler - listening chan struct{} // is closed when listen returns closed bool + closeChan chan struct{} + + enqueueClosePacket func(closePacket) deleteRetiredConnsAfter time.Duration - statelessResetEnabled bool - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash + statelessResetMutex sync.Mutex + statelessResetHasher hash.Hash - tracer logging.Tracer logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { - conn, ok := c.(interface{ SetReadBuffer(int) error }) - if !ok { - return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") - } - size, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if size >= protocol.DesiredReceiveBufferSize { - logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) - return nil - } - if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { - return fmt.Errorf("failed to increase receive buffer size: %w", err) - } - newSize, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if newSize == size { - return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - if newSize < protocol.DesiredReceiveBufferSize { - return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) - return nil -} - -// only print warnings about the UDP receive buffer size once -var receiveBufferWarningOnce sync.Once - -func newPacketHandlerMap( - c net.PacketConn, - connIDLen int, - statelessResetKey *StatelessResetKey, - tracer logging.Tracer, - logger utils.Logger, -) (packetHandlerManager, error) { - if err := setReceiveBuffer(c, logger); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - receiveBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) - }) - } - } - conn, err := wrapConn(c) - if err != nil { - return nil, err - } - m := &packetHandlerMap{ - conn: conn, - connIDLen: connIDLen, - listening: make(chan struct{}), +func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { + h := &packetHandlerMap{ + closeChan: make(chan struct{}), handlers: make(map[protocol.ConnectionID]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, - closeQueue: make(chan closePacket, 4), - statelessResetEnabled: statelessResetKey != nil, - tracer: tracer, + enqueueClosePacket: enqueueClosePacket, logger: logger, } - if m.statelessResetEnabled { - m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:]) + if key != nil { + h.statelessResetHasher = hmac.New(sha256.New, key[:]) } - go m.listen() - go m.runCloseQueue() - - if logger.Debug() { - go m.logUsage() + if h.logger.Debug() { + go h.logUsage() } - return m, nil + return h } func (h *packetHandlerMap) logUsage() { @@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() { var printedZero bool for { select { - case <-h.listening: + case <-h.closeChan: return case <-ticker.C: } @@ -233,12 +163,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p if connClosePacket != nil { handler = newClosedLocalConn( func(addr net.Addr, info *packetInfo) { - select { - case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}: - default: - // Oops, we're backlogged. - // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. - } + h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) }, pers, h.logger, @@ -265,17 +190,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p }) } -func (h *packetHandlerMap) runCloseQueue() { - for { - select { - case <-h.listening: - return - case p := <-h.closeQueue: - h.conn.WritePacket(p.payload, p.addr, p.info.OOB()) - } - } -} - func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { h.mutex.Lock() h.resetTokens[token] = handler @@ -288,19 +202,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) h.mutex.Unlock() } -func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { +func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) { h.mutex.Lock() - h.server = s - h.mutex.Unlock() + defer h.mutex.Unlock() + + handler, ok := h.resetTokens[token] + return handler, ok } func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() - if h.server == nil { - h.mutex.Unlock() - return - } - h.server = nil var wg sync.WaitGroup for _, handler := range h.handlers { if handler.getPerspective() == protocol.PerspectiveServer { @@ -316,23 +227,16 @@ func (h *packetHandlerMap) CloseServer() { wg.Wait() } -// Destroy closes the underlying connection and waits until listen() has returned. -// It does not close active connections. -func (h *packetHandlerMap) Destroy() error { - if err := h.conn.Close(); err != nil { - return err - } - <-h.listening // wait until listening returns - return nil -} - -func (h *packetHandlerMap) close(e error) error { +func (h *packetHandlerMap) Close(e error) { h.mutex.Lock() + if h.closed { h.mutex.Unlock() - return nil + return } + close(h.closeChan) + var wg sync.WaitGroup for _, handler := range h.handlers { wg.Add(1) @@ -341,89 +245,14 @@ func (h *packetHandlerMap) close(e error) error { wg.Done() }(handler) } - - if h.server != nil { - h.server.setCloseError(e) - } h.closed = true h.mutex.Unlock() wg.Wait() - return getMultiplexer().RemoveConn(h.conn) -} - -func (h *packetHandlerMap) listen() { - defer close(h.listening) - for { - p, err := h.conn.ReadPacket() - //nolint:staticcheck // SA1019 ignore this! - // TODO: This code is used to ignore wsa errors on Windows. - // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. - // See https://github.com/quic-go/quic-go/issues/1737 for details. - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - h.logger.Debugf("Temporary error reading from conn: %w", err) - continue - } - if err != nil { - h.close(err) - return - } - h.handlePacket(p) - } -} - -func (h *packetHandlerMap) handlePacket(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, h.connIDLen) - if err != nil { - h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - p.buffer.MaybeRelease() - return - } - - h.mutex.Lock() - defer h.mutex.Unlock() - - if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { - return - } - if handler, ok := h.handlers[connID]; ok { - handler.handlePacket(p) - return - } - if !wire.IsLongHeaderPacket(p.data[0]) { - go h.maybeSendStatelessReset(p, connID) - return - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) - return - } - h.server.handlePacket(p) -} - -func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { - // stateless resets are always short header packets - if wire.IsLongHeaderPacket(data[0]) { - return false - } - if len(data) < 17 /* type byte + 16 bytes for the reset token */ { - return false - } - - token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) - go sess.destroy(&StatelessResetError{Token: token}) - return true - } - return false } func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { var token protocol.StatelessResetToken - if !h.statelessResetEnabled { + if h.statelessResetHasher == nil { // Return a random stateless reset token. // This token will be sent in the server's transport parameters. // By using a random token, an off-path attacker won't be able to disrupt the connection. @@ -437,24 +266,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) h.statelessResetMutex.Unlock() return token } - -func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { - defer p.buffer.Release() - if !h.statelessResetEnabled { - return - } - // Don't send a stateless reset in response to very small packets. - // This includes packets that could be stateless resets. - if len(p.data) <= protocol.MinStatelessResetSize { - return - } - token := h.GetStatelessResetToken(connID) - h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) - data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) - rand.Read(data) - data[0] = (data[0] & 0x7f) | 0x40 - data = append(data, token[:]...) - if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - h.logger.Debugf("Error sending Stateless Reset: %s", err) - } -} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 0397e3f0..e87a75f8 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -6,405 +6,188 @@ import ( "net" "time" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/logging" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Packet Handler Map", func() { - type packetToRead struct { - addr net.Addr - data []byte - err error - } - - var ( - handler *packetHandlerMap - conn *MockPacketConn - tracer *mocklogging.MockTracer - packetChan chan packetToRead - - connIDLen int - statelessResetKey *StatelessResetKey - ) - - getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { - b, err := (&wire.ExtendedHeader{ - Header: wire.Header{ - Type: t, - DestConnectionID: connID, - Length: length, - Version: protocol.Version1, - }, - PacketNumberLen: protocol.PacketNumberLen2, - }).Append(nil, protocol.Version1) - Expect(err).ToNot(HaveOccurred()) - return b - } - - getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) - } - - BeforeEach(func() { - statelessResetKey = nil - connIDLen = 0 - tracer = mocklogging.NewMockTracer(mockCtrl) - packetChan = make(chan packetToRead, 10) + It("adds and gets", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + h, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + Expect(h).To(Equal(handler)) }) - JustBeforeEach(func() { - conn = NewMockPacketConn(mockCtrl) - conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { - p, ok := <-packetChan - if !ok { - return 0, nil, errors.New("closed") + It("refused to add duplicates", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + Expect(m.Add(connID, handler)).To(BeFalse()) + }) + + It("removes", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.Remove(connID) + _, ok := m.Get(connID) + Expect(ok).To(BeFalse()) + Expect(m.Add(connID, handler)).To(BeTrue()) + }) + + It("retires", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + dur := scaleDuration(50 * time.Millisecond) + m.deleteRetiredConnsAfter = dur + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + handler := NewMockPacketHandler(mockCtrl) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.Retire(connID) + _, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + time.Sleep(dur) + Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) + }) + + It("adds newly to-be-constructed handlers", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + var called bool + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + Expect(m.AddWithConnID(connID1, connID2, func() packetHandler { + called = true + return NewMockPacketHandler(mockCtrl) + })).To(BeTrue()) + Expect(called).To(BeTrue()) + Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() packetHandler { + Fail("didn't expect the constructor to be executed") + return nil + })).To(BeFalse()) + }) + + It("adds, gets and removes reset tokens", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + handler := NewMockPacketHandler(mockCtrl) + m.AddResetToken(token, handler) + h, ok := m.GetByResetToken(token) + Expect(ok).To(BeTrue()) + Expect(h).To(Equal(h)) + m.RemoveResetToken(token) + _, ok = m.GetByResetToken(token) + Expect(ok).To(BeFalse()) + }) + + It("generates stateless reset token, if no key is set", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + b := make([]byte, 8) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + token := m.GetStatelessResetToken(connID) + for i := 0; i < 1000; i++ { + to := m.GetStatelessResetToken(connID) + Expect(to).ToNot(Equal(token)) + token = to + } + }) + + It("generates stateless reset token, if a key is set", func() { + var key StatelessResetKey + rand.Read(key[:]) + m := newPacketHandlerMap(&key, nil, utils.DefaultLogger) + b := make([]byte, 8) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + token := m.GetStatelessResetToken(connID) + Expect(token).ToNot(BeZero()) + Expect(m.GetStatelessResetToken(connID)).To(Equal(token)) + // generate a new connection ID + rand.Read(b) + connID2 := protocol.ParseConnectionID(b) + Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token)) + }) + + It("replaces locally closed connections", func() { + var closePackets []closePacket + m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger) + dur := scaleDuration(50 * time.Millisecond) + m.deleteRetiredConnsAfter = dur + + handler := NewMockPacketHandler(mockCtrl) + connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar")) + h, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + Expect(h).ToNot(Equal(handler)) + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + h.handlePacket(&receivedPacket{remoteAddr: addr}) + Expect(closePackets).To(HaveLen(1)) + Expect(closePackets[0].addr).To(Equal(addr)) + Expect(closePackets[0].payload).To(Equal([]byte("foobar"))) + + time.Sleep(dur) + Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) + }) + + It("replaces remote closed connections", func() { + var closePackets []closePacket + m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger) + dur := scaleDuration(50 * time.Millisecond) + m.deleteRetiredConnsAfter = dur + + handler := NewMockPacketHandler(mockCtrl) + connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + Expect(m.Add(connID, handler)).To(BeTrue()) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil) + h, ok := m.Get(connID) + Expect(ok).To(BeTrue()) + Expect(h).ToNot(Equal(handler)) + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + h.handlePacket(&receivedPacket{remoteAddr: addr}) + Expect(closePackets).To(BeEmpty()) + + time.Sleep(dur) + Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) + }) + + It("closes the server", func() { + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + for i := 0; i < 10; i++ { + conn := NewMockPacketHandler(mockCtrl) + if i%2 == 0 { + conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) + } else { + conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) + conn.EXPECT().shutdown() } - return copy(b, p.data), p.addr, p.err - }).AnyTimes() - phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) - Expect(err).ToNot(HaveOccurred()) - handler = phm.(*packetHandlerMap) + b := make([]byte, 12) + rand.Read(b) + m.Add(protocol.ParseConnectionID(b), conn) + } + m.CloseServer() }) It("closes", func() { - getMultiplexer() // make the sync.Once execute - // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer - mockMultiplexer := NewMockMultiplexer(mockCtrl) - origMultiplexer := connMuxer - connMuxer = mockMultiplexer - - defer func() { - connMuxer = origMultiplexer - }() - - testErr := errors.New("test error ") - conn1 := NewMockPacketHandler(mockCtrl) - conn1.EXPECT().destroy(testErr) - conn2 := NewMockPacketHandler(mockCtrl) - conn2.EXPECT().destroy(testErr) - handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1) - handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2) - mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) - handler.close(testErr) - close(packetChan) - Eventually(handler.listening).Should(BeClosed()) - }) - - Context("other operations", func() { - AfterEach(func() { - // delete connections and the server before closing - // They might be mock implementations, and we'd have to register the expected calls before otherwise. - handler.mutex.Lock() - for connID := range handler.handlers { - delete(handler.handlers, connID) - } - handler.server = nil - handler.mutex.Unlock() - conn.EXPECT().Close().MaxTimes(1) - close(packetChan) - handler.Destroy() - Eventually(handler.listening).Should(BeClosed()) - }) - - Context("handling packets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - It("handles packets for different packet handlers on the same packet conn", func() { - connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) - packetHandler1 := NewMockPacketHandler(mockCtrl) - packetHandler2 := NewMockPacketHandler(mockCtrl) - handledPacket1 := make(chan struct{}) - handledPacket2 := make(chan struct{}) - packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID1)) - close(handledPacket1) - }) - packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID2)) - close(handledPacket2) - }) - handler.Add(connID1, packetHandler1) - handler.Add(connID2, packetHandler2) - packetChan <- packetToRead{data: getPacket(connID1)} - packetChan <- packetToRead{data: getPacket(connID2)} - - Eventually(handledPacket1).Should(BeClosed()) - Eventually(handledPacket2).Should(BeClosed()) - }) - - It("drops unparseable packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: []byte{0, 1, 2, 3}, - }) - }) - - It("deletes removed connections immediately", func() { - handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - handler.Add(connID, NewMockPacketHandler(mockCtrl)) - handler.Remove(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("deletes retired connection entries after a wait time", func() { - handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - conn := NewMockPacketHandler(mockCtrl) - handler.Add(connID, conn) - handler.Retire(connID) - time.Sleep(scaleDuration(30 * time.Millisecond)) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("passes packets arriving late for closed connections to that connection", func() { - handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - packetHandler := NewMockPacketHandler(mockCtrl) - handled := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - close(handled) - }) - handler.Add(connID, packetHandler) - handler.Retire(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - Eventually(handled).Should(BeClosed()) - }) - - It("drops packets for unknown receivers", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - }) - - It("closes the packet handlers when reading from the conn fails", func() { - done := make(chan struct{}) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { - Expect(e).To(HaveOccurred()) - close(done) - }) - handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) - packetChan <- packetToRead{err: errors.New("read failed")} - Eventually(done).Should(BeClosed()) - }) - - It("continues listening for temporary errors", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) - err := deadlineError{} - Expect(err.Temporary()).To(BeTrue()) - packetChan <- packetToRead{err: err} - // don't EXPECT any calls to packetHandler.destroy - time.Sleep(50 * time.Millisecond) - }) - - It("says if a connection ID is already taken", func() { - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) - }) - - It("says if a connection ID is already taken, for AddWithConnID", func() { - clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) - Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) - }) - }) - - Context("running a server", func() { - It("adds a server", func() { - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - cid, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(cid).To(Equal(connID)) - }) - handler.SetServer(server) - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("closes all server connections", func() { - handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) - clientConn := NewMockPacketHandler(mockCtrl) - clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - serverConn := NewMockPacketHandler(mockCtrl) - serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - serverConn.EXPECT().shutdown() - - handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn) - handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn) - handler.CloseServer() - }) - - It("stops handling packets with unknown connection IDs after the server is closed", func() { - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - handler.CloseServer() - handler.handlePacket(&receivedPacket{data: p}) - }) - }) - - Context("stateless resets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - Context("handling", func() { - It("handles stateless resets", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - destroyed := make(chan struct{}) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - defer close(destroyed) - Expect(err).To(HaveOccurred()) - var resetErr *StatelessResetError - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.Token).To(Equal(token)) - }) - packetChan <- packetToRead{data: packet} - Eventually(destroyed).Should(BeClosed()) - }) - - It("handles stateless resets for 0-length connection IDs", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - destroyed := make(chan struct{}) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - var resetErr *StatelessResetError - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.Token).To(Equal(token)) - close(destroyed) - }) - packetChan <- packetToRead{data: packet} - Eventually(destroyed).Should(BeClosed()) - }) - - It("removes reset tokens", func() { - connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(connID, packetHandler) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) - handler.RemoveResetToken(token) - // don't EXPECT any call to packetHandler.destroy() - packetHandler.EXPECT().handlePacket(gomock.Any()) - p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) - p = append(p, make([]byte, 50)...) - p = append(p, token[:]...) - - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("ignores packets too small to contain a stateless reset", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - done := make(chan struct{}) - // don't EXPECT any calls here, but register the closing of the done channel - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { - close(done) - }).AnyTimes() - packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)} - Consistently(done).ShouldNot(BeClosed()) - }) - }) - - Context("generating", func() { - BeforeEach(func() { - var key StatelessResetKey - rand.Read(key[:]) - statelessResetKey = &key - }) - - It("generates stateless reset tokens", func() { - connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) - connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) - Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) - }) - - It("sends stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { - defer close(done) - Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet - Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) - }) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send stateless resets for small packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - }) - - Context("if no key is configured", func() { - It("doesn't send stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, - }) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - }) - }) - }) + m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) + testErr := errors.New("shutdown") + for i := 0; i < 10; i++ { + conn := NewMockPacketHandler(mockCtrl) + conn.EXPECT().destroy(testErr) + b := make([]byte, 12) + rand.Read(b) + m.Add(protocol.ParseConnectionID(b), conn) + } + m.Close(testErr) + // check that Close can be called multiple times + m.Close(errors.New("close")) }) }) diff --git a/quic_suite_test.go b/quic_suite_test.go index 0eb6f03c..d979d81b 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -1,8 +1,11 @@ package quic import ( + "bytes" "io" "log" + "runtime/pprof" + "strings" "sync" "testing" @@ -29,6 +32,20 @@ var _ = BeforeSuite(func() { log.SetOutput(io.Discard) }) +func areServersRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*baseServer).run") +} + +func areTransportsRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*Transport).listen") +} + var _ = AfterEach(func() { mockCtrl.Finish() + Eventually(areServersRunning).Should(BeFalse()) + Eventually(areTransportsRunning()).Should(BeFalse()) }) diff --git a/server.go b/server.go index d5bb19e6..f8c9b3cd 100644 --- a/server.go +++ b/server.go @@ -20,7 +20,7 @@ import ( ) // ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. -var ErrServerClosed = errors.New("quic: Server closed") +var ErrServerClosed = errors.New("quic: server closed") // packetHandler handles packets type packetHandler interface { @@ -30,18 +30,13 @@ type packetHandler interface { getPerspective() protocol.Perspective } -type unknownPacketHandler interface { - handlePacket(*receivedPacket) - setCloseError(error) -} - type packetHandlerManager interface { Get(protocol.ConnectionID) (packetHandler, bool) + GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool - Destroy() error - connRunner - SetServer(unknownPacketHandler) + Close(error) CloseServer() + connRunner } type quicConn interface { @@ -70,13 +65,11 @@ type baseServer struct { config *Config conn rawConn - // If the server is started with ListenAddr, we create a packet conn. - // If it is started with Listen, we take a packet conn as a parameter. - createdPacketConn bool tokenGenerator *handshake.TokenGenerator connHandler packetHandlerManager + onClose func() receivedPackets chan *receivedPacket @@ -114,8 +107,6 @@ type baseServer struct { logger utils.Logger } -var _ unknownPacketHandler = &baseServer{} - // A Listener listens for incoming QUIC connections. // It returns connections once the handshake has completed. type Listener struct { @@ -166,37 +157,36 @@ func (l *EarlyListener) Addr() net.Addr { // The tls.Config must not be nil and must contain a certificate configuration. // The quic.Config may be nil, in that case the default values will be used. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { - s, err := listenAddr(addr, tlsConf, config, false) + conn, err := listenUDP(addr) if err != nil { return nil, err } - return &Listener{baseServer: s}, nil + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).Listen(tlsConf, config) } // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { - s, err := listenAddr(addr, tlsConf, config, true) + conn, err := listenUDP(addr) if err != nil { return nil, err } - return &EarlyListener{baseServer: s}, nil + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).ListenEarly(tlsConf, config) } -func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { +func listenUDP(addr string) (*net.UDPConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - serv, err := listen(conn, tlsConf, config, acceptEarly) - if err != nil { - return nil, err - } - serv.createdPacketConn = true - return serv, nil + return net.ListenUDP("udp", udpAddr) } // Listen listens for QUIC connections on a given net.PacketConn. If the @@ -210,45 +200,23 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo // Furthermore, it must define an application control (using NextProtos). // The quic.Config may be nil, in that case the default values will be used. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { - s, err := listen(conn, tlsConf, config, false) - if err != nil { - return nil, err - } - return &Listener{baseServer: s}, nil + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.Listen(tlsConf, config) } // ListenEarly works like Listen, but it returns connections before the handshake completes. func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { - s, err := listen(conn, tlsConf, config, true) - if err != nil { - return nil, err - } - return &EarlyListener{baseServer: s}, nil + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.ListenEarly(tlsConf, config) } -func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateServerConfig(config) - - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } +func newServer(conn rawConn, connHandler packetHandlerManager, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err } - c, err := wrapConn(conn) - if err != nil { - return nil, err - } s := &baseServer{ - conn: c, + conn: conn, tlsConf: tlsConf, config: config, tokenGenerator: tokenGenerator, @@ -260,12 +228,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl newConn: newConnection, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} } go s.run() - connHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } @@ -317,18 +285,12 @@ func (s *baseServer) Close() error { if s.serverError == nil { s.serverError = ErrServerClosed } - // If the server was started with ListenAddr, we created the packet conn. - // We need to close it in order to make the go routine reading from that conn return. - createdPacketConn := s.createdPacketConn s.closed = true close(s.errorChan) s.mutex.Unlock() <-s.running - s.connHandler.CloseServer() - if createdPacketConn { - return s.connHandler.Destroy() - } + s.onClose() return nil } diff --git a/server_test.go b/server_test.go index 4108c698..7f6e49a7 100644 --- a/server_test.go +++ b/server_test.go @@ -1,15 +1,12 @@ package quic import ( - "bytes" "context" "crypto/rand" "crypto/tls" "errors" "net" "reflect" - "runtime/pprof" - "strings" "sync" "sync/atomic" "time" @@ -24,17 +21,10 @@ import ( "github.com/quic-go/quic-go/logging" "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -func areServersRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*baseServer).run") -} - var _ = Describe("Server", func() { var ( conn *MockPacketConn @@ -96,15 +86,19 @@ var _ = Describe("Server", func() { BeforeEach(func() { conn = NewMockPacketConn(mockCtrl) conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() - conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1) + wait := make(chan struct{}) + conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) { + <-wait + return 0, nil, errors.New("done") + }).MaxTimes(1) + conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) { + close(wait) + conn.EXPECT().SetReadDeadline(time.Time{}) + }).MaxTimes(1) tlsConf = testdata.GetTLSConfig() tlsConf.NextProtos = []string{"proto1"} }) - AfterEach(func() { - Eventually(areServersRunning).Should(BeFalse()) - }) - It("errors when no tls.Config is given", func() { _, err := ListenAddr("localhost:0", nil, nil) Expect(err).To(HaveOccurred()) @@ -178,6 +172,7 @@ var _ = Describe("Server", func() { Context("server accepting connections that completed the handshake", func() { var ( + ln *Listener serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer @@ -185,7 +180,8 @@ var _ = Describe("Server", func() { BeforeEach(func() { tracer = mocklogging.NewMockTracer(mockCtrl) - ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) + var err error + ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer}) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer phm = NewMockPacketHandlerManager(mockCtrl) @@ -193,8 +189,7 @@ var _ = Describe("Server", func() { }) AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) - serv.Close() + ln.Close() }) Context("handling packets", func() { @@ -753,8 +748,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - phm.EXPECT().CloseServer() - conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -968,6 +962,7 @@ var _ = Describe("Server", func() { serv.setCloseError(testErr) Eventually(done).Should(BeClosed()) + serv.onClose() // shutdown }) It("returns immediately, if an error occurred before", func() { @@ -977,6 +972,7 @@ var _ = Describe("Server", func() { _, err := serv.Accept(context.Background()) Expect(err).To(MatchError(testErr)) } + serv.onClose() // shutdown }) It("returns when the context is canceled", func() { @@ -1064,7 +1060,6 @@ var _ = Describe("Server", func() { }) AfterEach(func() { - phm.EXPECT().CloseServer().MaxTimes(1) serv.Close() }) @@ -1234,8 +1229,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - phm.EXPECT().CloseServer() - conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/transport.go b/transport.go new file mode 100644 index 00000000..2e860ee2 --- /dev/null +++ b/transport.go @@ -0,0 +1,410 @@ +package quic + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/quic-go/quic-go/internal/wire" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +type Transport struct { + // A single net.PacketConn can only be handled by one Transport. + // Bad things will happen if passed to multiple Transports. + // + // If the connection satisfies the OOBCapablePacketConn interface + // (as a net.UDPConn does), ECN and packet info support will be enabled. + // In this case, optimized syscalls might be used, skipping the + // ReadFrom and WriteTo calls to read / write packets. + Conn net.PacketConn + + // The length of the connection ID in bytes. + // It can be 0, or any value between 4 and 18. + // If unset, a 4 byte connection ID will be used. + ConnectionIDLength int + + // Use for generating new connection IDs. + // This allows the application to control of the connection IDs used, + // which allows routing / load balancing based on connection IDs. + // All Connection IDs returned by the ConnectionIDGenerator MUST + // have the same length. + ConnectionIDGenerator ConnectionIDGenerator + + // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. + StatelessResetKey *StatelessResetKey + + // A Tracer traces events that don't belong to a single QUIC connection. + Tracer logging.Tracer + + handlerMap packetHandlerManager + + mutex sync.Mutex + initOnce sync.Once + initErr error + + // Set in init. + // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. + connIDLen int + + server unknownPacketHandler + + conn rawConn + + closeQueue chan closePacket + + listening chan struct{} // is closed when listen returns + closed bool + createdConn bool + isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + + logger utils.Logger +} + +// Listen starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(conf); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, false) + if err != nil { + return nil, err + } + t.server = s + return &Listener{baseServer: s}, nil +} + +// ListenEarly starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(conf); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, tlsConf, conf, t.closeServer, true) + if err != nil { + return nil, err + } + t.server = s + return &EarlyListener{baseServer: s}, nil +} + +// Dial dials a new connection to a remote host (not using 0-RTT). +func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateClientConfig(conf, t.createdConn) + if err := t.init(conf); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) +} + +// DialEarly dials a new connection, attempting to use 0-RTT if possible. +func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateClientConfig(conf, t.createdConn) + if err := t.init(conf); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, t.Conn, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) +} + +func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetReadBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") + } + size, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if size >= protocol.DesiredReceiveBufferSize { + logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil + } + if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return fmt.Errorf("failed to increase receive buffer size: %w", err) + } + newSize, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredReceiveBufferSize { + return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) + return nil +} + +// only print warnings about the UDP receive buffer size once +var receiveBufferWarningOnce sync.Once + +func (t *Transport) init(conf *Config) error { + t.initOnce.Do(func() { + getMultiplexer().AddConn(t.Conn) + + conn, err := wrapConn(t.Conn) + if err != nil { + t.initErr = err + return + } + + t.StatelessResetKey = conf.StatelessResetKey + t.Tracer = conf.Tracer + t.ConnectionIDLength = conf.ConnectionIDLength + t.ConnectionIDGenerator = conf.ConnectionIDGenerator + + t.logger = utils.DefaultLogger // TODO: make this configurable + t.conn = conn + t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + t.listening = make(chan struct{}) + + t.closeQueue = make(chan closePacket, 4) + + if t.ConnectionIDGenerator != nil { + t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() + } else { + t.connIDLen = t.ConnectionIDLength + } + + go t.listen(conn) + go t.runCloseQueue() + }) + return t.initErr +} + +func (t *Transport) enqueueClosePacket(p closePacket) { + select { + case t.closeQueue <- p: + default: + // Oops, we're backlogged. + // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. + } +} + +func (t *Transport) runCloseQueue() { + for { + select { + case <-t.listening: + return + case p := <-t.closeQueue: + t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + } + } +} + +// Close closes the underlying connection and waits until listen has returned. +// It is invalid to start new listeners or connections after that. +func (t *Transport) Close() error { + t.close(errors.New("closing")) + if t.createdConn { + if err := t.conn.Close(); err != nil { + return err + } + } else { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + } + <-t.listening // wait until listening returns + return nil +} + +func (t *Transport) closeServer() { + t.handlerMap.CloseServer() + t.mutex.Lock() + t.server = nil + if t.isSingleUse { + t.closed = true + } + t.mutex.Unlock() + if t.createdConn { + t.Conn.Close() + } + if t.isSingleUse { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + <-t.listening // wait until listening returns + } +} + +func (t *Transport) close(e error) { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.closed { + return + } + + t.handlerMap.Close(e) + if t.server != nil { + t.server.setCloseError(e) + } + t.closed = true +} + +func (t *Transport) listen(conn rawConn) { + defer close(t.listening) + defer getMultiplexer().RemoveConn(t.Conn) + + if err := setReceiveBuffer(t.Conn, t.logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + receiveBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) + }) + } + } + + for { + p, err := conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/quic-go/quic-go/issues/1737 for details. + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + t.mutex.Lock() + closed := t.closed + t.mutex.Unlock() + if closed { + return + } + t.logger.Debugf("Temporary error reading from conn: %w", err) + continue + } + if err != nil { + t.close(err) + return + } + t.handlePacket(p) + } +} + +func (t *Transport) handlePacket(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, t.connIDLen) + if err != nil { + t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + p.buffer.MaybeRelease() + return + } + + if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } + if handler, ok := t.handlerMap.Get(connID); ok { + handler.handlePacket(p) + return + } + if !wire.IsLongHeaderPacket(p.data[0]) { + go t.maybeSendStatelessReset(p, connID) + return + } + + t.mutex.Lock() + defer t.mutex.Unlock() + if t.server == nil { // no server set + t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + t.server.handlePacket(p) +} + +func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + if t.StatelessResetKey == nil { + return + } + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + return + } + token := t.handlerMap.GetStatelessResetToken(connID) + t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + t.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} + +func (t *Transport) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if wire.IsLongHeaderPacket(data[0]) { + return false + } + if len(data) < 17 /* type byte + 16 bytes for the reset token */ { + return false + } + + token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) + if conn, ok := t.handlerMap.GetByResetToken(token); ok { + t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) + go conn.destroy(&StatelessResetError{Token: token}) + return true + } + return false +} diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 00000000..122dbba3 --- /dev/null +++ b/transport_test.go @@ -0,0 +1,287 @@ +package quic + +import ( + "bytes" + "crypto/rand" + "crypto/tls" + "errors" + "net" + "time" + + mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Transport", func() { + type packetToRead struct { + addr net.Addr + data []byte + err error + } + + getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { + b, err := (&wire.ExtendedHeader{ + Header: wire.Header{ + Type: t, + DestConnectionID: connID, + Length: length, + Version: protocol.Version1, + }, + PacketNumberLen: protocol.PacketNumberLen2, + }).Append(nil, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + return b + } + + getPacket := func(connID protocol.ConnectionID) []byte { + return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) + } + + newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn { + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { + p, ok := <-packetChan + if !ok { + return 0, nil, errors.New("closed") + } + return copy(b, p.data), p.addr, p.err + }).AnyTimes() + // for shutdown + conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() + return conn + } + + It("handles packets for different packet handlers on the same packet conn", func() { + packetChan := make(chan packetToRead) + tr := &Transport{Conn: newMockPacketConn(packetChan)} + tr.init(&Config{}) + phm := NewMockPacketHandlerManager(mockCtrl) + tr.handlerMap = phm + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) + + handled := make(chan struct{}, 2) + phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { + h := NewMockPacketHandler(mockCtrl) + h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + defer GinkgoRecover() + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID1)) + handled <- struct{}{} + }) + return h, true + }) + phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { + h := NewMockPacketHandler(mockCtrl) + h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + defer GinkgoRecover() + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID2)) + handled <- struct{}{} + }) + return h, true + }) + + packetChan <- packetToRead{data: getPacket(connID1)} + packetChan <- packetToRead{data: getPacket(connID2)} + + Eventually(handled).Should(Receive()) + Eventually(handled).Should(Receive()) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("closes listeners", func() { + packetChan := make(chan packetToRead) + tr := &Transport{Conn: newMockPacketConn(packetChan)} + defer tr.Close() + ln, err := tr.Listen(&tls.Config{}, nil) + Expect(err).ToNot(HaveOccurred()) + phm := NewMockPacketHandlerManager(mockCtrl) + tr.handlerMap = phm + + phm.EXPECT().CloseServer() + Expect(ln.Close()).To(Succeed()) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("drops unparseable packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + } + tr.init(&Config{Tracer: tracer, ConnectionIDLength: 10}) + dropped := make(chan struct{}) + tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) + packetChan <- packetToRead{ + addr: addr, + data: []byte{0, 1, 2, 3}, + } + Eventually(dropped).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("closes when reading from the conn fails", func() { + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + done := make(chan struct{}) + phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) }) + packetChan <- packetToRead{err: errors.New("read failed")} + Eventually(done).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("continues listening after temporary errors", func() { + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + tempErr := deadlineError{} + Expect(tempErr.Temporary()).To(BeTrue()) + packetChan <- packetToRead{err: tempErr} + // don't expect any calls to phm.Close + time.Sleep(50 * time.Millisecond) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("handles short header packets resets", func() { + connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + tr.init(&Config{ConnectionIDLength: connID.Len()}) + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + var token protocol.StatelessResetToken + rand.Read(token[:]) + + var b []byte + b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + b = append(b, token[:]...) + conn := NewMockPacketHandler(mockCtrl) + gomock.InOrder( + phm.EXPECT().GetByResetToken(token), + phm.EXPECT().Get(connID).Return(conn, true), + conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(Equal(b)) + Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) + }), + ) + packetChan <- packetToRead{data: b} + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("handles stateless resets", func() { + connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) + packetChan := make(chan packetToRead) + tr := Transport{Conn: newMockPacketConn(packetChan)} + tr.init(&Config{ConnectionIDLength: connID.Len()}) + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + var token protocol.StatelessResetToken + rand.Read(token[:]) + + var b []byte + b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + b = append(b, token[:]...) + conn := NewMockPacketHandler(mockCtrl) + gomock.InOrder( + phm.EXPECT().GetByResetToken(token).Return(conn, true), + conn.EXPECT().destroy(gomock.Any()).Do(func(err error) { + Expect(err).To(MatchError(&StatelessResetError{Token: token})) + }), + ) + packetChan <- packetToRead{data: b} + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) + + It("sends stateless resets", func() { + connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) + packetChan := make(chan packetToRead) + conn := newMockPacketConn(packetChan) + tr := Transport{ + Conn: conn, + } + tr.init(&Config{ConnectionIDLength: connID.Len(), StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}}) + defer tr.Close() + phm := NewMockPacketHandlerManager(mockCtrl) + tr.init(&Config{}) + tr.handlerMap = phm + + var b []byte + b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne) + Expect(err).ToNot(HaveOccurred()) + b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...) + + var token protocol.StatelessResetToken + rand.Read(token[:]) + written := make(chan struct{}) + gomock.InOrder( + phm.EXPECT().GetByResetToken(gomock.Any()), + phm.EXPECT().Get(connID), + phm.EXPECT().GetStatelessResetToken(connID).Return(token), + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) { + defer close(written) + Expect(bytes.Contains(b, token[:])).To(BeTrue()) + }), + ) + packetChan <- packetToRead{data: b} + Eventually(written).Should(BeClosed()) + + // shutdown + phm.EXPECT().Close(gomock.Any()) + close(packetChan) + tr.Close() + }) +})