only create a single session for duplicate Initials

This commit is contained in:
Marten Seemann 2020-05-29 15:54:25 +07:00
parent 85c19fbb5a
commit dad30e7c33
6 changed files with 123 additions and 58 deletions

View file

@ -271,8 +271,6 @@ func (c *client) dial(ctx context.Context) error {
c.version, c.version,
) )
c.mutex.Unlock() c.mutex.Unlock()
// It's not possible to use the stateless reset token for the client's (first) connection ID,
// since there's no way to securely communicate it to the server.
c.packetHandlers.Add(c.srcConnID, c) c.packetHandlers.Add(c.srcConnID, c)
errorChan := make(chan error, 1) errorChan := make(chan error, 1)

View file

@ -60,6 +60,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1)
} }
// AddWithConnID mocks base method
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)
return ret0
}
// AddWithConnID indicates an expected call of AddWithConnID
func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
}
// CloseServer mocks base method // CloseServer mocks base method
func (m *MockPacketHandlerManager) CloseServer() { func (m *MockPacketHandlerManager) CloseServer() {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -118,6 +118,21 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true return true
} }
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
sid := string(clientDestConnID)
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.handlers[sid]; ok {
return false
}
sess := fn()
h.handlers[sid] = sess
h.handlers[string(newConnID)] = sess
return true
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.mutex.Lock() h.mutex.Lock()
delete(h.handlers, string(id)) delete(h.handlers, string(id))

View file

@ -181,6 +181,14 @@ var _ = Describe("Packet Handler Map", func() {
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
}) })
It("says if a connection ID is already taken, for AddWithConnID", func() {
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
newConnID1 := protocol.ConnectionID{1, 2, 3, 4}
newConnID2 := protocol.ConnectionID{4, 3, 2, 1}
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
})
}) })
Context("running a server", func() { Context("running a server", func() {

View file

@ -34,6 +34,7 @@ type unknownPacketHandler interface {
} }
type packetHandlerManager interface { type packetHandlerManager interface {
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
Destroy() error Destroy() error
sessionRunner sessionRunner
SetServer(unknownPacketHandler) SetServer(unknownPacketHandler)
@ -421,39 +422,39 @@ func (s *baseServer) createNewSession(
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
version protocol.VersionNumber, version protocol.VersionNumber,
) quicSession { ) quicSession {
var qlogger qlog.Tracer var sess quicSession
if s.config.GetLogWriter != nil { if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler {
// Use the same connection ID that is passed to the client's GetLogWriter callback. var qlogger qlog.Tracer
connID := clientDestConnID if s.config.GetLogWriter != nil {
if origDestConnID.Len() > 0 { // Use the same connection ID that is passed to the client's GetLogWriter callback.
connID = origDestConnID connID := clientDestConnID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
if w := s.config.GetLogWriter(connID); w != nil {
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
}
} }
if w := s.config.GetLogWriter(connID); w != nil { sess = s.newSession(
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID) &conn{pconn: s.conn, currentAddr: remoteAddr},
} s.sessionHandler,
} origDestConnID,
sess := s.newSession( clientDestConnID,
&conn{pconn: s.conn, currentAddr: remoteAddr}, destConnID,
s.sessionHandler, srcConnID,
origDestConnID, s.sessionHandler.GetStatelessResetToken(srcConnID),
clientDestConnID, s.config,
destConnID, s.tlsConf,
srcConnID, s.tokenGenerator,
s.sessionHandler.GetStatelessResetToken(srcConnID), s.acceptEarlySessions,
s.config, qlogger,
s.tlsConf, s.logger,
s.tokenGenerator, version,
s.acceptEarlySessions, )
qlogger, return sess
s.logger, }); !added {
version,
)
if added := s.sessionHandler.Add(clientDestConnID, sess); !added {
// We're already keeping track of this connection ID.
// This might happen if we receive two copies of the Initial at the same time.
return nil return nil
} }
s.sessionHandler.Add(srcConnID, sess)
go sess.run() go sess.run()
go s.handleNewSession(sess) go s.handleNewSession(sess)
return sess return sess

View file

@ -403,10 +403,7 @@ var _ = Describe("Server", func() {
var token [16]byte var token [16]byte
rand.Read(token[:]) rand.Read(token[:])
var newConnID protocol.ConnectionID var newConnID protocol.ConnectionID
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
serv.newSession = func( serv.newSession = func(
_ connection, _ connection,
@ -439,9 +436,13 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool { newConnID = c
Expect(c).To(Equal(newConnID)) phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
fn()
return true return true
}) })
@ -502,14 +503,20 @@ var _ = Describe("Server", func() {
Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue()) Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue())
// Then receive the Initial packet. // Then receive the Initial packet.
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
fn()
return true
})
Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue()) Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue())
Expect(createdSession).To(BeTrue()) Expect(createdSession).To(BeTrue())
}) })
It("drops packets if the receive queue is full", func() { It("drops packets if the receive queue is full", func() {
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes() phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).AnyTimes()
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
acceptSession := make(chan struct{}) acceptSession := make(chan struct{})
@ -532,7 +539,12 @@ var _ = Describe("Server", func() {
) quicSession { ) quicSession {
<-acceptSession <-acceptSession
atomic.AddUint32(&counter, 1) atomic.AddUint32(&counter, 1)
return nil sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
} }
serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}))
@ -577,10 +589,9 @@ var _ = Describe("Server", func() {
} }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false)
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdSession).To(BeTrue()) Expect(createdSession).To(BeFalse())
}) })
It("rejects new connection attempts if the accept queue is full", func() { It("rejects new connection attempts if the accept queue is full", func() {
@ -612,8 +623,11 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize) phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize) wg.Add(protocol.MaxAcceptQueueSize)
@ -673,8 +687,11 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(p) serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
@ -772,8 +789,11 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake cancel() // complete the handshake
@ -836,8 +856,11 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
close(ready) close(ready)
@ -874,8 +897,11 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize) phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ { for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
serv.handlePacket(getInitialWithRandomDestConnID()) serv.handlePacket(getInitialWithRandomDestConnID())
} }
@ -927,8 +953,11 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(p) serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed()) Eventually(sessionCreated).Should(BeClosed())