simplify connection handling for the client

This commit is contained in:
Marten Seemann 2023-04-25 13:48:57 +02:00
parent 7a0ef5f867
commit 5544f0f9a1
5 changed files with 34 additions and 70 deletions

View file

@ -12,10 +12,7 @@ import (
)
type client struct {
sconn sendConn
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
sendConn sendConn
use0RTT bool
@ -133,17 +130,15 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
func dial(
ctx context.Context,
conn net.PacketConn,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
addr net.Addr,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
createdPacketConn bool,
) (quicConn, error) {
c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn)
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
@ -158,7 +153,7 @@ func dial(
)
}
if c.tracer != nil {
c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
if err := c.dial(ctx); err != nil {
return nil, err
@ -166,16 +161,7 @@ func dial(
return c.conn, nil
}
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
connIDGenerator ConnectionIDGenerator,
config *Config,
tlsConf *tls.Config,
onClose func(),
use0RTT bool,
createdPacketConn bool,
) (*client, error) {
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
@ -191,27 +177,26 @@ func newClient(
return nil, err
}
c := &client{
connIDGenerator: connIDGenerator,
srcConnID: srcConnID,
destConnID: destConnID,
sconn: newSendPconn(pconn, remoteAddr),
createdPacketConn: createdPacketConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
connIDGenerator: connIDGenerator,
srcConnID: srcConnID,
destConnID: destConnID,
sendConn: sendConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
c.sconn,
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,

View file

@ -26,8 +26,7 @@ func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
var _ = Describe("Client", func() {
var (
cl *client
packetConn *MockPacketConn
addr net.Addr
packetConn *MockSendConn
connID protocol.ConnectionID
origMultiplexer multiplexer
tlsConf *tls.Config
@ -62,14 +61,14 @@ var _ = Describe("Client", func() {
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
Eventually(areConnsRunning).Should(BeFalse())
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl)
packetConn = NewMockSendConn(mockCtrl)
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
cl = &client{
srcConnID: connID,
destConnID: connID,
version: protocol.Version1,
sconn: newSendPconn(packetConn, addr),
sendConn: packetConn,
tracer: tracer,
logger: utils.DefaultLogger,
}
@ -134,7 +133,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false)
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -171,7 +170,7 @@ var _ = Describe("Client", func() {
return conn
}
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false)
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -207,7 +206,7 @@ var _ = Describe("Client", func() {
return conn
}
var closed bool
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false)
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
@ -266,7 +265,6 @@ var _ = Describe("Client", func() {
It("creates new connections with the right parameters", func() {
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
var conf *Config
done := make(chan struct{})
@ -286,7 +284,6 @@ var _ = Describe("Client", func() {
_ utils.Logger,
versionP protocol.VersionNumber,
) quicConn {
cconn = connP
version = versionP
conf = configP
close(c)
@ -298,15 +295,16 @@ var _ = Describe("Client", func() {
close(done)
return conn
}
packetConn := NewMockPacketConn(mockCtrl)
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
<-done
return 0, nil, errors.New("closed")
})
packetConn.EXPECT().LocalAddr()
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})

View file

@ -22,7 +22,7 @@ type sconn struct {
var _ sendConn = &sconn{}
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn {
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn {
return &sconn{
rawConn: c,
remoteAddr: remote,
@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr {
}
return addr
}
type spconn struct {
net.PacketConn
remoteAddr net.Addr
}
var _ sendConn = &spconn{}
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
return &spconn{PacketConn: c, remoteAddr: remote}
}
func (c *spconn) Write(p []byte) error {
_, err := c.WriteTo(p, c.remoteAddr)
return err
}
func (c *spconn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View file

@ -17,7 +17,9 @@ var _ = Describe("Connection (for sending packets)", func() {
BeforeEach(func() {
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl)
c = newSendPconn(packetConn, addr)
rawConn, err := wrapConn(packetConn)
Expect(err).ToNot(HaveOccurred())
c = newSendConn(rawConn, addr, nil)
})
It("writes", func() {

View file

@ -148,7 +148,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
@ -164,7 +164,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {