mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
add a method to retrieve non-QUIC packets from the Transport (#3992)
This commit is contained in:
parent
6880f88089
commit
fe3c4f271d
4 changed files with 195 additions and 4 deletions
|
@ -2,9 +2,11 @@ package self_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
|
@ -210,4 +212,67 @@ var _ = Describe("Multiplexing", func() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("sends and receives non-QUIC packets", func() {
|
||||||
|
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn1, err := net.ListenUDP("udp", addr1)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn1.Close()
|
||||||
|
tr1 := &quic.Transport{Conn: conn1}
|
||||||
|
|
||||||
|
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn2, err := net.ListenUDP("udp", addr2)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer conn2.Close()
|
||||||
|
tr2 := &quic.Transport{Conn: conn2}
|
||||||
|
|
||||||
|
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
runServer(server)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
var sentPackets, rcvdPackets atomic.Int64
|
||||||
|
const packetLen = 128
|
||||||
|
// send a non-QUIC packet every 100µs
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
ticker := time.NewTicker(time.Millisecond / 10)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b := make([]byte, packetLen)
|
||||||
|
rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
|
||||||
|
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
sentPackets.Add(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// receive and count non-QUIC packets
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
for {
|
||||||
|
b := make([]byte, 1024)
|
||||||
|
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
|
||||||
|
if err != nil {
|
||||||
|
Expect(err).To(MatchError(context.Canceled))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Expect(addr).To(Equal(tr1.Conn.LocalAddr()))
|
||||||
|
Expect(n).To(Equal(packetLen))
|
||||||
|
rcvdPackets.Add(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
dial(tr2, server.Addr())
|
||||||
|
Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10))
|
||||||
|
Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -74,6 +74,10 @@ func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.Arbitra
|
||||||
return destConnID, srcConnID, nil
|
return destConnID, srcConnID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsPotentialQUICPacket(firstByte byte) bool {
|
||||||
|
return firstByte&0x40 > 0
|
||||||
|
}
|
||||||
|
|
||||||
// IsLongHeaderPacket says if this is a Long Header packet
|
// IsLongHeaderPacket says if this is a Long Header packet
|
||||||
func IsLongHeaderPacket(firstByte byte) bool {
|
func IsLongHeaderPacket(firstByte byte) bool {
|
||||||
return firstByte&0x80 > 0
|
return firstByte&0x80 > 0
|
||||||
|
|
53
transport.go
53
transport.go
|
@ -7,12 +7,12 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
"github.com/quic-go/quic-go/logging"
|
"github.com/quic-go/quic-go/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,6 +85,9 @@ type Transport struct {
|
||||||
createdConn bool
|
createdConn bool
|
||||||
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
||||||
|
|
||||||
|
readingNonQUICPackets atomic.Bool
|
||||||
|
nonQUICPackets chan receivedPacket
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -341,6 +344,13 @@ func (t *Transport) listen(conn rawConn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) handlePacket(p receivedPacket) {
|
func (t *Transport) handlePacket(p receivedPacket) {
|
||||||
|
if len(p.data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
|
||||||
|
t.handleNonQUICPacket(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||||
|
@ -429,3 +439,42 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Transport) handleNonQUICPacket(p receivedPacket) {
|
||||||
|
// Strictly speaking, this is racy,
|
||||||
|
// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
|
||||||
|
if !t.readingNonQUICPackets.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case t.nonQUICPackets <- p:
|
||||||
|
default:
|
||||||
|
if t.Tracer != nil {
|
||||||
|
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxQueuedNonQUICPackets = 32
|
||||||
|
|
||||||
|
// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
|
||||||
|
// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
|
||||||
|
// Note that this is stricter than the detection logic defined in RFC 9443.
|
||||||
|
func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
|
||||||
|
if err := t.init(false); err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
if !t.readingNonQUICPackets.Load() {
|
||||||
|
t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
|
||||||
|
t.readingNonQUICPackets.Store(true)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, nil, ctx.Err()
|
||||||
|
case p := <-t.nonQUICPackets:
|
||||||
|
n := copy(b, p.data)
|
||||||
|
return n, p.remoteAddr, nil
|
||||||
|
case <-t.listening:
|
||||||
|
return 0, nil, errors.New("closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -122,7 +123,7 @@ var _ = Describe("Transport", func() {
|
||||||
tr.Close()
|
tr.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops unparseable packets", func() {
|
It("drops unparseable QUIC packets", func() {
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||||
packetChan := make(chan packetToRead)
|
packetChan := make(chan packetToRead)
|
||||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||||
|
@ -136,7 +137,7 @@ var _ = Describe("Transport", func() {
|
||||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||||
packetChan <- packetToRead{
|
packetChan <- packetToRead{
|
||||||
addr: addr,
|
addr: addr,
|
||||||
data: []byte{0, 1, 2, 3},
|
data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
|
||||||
}
|
}
|
||||||
Eventually(dropped).Should(BeClosed())
|
Eventually(dropped).Should(BeClosed())
|
||||||
|
|
||||||
|
@ -323,6 +324,78 @@ var _ = Describe("Transport", func() {
|
||||||
conns := getMultiplexer().(*connMultiplexer).conns
|
conns := getMultiplexer().(*connMultiplexer).conns
|
||||||
Expect(len(conns)).To(BeZero())
|
Expect(len(conns)).To(BeZero())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("allows receiving non-QUIC packets", func() {
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||||
|
tr := &Transport{
|
||||||
|
Conn: newMockPacketConn(packetChan),
|
||||||
|
ConnectionIDLength: 10,
|
||||||
|
Tracer: tracer,
|
||||||
|
}
|
||||||
|
tr.init(true)
|
||||||
|
receivedPacketChan := make(chan []byte)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
b := make([]byte, 100)
|
||||||
|
n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(addr).To(Equal(remoteAddr))
|
||||||
|
receivedPacketChan <- b[:n]
|
||||||
|
}()
|
||||||
|
// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
|
||||||
|
// Give the Go routine some time to spin up.
|
||||||
|
time.Sleep(scaleDuration(50 * time.Millisecond))
|
||||||
|
packetChan <- packetToRead{
|
||||||
|
addr: remoteAddr,
|
||||||
|
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
|
||||||
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||||
|
packetChan := make(chan packetToRead)
|
||||||
|
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||||
|
tr := &Transport{
|
||||||
|
Conn: newMockPacketConn(packetChan),
|
||||||
|
ConnectionIDLength: 10,
|
||||||
|
Tracer: tracer,
|
||||||
|
}
|
||||||
|
tr.init(true)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
|
||||||
|
Expect(err).To(MatchError(context.Canceled))
|
||||||
|
|
||||||
|
for i := 0; i < maxQueuedNonQUICPackets; i++ {
|
||||||
|
packetChan <- packetToRead{
|
||||||
|
addr: remoteAddr,
|
||||||
|
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
packetChan <- packetToRead{
|
||||||
|
addr: remoteAddr,
|
||||||
|
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
||||||
|
}
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
|
||||||
|
// shutdown
|
||||||
|
close(packetChan)
|
||||||
|
tr.Close()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
type mockSyscallConn struct {
|
type mockSyscallConn struct {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue