diff --git a/session.go b/session.go index 88669313..4e6b5ddd 100644 --- a/session.go +++ b/session.go @@ -274,6 +274,7 @@ var newSession = func( StatelessResetToken: &statelessResetToken, OriginalDestinationConnectionID: origDestConnID, ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, } if s.qlogger != nil { s.qlogger.SentTransportParameters(params) @@ -390,6 +391,7 @@ var newClientSession = func( AckDelayExponent: protocol.AckDelayExponent, DisableActiveMigration: true, ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, } if s.qlogger != nil { s.qlogger.SentTransportParameters(params) @@ -1289,7 +1291,13 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) { s.qlogger.ReceivedTransportParameters(params) } - // check the Retry token + // check the initial_source_connection_id + if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { + s.closeLocal(qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID))) + return + } + + // check the original_destination_connection_id if s.perspective == protocol.PerspectiveClient && !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { s.closeLocal(qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID))) return diff --git a/session_test.go b/session_test.go index c940d990..ca94d37c 100644 --- a/session_test.go +++ b/session_test.go @@ -1679,7 +1679,8 @@ var _ = Describe("Session", func() { InitialMaxData: 0x5000, ActiveConnectionIDLimit: 3, // marshaling always sets it to this value - MaxUDPPayloadSize: protocol.MaxReceivePacketSize, + MaxUDPPayloadSize: protocol.MaxReceivePacketSize, + InitialSourceConnectionID: destConnID, } streamManager.EXPECT().UpdateLimits(params) packer.EXPECT().HandleTransportParameters(params) @@ -1698,7 +1699,10 @@ var _ = Describe("Session", func() { streamManager.EXPECT().UpdateLimits(gomock.Any()) packer.EXPECT().HandleTransportParameters(gomock.Any()) qlogger.EXPECT().ReceivedTransportParameters(gomock.Any()) - sess.processTransportParameters(&wire.TransportParameters{MaxIdleTimeout: t}) + sess.processTransportParameters(&wire.TransportParameters{ + MaxIdleTimeout: t, + InitialSourceConnectionID: destConnID, + }) } runSession := func() { @@ -2228,6 +2232,7 @@ var _ = Describe("Client Session", func() { It("uses the preferred_address connection ID", func() { params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, PreferredAddress: &wire.PreferredAddress{ IPv4: net.IPv4(127, 0, 0, 1), IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, @@ -2253,6 +2258,7 @@ var _ = Describe("Client Session", func() { sess.config.MaxIdleTimeout = 19 * time.Second params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, MaxIdleTimeout: 18 * time.Second, } packer.EXPECT().HandleTransportParameters(gomock.Any()) @@ -2261,10 +2267,24 @@ var _ = Describe("Client Session", func() { Expect(sess.idleTimeout).To(Equal(18 * time.Second)) }) + It("errors if the TransportParameters contain a wrong initial_source_connection_id", func() { + sess.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose() + qlogger.EXPECT().ReceivedTransportParameters(params) + sess.processTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected initial_source_connection_id to equal 0xdeadbeef, is 0xdecafbad"))) + }) + It("errors if the TransportParameters contain a wrong original_destination_connection_id", func() { sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + InitialSourceConnectionID: sess.handshakeDestConnID, StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose()