set and verify the retry_source_connection_id TP

This commit is contained in:
Marten Seemann 2020-05-24 21:58:00 +07:00
parent cdb22513f3
commit 4f19b15670
4 changed files with 137 additions and 36 deletions

View file

@ -11,13 +11,12 @@ import (
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/qlog"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qlog"
)
// packetHandler handles packets
@ -75,7 +74,23 @@ type baseServer struct {
receivedPackets chan *receivedPacket
// set as a member, so they can be set in the tests
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, bool /* enable 0-RTT */, qlog.Tracer, utils.Logger, protocol.VersionNumber) quicSession
newSession func(
connection,
sessionRunner,
protocol.ConnectionID, /* original dest connection ID */
*protocol.ConnectionID, /* retry src connection ID */
protocol.ConnectionID, /* client dest connection ID */
protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */
[16]byte,
*Config,
*tls.Config,
*handshake.TokenGenerator,
bool, /* enable 0-RTT */
qlog.Tracer,
utils.Logger,
protocol.VersionNumber,
) quicSession
serverError error
errorChan chan struct{}
@ -347,7 +362,10 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return errors.New("too short connection ID")
}
var token *Token
var (
token *Token
retrySrcConnectionID *protocol.ConnectionID
)
origDestConnectionID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
c, err := s.tokenGenerator.DecodeToken(hdr.Token)
@ -359,6 +377,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
if token.IsRetryToken {
origDestConnectionID = c.OriginalDestConnectionID
retrySrcConnectionID = &c.RetrySrcConnectionID
}
}
}
@ -396,6 +415,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
sess := s.createNewSession(
p.remoteAddr,
origDestConnectionID,
retrySrcConnectionID,
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
@ -419,6 +439,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
func (s *baseServer) createNewSession(
remoteAddr net.Addr,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
@ -441,6 +462,7 @@ func (s *baseServer) createNewSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionHandler,
origDestConnID,
retrySrcConnID,
clientDestConnID,
destConnID,
srcConnID,

View file

@ -291,7 +291,11 @@ var _ = Describe("Server", func() {
It("creates a session when the token is accepted", func() {
serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true }
retryToken, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, nil)
retryToken, err := serv.tokenGenerator.NewRetryToken(
&net.UDPAddr{},
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
@ -305,16 +309,23 @@ var _ = Describe("Server", func() {
run := make(chan struct{})
var token [16]byte
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
return token
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
fn()
return true
})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
_ sessionRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
@ -329,6 +340,7 @@ var _ = Describe("Server", func() {
) quicSession {
Expect(enable0RTT).To(BeFalse())
Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}))
Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}))
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
@ -343,12 +355,6 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
Expect(c).To(Equal(newConnID))
return true
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -475,13 +481,24 @@ var _ = Describe("Server", func() {
run := make(chan struct{})
var token [16]byte
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
fn()
return true
})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
_ sessionRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
@ -496,6 +513,7 @@ var _ = Describe("Server", func() {
) quicSession {
Expect(enable0RTT).To(BeFalse())
Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
Expect(retrySrcConnID).To(BeNil())
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
@ -510,16 +528,6 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
fn()
return true
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -557,6 +565,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -599,6 +608,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -646,6 +656,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -675,6 +686,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -739,6 +751,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -846,6 +859,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -868,7 +882,7 @@ var _ = Describe("Server", func() {
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
Eventually(done).Should(BeClosed())
@ -912,6 +926,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -935,7 +950,7 @@ var _ = Describe("Server", func() {
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
close(ready)
Eventually(done).Should(BeClosed())
@ -949,6 +964,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
@ -1007,6 +1023,7 @@ var _ = Describe("Server", func() {
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,

View file

@ -113,7 +113,9 @@ type session struct {
handshakeDestConnID protocol.ConnectionID
// Set for the client. Destination connection ID used on the first Initial sent.
origDestConnID protocol.ConnectionID
srcConnIDLen int
retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed)
srcConnIDLen int
perspective protocol.Perspective
initialVersion protocol.VersionNumber // if version negotiation is performed, this is the version we initially tried
@ -201,6 +203,7 @@ var newSession = func(
conn connection,
runner sessionRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
@ -275,6 +278,7 @@ var newSession = func(
OriginalDestinationConnectionID: origDestConnID,
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
RetrySourceConnectionID: retrySrcConnID,
}
if s.qlogger != nil {
s.qlogger.SentTransportParameters(params)
@ -874,7 +878,7 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t
return false
}
s.logger.Debugf("<- Received Retry")
s.logger.Debugf("<- Received Retry: %#v", hdr)
s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
if s.qlogger != nil {
s.qlogger.ReceivedRetry(hdr)
@ -886,6 +890,7 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t
return false
}
s.handshakeDestConnID = newDestConnID
s.retrySrcConnID = &newDestConnID
s.cryptoStreamHandler.ChangeConnectionID(newDestConnID)
s.packer.SetToken(hdr.Token)
s.connIDManager.ChangeInitialConnID(newDestConnID)
@ -1284,6 +1289,12 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) {
}
func (s *session) processTransportParameters(params *wire.TransportParameters) {
if err := s.processTransportParametersImpl(params); err != nil {
s.closeLocal(err)
}
}
func (s *session) processTransportParametersImpl(params *wire.TransportParameters) error {
if s.logger.Debug() {
s.logger.Debugf("Processed Transport Parameters: %s", params)
}
@ -1293,14 +1304,24 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) {
// check the initial_source_connection_id
if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) {
s.closeLocal(qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID)))
return
return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID))
}
// check the original_destination_connection_id
if s.perspective == protocol.PerspectiveClient && !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) {
s.closeLocal(qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)))
return
if s.perspective == protocol.PerspectiveClient {
// check the original_destination_connection_id
if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) {
return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID))
}
if s.retrySrcConnID != nil { // a Retry was performed
if params.RetrySourceConnectionID == nil {
return qerr.NewError(qerr.TransportParameterError, "missing retry_source_connection_id")
}
if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) {
return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID))
}
} else if params.RetrySourceConnectionID != nil {
return qerr.NewError(qerr.TransportParameterError, "received retry_source_connection_id, although no Retry was performed")
}
}
s.peerParams = params
@ -1308,8 +1329,7 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) {
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
s.keepAliveInterval = utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval)
if err := s.streamsMap.UpdateLimits(params); err != nil {
s.closeLocal(err)
return
return err
}
s.packer.HandleTransportParameters(params)
s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
@ -1330,6 +1350,7 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) {
if s.perspective == protocol.PerspectiveServer {
close(s.earlySessionReadyChan)
}
return nil
}
func (s *session) sendPackets() error {

View file

@ -94,6 +94,7 @@ var _ = Describe("Session", func() {
mconn,
sessionRunner,
nil,
nil,
clientDestConnID,
destConnID,
srcConnID,
@ -2280,7 +2281,47 @@ var _ = Describe("Client Session", func() {
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected initial_source_connection_id to equal 0xdeadbeef, is 0xdecafbad")))
})
It("errors if the TransportParameters contain a wrong original_destination_connection_id", func() {
It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() {
sess.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
params := &wire.TransportParameters{
OriginalDestinationConnectionID: destConnID,
InitialSourceConnectionID: destConnID,
StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose()
qlogger.EXPECT().ReceivedTransportParameters(params)
sess.processTransportParameters(params)
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: missing retry_source_connection_id")))
})
It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() {
sess.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
params := &wire.TransportParameters{
OriginalDestinationConnectionID: destConnID,
InitialSourceConnectionID: destConnID,
RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose()
qlogger.EXPECT().ReceivedTransportParameters(params)
sess.processTransportParameters(params)
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected retry_source_connection_id to equal 0xdeadbeef, is 0xdeadc0de")))
})
It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() {
params := &wire.TransportParameters{
OriginalDestinationConnectionID: destConnID,
InitialSourceConnectionID: destConnID,
RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
expectClose()
qlogger.EXPECT().ReceivedTransportParameters(params)
sess.processTransportParameters(params)
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: received retry_source_connection_id, although no Retry was performed")))
})
It("errors if the transport parameters contain a wrong original_destination_connection_id", func() {
sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
params := &wire.TransportParameters{
OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},