mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
only create a single session for duplicate Initials
This commit is contained in:
parent
85c19fbb5a
commit
dad30e7c33
6 changed files with 123 additions and 58 deletions
|
@ -271,8 +271,6 @@ func (c *client) dial(ctx context.Context) error {
|
||||||
c.version,
|
c.version,
|
||||||
)
|
)
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
// It's not possible to use the stateless reset token for the client's (first) connection ID,
|
|
||||||
// since there's no way to securely communicate it to the server.
|
|
||||||
c.packetHandlers.Add(c.srcConnID, c)
|
c.packetHandlers.Add(c.srcConnID, c)
|
||||||
|
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
|
|
|
@ -60,6 +60,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddWithConnID mocks base method
|
||||||
|
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddWithConnID indicates an expected call of AddWithConnID
|
||||||
|
func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
// CloseServer mocks base method
|
// CloseServer mocks base method
|
||||||
func (m *MockPacketHandlerManager) CloseServer() {
|
func (m *MockPacketHandlerManager) CloseServer() {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -118,6 +118,21 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
|
sid := string(clientDestConnID)
|
||||||
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
|
if _, ok := h.handlers[sid]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sess := fn()
|
||||||
|
h.handlers[sid] = sess
|
||||||
|
h.handlers[string(newConnID)] = sess
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
delete(h.handlers, string(id))
|
delete(h.handlers, string(id))
|
||||||
|
|
|
@ -181,6 +181,14 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
||||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("says if a connection ID is already taken, for AddWithConnID", func() {
|
||||||
|
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
|
newConnID1 := protocol.ConnectionID{1, 2, 3, 4}
|
||||||
|
newConnID2 := protocol.ConnectionID{4, 3, 2, 1}
|
||||||
|
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
|
||||||
|
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("running a server", func() {
|
Context("running a server", func() {
|
||||||
|
|
61
server.go
61
server.go
|
@ -34,6 +34,7 @@ type unknownPacketHandler interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type packetHandlerManager interface {
|
type packetHandlerManager interface {
|
||||||
|
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
||||||
Destroy() error
|
Destroy() error
|
||||||
sessionRunner
|
sessionRunner
|
||||||
SetServer(unknownPacketHandler)
|
SetServer(unknownPacketHandler)
|
||||||
|
@ -421,39 +422,39 @@ func (s *baseServer) createNewSession(
|
||||||
srcConnID protocol.ConnectionID,
|
srcConnID protocol.ConnectionID,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) quicSession {
|
) quicSession {
|
||||||
var qlogger qlog.Tracer
|
var sess quicSession
|
||||||
if s.config.GetLogWriter != nil {
|
if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler {
|
||||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
var qlogger qlog.Tracer
|
||||||
connID := clientDestConnID
|
if s.config.GetLogWriter != nil {
|
||||||
if origDestConnID.Len() > 0 {
|
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||||
connID = origDestConnID
|
connID := clientDestConnID
|
||||||
|
if origDestConnID.Len() > 0 {
|
||||||
|
connID = origDestConnID
|
||||||
|
}
|
||||||
|
if w := s.config.GetLogWriter(connID); w != nil {
|
||||||
|
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if w := s.config.GetLogWriter(connID); w != nil {
|
sess = s.newSession(
|
||||||
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
|
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||||
}
|
s.sessionHandler,
|
||||||
}
|
origDestConnID,
|
||||||
sess := s.newSession(
|
clientDestConnID,
|
||||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
destConnID,
|
||||||
s.sessionHandler,
|
srcConnID,
|
||||||
origDestConnID,
|
s.sessionHandler.GetStatelessResetToken(srcConnID),
|
||||||
clientDestConnID,
|
s.config,
|
||||||
destConnID,
|
s.tlsConf,
|
||||||
srcConnID,
|
s.tokenGenerator,
|
||||||
s.sessionHandler.GetStatelessResetToken(srcConnID),
|
s.acceptEarlySessions,
|
||||||
s.config,
|
qlogger,
|
||||||
s.tlsConf,
|
s.logger,
|
||||||
s.tokenGenerator,
|
version,
|
||||||
s.acceptEarlySessions,
|
)
|
||||||
qlogger,
|
return sess
|
||||||
s.logger,
|
}); !added {
|
||||||
version,
|
|
||||||
)
|
|
||||||
if added := s.sessionHandler.Add(clientDestConnID, sess); !added {
|
|
||||||
// We're already keeping track of this connection ID.
|
|
||||||
// This might happen if we receive two copies of the Initial at the same time.
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.sessionHandler.Add(srcConnID, sess)
|
|
||||||
go sess.run()
|
go sess.run()
|
||||||
go s.handleNewSession(sess)
|
go s.handleNewSession(sess)
|
||||||
return sess
|
return sess
|
||||||
|
|
|
@ -403,10 +403,7 @@ var _ = Describe("Server", func() {
|
||||||
var token [16]byte
|
var token [16]byte
|
||||||
rand.Read(token[:])
|
rand.Read(token[:])
|
||||||
var newConnID protocol.ConnectionID
|
var newConnID protocol.ConnectionID
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
|
|
||||||
newConnID = c
|
|
||||||
return token
|
|
||||||
})
|
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
serv.newSession = func(
|
serv.newSession = func(
|
||||||
_ connection,
|
_ connection,
|
||||||
|
@ -439,9 +436,13 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
|
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
|
newConnID = c
|
||||||
Expect(c).To(Equal(newConnID))
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
|
||||||
|
newConnID = c
|
||||||
|
return token
|
||||||
|
})
|
||||||
|
fn()
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -502,14 +503,20 @@ var _ = Describe("Server", func() {
|
||||||
Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue())
|
Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue())
|
||||||
// Then receive the Initial packet.
|
// Then receive the Initial packet.
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
})
|
||||||
Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue())
|
Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue())
|
||||||
Expect(createdSession).To(BeTrue())
|
Expect(createdSession).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets if the receive queue is full", func() {
|
It("drops packets if the receive queue is full", func() {
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
|
serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
|
||||||
acceptSession := make(chan struct{})
|
acceptSession := make(chan struct{})
|
||||||
|
@ -532,7 +539,12 @@ var _ = Describe("Server", func() {
|
||||||
) quicSession {
|
) quicSession {
|
||||||
<-acceptSession
|
<-acceptSession
|
||||||
atomic.AddUint32(&counter, 1)
|
atomic.AddUint32(&counter, 1)
|
||||||
return nil
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
|
sess.EXPECT().handlePacket(gomock.Any())
|
||||||
|
sess.EXPECT().run()
|
||||||
|
sess.EXPECT().Context().Return(context.Background())
|
||||||
|
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||||
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}))
|
serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}))
|
||||||
|
@ -577,10 +589,9 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false)
|
||||||
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
|
|
||||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||||
Expect(createdSession).To(BeTrue())
|
Expect(createdSession).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects new connection attempts if the accept queue is full", func() {
|
It("rejects new connection attempts if the accept queue is full", func() {
|
||||||
|
@ -612,8 +623,11 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
}).Times(protocol.MaxAcceptQueueSize)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(protocol.MaxAcceptQueueSize)
|
wg.Add(protocol.MaxAcceptQueueSize)
|
||||||
|
@ -673,8 +687,11 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
|
@ -772,8 +789,11 @@ var _ = Describe("Server", func() {
|
||||||
sess.EXPECT().Context().Return(context.Background())
|
sess.EXPECT().Context().Return(context.Background())
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
})
|
||||||
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
cancel() // complete the handshake
|
cancel() // complete the handshake
|
||||||
|
@ -836,8 +856,11 @@ var _ = Describe("Server", func() {
|
||||||
sess.EXPECT().Context().Return(context.Background())
|
sess.EXPECT().Context().Return(context.Background())
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
})
|
||||||
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
close(ready)
|
close(ready)
|
||||||
|
@ -874,8 +897,11 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
}).Times(protocol.MaxAcceptQueueSize)
|
||||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||||
}
|
}
|
||||||
|
@ -927,8 +953,11 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||||
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
|
fn()
|
||||||
|
return true
|
||||||
|
})
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
Eventually(sessionCreated).Should(BeClosed())
|
Eventually(sessionCreated).Should(BeClosed())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue