mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
reject packets with the wrong connection ID in the client
This commit is contained in:
parent
5a94b2034c
commit
84f3ec5343
2 changed files with 47 additions and 19 deletions
|
@ -38,6 +38,8 @@ type client struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
// make it possible to mock connection ID generation in the tests
|
||||||
|
generateConnectionID = utils.GenerateConnectionID
|
||||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,7 +84,7 @@ func DialNonFWSecure(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (NonFWSession, error) {
|
) (NonFWSession, error) {
|
||||||
connID, err := utils.GenerateConnectionID()
|
connID, err := generateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -257,6 +259,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||||
if hdr.TruncateConnectionID && !c.config.RequestConnectionIDTruncation {
|
if hdr.TruncateConnectionID && !c.config.RequestConnectionIDTruncation {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// reject packets with the wrong connection ID
|
||||||
|
if !hdr.TruncateConnectionID && hdr.ConnectionID != c.connectionID {
|
||||||
|
return
|
||||||
|
}
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
|
|
|
@ -27,6 +27,18 @@ var _ = Describe("Client", func() {
|
||||||
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
|
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// generate a packet sent by the server that accepts the QUIC version suggested by the client
|
||||||
|
acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
err := (&wire.PublicHeader{
|
||||||
|
ConnectionID: connID,
|
||||||
|
PacketNumber: 1,
|
||||||
|
PacketNumberLen: 1,
|
||||||
|
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
originalClientSessConstructor = newClientSession
|
originalClientSessConstructor = newClientSession
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
Eventually(areSessionsRunning).Should(BeFalse())
|
||||||
|
@ -62,7 +74,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Dialing", func() {
|
Context("Dialing", func() {
|
||||||
var acceptClientVersionPacket []byte
|
var origGenerateConnectionID func() (protocol.ConnectionID, error)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
newClientSession = func(
|
newClientSession = func(
|
||||||
|
@ -75,22 +87,20 @@ var _ = Describe("Client", func() {
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
) (packetHandler, <-chan handshakeEvent, error) {
|
) (packetHandler, <-chan handshakeEvent, error) {
|
||||||
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
|
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
return sess, sess.handshakeChan, nil
|
return sess, sess.handshakeChan, nil
|
||||||
}
|
}
|
||||||
// accept the QUIC version suggested by the client
|
origGenerateConnectionID = generateConnectionID
|
||||||
b := &bytes.Buffer{}
|
generateConnectionID = func() (protocol.ConnectionID, error) {
|
||||||
err := (&wire.PublicHeader{
|
return cl.connectionID, nil
|
||||||
ConnectionID: 0x1337,
|
}
|
||||||
PacketNumber: 1,
|
})
|
||||||
PacketNumberLen: 1,
|
|
||||||
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
|
AfterEach(func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
generateConnectionID = origGenerateConnectionID
|
||||||
acceptClientVersionPacket = b.Bytes()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("dials non-forward-secure", func(done Done) {
|
It("dials non-forward-secure", func(done Done) {
|
||||||
packetConn.dataToRead = acceptClientVersionPacket
|
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -118,7 +128,7 @@ var _ = Describe("Client", func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr)
|
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -138,7 +148,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("Dial only returns after the handshake is complete", func(done Done) {
|
It("Dial only returns after the handshake is complete", func(done Done) {
|
||||||
packetConn.dataToRead = acceptClientVersionPacket
|
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -219,7 +229,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
|
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
|
||||||
testErr := errors.New("early handshake error")
|
testErr := errors.New("early handshake error")
|
||||||
packetConn.dataToRead = acceptClientVersionPacket
|
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||||
|
@ -231,7 +241,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
|
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
|
||||||
testErr := errors.New("late handshake error")
|
testErr := errors.New("late handshake error")
|
||||||
packetConn.dataToRead = acceptClientVersionPacket
|
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||||
go func() {
|
go func() {
|
||||||
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||||
Expect(dialErr).To(MatchError(testErr))
|
Expect(dialErr).To(MatchError(testErr))
|
||||||
|
@ -307,7 +317,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(newVersion).ToNot(Equal(cl.version))
|
Expect(newVersion).ToNot(Equal(cl.version))
|
||||||
Expect(config.Versions).To(ContainElement(newVersion))
|
Expect(config.Versions).To(ContainElement(newVersion))
|
||||||
packetConn.dataToRead = wire.ComposeVersionNegotiation(
|
packetConn.dataToRead = wire.ComposeVersionNegotiation(
|
||||||
0x1337,
|
cl.connectionID,
|
||||||
[]protocol.VersionNumber{newVersion},
|
[]protocol.VersionNumber{newVersion},
|
||||||
)
|
)
|
||||||
sessionChan := make(chan *mockSession)
|
sessionChan := make(chan *mockSession)
|
||||||
|
@ -324,7 +334,7 @@ var _ = Describe("Client", func() {
|
||||||
negotiatedVersions = negotiatedVersionsP
|
negotiatedVersions = negotiatedVersionsP
|
||||||
// make the server accept the new version
|
// make the server accept the new version
|
||||||
if len(negotiatedVersionsP) > 0 {
|
if len(negotiatedVersionsP) > 0 {
|
||||||
packetConn.dataToRead = acceptClientVersionPacket
|
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
|
||||||
}
|
}
|
||||||
sess := &mockSession{
|
sess := &mockSession{
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
|
@ -440,6 +450,18 @@ var _ = Describe("Client", func() {
|
||||||
Expect(sess.closed).To(BeFalse())
|
Expect(sess.closed).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("ignores packets with the wrong connection ID", func() {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
(&wire.PublicHeader{
|
||||||
|
ConnectionID: cl.connectionID + 1,
|
||||||
|
PacketNumber: 1,
|
||||||
|
PacketNumberLen: 1,
|
||||||
|
}).Write(buf, protocol.VersionWhatever, protocol.PerspectiveServer)
|
||||||
|
cl.handlePacket(addr, buf.Bytes())
|
||||||
|
Expect(sess.packetCount).To(BeZero())
|
||||||
|
Expect(sess.closed).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
It("creates new sessions with the right parameters", func(done Done) {
|
It("creates new sessions with the right parameters", func(done Done) {
|
||||||
c := make(chan struct{})
|
c := make(chan struct{})
|
||||||
var cconn connection
|
var cconn connection
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue