always send the original_destination_connection_id TP as a server

This commit is contained in:
Marten Seemann 2020-05-24 12:43:14 +07:00
parent a7005ac936
commit b391cce35c
8 changed files with 205 additions and 81 deletions

View file

@ -205,7 +205,7 @@ func newCryptoSetup(
qlogger.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
qlogger.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
extHandler := newExtensionHandler(tp.Marshal(), perspective)
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
cs := &cryptoSetup{
initialStream: initialStream,
initialSealer: initialSealer,

View file

@ -88,13 +88,14 @@ var _ = Describe("Crypto Setup TLS", func() {
return &tls.Config{ServerName: ch.ServerName}, nil
},
}
var token [16]byte
server := NewCryptoSetupServer(
&bytes.Buffer{},
&bytes.Buffer{},
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
tlsConf,
false,
@ -121,13 +122,14 @@ var _ = Describe("Crypto Setup TLS", func() {
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream := initStreams()
var token [16]byte
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
@ -160,13 +162,14 @@ var _ = Describe("Crypto Setup TLS", func() {
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token [16]byte
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
@ -202,13 +205,14 @@ var _ = Describe("Crypto Setup TLS", func() {
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
var token [16]byte
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
false,
@ -237,13 +241,14 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream := initStreams()
var token [16]byte
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
false,
@ -517,13 +522,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token [16]byte
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
@ -576,13 +582,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token [16]byte
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
@ -707,13 +714,14 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token [16]byte
server = NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
true,

View file

@ -24,6 +24,8 @@ var _ = Describe("Transport Parameters", func() {
rand.Seed(GinkgoRandomSeed())
})
var token [16]byte
It("has a string representation", func() {
p := &TransportParameters{
InitialMaxStreamDataBidiLocal: 1234,
@ -77,7 +79,7 @@ var _ = Describe("Transport Parameters", func() {
MaxAckDelay: 42 * time.Millisecond,
ActiveConnectionIDLimit: getRandomValue(),
}
data := params.Marshal()
data := params.Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed())
@ -101,8 +103,7 @@ var _ = Describe("Transport Parameters", func() {
utils.WriteVarInt(b, uint64(statelessResetTokenParameterID))
utils.WriteVarInt(b, 15)
b.Write(make([]byte, 15))
p := &TransportParameters{}
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for stateless_reset_token: 15 (expected 16)"))
Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for stateless_reset_token: 15 (expected 16)"))
})
It("errors when the max_packet_size is too small", func() {
@ -110,8 +111,7 @@ var _ = Describe("Transport Parameters", func() {
utils.WriteVarInt(b, uint64(maxUDPPayloadSizeParameterID))
utils.WriteVarInt(b, uint64(utils.VarIntLen(1199)))
utils.WriteVarInt(b, 1199)
p := &TransportParameters{}
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_packet_size: 1199 (minimum 1200)"))
Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_packet_size: 1199 (minimum 1200)"))
})
It("errors when disable_active_migration has content", func() {
@ -119,12 +119,22 @@ var _ = Describe("Transport Parameters", func() {
utils.WriteVarInt(b, uint64(disableActiveMigrationParameterID))
utils.WriteVarInt(b, 6)
b.Write([]byte("foobar"))
p := &TransportParameters{}
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for disable_active_migration: 6 (expected empty)"))
Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for disable_active_migration: 6 (expected empty)"))
})
It("errors when the server doesn't set the original_destination_connection_id", func() {
b := &bytes.Buffer{}
utils.WriteVarInt(b, uint64(statelessResetTokenParameterID))
utils.WriteVarInt(b, 16)
b.Write(token[:])
Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_destination_connection_id"))
})
It("errors when the max_ack_delay is too large", func() {
data := (&TransportParameters{MaxAckDelay: 1 << 14 * time.Millisecond}).Marshal()
data := (&TransportParameters{
MaxAckDelay: 1 << 14 * time.Millisecond,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_ack_delay: 16384ms (maximum 16383ms)"))
})
@ -135,9 +145,15 @@ var _ = Describe("Transport Parameters", func() {
// marshal 1000 times to average out the greasing transport parameter
maxAckDelay := protocol.DefaultMaxAckDelay + time.Millisecond
for i := 0; i < num; i++ {
dataDefault := (&TransportParameters{MaxAckDelay: protocol.DefaultMaxAckDelay}).Marshal()
dataDefault := (&TransportParameters{
MaxAckDelay: protocol.DefaultMaxAckDelay,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
defaultLen += len(dataDefault)
data := (&TransportParameters{MaxAckDelay: maxAckDelay}).Marshal()
data := (&TransportParameters{
MaxAckDelay: maxAckDelay,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
dataLen += len(data)
}
entryLen := utils.VarIntLen(uint64(ackDelayExponentParameterID)) /* parameter id */ + utils.VarIntLen(uint64(utils.VarIntLen(uint64(maxAckDelay.Milliseconds())))) /*length */ + utils.VarIntLen(uint64(maxAckDelay.Milliseconds())) /* value */
@ -145,7 +161,10 @@ var _ = Describe("Transport Parameters", func() {
})
It("errors when the ack_delay_exponenent is too large", func() {
data := (&TransportParameters{AckDelayExponent: 21}).Marshal()
data := (&TransportParameters{
AckDelayExponent: 21,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for ack_delay_exponent: 21 (maximum 20)"))
})
@ -155,9 +174,15 @@ var _ = Describe("Transport Parameters", func() {
var defaultLen, dataLen int
// marshal 1000 times to average out the greasing transport parameter
for i := 0; i < num; i++ {
dataDefault := (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent}).Marshal()
dataDefault := (&TransportParameters{
AckDelayExponent: protocol.DefaultAckDelayExponent,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
defaultLen += len(dataDefault)
data := (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent + 1}).Marshal()
data := (&TransportParameters{
AckDelayExponent: protocol.DefaultAckDelayExponent + 1,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
dataLen += len(data)
}
entryLen := utils.VarIntLen(uint64(ackDelayExponentParameterID)) /* parameter id */ + utils.VarIntLen(uint64(utils.VarIntLen(protocol.DefaultAckDelayExponent+1))) /* length */ + utils.VarIntLen(protocol.DefaultAckDelayExponent+1) /* value */
@ -165,7 +190,10 @@ var _ = Describe("Transport Parameters", func() {
})
It("sets the default value for the ack_delay_exponent, when no value was sent", func() {
data := (&TransportParameters{AckDelayExponent: protocol.DefaultAckDelayExponent}).Marshal()
data := (&TransportParameters{
AckDelayExponent: protocol.DefaultAckDelayExponent,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed())
Expect(p.AckDelayExponent).To(BeEquivalentTo(protocol.DefaultAckDelayExponent))
@ -191,7 +219,7 @@ var _ = Describe("Transport Parameters", func() {
utils.WriteVarInt(b, uint64(utils.VarIntLen(val)))
utils.WriteVarInt(b, val)
p := &TransportParameters{}
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(Succeed())
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed())
Expect(p.MaxAckDelay).To(BeNumerically(">", 290*365*24*time.Hour))
})
@ -210,7 +238,7 @@ var _ = Describe("Transport Parameters", func() {
utils.WriteVarInt(b, uint64(utils.VarIntLen(0x42)))
utils.WriteVarInt(b, 0x42)
p := &TransportParameters{}
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(Succeed())
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(Succeed())
Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337)))
Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42)))
})
@ -230,7 +258,7 @@ var _ = Describe("Transport Parameters", func() {
utils.WriteVarInt(b, uint64(utils.VarIntLen(0x1337)))
utils.WriteVarInt(b, 0x1337)
p := &TransportParameters{}
err := p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)
err := p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("received duplicate transport parameter"))
})
@ -245,18 +273,19 @@ var _ = Describe("Transport Parameters", func() {
})
It("errors if the client sent a stateless_reset_token", func() {
var token [16]byte
params := &TransportParameters{StatelessResetToken: &token}
data := params.Marshal()
Expect((&TransportParameters{}).Unmarshal(data, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a stateless_reset_token"))
b := &bytes.Buffer{}
utils.WriteVarInt(b, uint64(statelessResetTokenParameterID))
utils.WriteVarInt(b, uint64(utils.VarIntLen(16)))
b.Write(token[:])
Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a stateless_reset_token"))
})
It("errors if the client sent a stateless_reset_token", func() {
params := &TransportParameters{
OriginalDestinationConnectionID: protocol.ConnectionID{0xca, 0xfe},
}
data := params.Marshal()
Expect((&TransportParameters{}).Unmarshal(data, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent an original_destination_connection_id"))
It("errors if the client sent the original_destination_connection_id", func() {
b := &bytes.Buffer{}
utils.WriteVarInt(b, uint64(originalDestinationConnectionIDParameterID))
utils.WriteVarInt(b, 6)
b.Write([]byte("foobar"))
Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent an original_destination_connection_id"))
})
Context("preferred address", func() {
@ -274,7 +303,10 @@ var _ = Describe("Transport Parameters", func() {
})
It("marshals and unmarshals", func() {
data := (&TransportParameters{PreferredAddress: pa}).Marshal()
data := (&TransportParameters{
PreferredAddress: pa,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed())
Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String()))
@ -286,14 +318,20 @@ var _ = Describe("Transport Parameters", func() {
})
It("errors if the client sent a preferred_address", func() {
data := (&TransportParameters{PreferredAddress: pa}).Marshal()
b := &bytes.Buffer{}
utils.WriteVarInt(b, uint64(preferredAddressParameterID))
utils.WriteVarInt(b, 6)
b.Write([]byte("foobar"))
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a preferred_address"))
Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a preferred_address"))
})
It("errors on zero-length connection IDs", func() {
pa.ConnectionID = protocol.ConnectionID{}
data := (&TransportParameters{PreferredAddress: pa}).Marshal()
data := (&TransportParameters{
PreferredAddress: pa,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid connection ID length: 0"))
})
@ -301,7 +339,10 @@ var _ = Describe("Transport Parameters", func() {
It("errors on too long connection IDs", func() {
pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}
Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen))
data := (&TransportParameters{PreferredAddress: pa}).Marshal()
data := (&TransportParameters{
PreferredAddress: pa,
StatelessResetToken: &token,
}).Marshal(protocol.PerspectiveServer)
p := &TransportParameters{}
Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid connection ID length: 21"))
})

View file

@ -80,18 +80,21 @@ type TransportParameters struct {
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(data, sentBy); err != nil {
if err := p.unmarshal(data, sentBy, false); err != nil {
return qerr.NewError(qerr.TransportParameterError, err.Error())
}
return nil
}
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error {
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
var readAckDelayExponent bool
var readMaxAckDelay bool
var (
readAckDelayExponent bool
readMaxAckDelay bool
readOriginalDestinationConnectionID bool
)
r := bytes.NewReader(data)
for r.Len() > 0 {
@ -160,12 +163,16 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective
return errors.New("client sent an original_destination_connection_id")
}
p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
readOriginalDestinationConnectionID = true
default:
r.Seek(int64(paramLen), io.SeekCurrent)
}
}
}
if sentBy == protocol.PerspectiveServer && !fromSessionTicket && !readOriginalDestinationConnectionID {
return errors.New("expected original_destination_connection_id")
}
if !readAckDelayExponent {
p.AckDelayExponent = protocol.DefaultAckDelayExponent
}
@ -288,7 +295,7 @@ func (p *TransportParameters) readNumericTransportParameter(
}
// Marshal the transport parameters
func (p *TransportParameters) Marshal() []byte {
func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b := &bytes.Buffer{}
//add a greased value
@ -330,27 +337,28 @@ func (p *TransportParameters) Marshal() []byte {
utils.WriteVarInt(b, uint64(disableActiveMigrationParameterID))
utils.WriteVarInt(b, 0)
}
if p.StatelessResetToken != nil {
if pers == protocol.PerspectiveServer {
// stateless_reset_token
utils.WriteVarInt(b, uint64(statelessResetTokenParameterID))
utils.WriteVarInt(b, 16)
b.Write(p.StatelessResetToken[:])
}
if p.PreferredAddress != nil {
utils.WriteVarInt(b, uint64(preferredAddressParameterID))
utils.WriteVarInt(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ipv4 := p.PreferredAddress.IPv4
b.Write(ipv4[len(ipv4)-4:])
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port)
b.Write(p.PreferredAddress.IPv6)
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port)
b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len()))
b.Write(p.PreferredAddress.ConnectionID.Bytes())
b.Write(p.PreferredAddress.StatelessResetToken[:])
}
if p.OriginalDestinationConnectionID.Len() > 0 {
// original_destination_connection_id
utils.WriteVarInt(b, uint64(originalDestinationConnectionIDParameterID))
utils.WriteVarInt(b, uint64(p.OriginalDestinationConnectionID.Len()))
b.Write(p.OriginalDestinationConnectionID.Bytes())
// preferred_address
if p.PreferredAddress != nil {
utils.WriteVarInt(b, uint64(preferredAddressParameterID))
utils.WriteVarInt(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ipv4 := p.PreferredAddress.IPv4
b.Write(ipv4[len(ipv4)-4:])
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port)
b.Write(p.PreferredAddress.IPv6)
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port)
b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len()))
b.Write(p.PreferredAddress.ConnectionID.Bytes())
b.Write(p.PreferredAddress.StatelessResetToken[:])
}
}
// active_connection_id_limit
@ -401,7 +409,7 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(data []byte) error {
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.Unmarshal(data[len(data)-r.Len():], protocol.PerspectiveServer)
return p.unmarshal(data[len(data)-r.Len():], protocol.PerspectiveServer, true)
}
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.

View file

@ -348,7 +348,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
var token *Token
var origDestConnectionID protocol.ConnectionID
origDestConnectionID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
c, err := s.tokenGenerator.DecodeToken(hdr.Token)
if err == nil {
@ -357,7 +357,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
RemoteAddr: c.RemoteAddr,
SentTime: c.SentTime,
}
origDestConnectionID = c.OriginalDestConnectionID
if token.IsRetryToken {
origDestConnectionID = c.OriginalDestConnectionID
}
}
}
if !s.config.AcceptToken(p.remoteAddr, token) {

View file

@ -289,6 +289,79 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("creates a session when the token is accepted", func() {
serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
retryToken, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
Token: retryToken,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
run := make(chan struct{})
var token [16]byte
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
_ sessionRunner,
origDestConnID protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
tokenP [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
enable0RTT bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
) quicSession {
Expect(enable0RTT).To(BeFalse())
Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}))
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
Expect(srcConnID).To(Equal(newConnID))
Expect(tokenP).To(Equal(token))
sess.EXPECT().handlePacket(p)
sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
}
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
Expect(c).To(Equal(newConnID))
return true
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
// the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive())
close(done)
}()
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
@ -408,8 +481,8 @@ var _ = Describe("Server", func() {
serv.newSession = func(
_ connection,
_ sessionRunner,
_ protocol.ConnectionID,
origConnID protocol.ConnectionID,
origDestConnID protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
tokenP [16]byte,
@ -422,7 +495,8 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicSession {
Expect(enable0RTT).To(BeFalse())
Expect(origConnID).To(Equal(hdr.DestConnectionID))
Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))

View file

@ -111,7 +111,7 @@ type session struct {
// Destination connection ID used during the handshake.
// Used to check source connection ID on incoming packets.
handshakeDestConnID protocol.ConnectionID
// if the server sends a Retry, this is the connection ID we used initially
// Set for the client. Destination connection ID used on the first Initial sent.
origDestConnID protocol.ConnectionID
srcConnIDLen int
@ -338,6 +338,7 @@ var newClientSession = func(
s := &session{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
@ -876,7 +877,6 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t
if s.qlogger != nil {
s.qlogger.ReceivedRetry(hdr)
}
s.origDestConnID = s.handshakeDestConnID
newDestConnID := hdr.SrcConnectionID
s.receivedRetry = true
if err := s.sentPacketHandler.ResetForRetry(); err != nil {

View file

@ -2227,6 +2227,7 @@ var _ = Describe("Client Session", func() {
It("uses the preferred_address connection ID", func() {
params := &wire.TransportParameters{
OriginalDestinationConnectionID: destConnID,
PreferredAddress: &wire.PreferredAddress{
IPv4: net.IPv4(127, 0, 0, 1),
IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
@ -2251,7 +2252,8 @@ var _ = Describe("Client Session", func() {
It("uses the minimum of the peers' idle timeouts", func() {
sess.config.MaxIdleTimeout = 19 * time.Second
params := &wire.TransportParameters{
MaxIdleTimeout: 18 * time.Second,
OriginalDestinationConnectionID: destConnID,
MaxIdleTimeout: 18 * time.Second,
}
packer.EXPECT().HandleTransportParameters(gomock.Any())
qlogger.EXPECT().ReceivedTransportParameters(params)
@ -2259,17 +2261,6 @@ var _ = Describe("Client Session", func() {
Expect(sess.idleTimeout).To(Equal(18 * time.Second))
})
It("errors if the TransportParameters contain an original_destination_connection_id, although no Retry was performed", func() {
params := &wire.TransportParameters{
OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose()
qlogger.EXPECT().ReceivedTransportParameters(params)
sess.processTransportParameters(params)
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_destination_connection_id to equal (empty), is 0xdecafbad")))
})
It("errors if the TransportParameters contain a wrong original_destination_connection_id", func() {
sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
params := &wire.TransportParameters{