reject packets with the wrong connection ID in the client

This commit is contained in:
Marten Seemann 2017-09-20 09:27:37 +07:00
parent 5a94b2034c
commit 84f3ec5343
2 changed files with 47 additions and 19 deletions

View file

@ -38,6 +38,8 @@ type client struct {
}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = utils.GenerateConnectionID
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
)
@ -82,7 +84,7 @@ func DialNonFWSecure(
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID()
connID, err := generateConnectionID()
if err != nil {
return nil, err
}
@ -257,6 +259,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
if hdr.TruncateConnectionID && !c.config.RequestConnectionIDTruncation {
return
}
// reject packets with the wrong connection ID
if !hdr.TruncateConnectionID && hdr.ConnectionID != c.connectionID {
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock()

View file

@ -27,6 +27,18 @@ var _ = Describe("Client", func() {
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte {
b := &bytes.Buffer{}
err := (&wire.PublicHeader{
ConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: 1,
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
return b.Bytes()
}
BeforeEach(func() {
originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse())
@ -62,7 +74,7 @@ var _ = Describe("Client", func() {
})
Context("Dialing", func() {
var acceptClientVersionPacket []byte
var origGenerateConnectionID func() (protocol.ConnectionID, error)
BeforeEach(func() {
newClientSession = func(
@ -75,22 +87,20 @@ var _ = Describe("Client", func() {
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
// Expect(err).ToNot(HaveOccurred())
return sess, sess.handshakeChan, nil
}
// accept the QUIC version suggested by the client
b := &bytes.Buffer{}
err := (&wire.PublicHeader{
ConnectionID: 0x1337,
PacketNumber: 1,
PacketNumberLen: 1,
}).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
acceptClientVersionPacket = b.Bytes()
origGenerateConnectionID = generateConnectionID
generateConnectionID = func() (protocol.ConnectionID, error) {
return cl.connectionID, nil
}
})
AfterEach(func() {
generateConnectionID = origGenerateConnectionID
})
It("dials non-forward-secure", func(done Done) {
packetConn.dataToRead = acceptClientVersionPacket
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -118,7 +128,7 @@ var _ = Describe("Client", func() {
if err != nil {
return
}
_, err = server.WriteToUDP(acceptClientVersionPacket, clientAddr)
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
Expect(err).ToNot(HaveOccurred())
}
}()
@ -138,7 +148,7 @@ var _ = Describe("Client", func() {
})
It("Dial only returns after the handshake is complete", func(done Done) {
packetConn.dataToRead = acceptClientVersionPacket
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -219,7 +229,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the connection to become secure", func(done Done) {
testErr := errors.New("early handshake error")
packetConn.dataToRead = acceptClientVersionPacket
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
go func() {
defer GinkgoRecover()
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
@ -231,7 +241,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the handshake to complete", func(done Done) {
testErr := errors.New("late handshake error")
packetConn.dataToRead = acceptClientVersionPacket
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
go func() {
_, dialErr := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(dialErr).To(MatchError(testErr))
@ -307,7 +317,7 @@ var _ = Describe("Client", func() {
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
packetConn.dataToRead = wire.ComposeVersionNegotiation(
0x1337,
cl.connectionID,
[]protocol.VersionNumber{newVersion},
)
sessionChan := make(chan *mockSession)
@ -324,7 +334,7 @@ var _ = Describe("Client", func() {
negotiatedVersions = negotiatedVersionsP
// make the server accept the new version
if len(negotiatedVersionsP) > 0 {
packetConn.dataToRead = acceptClientVersionPacket
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
}
sess := &mockSession{
connectionID: connectionID,
@ -440,6 +450,18 @@ var _ = Describe("Client", func() {
Expect(sess.closed).To(BeFalse())
})
It("ignores packets with the wrong connection ID", func() {
buf := &bytes.Buffer{}
(&wire.PublicHeader{
ConnectionID: cl.connectionID + 1,
PacketNumber: 1,
PacketNumberLen: 1,
}).Write(buf, protocol.VersionWhatever, protocol.PerspectiveServer)
cl.handlePacket(addr, buf.Bytes())
Expect(sess.packetCount).To(BeZero())
Expect(sess.closed).To(BeFalse())
})
It("creates new sessions with the right parameters", func(done Done) {
c := make(chan struct{})
var cconn connection