diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 1634702c..89043d61 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -205,7 +205,7 @@ func newCryptoSetup( qlogger.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) qlogger.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } - extHandler := newExtensionHandler(tp.Marshal(), perspective) + extHandler := newExtensionHandler(tp.Marshal(perspective), perspective) cs := &cryptoSetup{ initialStream: initialStream, initialSealer: initialSealer, diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index ad5299c1..cbd22fb8 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -88,13 +88,14 @@ var _ = Describe("Crypto Setup TLS", func() { return &tls.Config{ServerName: ch.ServerName}, nil }, } + var token [16]byte server := NewCryptoSetupServer( &bytes.Buffer{}, &bytes.Buffer{}, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, NewMockHandshakeRunner(mockCtrl), tlsConf, false, @@ -121,13 +122,14 @@ var _ = Describe("Crypto Setup TLS", func() { runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) _, sInitialStream, sHandshakeStream := initStreams() + var token [16]byte server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), false, @@ -160,13 +162,14 @@ var _ = Describe("Crypto Setup TLS", func() { _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + var token [16]byte server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), false, @@ -202,13 +205,14 @@ var _ = Describe("Crypto Setup TLS", func() { _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) + var token [16]byte server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, runner, serverConf, false, @@ -237,13 +241,14 @@ var _ = Describe("Crypto Setup TLS", func() { It("returns Handshake() when it is closed", func() { _, sInitialStream, sHandshakeStream := initStreams() + var token [16]byte server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, NewMockHandshakeRunner(mockCtrl), serverConf, false, @@ -517,13 +522,14 @@ var _ = Describe("Crypto Setup TLS", func() { sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnHandshakeComplete() + var token [16]byte server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, false, @@ -576,13 +582,14 @@ var _ = Describe("Crypto Setup TLS", func() { sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnHandshakeComplete() + var token [16]byte server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, false, @@ -707,13 +714,14 @@ var _ = Describe("Crypto Setup TLS", func() { sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnHandshakeComplete() + var token [16]byte server = NewCryptoSetupServer( sInitialStream, sHandshakeStream, protocol.ConnectionID{}, nil, nil, - &wire.TransportParameters{}, + &wire.TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, true, diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 89da1905..1103755c 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -24,6 +24,8 @@ var _ = Describe("Transport Parameters", func() { rand.Seed(GinkgoRandomSeed()) }) + var token [16]byte + It("has a string representation", func() { p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1234, @@ -77,7 +79,7 @@ var _ = Describe("Transport Parameters", func() { MaxAckDelay: 42 * time.Millisecond, ActiveConnectionIDLimit: getRandomValue(), } - data := params.Marshal() + data := params.Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) @@ -101,8 +103,7 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, uint64(statelessResetTokenParameterID)) utils.WriteVarInt(b, 15) b.Write(make([]byte, 15)) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for stateless_reset_token: 15 (expected 16)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for stateless_reset_token: 15 (expected 16)")) }) It("errors when the max_packet_size is too small", func() { @@ -110,8 +111,7 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, uint64(maxUDPPayloadSizeParameterID)) utils.WriteVarInt(b, uint64(utils.VarIntLen(1199))) utils.WriteVarInt(b, 1199) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_packet_size: 1199 (minimum 1200)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_packet_size: 1199 (minimum 1200)")) }) It("errors when disable_active_migration has content", func() { @@ -119,12 +119,22 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, uint64(disableActiveMigrationParameterID)) utils.WriteVarInt(b, 6) b.Write([]byte("foobar")) - p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for disable_active_migration: 6 (expected empty)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for disable_active_migration: 6 (expected empty)")) + }) + + It("errors when the server doesn't set the original_destination_connection_id", func() { + b := &bytes.Buffer{} + utils.WriteVarInt(b, uint64(statelessResetTokenParameterID)) + utils.WriteVarInt(b, 16) + b.Write(token[:]) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_destination_connection_id")) }) It("errors when the max_ack_delay is too large", func() { - data := (&TransportParameters{MaxAckDelay: 1 << 14 * time.Millisecond}).Marshal() + data := (&TransportParameters{ + MaxAckDelay: 1 << 14 * time.Millisecond, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_ack_delay: 16384ms (maximum 16383ms)")) }) @@ -135,9 +145,15 @@ var _ = Describe("Transport Parameters", func() { // marshal 1000 times to average out the greasing transport parameter maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond for i := 0; i < num; i++ { - dataDefault := (&TransportParameters{MaxAckDelay: protocol.DefaultMaxAckDelay}).Marshal() + dataDefault := (&TransportParameters{ + MaxAckDelay: protocol.DefaultMaxAckDelay, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) defaultLen += len(dataDefault) - data := (&TransportParameters{MaxAckDelay: maxAckDelay}).Marshal() + data := (&TransportParameters{ + MaxAckDelay: maxAckDelay, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) dataLen += len(data) } entryLen := utils.VarIntLen(uint64(ackDelayExponentParameterID)) /* parameter id */ + utils.VarIntLen(uint64(utils.VarIntLen(uint64(maxAckDelay.Milliseconds())))) /*length */ + utils.VarIntLen(uint64(maxAckDelay.Milliseconds())) /* value */ @@ -145,7 +161,10 @@ var _ = Describe("Transport Parameters", func() { }) It("errors when the ack_delay_exponenent is too large", func() { - data := (&TransportParameters{AckDelayExponent: 21}).Marshal() + data := (&TransportParameters{ + AckDelayExponent: 21, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for ack_delay_exponent: 21 (maximum 20)")) }) @@ -155,9 +174,15 @@ var _ = Describe("Transport Parameters", func() { var defaultLen, dataLen int // marshal 1000 times to average out the greasing transport parameter for i := 0; i < num; i++ { - dataDefault := (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent}).Marshal() + dataDefault := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) defaultLen += len(dataDefault) - data := (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent + 1}).Marshal() + data := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent + 1, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) dataLen += len(data) } entryLen := utils.VarIntLen(uint64(ackDelayExponentParameterID)) /* parameter id */ + utils.VarIntLen(uint64(utils.VarIntLen(protocol.DefaultAckDelayExponent+1))) /* length */ + utils.VarIntLen(protocol.DefaultAckDelayExponent+1) /* value */ @@ -165,7 +190,10 @@ var _ = Describe("Transport Parameters", func() { }) It("sets the default value for the ack_delay_exponent, when no value was sent", func() { - data := (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent}).Marshal() + data := (&TransportParameters{ + AckDelayExponent: protocol.DefaultAckDelayExponent, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) Expect(p.AckDelayExponent).To(BeEquivalentTo(protocol.DefaultAckDelayExponent)) @@ -191,7 +219,7 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, uint64(utils.VarIntLen(val))) utils.WriteVarInt(b, val) p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(Succeed()) + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed()) Expect(p.MaxAckDelay).To(BeNumerically(">", 290*365*24*time.Hour)) }) @@ -210,7 +238,7 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, uint64(utils.VarIntLen(0x42))) utils.WriteVarInt(b, 0x42) p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(Succeed()) + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed()) Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337))) Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42))) }) @@ -230,7 +258,7 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, uint64(utils.VarIntLen(0x1337))) utils.WriteVarInt(b, 0x1337) p := &TransportParameters{} - err := p.Unmarshal(b.Bytes(), protocol.PerspectiveServer) + err := p.Unmarshal(b.Bytes(), protocol.PerspectiveClient) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("received duplicate transport parameter")) }) @@ -245,18 +273,19 @@ var _ = Describe("Transport Parameters", func() { }) It("errors if the client sent a stateless_reset_token", func() { - var token [16]byte - params := &TransportParameters{StatelessResetToken: &token} - data := params.Marshal() - Expect((&TransportParameters{}).Unmarshal(data, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a stateless_reset_token")) + b := &bytes.Buffer{} + utils.WriteVarInt(b, uint64(statelessResetTokenParameterID)) + utils.WriteVarInt(b, uint64(utils.VarIntLen(16))) + b.Write(token[:]) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a stateless_reset_token")) }) - It("errors if the client sent a stateless_reset_token", func() { - params := &TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xca, 0xfe}, - } - data := params.Marshal() - Expect((&TransportParameters{}).Unmarshal(data, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent an original_destination_connection_id")) + It("errors if the client sent the original_destination_connection_id", func() { + b := &bytes.Buffer{} + utils.WriteVarInt(b, uint64(originalDestinationConnectionIDParameterID)) + utils.WriteVarInt(b, 6) + b.Write([]byte("foobar")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent an original_destination_connection_id")) }) Context("preferred address", func() { @@ -274,7 +303,10 @@ var _ = Describe("Transport Parameters", func() { }) It("marshals and unmarshals", func() { - data := (&TransportParameters{PreferredAddress: pa}).Marshal() + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String())) @@ -286,14 +318,20 @@ var _ = Describe("Transport Parameters", func() { }) It("errors if the client sent a preferred_address", func() { - data := (&TransportParameters{PreferredAddress: pa}).Marshal() + b := &bytes.Buffer{} + utils.WriteVarInt(b, uint64(preferredAddressParameterID)) + utils.WriteVarInt(b, 6) + b.Write([]byte("foobar")) p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a preferred_address")) + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a preferred_address")) }) It("errors on zero-length connection IDs", func() { pa.ConnectionID = protocol.ConnectionID{} - data := (&TransportParameters{PreferredAddress: pa}).Marshal() + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid connection ID length: 0")) }) @@ -301,7 +339,10 @@ var _ = Describe("Transport Parameters", func() { It("errors on too long connection IDs", func() { pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) - data := (&TransportParameters{PreferredAddress: pa}).Marshal() + data := (&TransportParameters{ + PreferredAddress: pa, + StatelessResetToken: &token, + }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid connection ID length: 21")) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index c5268447..2fe056d8 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -80,18 +80,21 @@ type TransportParameters struct { // Unmarshal the transport parameters func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { - if err := p.unmarshal(data, sentBy); err != nil { + if err := p.unmarshal(data, sentBy, false); err != nil { return qerr.NewError(qerr.TransportParameterError, err.Error()) } return nil } -func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error { +func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective, fromSessionTicket bool) error { // needed to check that every parameter is only sent at most once var parameterIDs []transportParameterID - var readAckDelayExponent bool - var readMaxAckDelay bool + var ( + readAckDelayExponent bool + readMaxAckDelay bool + readOriginalDestinationConnectionID bool + ) r := bytes.NewReader(data) for r.Len() > 0 { @@ -160,12 +163,16 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective return errors.New("client sent an original_destination_connection_id") } p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readOriginalDestinationConnectionID = true default: r.Seek(int64(paramLen), io.SeekCurrent) } } } + if sentBy == protocol.PerspectiveServer && !fromSessionTicket && !readOriginalDestinationConnectionID { + return errors.New("expected original_destination_connection_id") + } if !readAckDelayExponent { p.AckDelayExponent = protocol.DefaultAckDelayExponent } @@ -288,7 +295,7 @@ func (p *TransportParameters) readNumericTransportParameter( } // Marshal the transport parameters -func (p *TransportParameters) Marshal() []byte { +func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { b := &bytes.Buffer{} //add a greased value @@ -330,27 +337,28 @@ func (p *TransportParameters) Marshal() []byte { utils.WriteVarInt(b, uint64(disableActiveMigrationParameterID)) utils.WriteVarInt(b, 0) } - if p.StatelessResetToken != nil { + if pers == protocol.PerspectiveServer { + // stateless_reset_token utils.WriteVarInt(b, uint64(statelessResetTokenParameterID)) utils.WriteVarInt(b, 16) b.Write(p.StatelessResetToken[:]) - } - if p.PreferredAddress != nil { - utils.WriteVarInt(b, uint64(preferredAddressParameterID)) - utils.WriteVarInt(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) - ipv4 := p.PreferredAddress.IPv4 - b.Write(ipv4[len(ipv4)-4:]) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) - b.Write(p.PreferredAddress.IPv6) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) - b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) - b.Write(p.PreferredAddress.ConnectionID.Bytes()) - b.Write(p.PreferredAddress.StatelessResetToken[:]) - } - if p.OriginalDestinationConnectionID.Len() > 0 { + // original_destination_connection_id utils.WriteVarInt(b, uint64(originalDestinationConnectionIDParameterID)) utils.WriteVarInt(b, uint64(p.OriginalDestinationConnectionID.Len())) b.Write(p.OriginalDestinationConnectionID.Bytes()) + // preferred_address + if p.PreferredAddress != nil { + utils.WriteVarInt(b, uint64(preferredAddressParameterID)) + utils.WriteVarInt(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) + ipv4 := p.PreferredAddress.IPv4 + b.Write(ipv4[len(ipv4)-4:]) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) + b.Write(p.PreferredAddress.IPv6) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) + b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) + b.Write(p.PreferredAddress.ConnectionID.Bytes()) + b.Write(p.PreferredAddress.StatelessResetToken[:]) + } } // active_connection_id_limit @@ -401,7 +409,7 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(data []byte) error { if version != transportParameterMarshalingVersion { return fmt.Errorf("unknown transport parameter marshaling version: %d", version) } - return p.Unmarshal(data[len(data)-r.Len():], protocol.PerspectiveServer) + return p.unmarshal(data[len(data)-r.Len():], protocol.PerspectiveServer, true) } // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. diff --git a/server.go b/server.go index a88737f0..5011150b 100644 --- a/server.go +++ b/server.go @@ -348,7 +348,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } var token *Token - var origDestConnectionID protocol.ConnectionID + origDestConnectionID := hdr.DestConnectionID if len(hdr.Token) > 0 { c, err := s.tokenGenerator.DecodeToken(hdr.Token) if err == nil { @@ -357,7 +357,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro RemoteAddr: c.RemoteAddr, SentTime: c.SentTime, } - origDestConnectionID = c.OriginalDestConnectionID + if token.IsRetryToken { + origDestConnectionID = c.OriginalDestConnectionID + } } } if !s.config.AcceptToken(p.remoteAddr, token) { diff --git a/server_test.go b/server_test.go index 5ab73b22..003ef1ba 100644 --- a/server_test.go +++ b/server_test.go @@ -289,6 +289,79 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) + It("creates a session when the token is accepted", func() { + serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } + retryToken, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + Token: retryToken, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + run := make(chan struct{}) + var token [16]byte + rand.Read(token[:]) + var newConnID protocol.ConnectionID + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { + newConnID = c + return token + }) + sess := NewMockQuicSession(mockCtrl) + serv.newSession = func( + _ connection, + _ sessionRunner, + origDestConnID protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + tokenP [16]byte, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + enable0RTT bool, + _ qlog.Tracer, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicSession { + Expect(enable0RTT).To(BeFalse()) + Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) + Expect(destConnID).To(Equal(hdr.SrcConnectionID)) + // make sure we're using a server-generated connection ID + Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) + Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) + Expect(srcConnID).To(Equal(newConnID)) + Expect(tokenP).To(Equal(token)) + sess.EXPECT().handlePacket(p) + sess.EXPECT().run().Do(func() { close(run) }) + sess.EXPECT().Context().Return(context.Background()) + sess.EXPECT().HandshakeComplete().Return(context.Background()) + return sess + } + + phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) + phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool { + Expect(c).To(Equal(newConnID)) + return true + }) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.handlePacket(p) + // the Handshake packet is written by the session + Consistently(conn.dataWritten).ShouldNot(Receive()) + close(done) + }() + // make sure we're using a server-generated connection ID + Eventually(run).Should(BeClosed()) + Eventually(done).Should(BeClosed()) + }) + It("sends a Version Negotiation Packet for unsupported versions", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} @@ -408,8 +481,8 @@ var _ = Describe("Server", func() { serv.newSession = func( _ connection, _ sessionRunner, - _ protocol.ConnectionID, - origConnID protocol.ConnectionID, + origDestConnID protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, tokenP [16]byte, @@ -422,7 +495,8 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicSession { Expect(enable0RTT).To(BeFalse()) - Expect(origConnID).To(Equal(hdr.DestConnectionID)) + Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) + Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) diff --git a/session.go b/session.go index f5df5059..88669313 100644 --- a/session.go +++ b/session.go @@ -111,7 +111,7 @@ type session struct { // Destination connection ID used during the handshake. // Used to check source connection ID on incoming packets. handshakeDestConnID protocol.ConnectionID - // if the server sends a Retry, this is the connection ID we used initially + // Set for the client. Destination connection ID used on the first Initial sent. origDestConnID protocol.ConnectionID srcConnIDLen int @@ -338,6 +338,7 @@ var newClientSession = func( s := &session{ conn: conn, config: conf, + origDestConnID: destConnID, handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), perspective: protocol.PerspectiveClient, @@ -876,7 +877,6 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t if s.qlogger != nil { s.qlogger.ReceivedRetry(hdr) } - s.origDestConnID = s.handshakeDestConnID newDestConnID := hdr.SrcConnectionID s.receivedRetry = true if err := s.sentPacketHandler.ResetForRetry(); err != nil { diff --git a/session_test.go b/session_test.go index b459c066..c940d990 100644 --- a/session_test.go +++ b/session_test.go @@ -2227,6 +2227,7 @@ var _ = Describe("Client Session", func() { It("uses the preferred_address connection ID", func() { params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, PreferredAddress: &wire.PreferredAddress{ IPv4: net.IPv4(127, 0, 0, 1), IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, @@ -2251,7 +2252,8 @@ var _ = Describe("Client Session", func() { It("uses the minimum of the peers' idle timeouts", func() { sess.config.MaxIdleTimeout = 19 * time.Second params := &wire.TransportParameters{ - MaxIdleTimeout: 18 * time.Second, + OriginalDestinationConnectionID: destConnID, + MaxIdleTimeout: 18 * time.Second, } packer.EXPECT().HandleTransportParameters(gomock.Any()) qlogger.EXPECT().ReceivedTransportParameters(params) @@ -2259,17 +2261,6 @@ var _ = Describe("Client Session", func() { Expect(sess.idleTimeout).To(Equal(18 * time.Second)) }) - It("errors if the TransportParameters contain an original_destination_connection_id, although no Retry was performed", func() { - params := &wire.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - expectClose() - qlogger.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_destination_connection_id to equal (empty), is 0xdecafbad"))) - }) - It("errors if the TransportParameters contain a wrong original_destination_connection_id", func() { sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{