only create a single session when two Initials arrive at the same time

This commit is contained in:
Marten Seemann 2019-11-27 09:24:56 +07:00
parent 5a834851a8
commit e65df402dd
7 changed files with 190 additions and 59 deletions

View file

@ -3,6 +3,7 @@ package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
@ -10,6 +11,7 @@ import (
"sync"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
@ -39,6 +41,29 @@ var _ = Describe("Server", func() {
}
}
getInitial := func(destConnID protocol.ConnectionID) *receivedPacket {
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: destConnID,
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.buffer = getPacketBuffer()
p.remoteAddr = senderAddr
return p
}
getInitialWithRandomDestConnID := func() *receivedPacket {
destConnID := make([]byte, 10)
_, err := rand.Read(destConnID)
Expect(err).ToNot(HaveOccurred())
return getInitial(destConnID)
}
parseHeader := func(data []byte) *wire.Header {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
@ -127,12 +152,17 @@ var _ = Describe("Server", func() {
})
Context("server accepting sessions that completed the handshake", func() {
var serv *baseServer
var (
serv *baseServer
phm *MockPacketHandlerManager
)
BeforeEach(func() {
ln, err := Listen(conn, tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
serv = ln.(*baseServer)
phm = NewMockPacketHandlerManager(mockCtrl)
serv.sessionHandler = phm
})
Context("handling packets", func() {
@ -282,6 +312,14 @@ var _ = Describe("Server", func() {
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
run := make(chan struct{})
var token [16]byte
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte {
newConnID = c
return token
})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
_ sessionRunner,
@ -289,6 +327,7 @@ var _ = Describe("Server", func() {
origConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
tokenP [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -300,7 +339,8 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
sess := NewMockQuicSession(mockCtrl)
Expect(srcConnID).To(Equal(newConnID))
Expect(tokenP).To(Equal(token))
sess.EXPECT().handlePacket(p)
sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().Context().Return(context.Background())
@ -308,6 +348,11 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) {
Expect(c).To(Equal(newConnID))
})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -321,19 +366,10 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
It("only creates a single session for a duplicate Initial", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = senderAddr
var createdSession bool
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
runner sessionRunner,
@ -341,6 +377,35 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) quicSession {
createdSession = true
return sess
}
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeFalse())
Expect(createdSession).To(BeTrue())
})
It("rejects new connection attempts if the accept queue is full", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -348,7 +413,7 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicSession {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
ctx, cancel := context.WithCancel(context.Background())
@ -357,21 +422,28 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
serv.handlePacket(p)
serv.handlePacket(getInitialWithRandomDestConnID())
Consistently(conn.dataWritten).ShouldNot(Receive())
}()
}
wg.Wait()
p := getInitialWithRandomDestConnID()
hdr, _, _, err := wire.ParsePacket(p.data, 0)
Expect(err).ToNot(HaveOccurred())
serv.handlePacket(p)
var reject mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&reject))
Expect(reject.to).To(Equal(senderAddr))
Expect(reject.to).To(Equal(p.remoteAddr))
rejectHdr := parseHeader(reject.data)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
Expect(rejectHdr.Version).To(Equal(hdr.Version))
@ -381,17 +453,8 @@ var _ = Describe("Server", func() {
It("doesn't accept new sessions if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = senderAddr
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
@ -402,6 +465,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -418,6 +482,10 @@ var _ = Describe("Server", func() {
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
phm.EXPECT().Add(gomock.Any(), gomock.Any())
serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed())
@ -433,6 +501,8 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
phm.EXPECT().CloseServer()
sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
@ -498,6 +568,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -509,6 +580,9 @@ var _ = Describe("Server", func() {
sess.EXPECT().Context().Return(context.Background())
return sess
}
phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
phm.EXPECT().Add(gomock.Any(), gomock.Any())
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
@ -546,6 +620,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -567,15 +642,6 @@ var _ = Describe("Server", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = senderAddr
serv.newSession = func(
_ connection,
runner sessionRunner,
@ -583,6 +649,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -592,7 +659,7 @@ var _ = Describe("Server", func() {
ready := make(chan struct{})
close(ready)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().handlePacket(gomock.Any())
sess.EXPECT().run()
sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background())
@ -605,11 +672,14 @@ var _ = Describe("Server", func() {
go func() {
defer GinkgoRecover()
defer wg.Done()
serv.handlePacket(p)
serv.handlePacket(getInitialWithRandomDestConnID())
Consistently(conn.dataWritten).ShouldNot(Receive())
}()
}
wg.Wait()
p := getInitialWithRandomDestConnID()
hdr := parseHeader(p.data)
serv.handlePacket(p)
var reject mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&reject))
@ -623,17 +693,8 @@ var _ = Describe("Server", func() {
It("doesn't accept new sessions if they were closed in the mean time", func() {
serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = senderAddr
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
@ -644,6 +705,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ [16]byte,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
@ -673,6 +735,7 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})