move 0-RTT queue handling from the packet handler map to the server

This commit is contained in:
Marten Seemann 2023-04-26 10:07:28 +02:00
parent c0b94ee4b0
commit 58487803d3
7 changed files with 367 additions and 217 deletions

View file

@ -215,8 +215,7 @@ var _ = Describe("Server", func() {
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize-100),
)
}, make([]byte, protocol.MinInitialPacketSize-100))
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
@ -234,6 +233,21 @@ var _ = Describe("Server", func() {
time.Sleep(50 * time.Millisecond)
})
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
serv.handlePacket(p)
Eventually(handled).Should(BeClosed())
})
It("creates a connection when the token is accepted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
@ -243,10 +257,11 @@ var _ = Describe("Server", func() {
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
)
Expect(err).ToNot(HaveOccurred())
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: connID,
Version: protocol.Version1,
Token: retryToken,
}
@ -257,7 +272,9 @@ var _ = Describe("Server", func() {
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
@ -408,11 +425,12 @@ var _ = Describe("Server", func() {
})
It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: connID,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
@ -435,15 +453,17 @@ var _ = Describe("Server", func() {
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
return len(b), nil
})
phm.EXPECT().Get(connID)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
It("creates a connection, if no token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: connID,
Version: protocol.Version1,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
@ -452,15 +472,18 @@ var _ = Describe("Server", func() {
rand.Read(token[:])
var newConnID protocol.ConnectionID
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
newConnID = c
return token
})
fn()
return true
})
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
}),
)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
conn := NewMockQUICConn(mockCtrl)
@ -513,6 +536,7 @@ var _ = Describe("Server", func() {
})
It("drops packets if the receive queue is full", func() {
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -597,8 +621,10 @@ var _ = Describe("Server", func() {
return conn
}
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
})
@ -632,6 +658,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -702,6 +729,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -765,6 +793,7 @@ var _ = Describe("Server", func() {
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
@ -801,6 +830,7 @@ var _ = Describe("Server", func() {
checkInvalidToken(b, hdr)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -837,6 +867,7 @@ var _ = Describe("Server", func() {
checkInvalidToken(b, hdr)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -864,6 +895,7 @@ var _ = Describe("Server", func() {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
Eventually(done).Should(BeClosed())
@ -893,6 +925,7 @@ var _ = Describe("Server", func() {
defer close(done)
return len(b), nil
})
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@ -913,6 +946,7 @@ var _ = Describe("Server", func() {
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
phm.EXPECT().Get(gomock.Any())
serv.handlePacket(packet)
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
@ -997,6 +1031,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1070,6 +1105,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1115,6 +1151,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1175,6 +1212,7 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
@ -1202,4 +1240,166 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
})
Context("0-RTT", func() {
var (
serv *earlyServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
)
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
Expect(err).ToNot(HaveOccurred())
serv = ln.(*earlyServer)
phm = NewMockPacketHandlerManager(mockCtrl)
serv.connHandler = phm
serv.zeroRTTQueueDuration = time.Hour
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
})
It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100))
conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
serv.handlePacket(p)
Eventually(handled).Should(BeClosed())
})
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
var zeroRTTPackets []*receivedPacket
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100+i))
phm.EXPECT().Get(connID)
serv.handlePacket(p)
zeroRTTPackets = append(zeroRTTPackets, p)
}
// send one more packet, this one should be dropped
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
phm.EXPECT().Get(connID)
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
serv.handlePacket(p)
initial := getPacket(&wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
called := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
var calls []*gomock.Call
calls = append(calls, conn.EXPECT().handlePacket(initial))
for _, p := range zeroRTTPackets {
calls = append(calls, conn.EXPECT().handlePacket(p))
}
gomock.InOrder(calls...)
conn.EXPECT().run()
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(context.Background())
close(called)
return conn
}
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
})
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())
})
It("limits the number of queues", func() {
for i := 0; i < protocol.Max0RTTQueues; i++ {
b := make([]byte, 16)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 100+i))
phm.EXPECT().Get(connID)
serv.handlePacket(p)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
phm.EXPECT().Get(connID)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(dropped)
})
serv.handlePacket(p)
Eventually(dropped).Should(BeClosed())
})
It("drops queues after a while", func() {
serv.zeroRTTQueueDuration = time.Millisecond
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
p := getPacket(&wire.Header{
Type: protocol.PacketType0RTT,
DestConnectionID: connID,
Version: serv.config.Versions[0],
}, make([]byte, 200))
got := make(chan struct{})
phm.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(got) })
serv.handlePacket(p)
<-got
// TODO: use the tracer call
Eventually(func() int {
serv.mutex.Lock()
defer serv.mutex.Unlock()
return len(serv.zeroRTTQueues)
}).Should(BeZero())
})
})
})