mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
* avoid lock contention when accepting new connections The server used to hold the packet handler map's lock while creating the connection struct for a newly accepted connection. This was intended to make sure that no two connections with the same Destination Connection ID could be created. This is a corner case: it can only happen if two Initial packets with the same Destination Connection ID are received at the same time. If the second one is received after the first one has already been processed, it would be routed to the first connection. We don't need to optimized for this corner case. It's ok to create a new connection in that case, and immediately close it if this collision is detected. * only pass 0-RTT to the connection if it was actually accepted
1584 lines
58 KiB
Go
1584 lines
58 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go/internal/handshake"
|
|
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
|
"github.com/quic-go/quic-go/internal/protocol"
|
|
"github.com/quic-go/quic-go/internal/qerr"
|
|
"github.com/quic-go/quic-go/internal/testdata"
|
|
"github.com/quic-go/quic-go/internal/utils"
|
|
"github.com/quic-go/quic-go/internal/wire"
|
|
"github.com/quic-go/quic-go/logging"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
var _ = Describe("Server", func() {
|
|
var (
|
|
conn *MockPacketConn
|
|
tlsConf *tls.Config
|
|
)
|
|
|
|
getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
|
|
buf := getPacketBuffer()
|
|
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
|
|
var err error
|
|
buf.Data, err = (&wire.ExtendedHeader{
|
|
Header: *hdr,
|
|
PacketNumber: 0x42,
|
|
PacketNumberLen: protocol.PacketNumberLen4,
|
|
}).Append(buf.Data, protocol.Version1)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
n := len(buf.Data)
|
|
buf.Data = append(buf.Data, p...)
|
|
data := buf.Data
|
|
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
|
|
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
|
|
data = data[:len(data)+16]
|
|
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
|
|
return receivedPacket{
|
|
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
|
|
data: data,
|
|
buffer: buf,
|
|
}
|
|
}
|
|
|
|
getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
|
|
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: destConnID,
|
|
Version: protocol.Version1,
|
|
}
|
|
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
p.buffer = getPacketBuffer()
|
|
p.remoteAddr = senderAddr
|
|
return p
|
|
}
|
|
|
|
getInitialWithRandomDestConnID := func() receivedPacket {
|
|
b := make([]byte, 10)
|
|
_, err := rand.Read(b)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
return getInitial(protocol.ParseConnectionID(b))
|
|
}
|
|
|
|
parseHeader := func(data []byte) *wire.Header {
|
|
hdr, _, _, err := wire.ParsePacket(data)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
return hdr
|
|
}
|
|
|
|
checkConnectionCloseError := func(b []byte, origHdr *wire.Header, errorCode qerr.TransportErrorCode) {
|
|
replyHdr := parseHeader(b)
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID))
|
|
Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID))
|
|
_, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version)
|
|
extHdr, err := unpackLongHeader(opener, replyHdr, b, origHdr.Version)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()])
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, f, err := wire.NewFrameParser(false).ParseNext(data, protocol.EncryptionInitial, origHdr.Version)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
ccf := f.(*wire.ConnectionCloseFrame)
|
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(errorCode))
|
|
Expect(ccf.ReasonPhrase).To(BeEmpty())
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
conn = NewMockPacketConn(mockCtrl)
|
|
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
|
wait := make(chan struct{})
|
|
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
|
|
<-wait
|
|
return 0, nil, errors.New("done")
|
|
}).MaxTimes(1)
|
|
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) error {
|
|
close(wait)
|
|
conn.EXPECT().SetReadDeadline(time.Time{})
|
|
return nil
|
|
}).MaxTimes(1)
|
|
tlsConf = testdata.GetTLSConfig()
|
|
tlsConf.NextProtos = []string{"proto1"}
|
|
})
|
|
|
|
It("errors when no tls.Config is given", func() {
|
|
_, err := ListenAddr("localhost:0", nil, nil)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set"))
|
|
})
|
|
|
|
It("errors when the Config contains an invalid version", func() {
|
|
version := protocol.Version(0x1234)
|
|
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.Version{version}})
|
|
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
|
|
})
|
|
|
|
It("fills in default values if options are not set in the Config", func() {
|
|
ln, err := Listen(conn, tlsConf, &Config{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
server := ln.baseServer
|
|
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
|
Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
|
Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
|
Expect(server.config.KeepAlivePeriod).To(BeZero())
|
|
// stop the listener
|
|
Expect(ln.Close()).To(Succeed())
|
|
})
|
|
|
|
It("setups with the right values", func() {
|
|
supportedVersions := []protocol.Version{protocol.Version1}
|
|
config := Config{
|
|
Versions: supportedVersions,
|
|
HandshakeIdleTimeout: 1337 * time.Hour,
|
|
MaxIdleTimeout: 42 * time.Minute,
|
|
KeepAlivePeriod: 5 * time.Second,
|
|
}
|
|
ln, err := Listen(conn, tlsConf, &config)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
server := ln.baseServer
|
|
Expect(server.connHandler).ToNot(BeNil())
|
|
Expect(server.config.Versions).To(Equal(supportedVersions))
|
|
Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour))
|
|
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
|
|
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
|
|
// stop the listener
|
|
Expect(ln.Close()).To(Succeed())
|
|
})
|
|
|
|
It("listens on a given address", func() {
|
|
addr := "127.0.0.1:13579"
|
|
ln, err := ListenAddr(addr, tlsConf, &Config{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(ln.Addr().String()).To(Equal(addr))
|
|
// stop the listener
|
|
Expect(ln.Close()).To(Succeed())
|
|
})
|
|
|
|
It("errors if given an invalid address", func() {
|
|
addr := "127.0.0.1"
|
|
_, err := ListenAddr(addr, tlsConf, &Config{})
|
|
Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
|
|
})
|
|
|
|
It("errors if given an invalid address", func() {
|
|
addr := "1.1.1.1:1111"
|
|
_, err := ListenAddr(addr, tlsConf, &Config{})
|
|
Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
|
|
})
|
|
|
|
Context("server accepting connections that completed the handshake", func() {
|
|
var (
|
|
tr *Transport
|
|
serv *baseServer
|
|
phm *MockPacketHandlerManager
|
|
tracer *mocklogging.MockTracer
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
var t *logging.Tracer
|
|
t, tracer = mocklogging.NewMockTracer(mockCtrl)
|
|
tr = &Transport{Conn: conn, Tracer: t}
|
|
ln, err := tr.Listen(tlsConf, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
serv = ln.baseServer
|
|
phm = NewMockPacketHandlerManager(mockCtrl)
|
|
serv.connHandler = phm
|
|
})
|
|
|
|
AfterEach(func() {
|
|
tracer.EXPECT().Close()
|
|
tr.Close()
|
|
})
|
|
|
|
Context("handling packets", func() {
|
|
It("drops Initial packets with a too short connection ID", func() {
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
|
Version: serv.config.Versions[0],
|
|
}, nil)
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
serv.handlePacket(p)
|
|
// make sure there are no Write calls on the packet conn
|
|
time.Sleep(50 * time.Millisecond)
|
|
})
|
|
|
|
It("drops too small Initial", func() {
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, protocol.MinInitialPacketSize-100))
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
serv.handlePacket(p)
|
|
// make sure there are no Write calls on the packet conn
|
|
time.Sleep(50 * time.Millisecond)
|
|
})
|
|
|
|
It("drops non-Initial packets", func() {
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
Version: serv.config.Versions[0],
|
|
}, []byte("invalid"))
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
|
|
serv.handlePacket(p)
|
|
// make sure there are no Write calls on the packet conn
|
|
time.Sleep(50 * time.Millisecond)
|
|
})
|
|
|
|
It("passes packets to existing connections", func() {
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, protocol.MinInitialPacketSize))
|
|
conn := NewMockPacketHandler(mockCtrl)
|
|
phm.EXPECT().Get(connID).Return(conn, true)
|
|
handled := make(chan struct{})
|
|
conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
|
|
serv.handlePacket(p)
|
|
Eventually(handled).Should(BeClosed())
|
|
})
|
|
|
|
It("creates a connection when the token is accepted", func() {
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
retryToken, err := serv.tokenGenerator.NewRetryToken(
|
|
raddr,
|
|
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}),
|
|
protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
|
|
)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: connID,
|
|
Version: protocol.Version1,
|
|
Token: retryToken,
|
|
}
|
|
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
p.remoteAddr = raddr
|
|
run := make(chan struct{})
|
|
var token protocol.StatelessResetToken
|
|
rand.Read(token[:])
|
|
|
|
var newConnID protocol.ConnectionID
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
_ connRunner,
|
|
origDestConnID protocol.ConnectionID,
|
|
retrySrcConnID *protocol.ConnectionID,
|
|
clientDestConnID protocol.ConnectionID,
|
|
destConnID protocol.ConnectionID,
|
|
srcConnID protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
tokenP protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
|
|
Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
|
|
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
|
|
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
|
|
// make sure we're using a server-generated connection ID
|
|
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
|
|
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
|
|
newConnID = srcConnID
|
|
Expect(tokenP).To(Equal(token))
|
|
conn.EXPECT().handlePacket(p)
|
|
conn.EXPECT().run().Do(func() error { close(run); return nil })
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
return conn
|
|
}
|
|
phm.EXPECT().Get(connID)
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token)
|
|
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, cid protocol.ConnectionID, h packetHandler) bool {
|
|
Expect(cid).To(Equal(newConnID))
|
|
return true
|
|
})
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
serv.handlePacket(p)
|
|
// the Handshake packet is written by the connection.
|
|
// Make sure there are no Write calls on the packet conn.
|
|
time.Sleep(50 * time.Millisecond)
|
|
close(done)
|
|
}()
|
|
// make sure we're using a server-generated connection ID
|
|
Eventually(run).Should(BeClosed())
|
|
Eventually(done).Should(BeClosed())
|
|
// shutdown
|
|
conn.EXPECT().closeWithTransportError(gomock.Any())
|
|
})
|
|
|
|
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
|
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
|
|
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
|
|
packet := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
SrcConnectionID: srcConnID,
|
|
DestConnectionID: destConnID,
|
|
Version: 0x42,
|
|
}, make([]byte, protocol.MinUnknownVersionPacketSize))
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
packet.remoteAddr = raddr
|
|
tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.Version) {
|
|
Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
|
|
Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue())
|
|
dest, src, versions, err := wire.ParseVersionNegotiationPacket(b)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes())))
|
|
Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes())))
|
|
Expect(versions).ToNot(ContainElement(protocol.Version(0x42)))
|
|
return len(b), nil
|
|
})
|
|
serv.handlePacket(packet)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("doesn't send a Version Negotiation packets if sending them is disabled", func() {
|
|
serv.disableVersionNegotiation = true
|
|
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
|
|
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
|
|
packet := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
SrcConnectionID: srcConnID,
|
|
DestConnectionID: destConnID,
|
|
Version: 0x42,
|
|
}, make([]byte, protocol.MinUnknownVersionPacketSize))
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
packet.remoteAddr = raddr
|
|
done := make(chan struct{})
|
|
serv.handlePacket(packet)
|
|
Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed())
|
|
})
|
|
|
|
It("ignores Version Negotiation packets", func() {
|
|
data := wire.ComposeVersionNegotiation(
|
|
protocol.ArbitraryLenConnectionID{1, 2, 3, 4},
|
|
protocol.ArbitraryLenConnectionID{4, 3, 2, 1},
|
|
[]protocol.Version{1, 2, 3},
|
|
)
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
})
|
|
serv.handlePacket(receivedPacket{
|
|
remoteAddr: raddr,
|
|
data: data,
|
|
buffer: getPacketBuffer(),
|
|
})
|
|
Eventually(done).Should(BeClosed())
|
|
// make sure no other packet is sent
|
|
time.Sleep(scaleDuration(20 * time.Millisecond))
|
|
})
|
|
|
|
It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() {
|
|
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
|
|
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeHandshake,
|
|
SrcConnectionID: srcConnID,
|
|
DestConnectionID: destConnID,
|
|
Version: 0x42,
|
|
}, make([]byte, protocol.MinUnknownVersionPacketSize-50))
|
|
Expect(p.Size()).To(BeNumerically("<", protocol.MinUnknownVersionPacketSize))
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
p.remoteAddr = raddr
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(done)
|
|
})
|
|
serv.handlePacket(p)
|
|
Eventually(done).Should(BeClosed())
|
|
// make sure no other packet is sent
|
|
time.Sleep(scaleDuration(20 * time.Millisecond))
|
|
})
|
|
|
|
It("replies with a Retry packet, if a token is required", func() {
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: connID,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
packet.remoteAddr = raddr
|
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
|
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
Expect(replyHdr.Token).ToNot(BeEmpty())
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
replyHdr := parseHeader(b)
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
|
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
Expect(replyHdr.Token).ToNot(BeEmpty())
|
|
Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:]))
|
|
return len(b), nil
|
|
})
|
|
phm.EXPECT().Get(connID)
|
|
serv.handlePacket(packet)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("creates a connection, if no token is required", func() {
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: connID,
|
|
Version: protocol.Version1,
|
|
}
|
|
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
run := make(chan struct{})
|
|
var token protocol.StatelessResetToken
|
|
rand.Read(token[:])
|
|
|
|
var newConnID protocol.ConnectionID
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
_ connRunner,
|
|
origDestConnID protocol.ConnectionID,
|
|
retrySrcConnID *protocol.ConnectionID,
|
|
clientDestConnID protocol.ConnectionID,
|
|
destConnID protocol.ConnectionID,
|
|
srcConnID protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
tokenP protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
|
|
Expect(retrySrcConnID).To(BeNil())
|
|
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
|
|
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
|
|
// make sure we're using a server-generated connection ID
|
|
Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID))
|
|
Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID))
|
|
newConnID = srcConnID
|
|
Expect(tokenP).To(Equal(token))
|
|
conn.EXPECT().handlePacket(p)
|
|
conn.EXPECT().run().Do(func() error { close(run); return nil })
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
return conn
|
|
}
|
|
gomock.InOrder(
|
|
phm.EXPECT().Get(connID),
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Return(token),
|
|
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, h packetHandler) bool {
|
|
Expect(c).To(Equal(newConnID))
|
|
return true
|
|
}),
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
serv.handlePacket(p)
|
|
// the Handshake packet is written by the connection
|
|
// make sure there are no Write calls on the packet conn
|
|
time.Sleep(50 * time.Millisecond)
|
|
close(done)
|
|
}()
|
|
// make sure we're using a server-generated connection ID
|
|
Eventually(run).Should(BeClosed())
|
|
Eventually(done).Should(BeClosed())
|
|
// shutdown
|
|
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
|
|
})
|
|
|
|
It("drops packets if the receive queue is full", func() {
|
|
serv.maxNumHandshakesTotal = 10000
|
|
serv.maxNumHandshakesUnvalidated = 10000
|
|
|
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
|
|
|
acceptConn := make(chan struct{})
|
|
var counter atomic.Uint32
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
<-acceptConn
|
|
counter.Add(1)
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1)
|
|
conn.EXPECT().run().MaxTimes(1)
|
|
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
|
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
|
|
// shutdown
|
|
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
|
|
return conn
|
|
}
|
|
|
|
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
|
|
serv.handlePacket(p)
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1)
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 3*protocol.MaxServerUnprocessedPackets; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
defer wg.Done()
|
|
serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})))
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
close(acceptConn)
|
|
Eventually(
|
|
func() uint32 { return counter.Load() },
|
|
scaleDuration(100*time.Millisecond),
|
|
).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
|
|
Consistently(func() uint32 { return counter.Load() }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1))
|
|
})
|
|
|
|
PIt("only creates a single connection for a duplicate Initial", func() {
|
|
var createdConn bool
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
createdConn = true
|
|
return NewMockQUICConn(mockCtrl)
|
|
}
|
|
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
|
p := getInitial(connID)
|
|
phm.EXPECT().Get(connID)
|
|
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) // connection ID collision
|
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) (int, error) { close(done); return 0, nil })
|
|
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
|
Expect(createdConn).To(BeFalse())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("limits the number of unvalidated handshakes", func() {
|
|
const limit = 3
|
|
serv.maxNumHandshakesTotal = 10000
|
|
serv.maxNumHandshakesUnvalidated = limit
|
|
|
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
|
|
|
handshakeChan := make(chan struct{})
|
|
connChan := make(chan *MockQUICConn, 1)
|
|
var wg sync.WaitGroup
|
|
wg.Add(2 * limit)
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn := <-connChan
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan).Do(func() <-chan struct{} { wg.Done(); return nil })
|
|
return conn
|
|
}
|
|
|
|
// Initiate the maximum number of allowed connection attempts.
|
|
for i := 0; i < limit; i++ {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
connChan <- conn
|
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
}
|
|
|
|
// Now initiate another connection attempt.
|
|
p := getInitialWithRandomDestConnID()
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
defer GinkgoRecover()
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
|
})
|
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer GinkgoRecover()
|
|
defer close(done)
|
|
hdr, _, _, err := wire.ParsePacket(b)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
|
|
return len(b), nil
|
|
})
|
|
serv.handlePacket(p)
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
close(handshakeChan)
|
|
for i := 0; i < limit; i++ {
|
|
_, err := serv.Accept(context.Background())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
}
|
|
for i := 0; i < limit; i++ {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
|
connChan <- conn
|
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
}
|
|
wg.Wait()
|
|
})
|
|
|
|
It("limits the number of total handshakes", func() {
|
|
const limit = 3
|
|
serv.maxNumHandshakesTotal = limit
|
|
serv.maxNumHandshakesUnvalidated = limit // same limit, but we check that we send CONNECTION_REFUSED and not Retry
|
|
|
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).AnyTimes()
|
|
|
|
handshakeChan := make(chan struct{})
|
|
connChan := make(chan *MockQUICConn, 1)
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn := <-connChan
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
return conn
|
|
}
|
|
|
|
for i := 0; i < limit; i++ {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
connChan <- conn
|
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
}
|
|
|
|
p := getInitialWithRandomDestConnID()
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().SentPacket(p.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
defer GinkgoRecover()
|
|
hdr, _, _, err := wire.ParsePacket(p.data)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
Expect(frames).To(HaveLen(1))
|
|
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ConnectionRefused))
|
|
})
|
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer GinkgoRecover()
|
|
defer close(done)
|
|
hdr, _, _, err := wire.ParsePacket(p.data)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
checkConnectionCloseError(b, hdr, qerr.ConnectionRefused)
|
|
return len(b), nil
|
|
})
|
|
serv.handlePacket(p)
|
|
Eventually(done).Should(BeClosed())
|
|
|
|
close(handshakeChan)
|
|
for i := 0; i < limit; i++ {
|
|
_, err := serv.Accept(context.Background())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
}
|
|
// make sure we can enqueue and accept more connections after that
|
|
for i := 0; i < limit; i++ {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1) // called when the server is closed
|
|
connChan <- conn
|
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
|
}
|
|
for i := 0; i < limit; i++ {
|
|
_, err := serv.Accept(context.Background())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
}
|
|
})
|
|
})
|
|
|
|
Context("token validation", func() {
|
|
It("decodes the token from the token field", func() {
|
|
raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337}
|
|
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
packet := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
Token: token,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, protocol.MinInitialPacketSize))
|
|
packet.remoteAddr = raddr
|
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
|
|
done := make(chan struct{})
|
|
phm.EXPECT().Get(gomock.Any())
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, _ packetHandler) bool {
|
|
close(done)
|
|
return true
|
|
})
|
|
phm.EXPECT().Remove(gomock.Any()).AnyTimes()
|
|
serv.handlePacket(packet)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
packet.remoteAddr = raddr
|
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
Expect(frames).To(HaveLen(1))
|
|
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
|
|
return len(b), nil
|
|
})
|
|
phm.EXPECT().Get(gomock.Any())
|
|
serv.handlePacket(packet)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
|
|
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
time.Sleep(2 * time.Millisecond) // make sure the token is expired
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
packet.remoteAddr = raddr
|
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
|
Expect(frames).To(HaveLen(1))
|
|
Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{}))
|
|
ccf := frames[0].(*logging.ConnectionCloseFrame)
|
|
Expect(ccf.IsApplicationError).To(BeFalse())
|
|
Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken))
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
checkConnectionCloseError(b, hdr, qerr.InvalidToken)
|
|
return len(b), nil
|
|
})
|
|
phm.EXPECT().Get(gomock.Any())
|
|
serv.handlePacket(packet)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
packet.remoteAddr = raddr
|
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
replyHdr := parseHeader(b)
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
|
return len(b), nil
|
|
})
|
|
phm.EXPECT().Get(gomock.Any())
|
|
serv.handlePacket(packet)
|
|
// make sure there are no Write calls on the packet conn
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
serv.maxTokenAge = time.Millisecond
|
|
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
token, err := serv.tokenGenerator.NewToken(raddr)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
time.Sleep(2 * time.Millisecond) // make sure the token is expired
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
packet.remoteAddr = raddr
|
|
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) {
|
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
return len(b), nil
|
|
})
|
|
phm.EXPECT().Get(gomock.Any())
|
|
serv.handlePacket(packet)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
|
|
serv.maxNumHandshakesUnvalidated = 0
|
|
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
hdr := &wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
|
|
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
Token: token,
|
|
Version: protocol.Version1,
|
|
}
|
|
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
|
|
packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet
|
|
packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) })
|
|
phm.EXPECT().Get(gomock.Any())
|
|
serv.handlePacket(packet)
|
|
// make sure there are no Write calls on the packet conn
|
|
time.Sleep(50 * time.Millisecond)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
})
|
|
|
|
Context("accepting connections", func() {
|
|
It("returns Accept when closed", func() {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_, err := serv.Accept(context.Background())
|
|
Expect(err).To(MatchError(ErrServerClosed))
|
|
close(done)
|
|
}()
|
|
|
|
serv.Close()
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("returns immediately, if an error occurred before", func() {
|
|
serv.Close()
|
|
for i := 0; i < 3; i++ {
|
|
_, err := serv.Accept(context.Background())
|
|
Expect(err).To(MatchError(ErrServerClosed))
|
|
}
|
|
})
|
|
|
|
It("closes connection that are still handshaking after Close", func() {
|
|
serv.Close()
|
|
|
|
destroyed := make(chan struct{})
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
_ connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
conf *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
|
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
conn.EXPECT().run().MaxTimes(1)
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
return conn
|
|
}
|
|
phm.EXPECT().Get(gomock.Any())
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
serv.handleInitialImpl(
|
|
receivedPacket{buffer: getPacketBuffer()},
|
|
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
|
)
|
|
Eventually(destroyed).Should(BeClosed())
|
|
})
|
|
|
|
It("returns when the context is canceled", func() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_, err := serv.Accept(ctx)
|
|
Expect(err).To(MatchError("context canceled"))
|
|
close(done)
|
|
}()
|
|
|
|
Consistently(done).ShouldNot(BeClosed())
|
|
cancel()
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("uses the config returned by GetConfigClient", func() {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
|
|
conf := &Config{MaxIncomingStreams: 1234}
|
|
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
s, err := serv.Accept(context.Background())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(s).To(Equal(conn))
|
|
close(done)
|
|
}()
|
|
|
|
handshakeChan := make(chan struct{})
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
_ connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
conf *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
return conn
|
|
}
|
|
phm.EXPECT().Get(gomock.Any())
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
serv.handleInitialImpl(
|
|
receivedPacket{buffer: getPacketBuffer()},
|
|
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
|
)
|
|
Consistently(done).ShouldNot(BeClosed())
|
|
close(handshakeChan) // complete the handshake
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("rejects a connection attempt when GetConfigClient returns an error", func() {
|
|
serv.config = populateConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
|
|
|
|
phm.EXPECT().Get(gomock.Any())
|
|
done := make(chan struct{})
|
|
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
|
defer close(done)
|
|
rejectHdr := parseHeader(b)
|
|
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
|
return len(b), nil
|
|
})
|
|
serv.handleInitialImpl(
|
|
receivedPacket{buffer: getPacketBuffer()},
|
|
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
|
|
)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("accepts new connections when the handshake completes", func() {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
s, err := serv.Accept(context.Background())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(s).To(Equal(conn))
|
|
close(done)
|
|
}()
|
|
|
|
handshakeChan := make(chan struct{})
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
return conn
|
|
}
|
|
phm.EXPECT().Get(gomock.Any())
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
serv.handleInitialImpl(
|
|
receivedPacket{buffer: getPacketBuffer()},
|
|
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
|
)
|
|
Consistently(done).ShouldNot(BeClosed())
|
|
close(handshakeChan) // complete the handshake
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
})
|
|
})
|
|
|
|
Context("server accepting connections that haven't completed the handshake", func() {
|
|
var (
|
|
serv *EarlyListener
|
|
phm *MockPacketHandlerManager
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
var err error
|
|
serv, err = ListenEarly(conn, tlsConf, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
phm = NewMockPacketHandlerManager(mockCtrl)
|
|
serv.baseServer.connHandler = phm
|
|
})
|
|
|
|
AfterEach(func() {
|
|
serv.Close()
|
|
})
|
|
|
|
It("accepts new connections when they become ready", func() {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
s, err := serv.Accept(context.Background())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(s).To(Equal(conn))
|
|
close(done)
|
|
}()
|
|
|
|
ready := make(chan struct{})
|
|
serv.baseServer.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().earlyConnReady().Return(ready)
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
return conn
|
|
}
|
|
phm.EXPECT().Get(gomock.Any())
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
serv.baseServer.handleInitialImpl(
|
|
receivedPacket{buffer: getPacketBuffer()},
|
|
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
|
)
|
|
Consistently(done).ShouldNot(BeClosed())
|
|
close(ready)
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("rejects new connection attempts if the accept queue is full", func() {
|
|
connChan := make(chan *MockQUICConn, 1)
|
|
var wg sync.WaitGroup // to make sure the test fully completes
|
|
wg.Add(protocol.MaxAcceptQueueSize + 1)
|
|
serv.baseServer.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
defer wg.Done()
|
|
ready := make(chan struct{})
|
|
close(ready)
|
|
conn := <-connChan
|
|
conn.EXPECT().handlePacket(gomock.Any())
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().earlyConnReady().Return(ready)
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
return conn
|
|
}
|
|
|
|
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
|
|
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
connChan <- conn
|
|
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
|
}
|
|
|
|
Eventually(serv.baseServer.connQueue).Should(HaveLen(protocol.MaxAcceptQueueSize))
|
|
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
conn.EXPECT().closeWithTransportError(ConnectionRefused)
|
|
connChan <- conn
|
|
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
|
wg.Wait()
|
|
})
|
|
|
|
It("doesn't accept new connections if they were closed in the mean time", func() {
|
|
p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
connCreated := make(chan struct{})
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
serv.baseServer.newConn = func(
|
|
_ sendConn,
|
|
runner connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn.EXPECT().handlePacket(p)
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().earlyConnReady()
|
|
conn.EXPECT().Context().Return(ctx)
|
|
close(connCreated)
|
|
return conn
|
|
}
|
|
|
|
phm.EXPECT().Get(gomock.Any())
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
serv.baseServer.handlePacket(p)
|
|
// make sure there are no Write calls on the packet conn
|
|
time.Sleep(50 * time.Millisecond)
|
|
Eventually(connCreated).Should(BeClosed())
|
|
cancel()
|
|
time.Sleep(scaleDuration(200 * time.Millisecond))
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
serv.Accept(context.Background())
|
|
close(done)
|
|
}()
|
|
Consistently(done).ShouldNot(BeClosed())
|
|
|
|
// make the go routine return
|
|
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
|
Expect(serv.Close()).To(Succeed())
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
})
|
|
|
|
Context("0-RTT", func() {
|
|
var (
|
|
tr *Transport
|
|
serv *baseServer
|
|
phm *MockPacketHandlerManager
|
|
tracer *mocklogging.MockTracer
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
var t *logging.Tracer
|
|
t, tracer = mocklogging.NewMockTracer(mockCtrl)
|
|
tr = &Transport{Conn: conn, Tracer: t}
|
|
ln, err := tr.ListenEarly(tlsConf, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
phm = NewMockPacketHandlerManager(mockCtrl)
|
|
serv = ln.baseServer
|
|
serv.connHandler = phm
|
|
})
|
|
|
|
AfterEach(func() {
|
|
tracer.EXPECT().Close()
|
|
Expect(tr.Close()).To(Succeed())
|
|
})
|
|
|
|
It("passes packets to existing connections", func() {
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 100))
|
|
conn := NewMockPacketHandler(mockCtrl)
|
|
phm.EXPECT().Get(connID).Return(conn, true)
|
|
handled := make(chan struct{})
|
|
conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
|
|
serv.handlePacket(p)
|
|
Eventually(handled).Should(BeClosed())
|
|
})
|
|
|
|
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
|
|
var zeroRTTPackets []receivedPacket
|
|
|
|
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 100+i))
|
|
phm.EXPECT().Get(connID)
|
|
serv.handlePacket(p)
|
|
zeroRTTPackets = append(zeroRTTPackets, p)
|
|
}
|
|
|
|
// send one more packet, this one should be dropped
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 200))
|
|
phm.EXPECT().Get(connID)
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
|
serv.handlePacket(p)
|
|
|
|
initial := getPacket(&wire.Header{
|
|
Type: protocol.PacketTypeInitial,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, protocol.MinInitialPacketSize))
|
|
called := make(chan struct{})
|
|
serv.newConn = func(
|
|
_ sendConn,
|
|
_ connRunner,
|
|
_ protocol.ConnectionID,
|
|
_ *protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ protocol.ConnectionID,
|
|
_ ConnectionIDGenerator,
|
|
_ protocol.StatelessResetToken,
|
|
_ *Config,
|
|
_ *tls.Config,
|
|
_ *handshake.TokenGenerator,
|
|
_ bool,
|
|
_ *logging.ConnectionTracer,
|
|
_ uint64,
|
|
_ utils.Logger,
|
|
_ protocol.Version,
|
|
) quicConn {
|
|
conn := NewMockQUICConn(mockCtrl)
|
|
var calls []any
|
|
calls = append(calls, conn.EXPECT().handlePacket(initial))
|
|
for _, p := range zeroRTTPackets {
|
|
calls = append(calls, conn.EXPECT().handlePacket(p))
|
|
}
|
|
gomock.InOrder(calls...)
|
|
conn.EXPECT().run()
|
|
conn.EXPECT().earlyConnReady()
|
|
conn.EXPECT().Context().Return(context.Background())
|
|
close(called)
|
|
// shutdown
|
|
conn.EXPECT().closeWithTransportError(gomock.Any())
|
|
return conn
|
|
}
|
|
|
|
phm.EXPECT().Get(connID)
|
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
|
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
|
|
serv.handlePacket(initial)
|
|
Eventually(called).Should(BeClosed())
|
|
})
|
|
|
|
It("limits the number of queues", func() {
|
|
for i := 0; i < protocol.Max0RTTQueues; i++ {
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
connID := protocol.ParseConnectionID(b)
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 100+i))
|
|
phm.EXPECT().Get(connID)
|
|
serv.handlePacket(p)
|
|
}
|
|
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 200))
|
|
phm.EXPECT().Get(connID)
|
|
dropped := make(chan struct{})
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(dropped)
|
|
})
|
|
serv.handlePacket(p)
|
|
Eventually(dropped).Should(BeClosed())
|
|
})
|
|
|
|
It("drops queues after a while", func() {
|
|
now := time.Now()
|
|
|
|
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
|
p := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 200))
|
|
p.rcvTime = now
|
|
|
|
connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9})
|
|
p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2)
|
|
p2 := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID2,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 300))
|
|
p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet
|
|
|
|
dropped1 := make(chan struct{})
|
|
dropped2 := make(chan struct{})
|
|
// need to register the call before handling the packet to avoid race condition
|
|
gomock.InOrder(
|
|
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(dropped1)
|
|
}),
|
|
tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
|
close(dropped2)
|
|
}),
|
|
)
|
|
|
|
phm.EXPECT().Get(connID)
|
|
serv.handlePacket(p)
|
|
|
|
// There's no cleanup Go routine.
|
|
// Cleanup is triggered when new packets are received.
|
|
|
|
phm.EXPECT().Get(connID2)
|
|
serv.handlePacket(p2)
|
|
// make sure no cleanup is executed
|
|
Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed())
|
|
|
|
// There's no cleanup Go routine.
|
|
// Cleanup is triggered when new packets are received.
|
|
connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0})
|
|
p3 := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID3,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 200))
|
|
p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
|
|
phm.EXPECT().Get(connID3)
|
|
serv.handlePacket(p3)
|
|
Eventually(dropped1).Should(BeClosed())
|
|
Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed())
|
|
|
|
// make sure the second packet is also cleaned up
|
|
connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1})
|
|
p4 := getPacket(&wire.Header{
|
|
Type: protocol.PacketType0RTT,
|
|
DestConnectionID: connID4,
|
|
Version: serv.config.Versions[0],
|
|
}, make([]byte, 200))
|
|
p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup
|
|
phm.EXPECT().Get(connID4)
|
|
serv.handlePacket(p4)
|
|
Eventually(dropped2).Should(BeClosed())
|
|
})
|
|
})
|
|
})
|