move initialization of the client's transport parameters to the session

This commit is contained in:
Marten Seemann 2019-10-28 08:25:45 +07:00
parent b64535e656
commit 35ea8213c5
10 changed files with 37 additions and 89 deletions

View file

@ -368,6 +368,8 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
c.version, c.version,
) )
c.mutex.Unlock() c.mutex.Unlock()
// It's not possible to use the stateless reset token for the client's (first) connection ID,
// since there's no way to securely communicate it to the server.
c.packetHandlers.Add(c.srcConnID, c) c.packetHandlers.Add(c.srcConnID, c)
} }

View file

@ -338,45 +338,6 @@ var _ = Describe("Client", func() {
Eventually(dialed).Should(BeClosed()) Eventually(dialed).Should(BeClosed())
}) })
It("removes closed sessions from the multiplexer", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
manager.EXPECT().Retire(connID)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
var runner sessionRunner
sess := NewMockQuicSession(mockCtrl)
newClientSession = func(
_ connection,
runnerP sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) quicSession {
runner = runnerP
return sess
}
sess.EXPECT().run().Do(func() {
runner.Retire(connID)
})
sess.EXPECT().HandshakeComplete().Return(context.Background())
_, err := DialContext(
context.Background(),
packetConn,
addr,
"localhost:1337",
tlsConf,
&Config{},
)
Expect(err).ToNot(HaveOccurred())
})
It("closes the connection when it was created by DialAddr", func() { It("closes the connection when it was created by DialAddr", func() {
if os.Getenv("APPVEYOR") == "True" { if os.Getenv("APPVEYOR") == "True" {
Skip("This test is flaky on AppVeyor.") Skip("This test is flaky on AppVeyor.")

View file

@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde
} }
// Add mocks base method // Add mocks base method
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1) ret := m.ctrl.Call(m, "Add", arg0, arg1)
ret0, _ := ret[0].([16]byte)
return ret0
} }
// Add indicates an expected call of Add // Add indicates an expected call of Add
@ -84,20 +86,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
} }
// GetStatelessResetToken mocks base method
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
ret0, _ := ret[0].([16]byte)
return ret0
}
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0)
}
// Remove mocks base method // Remove mocks base method
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -34,6 +34,20 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
return m.recorder return m.recorder
} }
// Add mocks base method
func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0, arg1)
ret0, _ := ret[0].([16]byte)
return ret0
}
// Add indicates an expected call of Add
func (mr *MockSessionRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionRunner)(nil).Add), arg0, arg1)
}
// AddResetToken mocks base method // AddResetToken mocks base method
func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) { func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -46,20 +60,6 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1)
} }
// GetStatelessResetToken mocks base method
func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
ret0, _ := ret[0].([16]byte)
return ret0
}
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
func (mr *MockSessionRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockSessionRunner)(nil).GetStatelessResetToken), arg0)
}
// Remove mocks base method // Remove mocks base method
func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -95,10 +95,11 @@ func (h *packetHandlerMap) logUsage() {
} }
} }
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte {
h.mutex.Lock() h.mutex.Lock()
h.handlers[string(id)] = handler h.handlers[string(id)] = handler
h.mutex.Unlock() h.mutex.Unlock()
return h.getStatelessResetToken(id)
} }
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
@ -283,7 +284,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
return false return false
} }
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { func (h *packetHandlerMap) getStatelessResetToken(connID protocol.ConnectionID) [16]byte {
var token [16]byte var token [16]byte
if !h.statelessResetEnabled { if !h.statelessResetEnabled {
// Return a random stateless reset token. // Return a random stateless reset token.
@ -310,7 +311,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro
if len(p.data) <= protocol.MinStatelessResetSize { if len(p.data) <= protocol.MinStatelessResetSize {
return return
} }
token := h.GetStatelessResetToken(connID) token := h.getStatelessResetToken(connID)
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data) rand.Read(data)

View file

@ -292,9 +292,9 @@ var _ = Describe("Packet Handler Map", func() {
It("generates stateless reset tokens", func() { It("generates stateless reset tokens", func() {
connID1 := []byte{0xde, 0xad, 0xbe, 0xef} connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
connID2 := []byte{0xde, 0xca, 0xfb, 0xad} connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
token1 := handler.GetStatelessResetToken(connID1) token1 := handler.Add(connID1, nil)
Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) Expect(handler.Add(connID1, nil)).To(Equal(token1))
Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) Expect(handler.Add(connID2, nil)).ToNot(Equal(token1))
}) })
It("sends stateless resets", func() { It("sends stateless resets", func() {

View file

@ -34,7 +34,6 @@ type unknownPacketHandler interface {
type packetHandlerManager interface { type packetHandlerManager interface {
io.Closer io.Closer
Add(protocol.ConnectionID, packetHandler)
SetServer(unknownPacketHandler) SetServer(unknownPacketHandler)
CloseServer() CloseServer()
sessionRunner sessionRunner
@ -364,7 +363,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
s.logger.Debugf("<- Received Initial packet.") s.logger.Debugf("<- Received Initial packet.")
sess, connID, err := s.handleInitialImpl(p, hdr) sess, err := s.handleInitialImpl(p, hdr)
if err != nil { if err != nil {
s.logger.Errorf("Error occurred handling initial packet: %s", err) s.logger.Errorf("Error occurred handling initial packet: %s", err)
return false return false
@ -374,13 +373,12 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
} }
// Don't put the packet buffer back if a new session was created. // Don't put the packet buffer back if a new session was created.
// The session will handle the packet and take of that. // The session will handle the packet and take of that.
s.sessionHandler.Add(connID, sess)
return true return true
} }
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) { func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, error) {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, nil, errors.New("too short connection ID") return nil, errors.New("too short connection ID")
} }
var token *Token var token *Token
@ -400,17 +398,17 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session. // If no Retry is sent, the packet will be logged by the session.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
return nil, nil, s.sendRetry(p.remoteAddr, hdr) return nil, s.sendRetry(p.remoteAddr, hdr)
} }
if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize { if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
return nil, nil, s.sendServerBusy(p.remoteAddr, hdr) return nil, s.sendServerBusy(p.remoteAddr, hdr)
} }
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
sess := s.createNewSession( sess := s.createNewSession(
@ -422,7 +420,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
hdr.Version, hdr.Version,
) )
sess.handlePacket(p) sess.handlePacket(p)
return sess, connID, nil return sess, nil
} }
func (s *baseServer) createNewSession( func (s *baseServer) createNewSession(

View file

@ -433,7 +433,6 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
// make the go routine return // make the go routine return
sess.EXPECT().getPerspective()
Expect(serv.Close()).To(Succeed()) Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -674,7 +673,6 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
// make the go routine return // make the go routine return
sess.EXPECT().getPerspective()
Expect(serv.Close()).To(Succeed()) Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })

View file

@ -73,13 +73,13 @@ func (p *receivedPacket) Clone() *receivedPacket {
} }
type sessionRunner interface { type sessionRunner interface {
Add(protocol.ConnectionID, packetHandler) [16]byte
Retire(protocol.ConnectionID) Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID) Remove(protocol.ConnectionID)
ReplaceWithClosed(protocol.ConnectionID, packetHandler) ReplaceWithClosed(protocol.ConnectionID, packetHandler)
AddResetToken([16]byte, packetHandler) AddResetToken([16]byte, packetHandler)
RemoveResetToken([16]byte) RemoveResetToken([16]byte)
RetireResetToken([16]byte) RetireResetToken([16]byte)
GetStatelessResetToken(protocol.ConnectionID) [16]byte
} }
type handshakeRunner struct { type handshakeRunner struct {
@ -220,7 +220,7 @@ var newSession = func(
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer) oneRTTStream := newPostHandshakeCryptoStream(s.framer)
token := s.sessionRunner.GetStatelessResetToken(srcConnID) token := s.sessionRunner.Add(srcConnID, s)
params := &handshake.TransportParameters{ params := &handshake.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,

View file

@ -104,7 +104,7 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
sessionRunner = NewMockSessionRunner(mockCtrl) sessionRunner = NewMockSessionRunner(mockCtrl)
sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()) sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any())
mconn = newMockConnection() mconn = newMockConnection()
tokenGenerator, err := handshake.NewTokenGenerator() tokenGenerator, err := handshake.NewTokenGenerator()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())