mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
only create a single session when two Initials arrive at the same time
This commit is contained in:
parent
5a834851a8
commit
e65df402dd
7 changed files with 190 additions and 59 deletions
|
@ -48,6 +48,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
|
||||
}
|
||||
|
||||
// AddIfNotTaken mocks base method
|
||||
func (m *MockPacketHandlerManager) AddIfNotTaken(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddIfNotTaken", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddIfNotTaken indicates an expected call of AddIfNotTaken
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) AddIfNotTaken(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIfNotTaken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddIfNotTaken), arg0, arg1)
|
||||
}
|
||||
|
||||
// AddResetToken mocks base method
|
||||
func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -86,6 +100,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method
|
||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0)
|
||||
ret0, _ := ret[0].([16]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStatelessResetToken indicates an expected call of GetStatelessResetToken
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -99,7 +99,19 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
|||
h.mutex.Lock()
|
||||
h.handlers[string(id)] = handler
|
||||
h.mutex.Unlock()
|
||||
return h.getStatelessResetToken(id)
|
||||
return h.GetStatelessResetToken(id)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
|
||||
sid := string(id)
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if _, ok := h.handlers[sid]; !ok {
|
||||
h.handlers[sid] = handler
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||
|
@ -284,7 +296,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) getStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
||||
var token [16]byte
|
||||
if !h.statelessResetEnabled {
|
||||
// Return a random stateless reset token.
|
||||
|
@ -311,7 +323,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro
|
|||
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||
return
|
||||
}
|
||||
token := h.getStatelessResetToken(connID)
|
||||
token := h.GetStatelessResetToken(connID)
|
||||
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
|
|
|
@ -175,6 +175,12 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
conn.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("says if a connection ID is already taken", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
||||
Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("running a server", func() {
|
||||
|
@ -289,7 +295,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
statelessResetKey = key
|
||||
})
|
||||
|
||||
It("generates stateless reset tokens", func() {
|
||||
It("generates stateless reset tokens when adding new sessions", func() {
|
||||
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
token1 := handler.Add(connID1, nil)
|
||||
|
@ -297,6 +303,12 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Expect(handler.Add(connID2, nil)).ToNot(Equal(token1))
|
||||
})
|
||||
|
||||
It("generates stateless reset tokens", func() {
|
||||
connID1 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
connID2 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
|
||||
})
|
||||
|
||||
It("sends stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
|
|
23
server.go
23
server.go
|
@ -17,6 +17,8 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
"github.com/onsi/ginkgo"
|
||||
)
|
||||
|
||||
// packetHandler handles packets
|
||||
|
@ -37,6 +39,8 @@ type packetHandlerManager interface {
|
|||
SetServer(unknownPacketHandler)
|
||||
CloseServer()
|
||||
sessionRunner
|
||||
AddIfNotTaken(protocol.ConnectionID, packetHandler) bool
|
||||
GetStatelessResetToken(protocol.ConnectionID) [16]byte
|
||||
}
|
||||
|
||||
type quicSession interface {
|
||||
|
@ -70,7 +74,7 @@ type baseServer struct {
|
|||
sessionHandler packetHandlerManager
|
||||
|
||||
// set as a member, so they can be set in the tests
|
||||
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession
|
||||
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession
|
||||
|
||||
serverError error
|
||||
errorChan chan struct{}
|
||||
|
@ -327,6 +331,7 @@ func (s *baseServer) Addr() net.Addr {
|
|||
|
||||
func (s *baseServer) handlePacket(p *receivedPacket) {
|
||||
go func() {
|
||||
defer ginkgo.GinkgoRecover()
|
||||
if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer {
|
||||
p.buffer.Release()
|
||||
}
|
||||
|
@ -368,7 +373,9 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet
|
|||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||
return false
|
||||
}
|
||||
if sess == nil { // a retry was done, or the connection attempt was rejected
|
||||
// A retry was done, or the connection attempt was rejected,
|
||||
// or if the Initial was a duplicate.
|
||||
if sess == nil {
|
||||
return false
|
||||
}
|
||||
// Don't put the packet buffer back if a new session was created.
|
||||
|
@ -419,7 +426,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
|
|||
connID,
|
||||
hdr.Version,
|
||||
)
|
||||
sess.handlePacket(p)
|
||||
if sess != nil {
|
||||
sess.handlePacket(p)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
|
@ -438,12 +447,20 @@ func (s *baseServer) createNewSession(
|
|||
clientDestConnID,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
s.sessionHandler.GetStatelessResetToken(srcConnID),
|
||||
s.config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
s.logger,
|
||||
version,
|
||||
)
|
||||
added := s.sessionHandler.AddIfNotTaken(clientDestConnID, sess)
|
||||
// We're already keeping track of this connection ID.
|
||||
// This might happen if we receive two copies of the Initial at the same time.
|
||||
if !added {
|
||||
return nil
|
||||
}
|
||||
s.sessionHandler.Add(srcConnID, sess)
|
||||
go sess.run()
|
||||
go s.handleNewSession(sess)
|
||||
return sess
|
||||
|
|
159
server_test.go
159
server_test.go
|
@ -3,6 +3,7 @@ package quic
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
@ -39,6 +41,29 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
}
|
||||
|
||||
getInitial := func(destConnID protocol.ConnectionID) *receivedPacket {
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: destConnID,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.buffer = getPacketBuffer()
|
||||
p.remoteAddr = senderAddr
|
||||
return p
|
||||
}
|
||||
|
||||
getInitialWithRandomDestConnID := func() *receivedPacket {
|
||||
destConnID := make([]byte, 10)
|
||||
_, err := rand.Read(destConnID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return getInitial(destConnID)
|
||||
}
|
||||
|
||||
parseHeader := func(data []byte) *wire.Header {
|
||||
hdr, _, _, err := wire.ParsePacket(data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -127,12 +152,17 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
Context("server accepting sessions that completed the handshake", func() {
|
||||
var serv *baseServer
|
||||
var (
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
ln, err := Listen(conn, tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.(*baseServer)
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
serv.sessionHandler = phm
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
|
@ -282,6 +312,14 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
run := make(chan struct{})
|
||||
var token [16]byte
|
||||
rand.Read(token[:])
|
||||
var newConnID protocol.ConnectionID
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
|
@ -289,6 +327,7 @@ var _ = Describe("Server", func() {
|
|||
origConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
tokenP [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -300,7 +339,8 @@ var _ = Describe("Server", func() {
|
|||
// make sure we're using a server-generated connection ID
|
||||
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
|
||||
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
Expect(srcConnID).To(Equal(newConnID))
|
||||
Expect(tokenP).To(Equal(token))
|
||||
sess.EXPECT().handlePacket(p)
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sess.EXPECT().Context().Return(context.Background())
|
||||
|
@ -308,6 +348,11 @@ var _ = Describe("Server", func() {
|
|||
return sess
|
||||
}
|
||||
|
||||
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
|
||||
phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) {
|
||||
Expect(c).To(Equal(newConnID))
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -321,19 +366,10 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
It("only creates a single session for a duplicate Initial", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
var createdSession bool
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
|
@ -341,6 +377,35 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
createdSession = true
|
||||
return sess
|
||||
}
|
||||
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
|
||||
Expect(serv.handlePacketImpl(p)).To(BeFalse())
|
||||
Expect(createdSession).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -348,7 +413,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(p)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().run()
|
||||
sess.EXPECT().Context().Return(context.Background())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -357,21 +422,28 @@ var _ = Describe("Server", func() {
|
|||
return sess
|
||||
}
|
||||
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
|
||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
serv.handlePacket(p)
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
p := getInitialWithRandomDestConnID()
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv.handlePacket(p)
|
||||
var reject mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&reject))
|
||||
Expect(reject.to).To(Equal(senderAddr))
|
||||
Expect(reject.to).To(Equal(p.remoteAddr))
|
||||
rejectHdr := parseHeader(reject.data)
|
||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(rejectHdr.Version).To(Equal(hdr.Version))
|
||||
|
@ -381,17 +453,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("doesn't accept new sessions if they were closed in the mean time", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionCreated := make(chan struct{})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -402,6 +465,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -418,6 +482,10 @@ var _ = Describe("Server", func() {
|
|||
return sess
|
||||
}
|
||||
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
||||
serv.handlePacket(p)
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
Eventually(sessionCreated).Should(BeClosed())
|
||||
|
@ -433,6 +501,8 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
phm.EXPECT().CloseServer()
|
||||
sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -498,6 +568,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -509,6 +580,9 @@ var _ = Describe("Server", func() {
|
|||
sess.EXPECT().Context().Return(context.Background())
|
||||
return sess
|
||||
}
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
cancel() // complete the handshake
|
||||
|
@ -546,6 +620,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -567,15 +642,6 @@ var _ = Describe("Server", func() {
|
|||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
|
@ -583,6 +649,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -592,7 +659,7 @@ var _ = Describe("Server", func() {
|
|||
ready := make(chan struct{})
|
||||
close(ready)
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(p)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().run()
|
||||
sess.EXPECT().earlySessionReady().Return(ready)
|
||||
sess.EXPECT().Context().Return(context.Background())
|
||||
|
@ -605,11 +672,14 @@ var _ = Describe("Server", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
serv.handlePacket(p)
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
p := getInitialWithRandomDestConnID()
|
||||
hdr := parseHeader(p.data)
|
||||
serv.handlePacket(p)
|
||||
var reject mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&reject))
|
||||
|
@ -623,17 +693,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("doesn't accept new sessions if they were closed in the mean time", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionCreated := make(chan struct{})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -644,6 +705,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -673,6 +735,7 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
|
|
@ -196,6 +196,7 @@ var newSession = func(
|
|||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
statelessResetToken [16]byte,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
tokenGenerator *handshake.TokenGenerator,
|
||||
|
@ -239,8 +240,6 @@ var newSession = func(
|
|||
initialStream := newCryptoStream()
|
||||
handshakeStream := newCryptoStream()
|
||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
||||
runner.Add(clientDestConnID, s)
|
||||
token := runner.Add(srcConnID, s)
|
||||
params := &handshake.TransportParameters{
|
||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||
|
@ -252,7 +251,7 @@ var newSession = func(
|
|||
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
|
||||
AckDelayExponent: protocol.AckDelayExponent,
|
||||
DisableMigration: true,
|
||||
StatelessResetToken: &token,
|
||||
StatelessResetToken: &statelessResetToken,
|
||||
OriginalConnectionID: origDestConnID,
|
||||
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
|
||||
}
|
||||
|
|
|
@ -108,7 +108,6 @@ var _ = Describe("Session", func() {
|
|||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
|
||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||
sessionRunner.EXPECT().Add(gomock.Any(), gomock.Any()).Times(2)
|
||||
mconn = newMockConnection()
|
||||
tokenGenerator, err := handshake.NewTokenGenerator()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -119,6 +118,7 @@ var _ = Describe("Session", func() {
|
|||
clientDestConnID,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
[16]byte{},
|
||||
populateServerConfig(&Config{}),
|
||||
nil, // tls.Config
|
||||
tokenGenerator,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue