refactor session creation in the server

This commit is contained in:
Marten Seemann 2021-03-08 12:46:22 +08:00
parent bab0384444
commit ecc86aa1ab
2 changed files with 32 additions and 43 deletions

View file

@ -401,10 +401,10 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
var (
token *Token
retrySrcConnectionID *protocol.ConnectionID
token *Token
retrySrcConnID *protocol.ConnectionID
)
origDestConnectionID := hdr.DestConnectionID
origDestConnID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
c, err := s.tokenGenerator.DecodeToken(hdr.Token)
if err == nil {
@ -414,8 +414,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
SentTime: c.SentTime,
}
if token.IsRetryToken {
origDestConnectionID = c.OriginalDestConnectionID
retrySrcConnectionID = &c.RetrySrcConnectionID
origDestConnID = c.OriginalDestConnectionID
retrySrcConnID = &c.RetrySrcConnectionID
}
}
}
@ -451,68 +451,47 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
sess := s.createNewSession(
p.remoteAddr,
origDestConnectionID,
retrySrcConnectionID,
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
hdr.Version,
)
if sess == nil {
p.buffer.Release()
return nil
}
sess.handlePacket(p)
s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess)
return nil
}
func (s *baseServer) createNewSession(
remoteAddr net.Addr,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
version protocol.VersionNumber,
) quicSession {
var sess quicSession
if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler {
if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
var tracer logging.ConnectionTracer
if s.config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := clientDestConnID
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID)
}
sess = s.newSession(
newSendConn(s.conn, remoteAddr),
newSendConn(s.conn, p.remoteAddr),
s.sessionHandler,
origDestConnID,
retrySrcConnID,
clientDestConnID,
destConnID,
srcConnID,
s.sessionHandler.GetStatelessResetToken(srcConnID),
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.sessionHandler.GetStatelessResetToken(connID),
s.config,
s.tlsConf,
s.tokenGenerator,
s.acceptEarlySessions,
tracer,
s.logger,
version,
hdr.Version,
)
sess.handlePacket(p)
return sess
}); !added {
return nil
}
go sess.run()
go s.handleNewSession(sess)
return sess
if sess == nil {
p.buffer.Release()
return nil
}
s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess)
return nil
}
func (s *baseServer) handleNewSession(sess quicSession) {

View file

@ -958,6 +958,7 @@ var _ = Describe("Server", func() {
}()
ctx, cancel := context.WithCancel(context.Background()) // handshake context
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.newSession = func(
_ sendConn,
runner sessionRunner,
@ -975,6 +976,7 @@ var _ = Describe("Server", func() {
_ utils.Logger,
_ protocol.VersionNumber,
) quicSession {
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().HandshakeComplete().Return(ctx)
sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
@ -986,7 +988,10 @@ var _ = Describe("Server", func() {
return true
})
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any())
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever)
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},
)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
Eventually(done).Should(BeClosed())
@ -1026,6 +1031,7 @@ var _ = Describe("Server", func() {
}()
ready := make(chan struct{})
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.newSession = func(
_ sendConn,
runner sessionRunner,
@ -1044,6 +1050,7 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicSession {
Expect(enable0RTT).To(BeTrue())
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run().Do(func() {})
sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background())
@ -1054,7 +1061,10 @@ var _ = Describe("Server", func() {
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever)
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},
)
Consistently(done).ShouldNot(BeClosed())
close(ready)
Eventually(done).Should(BeClosed())