mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
reject packets with the wrong connection ID in the client
This commit is contained in:
parent
5a94b2034c
commit
84f3ec5343
2 changed files with 47 additions and 19 deletions
|
@ -38,6 +38,8 @@ type client struct {
|
|||
}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = utils.GenerateConnectionID
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
)
|
||||
|
||||
|
@ -82,7 +84,7 @@ func DialNonFWSecure(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
connID, err := utils.GenerateConnectionID()
|
||||
connID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -257,6 +259,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
if hdr.TruncateConnectionID && !c.config.RequestConnectionIDTruncation {
|
||||
return
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.TruncateConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
|
||||
c.mutex.Lock()
|
||||
|
|
|
@ -27,6 +27,18 @@ var _ = Describe("Client", func() {
|
|||
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
|
||||
)
|
||||
|
||||
// generate a packet sent by the server that accepts the QUIC version suggested by the client
|
||||
acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
err := (&wire.PublicHeader{
|
||||
ConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
originalClientSessConstructor = newClientSession
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
|
@ -62,7 +74,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
Context("Dialing", func() {
|
||||
var acceptClientVersionPacket []byte
|
||||
var origGenerateConnectionID func() (protocol.ConnectionID, error)
|
||||
|
||||
BeforeEach(func() {
|
||||
newClientSession = func(
|
||||
|
@ -75,22 +87,20 @@ var _ = Describe("Client", func() {
|
|||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
|
||||
// Expect(err).ToNot(HaveOccurred())
|
||||
return sess, sess.handshakeChan, nil
|
||||
}
|
||||
// accept the QUIC version suggested by the client
|
||||
b := &bytes.Buffer{}
|
||||
err := (&wire.PublicHeader{
|
||||
ConnectionID: 0x1337,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
acceptClientVersionPacket = b.Bytes()
|
||||
origGenerateConnectionID = generateConnectionID
|
||||
generateConnectionID = func() (protocol.ConnectionID, error) {
|
||||
return cl.connectionID, nil
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
generateConnectionID = origGenerateConnectionID
|
||||
})
|
||||
|
||||
It("dials non-forward-secure", func(done Done) {
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -118,7 +128,7 @@ var _ = Describe("Client", func() {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr)
|
||||
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}()
|
||||
|
@ -138,7 +148,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("Dial only returns after the handshake is complete", func(done Done) {
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -219,7 +229,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
|
||||
testErr := errors.New("early handshake error")
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
|
@ -231,7 +241,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
|
||||
testErr := errors.New("late handshake error")
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
go func() {
|
||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(dialErr).To(MatchError(testErr))
|
||||
|
@ -307,7 +317,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
packetConn.dataToRead = wire.ComposeVersionNegotiation(
|
||||
0x1337,
|
||||
cl.connectionID,
|
||||
[]protocol.VersionNumber{newVersion},
|
||||
)
|
||||
sessionChan := make(chan *mockSession)
|
||||
|
@ -324,7 +334,7 @@ var _ = Describe("Client", func() {
|
|||
negotiatedVersions = negotiatedVersionsP
|
||||
// make the server accept the new version
|
||||
if len(negotiatedVersionsP) > 0 {
|
||||
packetConn.dataToRead = acceptClientVersionPacket
|
||||
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
|
||||
}
|
||||
sess := &mockSession{
|
||||
connectionID: connectionID,
|
||||
|
@ -440,6 +450,18 @@ var _ = Describe("Client", func() {
|
|||
Expect(sess.closed).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores packets with the wrong connection ID", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
(&wire.PublicHeader{
|
||||
ConnectionID: cl.connectionID + 1,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
}).Write(buf, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||
cl.handlePacket(addr, buf.Bytes())
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
Expect(sess.closed).To(BeFalse())
|
||||
})
|
||||
|
||||
It("creates new sessions with the right parameters", func(done Done) {
|
||||
c := make(chan struct{})
|
||||
var cconn connection
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue