mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
add a method to retrieve non-QUIC packets from the Transport (#3992)
This commit is contained in:
parent
6880f88089
commit
fe3c4f271d
4 changed files with 195 additions and 4 deletions
|
@ -2,9 +2,11 @@ package self_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
@ -210,4 +212,67 @@ var _ = Describe("Multiplexing", func() {
|
|||
})
|
||||
}
|
||||
})
|
||||
|
||||
It("sends and receives non-QUIC packets", func() {
|
||||
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
|
||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
|
||||
server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server)
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
var sentPackets, rcvdPackets atomic.Int64
|
||||
const packetLen = 128
|
||||
// send a non-QUIC packet every 100µs
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
ticker := time.NewTicker(time.Millisecond / 10)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
b := make([]byte, packetLen)
|
||||
rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
|
||||
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sentPackets.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// receive and count non-QUIC packets
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
b := make([]byte, 1024)
|
||||
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
|
||||
if err != nil {
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
return
|
||||
}
|
||||
Expect(addr).To(Equal(tr1.Conn.LocalAddr()))
|
||||
Expect(n).To(Equal(packetLen))
|
||||
rcvdPackets.Add(1)
|
||||
}
|
||||
}()
|
||||
dial(tr2, server.Addr())
|
||||
Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10))
|
||||
Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -74,6 +74,10 @@ func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.Arbitra
|
|||
return destConnID, srcConnID, nil
|
||||
}
|
||||
|
||||
func IsPotentialQUICPacket(firstByte byte) bool {
|
||||
return firstByte&0x40 > 0
|
||||
}
|
||||
|
||||
// IsLongHeaderPacket says if this is a Long Header packet
|
||||
func IsLongHeaderPacket(firstByte byte) bool {
|
||||
return firstByte&0x80 > 0
|
||||
|
|
53
transport.go
53
transport.go
|
@ -7,12 +7,12 @@ import (
|
|||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
|
@ -85,6 +85,9 @@ type Transport struct {
|
|||
createdConn bool
|
||||
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
||||
|
||||
readingNonQUICPackets atomic.Bool
|
||||
nonQUICPackets chan receivedPacket
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
|
@ -341,6 +344,13 @@ func (t *Transport) listen(conn rawConn) {
|
|||
}
|
||||
|
||||
func (t *Transport) handlePacket(p receivedPacket) {
|
||||
if len(p.data) == 0 {
|
||||
return
|
||||
}
|
||||
if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
t.handleNonQUICPacket(p)
|
||||
return
|
||||
}
|
||||
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
||||
if err != nil {
|
||||
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
|
@ -429,3 +439,42 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Transport) handleNonQUICPacket(p receivedPacket) {
|
||||
// Strictly speaking, this is racy,
|
||||
// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
|
||||
if !t.readingNonQUICPackets.Load() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case t.nonQUICPackets <- p:
|
||||
default:
|
||||
if t.Tracer != nil {
|
||||
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const maxQueuedNonQUICPackets = 32
|
||||
|
||||
// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
|
||||
// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
|
||||
// Note that this is stricter than the detection logic defined in RFC 9443.
|
||||
func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
|
||||
if err := t.init(false); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if !t.readingNonQUICPackets.Load() {
|
||||
t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
|
||||
t.readingNonQUICPackets.Store(true)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, nil, ctx.Err()
|
||||
case p := <-t.nonQUICPackets:
|
||||
n := copy(b, p.data)
|
||||
return n, p.remoteAddr, nil
|
||||
case <-t.listening:
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
|
@ -122,7 +123,7 @@ var _ = Describe("Transport", func() {
|
|||
tr.Close()
|
||||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
It("drops unparseable QUIC packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
packetChan := make(chan packetToRead)
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
|
@ -136,7 +137,7 @@ var _ = Describe("Transport", func() {
|
|||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||
packetChan <- packetToRead{
|
||||
addr: addr,
|
||||
data: []byte{0, 1, 2, 3},
|
||||
data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
|
||||
}
|
||||
Eventually(dropped).Should(BeClosed())
|
||||
|
||||
|
@ -323,6 +324,78 @@ var _ = Describe("Transport", func() {
|
|||
conns := getMultiplexer().(*connMultiplexer).conns
|
||||
Expect(len(conns)).To(BeZero())
|
||||
})
|
||||
|
||||
It("allows receiving non-QUIC packets", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
packetChan := make(chan packetToRead)
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: 10,
|
||||
Tracer: tracer,
|
||||
}
|
||||
tr.init(true)
|
||||
receivedPacketChan := make(chan []byte)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
b := make([]byte, 100)
|
||||
n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(addr).To(Equal(remoteAddr))
|
||||
receivedPacketChan <- b[:n]
|
||||
}()
|
||||
// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
|
||||
// Give the Go routine some time to spin up.
|
||||
time.Sleep(scaleDuration(50 * time.Millisecond))
|
||||
packetChan <- packetToRead{
|
||||
addr: remoteAddr,
|
||||
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
||||
}
|
||||
|
||||
Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))
|
||||
|
||||
// shutdown
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
packetChan := make(chan packetToRead)
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: 10,
|
||||
Tracer: tracer,
|
||||
}
|
||||
tr.init(true)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
|
||||
for i := 0; i < maxQueuedNonQUICPackets; i++ {
|
||||
packetChan <- packetToRead{
|
||||
addr: remoteAddr,
|
||||
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
||||
}
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
packetChan <- packetToRead{
|
||||
addr: remoteAddr,
|
||||
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
|
||||
}
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
})
|
||||
|
||||
type mockSyscallConn struct {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue