mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 21:27:35 +03:00
move 0-RTT queue handling from the packet handler map to the server
This commit is contained in:
parent
c0b94ee4b0
commit
58487803d3
7 changed files with 367 additions and 217 deletions
232
server_test.go
232
server_test.go
|
@ -215,8 +215,7 @@ var _ = Describe("Server", func() {
|
|||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize-100),
|
||||
)
|
||||
}, make([]byte, protocol.MinInitialPacketSize-100))
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
serv.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
|
@ -234,6 +233,21 @@ var _ = Describe("Server", func() {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
phm.EXPECT().Get(connID).Return(conn, true)
|
||||
handled := make(chan struct{})
|
||||
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
|
||||
serv.handlePacket(p)
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a connection when the token is accepted", func() {
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
|
@ -243,10 +257,11 @@ var _ = Describe("Server", func() {
|
|||
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: connID,
|
||||
Version: protocol.Version1,
|
||||
Token: retryToken,
|
||||
}
|
||||
|
@ -257,7 +272,9 @@ var _ = Describe("Server", func() {
|
|||
rand.Read(token[:])
|
||||
|
||||
var newConnID protocol.ConnectionID
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
|
@ -408,11 +425,12 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("replies with a Retry packet, if a token is required", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: connID,
|
||||
Version: protocol.Version1,
|
||||
}
|
||||
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
|
@ -435,15 +453,17 @@ var _ = Describe("Server", func() {
|
|||
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a connection, if no token is required", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
||||
hdr := &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
DestConnectionID: connID,
|
||||
Version: protocol.Version1,
|
||||
}
|
||||
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
||||
|
@ -452,15 +472,18 @@ var _ = Describe("Server", func() {
|
|||
rand.Read(token[:])
|
||||
|
||||
var newConnID protocol.ConnectionID
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
})
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
}),
|
||||
)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
@ -513,6 +536,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -597,8 +621,10 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
p := getInitial(connID)
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
|
||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||
Expect(createdConn).To(BeFalse())
|
||||
})
|
||||
|
@ -632,6 +658,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -702,6 +729,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -765,6 +793,7 @@ var _ = Describe("Server", func() {
|
|||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -801,6 +830,7 @@ var _ = Describe("Server", func() {
|
|||
checkInvalidToken(b, hdr)
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -837,6 +867,7 @@ var _ = Describe("Server", func() {
|
|||
checkInvalidToken(b, hdr)
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -864,6 +895,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -893,6 +925,7 @@ var _ = Describe("Server", func() {
|
|||
defer close(done)
|
||||
return len(b), nil
|
||||
})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -913,6 +946,7 @@ var _ = Describe("Server", func() {
|
|||
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
serv.handlePacket(packet)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
@ -997,6 +1031,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1070,6 +1105,7 @@ var _ = Describe("Server", func() {
|
|||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1115,6 +1151,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1175,6 +1212,7 @@ var _ = Describe("Server", func() {
|
|||
return conn
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
|
@ -1202,4 +1240,166 @@ var _ = Describe("Server", func() {
|
|||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("0-RTT", func() {
|
||||
var (
|
||||
serv *earlyServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.(*earlyServer)
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
serv.connHandler = phm
|
||||
serv.zeroRTTQueueDuration = time.Hour
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
})
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 100))
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
phm.EXPECT().Get(connID).Return(conn, true)
|
||||
handled := make(chan struct{})
|
||||
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) })
|
||||
serv.handlePacket(p)
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
|
||||
var zeroRTTPackets []*receivedPacket
|
||||
|
||||
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 100+i))
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(p)
|
||||
zeroRTTPackets = append(zeroRTTPackets, p)
|
||||
}
|
||||
|
||||
// send one more packet, this one should be dropped
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
phm.EXPECT().Get(connID)
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
serv.handlePacket(p)
|
||||
|
||||
initial := getPacket(&wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, protocol.MinInitialPacketSize))
|
||||
called := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
var calls []*gomock.Call
|
||||
calls = append(calls, conn.EXPECT().handlePacket(initial))
|
||||
for _, p := range zeroRTTPackets {
|
||||
calls = append(calls, conn.EXPECT().handlePacket(p))
|
||||
}
|
||||
gomock.InOrder(calls...)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().earlyConnReady()
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
close(called)
|
||||
return conn
|
||||
}
|
||||
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
})
|
||||
serv.handlePacket(initial)
|
||||
Eventually(called).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("limits the number of queues", func() {
|
||||
for i := 0; i < protocol.Max0RTTQueues; i++ {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 100+i))
|
||||
phm.EXPECT().Get(connID)
|
||||
serv.handlePacket(p)
|
||||
}
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
phm.EXPECT().Get(connID)
|
||||
dropped := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||
close(dropped)
|
||||
})
|
||||
serv.handlePacket(p)
|
||||
Eventually(dropped).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops queues after a while", func() {
|
||||
serv.zeroRTTQueueDuration = time.Millisecond
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
p := getPacket(&wire.Header{
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: connID,
|
||||
Version: serv.config.Versions[0],
|
||||
}, make([]byte, 200))
|
||||
got := make(chan struct{})
|
||||
phm.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(got) })
|
||||
serv.handlePacket(p)
|
||||
<-got
|
||||
// TODO: use the tracer call
|
||||
Eventually(func() int {
|
||||
serv.mutex.Lock()
|
||||
defer serv.mutex.Unlock()
|
||||
return len(serv.zeroRTTQueues)
|
||||
}).Should(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue