move unmarshaling of the transport parameters to the crypto setup

This commit is contained in:
Marten Seemann 2019-08-05 10:44:31 +07:00
parent fc37cdc5c5
commit 9b0a4a8813
6 changed files with 69 additions and 107 deletions

View file

@ -339,7 +339,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
h.logger.Debugf("Sending HelloRetryRequest")
return false
case data := <-h.paramsChan:
h.runner.OnReceivedParams(data)
h.handleTransportParameters(data)
case <-h.handshakeDone:
return false
}
@ -404,7 +404,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
case typeEncryptedExtensions:
select {
case data := <-h.paramsChan:
h.runner.OnReceivedParams(data)
h.handleTransportParameters(data)
case <-h.handshakeDone:
return false
}
@ -431,6 +431,14 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
}
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
var tp TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
h.runner.OnError(qerr.Error(qerr.TransportParameterError, err.Error()))
}
h.runner.OnReceivedParams(&tp)
}
// only valid for the server
func (h *cryptoSetup) maybeSendSessionTicket() {
ticket, err := h.conn.GetSessionTicket(h.ourParams.MarshalForSessionTicket())

View file

@ -430,11 +430,11 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd []byte
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
@ -453,7 +453,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
var token [16]byte
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b })
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
sRunner.EXPECT().OnHandshakeComplete()
sTransportParameters := &TransportParameters{
MaxIdleTimeout: 0x1337 * time.Second,
@ -480,14 +480,9 @@ var _ = Describe("Crypto Setup TLS", func() {
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(cTransportParametersRcvd).ToNot(BeNil())
clTP := &TransportParameters{}
Expect(clTP.Unmarshal(cTransportParametersRcvd, protocol.PerspectiveClient)).To(Succeed())
Expect(clTP.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
Expect(sTransportParametersRcvd).ToNot(BeNil())
srvTP := &TransportParameters{}
Expect(srvTP.Unmarshal(sTransportParametersRcvd, protocol.PerspectiveServer)).To(Succeed())
Expect(srvTP.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
})
Context("with session tickets", func() {

View file

@ -59,7 +59,7 @@ type tlsExtensionHandler interface {
}
type handshakeRunner interface {
OnReceivedParams([]byte)
OnReceivedParams(*TransportParameters)
OnHandshakeComplete()
OnError(error)
DropKeys(protocol.EncryptionLevel)

View file

@ -71,7 +71,7 @@ func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call {
}
// OnReceivedParams mocks base method
func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) {
func (m *MockHandshakeRunner) OnReceivedParams(arg0 *TransportParameters) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnReceivedParams", arg0)
}

View file

@ -84,16 +84,16 @@ type sessionRunner interface {
}
type handshakeRunner struct {
onReceivedParams func([]byte)
onReceivedParams func(*handshake.TransportParameters)
onError func(error)
dropKeys func(protocol.EncryptionLevel)
onHandshakeComplete func()
}
func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) }
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
func (r *handshakeRunner) OnReceivedParams(tp *handshake.TransportParameters) { r.onReceivedParams(tp) }
func (r *handshakeRunner) OnError(e error) { r.onError(e) }
func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) }
func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() }
type closeError struct {
err error
@ -1092,19 +1092,13 @@ func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
s.receivedPacketHandler.DropPackets(encLevel)
}
func (s *session) processTransportParameters(data []byte) {
var params *handshake.TransportParameters
var err error
switch s.perspective {
case protocol.PerspectiveClient:
params, err = s.processTransportParametersForClient(data)
case protocol.PerspectiveServer:
params, err = s.processTransportParametersForServer(data)
}
if err != nil {
s.closeLocal(err)
func (s *session) processTransportParameters(params *handshake.TransportParameters) {
// check the Retry token
if s.perspective == protocol.PerspectiveClient && !params.OriginalConnectionID.Equal(s.origDestConnID) {
s.closeLocal(qerr.Error(qerr.TransportParameterError, fmt.Sprintf("expected original_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalConnectionID)))
return
}
s.logger.Debugf("Received Transport Parameters: %s", params)
s.peerParams = params
// Our local idle timeout will always be > 0.
@ -1122,36 +1116,15 @@ func (s *session) processTransportParameters(data []byte) {
if params.StatelessResetToken != nil {
s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken)
}
// On the server side, the early session is ready as soon as we processed
// the client's transport parameters.
close(s.earlySessionReadyChan)
}
func (s *session) processTransportParametersForClient(data []byte) (*handshake.TransportParameters, error) {
params := &handshake.TransportParameters{}
if err := params.Unmarshal(data, s.perspective.Opposite()); err != nil {
return nil, err
}
// check the Retry token
if !params.OriginalConnectionID.Equal(s.origDestConnID) {
return nil, qerr.Error(qerr.TransportParameterError, fmt.Sprintf("expected original_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalConnectionID))
}
// We don't support connection migration yet, so we don't have any use for the preferred_address.
if params.PreferredAddress != nil {
s.logger.Debugf("Server sent preferred_address. Retiring the preferred_address connection ID.")
// Retire the connection ID.
s.framer.QueueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: 1})
}
return params, nil
}
func (s *session) processTransportParametersForServer(data []byte) (*handshake.TransportParameters, error) {
params := &handshake.TransportParameters{}
if err := params.Unmarshal(data, s.perspective.Opposite()); err != nil {
return nil, err
}
return params, nil
// On the server side, the early session is ready as soon as we processed
// the client's transport parameters.
close(s.earlySessionReadyChan)
}
func (s *session) sendPackets() error {

View file

@ -1313,23 +1313,7 @@ var _ = Describe("Session", func() {
})
Context("transport parameters", func() {
It("errors if it can't unmarshal the TransportParameters", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
err := sess.run()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("transport parameter"))
}()
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
sess.processTransportParameters([]byte("invalid"))
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("processes transport parameters received from the client", func() {
It("process transport parameters received from the client", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
@ -1348,7 +1332,7 @@ var _ = Describe("Session", func() {
packer.EXPECT().PackPacket().MaxTimes(3)
Expect(sess.earlySessionReady()).ToNot(BeClosed())
sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2)
sess.processTransportParameters(params.Marshal())
sess.processTransportParameters(params)
Expect(sess.earlySessionReady()).To(BeClosed())
// make the go routine return
@ -1367,10 +1351,9 @@ var _ = Describe("Session", func() {
Context("keep-alives", func() {
setRemoteIdleTimeout := func(t time.Duration) {
tp := &handshake.TransportParameters{MaxIdleTimeout: t}
streamManager.EXPECT().UpdateLimits(gomock.Any())
packer.EXPECT().HandleTransportParameters(gomock.Any())
sess.processTransportParameters(tp.Marshal())
sess.processTransportParameters(&handshake.TransportParameters{MaxIdleTimeout: t})
}
runSession := func() {
@ -1814,27 +1797,40 @@ var _ = Describe("Client Session", func() {
})
Context("transport parameters", func() {
It("errors if it can't unmarshal the TransportParameters", func() {
var (
closed bool
errChan chan error
)
JustBeforeEach(func() {
errChan = make(chan error, 1)
closed = false
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
err := sess.run()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("transport parameter"))
errChan <- sess.run()
}()
expectReplaceWithClosed()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
sess.processTransportParameters([]byte("invalid"))
})
expectClose := func() {
if !closed {
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{}))
Expect(s.Close()).To(Succeed())
})
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil).MaxTimes(1)
cryptoSetup.EXPECT().Close()
}
closed = true
}
AfterEach(func() {
expectClose()
sess.Close()
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("immediately retires the preferred_address connection ID", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
sess.run()
}()
params := &handshake.TransportParameters{
PreferredAddress: &handshake.PreferredAddress{
IPv4: net.IPv4(127, 0, 0, 1),
@ -1844,20 +1840,10 @@ var _ = Describe("Client Session", func() {
}
packer.EXPECT().HandleTransportParameters(gomock.Any())
packer.EXPECT().PackPacket().MaxTimes(1)
sess.processTransportParameters(params.Marshal())
sess.processTransportParameters(params)
cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(cf).To(HaveLen(1))
Expect(cf[0].Frame).To(Equal(&wire.RetireConnectionIDFrame{SequenceNumber: 1}))
// make the go routine return
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{}))
Expect(s.Close()).To(Succeed())
})
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
sess.Close()
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("uses the minimum of the peers' idle timeouts", func() {
@ -1866,27 +1852,27 @@ var _ = Describe("Client Session", func() {
MaxIdleTimeout: 18 * time.Second,
}
packer.EXPECT().HandleTransportParameters(gomock.Any())
sess.processTransportParameters(params.Marshal())
sess.processTransportParameters(params)
Expect(sess.idleTimeout).To(Equal(18 * time.Second))
})
It("errors if the TransportParameters contain an original_connection_id, although no Retry was performed", func() {
params := &handshake.TransportParameters{
expectClose()
sess.processTransportParameters(&handshake.TransportParameters{
OriginalConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
_, err := sess.processTransportParametersForClient(params.Marshal())
Expect(err).To(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal (empty), is 0xdecafbad"))
})
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal (empty), is 0xdecafbad")))
})
It("errors if the TransportParameters contain a wrong original_connection_id", func() {
sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
params := &handshake.TransportParameters{
expectClose()
sess.processTransportParameters(&handshake.TransportParameters{
OriginalConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}
_, err := sess.processTransportParametersForClient(params.Marshal())
Expect(err).To(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad"))
})
Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad")))
})
})