mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
only create a single session when two Initials arrive at the same time
This commit is contained in:
parent
5a834851a8
commit
e65df402dd
7 changed files with 190 additions and 59 deletions
159
server_test.go
159
server_test.go
|
@ -3,6 +3,7 @@ package quic
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
@ -39,6 +41,29 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
}
|
||||
|
||||
getInitial := func(destConnID protocol.ConnectionID) *receivedPacket {
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: destConnID,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.buffer = getPacketBuffer()
|
||||
p.remoteAddr = senderAddr
|
||||
return p
|
||||
}
|
||||
|
||||
getInitialWithRandomDestConnID := func() *receivedPacket {
|
||||
destConnID := make([]byte, 10)
|
||||
_, err := rand.Read(destConnID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return getInitial(destConnID)
|
||||
}
|
||||
|
||||
parseHeader := func(data []byte) *wire.Header {
|
||||
hdr, _, _, err := wire.ParsePacket(data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -127,12 +152,17 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
Context("server accepting sessions that completed the handshake", func() {
|
||||
var serv *baseServer
|
||||
var (
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
ln, err := Listen(conn, tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.(*baseServer)
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
serv.sessionHandler = phm
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
|
@ -282,6 +312,14 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
run := make(chan struct{})
|
||||
var token [16]byte
|
||||
rand.Read(token[:])
|
||||
var newConnID protocol.ConnectionID
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
|
@ -289,6 +327,7 @@ var _ = Describe("Server", func() {
|
|||
origConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
tokenP [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -300,7 +339,8 @@ var _ = Describe("Server", func() {
|
|||
// make sure we're using a server-generated connection ID
|
||||
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
|
||||
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
Expect(srcConnID).To(Equal(newConnID))
|
||||
Expect(tokenP).To(Equal(token))
|
||||
sess.EXPECT().handlePacket(p)
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
sess.EXPECT().Context().Return(context.Background())
|
||||
|
@ -308,6 +348,11 @@ var _ = Describe("Server", func() {
|
|||
return sess
|
||||
}
|
||||
|
||||
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
|
||||
phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) {
|
||||
Expect(c).To(Equal(newConnID))
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -321,19 +366,10 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
It("only creates a single session for a duplicate Initial", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
var createdSession bool
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
|
@ -341,6 +377,35 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
createdSession = true
|
||||
return sess
|
||||
}
|
||||
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
|
||||
Expect(serv.handlePacketImpl(p)).To(BeFalse())
|
||||
Expect(createdSession).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -348,7 +413,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(p)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().run()
|
||||
sess.EXPECT().Context().Return(context.Background())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -357,21 +422,28 @@ var _ = Describe("Server", func() {
|
|||
return sess
|
||||
}
|
||||
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
|
||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
serv.handlePacket(p)
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
p := getInitialWithRandomDestConnID()
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv.handlePacket(p)
|
||||
var reject mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&reject))
|
||||
Expect(reject.to).To(Equal(senderAddr))
|
||||
Expect(reject.to).To(Equal(p.remoteAddr))
|
||||
rejectHdr := parseHeader(reject.data)
|
||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(rejectHdr.Version).To(Equal(hdr.Version))
|
||||
|
@ -381,17 +453,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("doesn't accept new sessions if they were closed in the mean time", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionCreated := make(chan struct{})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -402,6 +465,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -418,6 +482,10 @@ var _ = Describe("Server", func() {
|
|||
return sess
|
||||
}
|
||||
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
||||
serv.handlePacket(p)
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
Eventually(sessionCreated).Should(BeClosed())
|
||||
|
@ -433,6 +501,8 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
phm.EXPECT().CloseServer()
|
||||
sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -498,6 +568,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -509,6 +580,9 @@ var _ = Describe("Server", func() {
|
|||
sess.EXPECT().Context().Return(context.Background())
|
||||
return sess
|
||||
}
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
cancel() // complete the handshake
|
||||
|
@ -546,6 +620,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -567,15 +642,6 @@ var _ = Describe("Server", func() {
|
|||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
serv.newSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
|
@ -583,6 +649,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -592,7 +659,7 @@ var _ = Describe("Server", func() {
|
|||
ready := make(chan struct{})
|
||||
close(ready)
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(p)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
sess.EXPECT().run()
|
||||
sess.EXPECT().earlySessionReady().Return(ready)
|
||||
sess.EXPECT().Context().Return(context.Background())
|
||||
|
@ -605,11 +672,14 @@ var _ = Describe("Server", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
serv.handlePacket(p)
|
||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
p := getInitialWithRandomDestConnID()
|
||||
hdr := parseHeader(p.data)
|
||||
serv.handlePacket(p)
|
||||
var reject mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&reject))
|
||||
|
@ -623,17 +693,8 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("doesn't accept new sessions if they were closed in the mean time", func() {
|
||||
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
|
||||
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
p.remoteAddr = senderAddr
|
||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionCreated := make(chan struct{})
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -644,6 +705,7 @@ var _ = Describe("Server", func() {
|
|||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ [16]byte,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
|
@ -673,6 +735,7 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue