From 5544f0f9a1ef627fde4c6bc92d0f928d53f7edd6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 25 Apr 2023 13:48:57 +0200 Subject: [PATCH] simplify connection handling for the client --- client.go | 51 +++++++++++++++++------------------------------ client_test.go | 22 ++++++++++---------- send_conn.go | 23 +-------------------- send_conn_test.go | 4 +++- transport.go | 4 ++-- 5 files changed, 34 insertions(+), 70 deletions(-) diff --git a/client.go b/client.go index ed6ccfb8..e1f03c87 100644 --- a/client.go +++ b/client.go @@ -12,10 +12,7 @@ import ( ) type client struct { - sconn sendConn - // If the client is created with DialAddr, we create a packet conn. - // If it is started with Dial, we take a packet conn as a parameter. - createdPacketConn bool + sendConn sendConn use0RTT bool @@ -133,17 +130,15 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo func dial( ctx context.Context, - conn net.PacketConn, + conn sendConn, connIDGenerator ConnectionIDGenerator, packetHandlers packetHandlerManager, - addr net.Addr, tlsConf *tls.Config, config *Config, onClose func(), use0RTT bool, - createdPacketConn bool, ) (quicConn, error) { - c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn) + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) if err != nil { return nil, err } @@ -158,7 +153,7 @@ func dial( ) } if c.tracer != nil { - c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) + c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { return nil, err @@ -166,16 +161,7 @@ func dial( return c.conn, nil } -func newClient( - pconn net.PacketConn, - remoteAddr net.Addr, - connIDGenerator ConnectionIDGenerator, - config *Config, - tlsConf *tls.Config, - onClose func(), - use0RTT bool, - createdPacketConn bool, -) (*client, error) { +func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { @@ -191,27 +177,26 @@ func newClient( return nil, err } c := &client{ - connIDGenerator: connIDGenerator, - srcConnID: srcConnID, - destConnID: destConnID, - sconn: newSendPconn(pconn, remoteAddr), - createdPacketConn: createdPacketConn, - use0RTT: use0RTT, - onClose: onClose, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), + connIDGenerator: connIDGenerator, + srcConnID: srcConnID, + destConnID: destConnID, + sendConn: sendConn, + use0RTT: use0RTT, + onClose: onClose, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), } return c, nil } func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.conn = newClientConnection( - c.sconn, + c.sendConn, c.packetHandlers, c.destConnID, c.srcConnID, diff --git a/client_test.go b/client_test.go index ce53ef4b..a3c31e1d 100644 --- a/client_test.go +++ b/client_test.go @@ -26,8 +26,7 @@ func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil } var _ = Describe("Client", func() { var ( cl *client - packetConn *MockPacketConn - addr net.Addr + packetConn *MockSendConn connID protocol.ConnectionID origMultiplexer multiplexer tlsConf *tls.Config @@ -62,14 +61,14 @@ var _ = Describe("Client", func() { tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}} Eventually(areConnsRunning).Should(BeFalse()) - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = NewMockPacketConn(mockCtrl) + packetConn = NewMockSendConn(mockCtrl) packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() cl = &client{ srcConnID: connID, destConnID: connID, version: protocol.Version1, - sconn: newSendPconn(packetConn, addr), + sendConn: packetConn, tracer: tracer, logger: utils.DefaultLogger, } @@ -134,7 +133,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -171,7 +170,7 @@ var _ = Describe("Client", func() { return conn } - cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -207,7 +206,7 @@ var _ = Describe("Client", func() { return conn } var closed bool - cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -266,7 +265,6 @@ var _ = Describe("Client", func() { It("creates new connections with the right parameters", func() { config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} c := make(chan struct{}) - var cconn sendConn var version protocol.VersionNumber var conf *Config done := make(chan struct{}) @@ -286,7 +284,6 @@ var _ = Describe("Client", func() { _ utils.Logger, versionP protocol.VersionNumber, ) quicConn { - cconn = connP version = versionP conf = configP close(c) @@ -298,15 +295,16 @@ var _ = Describe("Client", func() { close(done) return conn } + packetConn := NewMockPacketConn(mockCtrl) packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) { <-done return 0, nil, errors.New("closed") }) + packetConn.EXPECT().LocalAddr() packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes() - _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) + _, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) - Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn)) Expect(version).To(Equal(config.Versions[0])) Expect(conf.Versions).To(Equal(config.Versions)) }) diff --git a/send_conn.go b/send_conn.go index c53ebdfa..0ac27037 100644 --- a/send_conn.go +++ b/send_conn.go @@ -22,7 +22,7 @@ type sconn struct { var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { +func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { return &sconn{ rawConn: c, remoteAddr: remote, @@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr { } return addr } - -type spconn struct { - net.PacketConn - - remoteAddr net.Addr -} - -var _ sendConn = &spconn{} - -func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { - return &spconn{PacketConn: c, remoteAddr: remote} -} - -func (c *spconn) Write(p []byte) error { - _, err := c.WriteTo(p, c.remoteAddr) - return err -} - -func (c *spconn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/send_conn_test.go b/send_conn_test.go index 6c36c1b6..2da3e3ab 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -17,7 +17,9 @@ var _ = Describe("Connection (for sending packets)", func() { BeforeEach(func() { addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = NewMockPacketConn(mockCtrl) - c = newSendPconn(packetConn, addr) + rawConn, err := wrapConn(packetConn) + Expect(err).ToNot(HaveOccurred()) + c = newSendConn(rawConn, addr, nil) }) It("writes", func() { diff --git a/transport.go b/transport.go index ebb144f4..baeb592b 100644 --- a/transport.go +++ b/transport.go @@ -148,7 +148,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn) + return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. @@ -164,7 +164,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn) + return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) } func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {