mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 05:37:36 +03:00
also use the multiplexer when dialing an address
This commit is contained in:
parent
d24fb4b86d
commit
f806f9146b
2 changed files with 49 additions and 120 deletions
74
client.go
74
client.go
|
@ -7,9 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -21,6 +19,7 @@ import (
|
|||
type client struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
pconn net.PacketConn
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
|
@ -83,15 +82,7 @@ func DialAddrContext(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go c.listen()
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session, nil
|
||||
return DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
|
@ -129,6 +120,11 @@ func DialContext(
|
|||
if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config.RequestConnectionIDOmission {
|
||||
if err := multiplexer.AddHandler(pconn, protocol.ConnectionID{}, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -168,6 +164,7 @@ func newClient(
|
|||
onClose = closeCallback
|
||||
}
|
||||
c := &client{
|
||||
pconn: pconn,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
|
@ -350,58 +347,6 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Listen listens on the underlying connection and passes packets on for handling.
|
||||
// It returns when the connection is closed.
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
for {
|
||||
var n int
|
||||
var addr net.Addr
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.mutex.Lock()
|
||||
if c.session != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
c.handleRead(addr, data[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleRead(remoteAddr net.Addr, packet []byte) {
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, c.config.ConnectionIDLength)
|
||||
// drop the packet if we can't parse the header
|
||||
if err != nil {
|
||||
c.logger.Errorf("error parsing invariant header: %s", err)
|
||||
return
|
||||
}
|
||||
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, c.version)
|
||||
if err != nil {
|
||||
c.logger.Errorf("error parsing header: %s", err)
|
||||
return
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
packetData := packet[len(packet)-r.Len():]
|
||||
c.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handlePacket(p *receivedPacket) {
|
||||
if err := c.handlePacketImpl(p); err != nil {
|
||||
c.logger.Errorf("error handling packet: %s", err)
|
||||
|
@ -524,6 +469,9 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
c.generateConnectionIDs()
|
||||
if err := getClientMultiplexer().AddHandler(c.pconn, c.srcConnID, c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
|
|
|
@ -45,6 +45,17 @@ var _ = Describe("Client", func() {
|
|||
return b.Bytes()
|
||||
}
|
||||
|
||||
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
|
||||
return &receivedPacket{
|
||||
rcvTime: time.Now(),
|
||||
header: &wire.Header{
|
||||
IsVersionNegotiation: true,
|
||||
DestConnectionID: connID,
|
||||
SupportedVersions: versions,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||
originalClientSessConstructor = newClientSession
|
||||
|
@ -95,6 +106,10 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("resolves the address", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
if os.Getenv("APPVEYOR") == "True" {
|
||||
Skip("This test is flaky on AppVeyor.")
|
||||
}
|
||||
|
@ -122,6 +137,10 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("uses the tls.Config.ServerName as the hostname, if present", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
hostnameChan := make(chan string, 1)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
|
@ -466,6 +485,8 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("changes the version after receiving a version negotiation packet", func() {
|
||||
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
version1 := protocol.Version39
|
||||
version2 := protocol.Version39 + 1
|
||||
Expect(version2.UsesTLS()).To(BeFalse())
|
||||
|
@ -502,11 +523,13 @@ var _ = Describe("Client", func() {
|
|||
close(dialed)
|
||||
}()
|
||||
Eventually(sessionChan).Should(HaveLen(1))
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2}))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2}))
|
||||
Eventually(sessionChan).Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("only accepts one version negotiation packet", func() {
|
||||
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
|
||||
version1 := protocol.Version39
|
||||
version2 := protocol.Version39 + 1
|
||||
version3 := protocol.Version39 + 2
|
||||
|
@ -545,10 +568,10 @@ var _ = Describe("Client", func() {
|
|||
close(dialed)
|
||||
}()
|
||||
Eventually(sessionChan).Should(HaveLen(1))
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2}))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2}))
|
||||
Eventually(sessionChan).Should(BeEmpty())
|
||||
Expect(cl.version).To(Equal(version2))
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3}))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version3}))
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(version2))
|
||||
})
|
||||
|
@ -558,7 +581,7 @@ var _ = Describe("Client", func() {
|
|||
sess.EXPECT().Close(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1}))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1}))
|
||||
})
|
||||
|
||||
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
|
||||
|
@ -568,23 +591,24 @@ var _ = Describe("Client", func() {
|
|||
v := protocol.VersionNumber(1234)
|
||||
Expect(v).ToNot(Equal(cl.version))
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v}))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v}))
|
||||
})
|
||||
|
||||
It("changes to the version preferred by the quic.Config", func() {
|
||||
mockMultiplexer.EXPECT().AddHandler(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().Close(errCloseSessionForNewVersion)
|
||||
cl.session = sess
|
||||
config := &Config{Versions: []protocol.VersionNumber{1234, 4321}}
|
||||
cl.config = config
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234}))
|
||||
versions := []protocol.VersionNumber{1234, 4321}
|
||||
cl.config = &Config{Versions: versions}
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, versions))
|
||||
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
||||
})
|
||||
|
||||
It("drops version negotiation packets that contain the offered version", func() {
|
||||
cl.config = &Config{}
|
||||
ver := cl.version
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver}))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}))
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
})
|
||||
|
@ -595,12 +619,6 @@ var _ = Describe("Client", func() {
|
|||
Expect(cl.GetVersion()).To(Equal(cl.version))
|
||||
})
|
||||
|
||||
It("ignores packets with an invalid public header", func() {
|
||||
cl.config = &Config{}
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
cl.handleRead(addr, []byte("invalid packet"))
|
||||
})
|
||||
|
||||
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
hdr := &wire.Header{
|
||||
|
@ -829,47 +847,6 @@ var _ = Describe("Client", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
It("handles packets", func() {
|
||||
cl.config = &Config{}
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
ph := wire.Header{
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
DestConnectionID: connID,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
err := ph.Write(b, protocol.PerspectiveServer, cl.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
packetConn.dataToRead <- b.Bytes()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cl.listen()
|
||||
// it should continue listening when receiving valid packets
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
sess.EXPECT().Close(gomock.Any())
|
||||
Expect(packetConn.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the session when encountering an error while reading from the connection", func() {
|
||||
testErr := errors.New("test error")
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().Close(testErr)
|
||||
cl.session = sess
|
||||
packetConn.readErr = testErr
|
||||
cl.listen()
|
||||
})
|
||||
})
|
||||
|
||||
Context("Public Reset handling", func() {
|
||||
var (
|
||||
pr []byte
|
||||
|
@ -895,7 +872,11 @@ var _ = Describe("Client", func() {
|
|||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
|
||||
})
|
||||
cl.session = sess
|
||||
cl.handleRead(addr, wire.WritePublicReset(cl.destConnID, 1, 0))
|
||||
cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: pr[len(pr)-hdrLen:],
|
||||
})
|
||||
})
|
||||
|
||||
It("ignores Public Resets from the wrong remote address", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue