mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
simplify connection handling for the client
This commit is contained in:
parent
7a0ef5f867
commit
5544f0f9a1
5 changed files with 34 additions and 70 deletions
51
client.go
51
client.go
|
@ -12,10 +12,7 @@ import (
|
|||
)
|
||||
|
||||
type client struct {
|
||||
sconn sendConn
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
sendConn sendConn
|
||||
|
||||
use0RTT bool
|
||||
|
||||
|
@ -133,17 +130,15 @@ func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn boo
|
|||
|
||||
func dial(
|
||||
ctx context.Context,
|
||||
conn net.PacketConn,
|
||||
conn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
packetHandlers packetHandlerManager,
|
||||
addr net.Addr,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, addr, connIDGenerator, config, tlsConf, onClose, use0RTT, createdPacketConn)
|
||||
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -158,7 +153,7 @@ func dial(
|
|||
)
|
||||
}
|
||||
if c.tracer != nil {
|
||||
c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
|
@ -166,16 +161,7 @@ func dial(
|
|||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
|
@ -191,27 +177,26 @@ func newClient(
|
|||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
connIDGenerator: connIDGenerator,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sconn: newSendPconn(pconn, remoteAddr),
|
||||
createdPacketConn: createdPacketConn,
|
||||
use0RTT: use0RTT,
|
||||
onClose: onClose,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
connIDGenerator: connIDGenerator,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sendConn: sendConn,
|
||||
use0RTT: use0RTT,
|
||||
onClose: onClose,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
c.conn = newClientConnection(
|
||||
c.sconn,
|
||||
c.sendConn,
|
||||
c.packetHandlers,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
|
|
|
@ -26,8 +26,7 @@ func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
|
|||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
cl *client
|
||||
packetConn *MockPacketConn
|
||||
addr net.Addr
|
||||
packetConn *MockSendConn
|
||||
connID protocol.ConnectionID
|
||||
origMultiplexer multiplexer
|
||||
tlsConf *tls.Config
|
||||
|
@ -62,14 +61,14 @@ var _ = Describe("Client", func() {
|
|||
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
|
||||
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
Eventually(areConnsRunning).Should(BeFalse())
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
packetConn = NewMockPacketConn(mockCtrl)
|
||||
packetConn = NewMockSendConn(mockCtrl)
|
||||
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
cl = &client{
|
||||
srcConnID: connID,
|
||||
destConnID: connID,
|
||||
version: protocol.Version1,
|
||||
sconn: newSendPconn(packetConn, addr),
|
||||
sendConn: packetConn,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
|
@ -134,7 +133,7 @@ var _ = Describe("Client", func() {
|
|||
conn.EXPECT().HandshakeComplete().Return(c)
|
||||
return conn
|
||||
}
|
||||
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, false)
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -171,7 +170,7 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, false)
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -207,7 +206,7 @@ var _ = Describe("Client", func() {
|
|||
return conn
|
||||
}
|
||||
var closed bool
|
||||
cl, err := newClient(packetConn, addr, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, false)
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
|
@ -266,7 +265,6 @@ var _ = Describe("Client", func() {
|
|||
It("creates new connections with the right parameters", func() {
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
c := make(chan struct{})
|
||||
var cconn sendConn
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
done := make(chan struct{})
|
||||
|
@ -286,7 +284,6 @@ var _ = Describe("Client", func() {
|
|||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) quicConn {
|
||||
cconn = connP
|
||||
version = versionP
|
||||
conf = configP
|
||||
close(c)
|
||||
|
@ -298,15 +295,16 @@ var _ = Describe("Client", func() {
|
|||
close(done)
|
||||
return conn
|
||||
}
|
||||
packetConn := NewMockPacketConn(mockCtrl)
|
||||
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
|
||||
<-done
|
||||
return 0, nil, errors.New("closed")
|
||||
})
|
||||
packetConn.EXPECT().LocalAddr()
|
||||
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
|
23
send_conn.go
23
send_conn.go
|
@ -22,7 +22,7 @@ type sconn struct {
|
|||
|
||||
var _ sendConn = &sconn{}
|
||||
|
||||
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn {
|
||||
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn {
|
||||
return &sconn{
|
||||
rawConn: c,
|
||||
remoteAddr: remote,
|
||||
|
@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr {
|
|||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
type spconn struct {
|
||||
net.PacketConn
|
||||
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ sendConn = &spconn{}
|
||||
|
||||
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
|
||||
return &spconn{PacketConn: c, remoteAddr: remote}
|
||||
}
|
||||
|
||||
func (c *spconn) Write(p []byte) error {
|
||||
_, err := c.WriteTo(p, c.remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *spconn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
|
|
@ -17,7 +17,9 @@ var _ = Describe("Connection (for sending packets)", func() {
|
|||
BeforeEach(func() {
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
packetConn = NewMockPacketConn(mockCtrl)
|
||||
c = newSendPconn(packetConn, addr)
|
||||
rawConn, err := wrapConn(packetConn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c = newSendConn(rawConn, addr, nil)
|
||||
})
|
||||
|
||||
It("writes", func() {
|
||||
|
|
|
@ -148,7 +148,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
|
|||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, false, t.createdConn)
|
||||
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
|
||||
}
|
||||
|
||||
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
|
||||
|
@ -164,7 +164,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, t.Conn, t.connIDGenerator, t.handlerMap, addr, tlsConf, conf, onClose, true, t.createdConn)
|
||||
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
|
||||
}
|
||||
|
||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue