mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
move initialization of the client's transport parameters to the session
This commit is contained in:
parent
b64535e656
commit
35ea8213c5
10 changed files with 37 additions and 89 deletions
|
@ -368,6 +368,8 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
|
|||
c.version,
|
||||
)
|
||||
c.mutex.Unlock()
|
||||
// It's not possible to use the stateless reset token for the client's (first) connection ID,
|
||||
// since there's no way to securely communicate it to the server.
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
}
|
||||
|
||||
|
|
|
@ -338,45 +338,6 @@ var _ = Describe("Client", func() {
|
|||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("removes closed sessions from the multiplexer", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
manager.EXPECT().Retire(connID)
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
var runner sessionRunner
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
runnerP sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
runner = runnerP
|
||||
return sess
|
||||
}
|
||||
sess.EXPECT().run().Do(func() {
|
||||
runner.Retire(connID)
|
||||
})
|
||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||
|
||||
_, err := DialContext(
|
||||
context.Background(),
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
tlsConf,
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("closes the connection when it was created by DialAddr", func() {
|
||||
if os.Getenv("APPVEYOR") == "True" {
|
||||
Skip("This test is flaky on AppVeyor.")
|
||||
|
|
|
@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde
|
|||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
||||
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
ret0, _ := ret[0].([16]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
|
@ -84,20 +86,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method
|
||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
||||
ret0, _ := ret[0].([16]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -34,6 +34,20 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
ret0, _ := ret[0].([16]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockSessionRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionRunner)(nil).Add), arg0, arg1)
|
||||
}
|
||||
|
||||
// AddResetToken mocks base method
|
||||
func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -46,20 +60,6 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method
|
||||
func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
||||
ret0, _ := ret[0].([16]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
|
||||
func (mr *MockSessionRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockSessionRunner)(nil).GetStatelessResetToken), arg0)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -95,10 +95,11 @@ func (h *packetHandlerMap) logUsage() {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
|
||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte {
|
||||
h.mutex.Lock()
|
||||
h.handlers[string(id)] = handler
|
||||
h.mutex.Unlock()
|
||||
return h.getStatelessResetToken(id)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
|
@ -283,7 +284,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
||||
func (h *packetHandlerMap) getStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
||||
var token [16]byte
|
||||
if !h.statelessResetEnabled {
|
||||
// Return a random stateless reset token.
|
||||
|
@ -310,7 +311,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro
|
|||
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||
return
|
||||
}
|
||||
token := h.GetStatelessResetToken(connID)
|
||||
token := h.getStatelessResetToken(connID)
|
||||
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
|
|
|
@ -292,9 +292,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
It("generates stateless reset tokens", func() {
|
||||
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
token1 := handler.GetStatelessResetToken(connID1)
|
||||
Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1))
|
||||
Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1))
|
||||
token1 := handler.Add(connID1, nil)
|
||||
Expect(handler.Add(connID1, nil)).To(Equal(token1))
|
||||
Expect(handler.Add(connID2, nil)).ToNot(Equal(token1))
|
||||
})
|
||||
|
||||
It("sends stateless resets", func() {
|
||||
|
|
16
server.go
16
server.go
|
@ -34,7 +34,6 @@ type unknownPacketHandler interface {
|
|||
|
||||
type packetHandlerManager interface {
|
||||
io.Closer
|
||||
Add(protocol.ConnectionID, packetHandler)
|
||||
SetServer(unknownPacketHandler)
|
||||
CloseServer()
|
||||
sessionRunner
|
||||
|
@ -364,7 +363,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
|
|||
|
||||
s.logger.Debugf("<- Received Initial packet.")
|
||||
|
||||
sess, connID, err := s.handleInitialImpl(p, hdr)
|
||||
sess, err := s.handleInitialImpl(p, hdr)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||
return false
|
||||
|
@ -374,13 +373,12 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
|
|||
}
|
||||
// Don't put the packet buffer back if a new session was created.
|
||||
// The session will handle the packet and take of that.
|
||||
s.sessionHandler.Add(connID, sess)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) {
|
||||
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, error) {
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
return nil, nil, errors.New("too short connection ID")
|
||||
return nil, errors.New("too short connection ID")
|
||||
}
|
||||
|
||||
var token *Token
|
||||
|
@ -400,17 +398,17 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
|
|||
// Log the Initial packet now.
|
||||
// If no Retry is sent, the packet will be logged by the session.
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
||||
return nil, s.sendRetry(p.remoteAddr, hdr)
|
||||
}
|
||||
|
||||
if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
|
||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
|
||||
return nil, nil, s.sendServerBusy(p.remoteAddr, hdr)
|
||||
return nil, s.sendServerBusy(p.remoteAddr, hdr)
|
||||
}
|
||||
|
||||
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
sess := s.createNewSession(
|
||||
|
@ -422,7 +420,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
|
|||
hdr.Version,
|
||||
)
|
||||
sess.handlePacket(p)
|
||||
return sess, connID, nil
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *baseServer) createNewSession(
|
||||
|
|
|
@ -433,7 +433,6 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sess.EXPECT().getPerspective()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -674,7 +673,6 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sess.EXPECT().getPerspective()
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
|
|
@ -73,13 +73,13 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
|||
}
|
||||
|
||||
type sessionRunner interface {
|
||||
Add(protocol.ConnectionID, packetHandler) [16]byte
|
||||
Retire(protocol.ConnectionID)
|
||||
Remove(protocol.ConnectionID)
|
||||
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
|
||||
AddResetToken([16]byte, packetHandler)
|
||||
RemoveResetToken([16]byte)
|
||||
RetireResetToken([16]byte)
|
||||
GetStatelessResetToken(protocol.ConnectionID) [16]byte
|
||||
}
|
||||
|
||||
type handshakeRunner struct {
|
||||
|
@ -220,7 +220,7 @@ var newSession = func(
|
|||
initialStream := newCryptoStream()
|
||||
handshakeStream := newCryptoStream()
|
||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
||||
token := s.sessionRunner.GetStatelessResetToken(srcConnID)
|
||||
token := s.sessionRunner.Add(srcConnID, s)
|
||||
params := &handshake.TransportParameters{
|
||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||
|
|
|
@ -104,7 +104,7 @@ var _ = Describe("Session", func() {
|
|||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
|
||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||
sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mconn = newMockConnection()
|
||||
tokenGenerator, err := handshake.NewTokenGenerator()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue