mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
move initialization of the client's transport parameters to the session
This commit is contained in:
parent
b64535e656
commit
35ea8213c5
10 changed files with 37 additions and 89 deletions
|
@ -368,6 +368,8 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
|
||||||
c.version,
|
c.version,
|
||||||
)
|
)
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
|
// It's not possible to use the stateless reset token for the client's (first) connection ID,
|
||||||
|
// since there's no way to securely communicate it to the server.
|
||||||
c.packetHandlers.Add(c.srcConnID, c)
|
c.packetHandlers.Add(c.srcConnID, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -338,45 +338,6 @@ var _ = Describe("Client", func() {
|
||||||
Eventually(dialed).Should(BeClosed())
|
Eventually(dialed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("removes closed sessions from the multiplexer", func() {
|
|
||||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
|
||||||
manager.EXPECT().Add(connID, gomock.Any())
|
|
||||||
manager.EXPECT().Retire(connID)
|
|
||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
|
||||||
|
|
||||||
var runner sessionRunner
|
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
|
||||||
newClientSession = func(
|
|
||||||
_ connection,
|
|
||||||
runnerP sessionRunner,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ protocol.ConnectionID,
|
|
||||||
_ *Config,
|
|
||||||
_ *tls.Config,
|
|
||||||
_ protocol.PacketNumber,
|
|
||||||
_ protocol.VersionNumber,
|
|
||||||
_ utils.Logger,
|
|
||||||
_ protocol.VersionNumber,
|
|
||||||
) quicSession {
|
|
||||||
runner = runnerP
|
|
||||||
return sess
|
|
||||||
}
|
|
||||||
sess.EXPECT().run().Do(func() {
|
|
||||||
runner.Retire(connID)
|
|
||||||
})
|
|
||||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
|
||||||
|
|
||||||
_, err := DialContext(
|
|
||||||
context.Background(),
|
|
||||||
packetConn,
|
|
||||||
addr,
|
|
||||||
"localhost:1337",
|
|
||||||
tlsConf,
|
|
||||||
&Config{},
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("closes the connection when it was created by DialAddr", func() {
|
It("closes the connection when it was created by DialAddr", func() {
|
||||||
if os.Getenv("APPVEYOR") == "True" {
|
if os.Getenv("APPVEYOR") == "True" {
|
||||||
Skip("This test is flaky on AppVeyor.")
|
Skip("This test is flaky on AppVeyor.")
|
||||||
|
|
|
@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add mocks base method
|
// Add mocks base method
|
||||||
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].([16]byte)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add indicates an expected call of Add
|
// Add indicates an expected call of Add
|
||||||
|
@ -84,20 +86,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatelessResetToken mocks base method
|
|
||||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
|
||||||
ret0, _ := ret[0].([16]byte)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
|
|
||||||
func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove mocks base method
|
// Remove mocks base method
|
||||||
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -34,6 +34,20 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add mocks base method
|
||||||
|
func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].([16]byte)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add indicates an expected call of Add
|
||||||
|
func (mr *MockSessionRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionRunner)(nil).Add), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// AddResetToken mocks base method
|
// AddResetToken mocks base method
|
||||||
func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
|
func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -46,20 +60,6 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatelessResetToken mocks base method
|
|
||||||
func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
|
||||||
ret0, _ := ret[0].([16]byte)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
|
|
||||||
func (mr *MockSessionRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockSessionRunner)(nil).GetStatelessResetToken), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove mocks base method
|
// Remove mocks base method
|
||||||
func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) {
|
func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -95,10 +95,11 @@ func (h *packetHandlerMap) logUsage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
|
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.handlers[string(id)] = handler
|
h.handlers[string(id)] = handler
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
|
return h.getStatelessResetToken(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||||
|
@ -283,7 +284,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
func (h *packetHandlerMap) getStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
||||||
var token [16]byte
|
var token [16]byte
|
||||||
if !h.statelessResetEnabled {
|
if !h.statelessResetEnabled {
|
||||||
// Return a random stateless reset token.
|
// Return a random stateless reset token.
|
||||||
|
@ -310,7 +311,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro
|
||||||
if len(p.data) <= protocol.MinStatelessResetSize {
|
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token := h.GetStatelessResetToken(connID)
|
token := h.getStatelessResetToken(connID)
|
||||||
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
|
|
|
@ -292,9 +292,9 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
It("generates stateless reset tokens", func() {
|
It("generates stateless reset tokens", func() {
|
||||||
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
|
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||||
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
|
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||||
token1 := handler.GetStatelessResetToken(connID1)
|
token1 := handler.Add(connID1, nil)
|
||||||
Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1))
|
Expect(handler.Add(connID1, nil)).To(Equal(token1))
|
||||||
Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1))
|
Expect(handler.Add(connID2, nil)).ToNot(Equal(token1))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends stateless resets", func() {
|
It("sends stateless resets", func() {
|
||||||
|
|
16
server.go
16
server.go
|
@ -34,7 +34,6 @@ type unknownPacketHandler interface {
|
||||||
|
|
||||||
type packetHandlerManager interface {
|
type packetHandlerManager interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
Add(protocol.ConnectionID, packetHandler)
|
|
||||||
SetServer(unknownPacketHandler)
|
SetServer(unknownPacketHandler)
|
||||||
CloseServer()
|
CloseServer()
|
||||||
sessionRunner
|
sessionRunner
|
||||||
|
@ -364,7 +363,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
|
||||||
|
|
||||||
s.logger.Debugf("<- Received Initial packet.")
|
s.logger.Debugf("<- Received Initial packet.")
|
||||||
|
|
||||||
sess, connID, err := s.handleInitialImpl(p, hdr)
|
sess, err := s.handleInitialImpl(p, hdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||||
return false
|
return false
|
||||||
|
@ -374,13 +373,12 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
|
||||||
}
|
}
|
||||||
// Don't put the packet buffer back if a new session was created.
|
// Don't put the packet buffer back if a new session was created.
|
||||||
// The session will handle the packet and take of that.
|
// The session will handle the packet and take of that.
|
||||||
s.sessionHandler.Add(connID, sess)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) {
|
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, error) {
|
||||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||||
return nil, nil, errors.New("too short connection ID")
|
return nil, errors.New("too short connection ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
var token *Token
|
var token *Token
|
||||||
|
@ -400,17 +398,17 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
|
||||||
// Log the Initial packet now.
|
// Log the Initial packet now.
|
||||||
// If no Retry is sent, the packet will be logged by the session.
|
// If no Retry is sent, the packet will be logged by the session.
|
||||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||||
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
return nil, s.sendRetry(p.remoteAddr, hdr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
|
if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize {
|
||||||
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
|
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
|
||||||
return nil, nil, s.sendServerBusy(p.remoteAddr, hdr)
|
return nil, s.sendServerBusy(p.remoteAddr, hdr)
|
||||||
}
|
}
|
||||||
|
|
||||||
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||||
sess := s.createNewSession(
|
sess := s.createNewSession(
|
||||||
|
@ -422,7 +420,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
|
||||||
hdr.Version,
|
hdr.Version,
|
||||||
)
|
)
|
||||||
sess.handlePacket(p)
|
sess.handlePacket(p)
|
||||||
return sess, connID, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *baseServer) createNewSession(
|
func (s *baseServer) createNewSession(
|
||||||
|
|
|
@ -433,7 +433,6 @@ var _ = Describe("Server", func() {
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
sess.EXPECT().getPerspective()
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
Expect(serv.Close()).To(Succeed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -674,7 +673,6 @@ var _ = Describe("Server", func() {
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
|
||||||
// make the go routine return
|
// make the go routine return
|
||||||
sess.EXPECT().getPerspective()
|
|
||||||
Expect(serv.Close()).To(Succeed())
|
Expect(serv.Close()).To(Succeed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
|
@ -73,13 +73,13 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
||||||
}
|
}
|
||||||
|
|
||||||
type sessionRunner interface {
|
type sessionRunner interface {
|
||||||
|
Add(protocol.ConnectionID, packetHandler) [16]byte
|
||||||
Retire(protocol.ConnectionID)
|
Retire(protocol.ConnectionID)
|
||||||
Remove(protocol.ConnectionID)
|
Remove(protocol.ConnectionID)
|
||||||
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
|
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
|
||||||
AddResetToken([16]byte, packetHandler)
|
AddResetToken([16]byte, packetHandler)
|
||||||
RemoveResetToken([16]byte)
|
RemoveResetToken([16]byte)
|
||||||
RetireResetToken([16]byte)
|
RetireResetToken([16]byte)
|
||||||
GetStatelessResetToken(protocol.ConnectionID) [16]byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type handshakeRunner struct {
|
type handshakeRunner struct {
|
||||||
|
@ -220,7 +220,7 @@ var newSession = func(
|
||||||
initialStream := newCryptoStream()
|
initialStream := newCryptoStream()
|
||||||
handshakeStream := newCryptoStream()
|
handshakeStream := newCryptoStream()
|
||||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
||||||
token := s.sessionRunner.GetStatelessResetToken(srcConnID)
|
token := s.sessionRunner.Add(srcConnID, s)
|
||||||
params := &handshake.TransportParameters{
|
params := &handshake.TransportParameters{
|
||||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||||
|
|
|
@ -104,7 +104,7 @@ var _ = Describe("Session", func() {
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
Eventually(areSessionsRunning).Should(BeFalse())
|
||||||
|
|
||||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||||
sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any())
|
sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||||
mconn = newMockConnection()
|
mconn = newMockConnection()
|
||||||
tokenGenerator, err := handshake.NewTokenGenerator()
|
tokenGenerator, err := handshake.NewTokenGenerator()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue