uquic/packet_handler_map_test.go
Marten Seemann 8e93770dd3
avoid lock contention when accepting new connections (#4313)
* avoid lock contention when accepting new connections

The server used to hold the packet handler map's lock while creating the
connection struct for a newly accepted connection. This was intended to
make sure that no two connections with the same Destination Connection
ID could be created.

This is a corner case: it can only happen if two Initial packets with
the same Destination Connection ID are received at the same time. If
the second one is received after the first one has already been
processed, it would be routed to the first connection. We don't need to
optimized for this corner case. It's ok to create a new connection in
that case, and immediately close it if this collision is detected.

* only pass 0-RTT to the connection if it was actually accepted
2024-02-08 19:34:42 -08:00

170 lines
5.8 KiB
Go

package quic
import (
"crypto/rand"
"errors"
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Packet Handler Map", func() {
It("adds and gets", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).To(Equal(handler))
})
It("refused to add duplicates", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
Expect(m.Add(connID, handler)).To(BeFalse())
})
It("removes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
m.Remove(connID)
_, ok := m.Get(connID)
Expect(ok).To(BeFalse())
Expect(m.Add(connID, handler)).To(BeTrue())
})
It("retires", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
m.Retire(connID)
_, ok := m.Get(connID)
Expect(ok).To(BeTrue())
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("adds newly to-be-constructed handlers", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
h := NewMockPacketHandler(mockCtrl)
Expect(m.AddWithConnID(connID1, connID2, h)).To(BeTrue())
// collision of the destination connection ID, this handler should not be added
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil)).To(BeFalse())
})
It("adds, gets and removes reset tokens", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
handler := NewMockPacketHandler(mockCtrl)
m.AddResetToken(token, handler)
h, ok := m.GetByResetToken(token)
Expect(ok).To(BeTrue())
Expect(h).To(Equal(h))
m.RemoveResetToken(token)
_, ok = m.GetByResetToken(token)
Expect(ok).To(BeFalse())
})
It("generates stateless reset token, if no key is set", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
for i := 0; i < 1000; i++ {
to := m.GetStatelessResetToken(connID)
Expect(to).ToNot(Equal(token))
token = to
}
})
It("generates stateless reset token, if a key is set", func() {
var key StatelessResetKey
rand.Read(key[:])
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
Expect(token).ToNot(BeZero())
Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
// generate a new connection ID
rand.Read(b)
connID2 := protocol.ParseConnectionID(b)
Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
})
It("replaces locally closed connections", func() {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
handler := NewMockPacketHandler(mockCtrl)
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.Add(connID, handler)).To(BeTrue())
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(receivedPacket{remoteAddr: addr})
Expect(closePackets).To(HaveLen(1))
Expect(closePackets[0].addr).To(Equal(addr))
Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("replaces remote closed connections", func() {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
handler := NewMockPacketHandler(mockCtrl)
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.Add(connID, handler)).To(BeTrue())
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(receivedPacket{remoteAddr: addr})
Expect(closePackets).To(BeEmpty())
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("closes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
conn.EXPECT().destroy(testErr)
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.Close(testErr)
// check that Close can be called multiple times
m.Close(errors.New("close"))
})
})