mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 04:07:35 +03:00
read the ECN bits
This commit is contained in:
parent
876ab1d531
commit
ea3d32394d
19 changed files with 420 additions and 75 deletions
|
@ -14,9 +14,7 @@ env:
|
|||
global:
|
||||
- TIMESCALE_FACTOR=20
|
||||
matrix:
|
||||
- TRAVIS_GOARCH=amd64 TESTMODE=unit
|
||||
- TRAVIS_GOARCH=amd64 TESTMODE=integration
|
||||
- TRAVIS_GOARCH=386 TESTMODE=unit
|
||||
- TRAVIS_GOARCH=386 TESTMODE=integration
|
||||
|
||||
# second part of the GOARCH workaround
|
||||
|
|
|
@ -2,18 +2,12 @@
|
|||
|
||||
set -ex
|
||||
|
||||
if [ "${TESTMODE}" == "unit" ]; then
|
||||
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage integrationtests,benchmark
|
||||
fi
|
||||
|
||||
if [ "${TESTMODE}" == "integration" ]; then
|
||||
# run benchmark tests
|
||||
ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=10
|
||||
# run benchmark tests with the Go race detector
|
||||
# The Go race detector only works on amd64.
|
||||
if [ "${TRAVIS_GOARCH}" == 'amd64' ]; then
|
||||
ginkgo -race -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=5
|
||||
fi
|
||||
# run integration tests
|
||||
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace integrationtests
|
||||
# run benchmark tests
|
||||
ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=10
|
||||
# run benchmark tests with the Go race detector
|
||||
# The Go race detector only works on amd64.
|
||||
if [ "${TRAVIS_GOARCH}" != '386' ]; then
|
||||
ginkgo -race -randomizeAllSpecs -randomizeSuites -trace benchmark -- -size=5
|
||||
fi
|
||||
# run integration tests
|
||||
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace integrationtests
|
||||
|
|
|
@ -575,7 +575,7 @@ var _ = Describe("Client", func() {
|
|||
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).PacketConn).To(Equal(packetConn))
|
||||
Expect(cconn.(*sconn).PacketConn).To(Equal(packetConn))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
|
50
conn.go
Normal file
50
conn.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type connection interface {
|
||||
ReadPacket() (*receivedPacket, error)
|
||||
WriteTo([]byte, net.Addr) (int, error)
|
||||
LocalAddr() net.Addr
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func wrapConn(pc net.PacketConn) (connection, error) {
|
||||
udpConn, ok := pc.(*net.UDPConn)
|
||||
if !ok {
|
||||
utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.")
|
||||
return &basicConn{PacketConn: pc}, nil
|
||||
}
|
||||
return newConn(udpConn)
|
||||
}
|
||||
|
||||
type basicConn struct {
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
var _ connection = &basicConn{}
|
||||
|
||||
func (c *basicConn) ReadPacket() (*receivedPacket, error) {
|
||||
buffer := getPacketBuffer()
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
buffer.Data = buffer.Data[:protocol.MaxReceivePacketSize]
|
||||
n, addr, err := c.PacketConn.ReadFrom(buffer.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
rcvTime: time.Now(),
|
||||
data: buffer.Data[:n],
|
||||
buffer: buffer,
|
||||
}, nil
|
||||
}
|
91
conn_ecn.go
Normal file
91
conn_ecn.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
// +build !windows
|
||||
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const ecnMask uint8 = 0x3
|
||||
|
||||
type ecnConn struct {
|
||||
*net.UDPConn
|
||||
oobBuffer []byte
|
||||
}
|
||||
|
||||
var _ connection = &ecnConn{}
|
||||
|
||||
func newConn(c *net.UDPConn) (*ecnConn, error) {
|
||||
rawConn, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
|
||||
// Try enabling receiving of ECN for both IP versions.
|
||||
// We expect at least one of those syscalls to succeed.
|
||||
var errIPv4, errIPv6 error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
errIPv4 = setRECVTOS(fd)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
errIPv6 = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case errIPv4 == nil && errIPv6 == nil:
|
||||
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
|
||||
case errIPv4 == nil && errIPv6 != nil:
|
||||
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
|
||||
case errIPv4 != nil && errIPv6 == nil:
|
||||
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
|
||||
case errIPv4 != nil && errIPv6 != nil:
|
||||
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
|
||||
}
|
||||
return &ecnConn{
|
||||
UDPConn: c,
|
||||
oobBuffer: make([]byte, 128),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *ecnConn) ReadPacket() (*receivedPacket, error) {
|
||||
buffer := getPacketBuffer()
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
buffer.Data = buffer.Data[:protocol.MaxReceivePacketSize]
|
||||
c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)]
|
||||
n, oobn, _, addr, err := c.UDPConn.ReadMsgUDP(buffer.Data, c.oobBuffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctrlMsgs, err := syscall.ParseSocketControlMessage(c.oobBuffer[:oobn])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ecn protocol.ECN
|
||||
for _, ctrlMsg := range ctrlMsgs {
|
||||
if ctrlMsg.Header.Level == syscall.IPPROTO_IP && ctrlMsg.Header.Type == msgTypeIPTOS {
|
||||
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
|
||||
break
|
||||
}
|
||||
if ctrlMsg.Header.Level == syscall.IPPROTO_IPV6 && ctrlMsg.Header.Type == syscall.IPV6_TCLASS {
|
||||
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
|
||||
break
|
||||
}
|
||||
}
|
||||
return &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
rcvTime: time.Now(),
|
||||
data: buffer.Data[:n],
|
||||
ecn: ecn,
|
||||
buffer: buffer,
|
||||
}, nil
|
||||
}
|
128
conn_ecn_test.go
Normal file
128
conn_ecn_test.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
// +build !windows
|
||||
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Basic Conn Test", func() {
|
||||
Context("ECN conn", func() {
|
||||
runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) {
|
||||
addr, err := net.ResolveUDPAddr(network, address)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP(network, addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ecnConn, err := newConn(udpConn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
packetChan := make(chan *receivedPacket)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
p, err := ecnConn.ReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
packetChan <- p
|
||||
}
|
||||
}()
|
||||
|
||||
return udpConn, packetChan
|
||||
}
|
||||
|
||||
sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr {
|
||||
conn, err := net.DialUDP(network, nil, addr)
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
rawConn, err := conn.SyscallConn()
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, rawConn.Control(func(fd uintptr) {
|
||||
setECN(fd)
|
||||
})).To(Succeed())
|
||||
_, err = conn.Write([]byte("foobar"))
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
return conn.LocalAddr()
|
||||
}
|
||||
|
||||
It("reads ECN flags on IPv4", func() {
|
||||
conn, packetChan := runServer("udp4", "localhost:0")
|
||||
defer conn.Close()
|
||||
|
||||
sentFrom := sendPacketWithECN(
|
||||
"udp4",
|
||||
conn.LocalAddr().(*net.UDPAddr),
|
||||
func(fd uintptr) {
|
||||
Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, 2)).To(Succeed())
|
||||
},
|
||||
)
|
||||
|
||||
var p *receivedPacket
|
||||
Eventually(packetChan).Should(Receive(&p))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
|
||||
Expect(p.data).To(Equal([]byte("foobar")))
|
||||
Expect(p.remoteAddr).To(Equal(sentFrom))
|
||||
Expect(p.ecn).To(Equal(protocol.ECT0))
|
||||
})
|
||||
|
||||
It("reads ECN flags on IPv6", func() {
|
||||
conn, packetChan := runServer("udp6", "[::]:0")
|
||||
defer conn.Close()
|
||||
|
||||
sentFrom := sendPacketWithECN(
|
||||
"udp6",
|
||||
conn.LocalAddr().(*net.UDPAddr),
|
||||
func(fd uintptr) {
|
||||
Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 3)).To(Succeed())
|
||||
},
|
||||
)
|
||||
|
||||
var p *receivedPacket
|
||||
Eventually(packetChan).Should(Receive(&p))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
|
||||
Expect(p.data).To(Equal([]byte("foobar")))
|
||||
Expect(p.remoteAddr).To(Equal(sentFrom))
|
||||
Expect(p.ecn).To(Equal(protocol.ECNCE))
|
||||
})
|
||||
|
||||
It("reads ECN flags on a connection that supports both IPv4 and IPv6", func() {
|
||||
conn, packetChan := runServer("udp", "0.0.0.0:0")
|
||||
defer conn.Close()
|
||||
port := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
|
||||
// IPv4
|
||||
sendPacketWithECN(
|
||||
"udp4",
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port},
|
||||
func(fd uintptr) {
|
||||
Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, 3)).To(Succeed())
|
||||
},
|
||||
)
|
||||
|
||||
var p *receivedPacket
|
||||
Eventually(packetChan).Should(Receive(&p))
|
||||
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
|
||||
Expect(p.ecn).To(Equal(protocol.ECNCE))
|
||||
|
||||
// IPv6
|
||||
sendPacketWithECN(
|
||||
"udp6",
|
||||
&net.UDPAddr{IP: net.IPv6loopback, Port: port},
|
||||
func(fd uintptr) {
|
||||
Expect(syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 1)).To(Succeed())
|
||||
},
|
||||
)
|
||||
|
||||
Eventually(packetChan).Should(Receive(&p))
|
||||
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
|
||||
Expect(p.ecn).To(Equal(protocol.ECT1))
|
||||
})
|
||||
})
|
||||
})
|
14
conn_helper_darwin.go
Normal file
14
conn_helper_darwin.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
// +build darwin
|
||||
|
||||
package quic
|
||||
|
||||
import "syscall"
|
||||
|
||||
const (
|
||||
ip_recvtos = 27
|
||||
msgTypeIPTOS = ip_recvtos
|
||||
)
|
||||
|
||||
func setRECVTOS(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1)
|
||||
}
|
11
conn_helper_generic.go
Normal file
11
conn_helper_generic.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
// +build !darwin,!windows
|
||||
|
||||
package quic
|
||||
|
||||
import "syscall"
|
||||
|
||||
const msgTypeIPTOS = syscall.IP_TOS
|
||||
|
||||
func setRECVTOS(fd uintptr) error {
|
||||
return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_RECVTOS, 1)
|
||||
}
|
26
conn_test.go
Normal file
26
conn_test.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Basic Conn Test", func() {
|
||||
It("reads a packet", func() {
|
||||
c := newMockPacketConn()
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
c.dataReadFrom = addr
|
||||
c.dataToRead <- []byte("foobar")
|
||||
|
||||
conn, err := wrapConn(c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
p, err := conn.ReadPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.data).To(Equal([]byte("foobar")))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(100*time.Millisecond)))
|
||||
Expect(p.remoteAddr).To(Equal(addr))
|
||||
})
|
||||
})
|
9
conn_windows.go
Normal file
9
conn_windows.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
// +build windows
|
||||
|
||||
package quic
|
||||
|
||||
import "net"
|
||||
|
||||
func newConn(c *net.UDPConn) (connection, error) {
|
||||
return &basicConn{PacketConn: c}, nil
|
||||
}
|
|
@ -37,10 +37,10 @@ func (t PacketType) String() string {
|
|||
type ECN uint8
|
||||
|
||||
const (
|
||||
ECNNon ECN = iota
|
||||
ECT0
|
||||
ECT1
|
||||
ECNCE
|
||||
ECNNon ECN = iota // 00
|
||||
ECT1 // 01
|
||||
ECT0 // 10
|
||||
ECNCE // 11
|
||||
)
|
||||
|
||||
// A ByteCount in QUIC
|
||||
|
|
|
@ -15,4 +15,11 @@ var _ = Describe("Protocol", func() {
|
|||
Expect(PacketType(10).String()).To(Equal("unknown packet type: 10"))
|
||||
})
|
||||
})
|
||||
|
||||
It("converts ECN bits from the IP header wire to the correct types", func() {
|
||||
Expect(ECN(0)).To(Equal(ECNNon))
|
||||
Expect(ECN(0b00000010)).To(Equal(ECT0))
|
||||
Expect(ECN(0b00000001)).To(Equal(ECT1))
|
||||
Expect(ECN(0b00000011)).To(Equal(ECNCE))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -51,7 +51,7 @@ func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{
|
|||
}
|
||||
|
||||
// RemoveConn mocks base method
|
||||
func (m *MockMultiplexer) RemoveConn(arg0 net.PacketConn) error {
|
||||
func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveConn", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
|
|
@ -15,9 +15,13 @@ var (
|
|||
connMuxer multiplexer
|
||||
)
|
||||
|
||||
type indexableConn interface {
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
|
||||
type multiplexer interface {
|
||||
AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error)
|
||||
RemoveConn(net.PacketConn) error
|
||||
RemoveConn(indexableConn) error
|
||||
}
|
||||
|
||||
type connManager struct {
|
||||
|
@ -33,7 +37,7 @@ type connMultiplexer struct {
|
|||
mutex sync.Mutex
|
||||
|
||||
conns map[string] /* LocalAddr().String() */ connManager
|
||||
newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) packetHandlerManager // so it can be replaced in the tests
|
||||
newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -63,7 +67,10 @@ func (m *connMultiplexer) AddConn(
|
|||
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
|
||||
p, ok := m.conns[connIndex]
|
||||
if !ok {
|
||||
manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
|
||||
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p = connManager{
|
||||
connIDLen: connIDLen,
|
||||
statelessResetKey: statelessResetKey,
|
||||
|
@ -85,7 +92,7 @@ func (m *connMultiplexer) AddConn(
|
|||
return p.manager, nil
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) RemoveConn(c net.PacketConn) error {
|
||||
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ func (e statelessResetErr) Error() string {
|
|||
type packetHandlerMap struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
conn net.PacketConn
|
||||
conn connection
|
||||
connIDLen int
|
||||
|
||||
handlers map[string] /* string(ConnectionID)*/ packetHandler
|
||||
|
@ -54,12 +54,16 @@ type packetHandlerMap struct {
|
|||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
func newPacketHandlerMap(
|
||||
conn net.PacketConn,
|
||||
c net.PacketConn,
|
||||
connIDLen int,
|
||||
statelessResetKey []byte,
|
||||
tracer logging.Tracer,
|
||||
logger utils.Logger,
|
||||
) packetHandlerManager {
|
||||
) (packetHandlerManager, error) {
|
||||
conn, err := wrapConn(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := &packetHandlerMap{
|
||||
conn: conn,
|
||||
connIDLen: connIDLen,
|
||||
|
@ -78,7 +82,7 @@ func newPacketHandlerMap(
|
|||
go m.logUsage()
|
||||
}
|
||||
|
||||
return m
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) logUsage() {
|
||||
|
@ -253,55 +257,40 @@ func (h *packetHandlerMap) close(e error) error {
|
|||
func (h *packetHandlerMap) listen() {
|
||||
defer close(h.listening)
|
||||
for {
|
||||
buffer := getPacketBuffer()
|
||||
data := buffer.Data[:protocol.MaxReceivePacketSize]
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err := h.conn.ReadFrom(data)
|
||||
p, err := h.conn.ReadPacket()
|
||||
if err != nil {
|
||||
h.close(err)
|
||||
return
|
||||
}
|
||||
h.handlePacket(addr, buffer, data[:n])
|
||||
h.handlePacket(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) handlePacket(
|
||||
addr net.Addr,
|
||||
buffer *packetBuffer,
|
||||
data []byte,
|
||||
) {
|
||||
connID, err := wire.ParseConnectionID(data, h.connIDLen)
|
||||
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
|
||||
if err != nil {
|
||||
buffer.MaybeRelease()
|
||||
h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err)
|
||||
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
if h.tracer != nil {
|
||||
h.tracer.DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
|
||||
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
p.buffer.MaybeRelease()
|
||||
return
|
||||
}
|
||||
rcvTime := time.Now()
|
||||
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset {
|
||||
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
|
||||
handler, handlerFound := h.handlers[string(connID)]
|
||||
|
||||
p := &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
rcvTime: rcvTime,
|
||||
buffer: buffer,
|
||||
data: data,
|
||||
}
|
||||
if handlerFound { // existing session
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
if data[0]&0x80 == 0 {
|
||||
if p.data[0]&0x80 == 0 {
|
||||
go h.maybeSendStatelessReset(p, connID)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -56,7 +56,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
|
||||
JustBeforeEach(func() {
|
||||
conn = newMockPacketConn()
|
||||
handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger).(*packetHandlerMap)
|
||||
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler = phm.(*packetHandlerMap)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -130,7 +132,11 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
It("drops unparseable packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
|
||||
handler.handlePacket(addr, getPacketBuffer(), []byte{0, 1, 2, 3})
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: []byte{0, 1, 2, 3},
|
||||
})
|
||||
})
|
||||
|
||||
It("deletes removed sessions immediately", func() {
|
||||
|
@ -138,7 +144,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
|
@ -149,7 +155,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
handler.Add(connID, sess)
|
||||
handler.Retire(connID)
|
||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
|
@ -163,13 +169,13 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
})
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.Retire(connID)
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops packets for unknown receivers", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
})
|
||||
|
||||
It("closes the packet handlers when reading from the conn fails", func() {
|
||||
|
@ -210,7 +216,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Expect(cid).To(Equal(connID))
|
||||
})
|
||||
handler.SetServer(server)
|
||||
handler.handlePacket(nil, nil, p)
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
|
||||
It("closes all server sessions", func() {
|
||||
|
@ -232,7 +238,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
// don't EXPECT any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
handler.CloseServer()
|
||||
handler.handlePacket(nil, nil, p)
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -297,7 +303,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
p = append(p, token[:]...)
|
||||
|
||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||
handler.handlePacket(nil, nil, p)
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
|
||||
It("ignores packets too small to contain a stateless reset", func() {
|
||||
|
@ -332,7 +338,11 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
It("sends stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
handler.handlePacket(addr, getPacketBuffer(), p)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
var reset mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&reset))
|
||||
Expect(reset.to).To(Equal(addr))
|
||||
|
@ -343,7 +353,11 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
It("doesn't send stateless resets for small packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
|
||||
handler.handlePacket(addr, getPacketBuffer(), p)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
})
|
||||
|
@ -352,7 +366,11 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
It("doesn't send stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
handler.handlePacket(addr, getPacketBuffer(), p)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
})
|
||||
|
|
10
send_conn.go
10
send_conn.go
|
@ -12,23 +12,23 @@ type sendConn interface {
|
|||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
type sconn struct {
|
||||
net.PacketConn
|
||||
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ sendConn = &conn{}
|
||||
var _ sendConn = &sconn{}
|
||||
|
||||
func newSendConn(c net.PacketConn, remote net.Addr) sendConn {
|
||||
return &conn{PacketConn: c, remoteAddr: remote}
|
||||
return &sconn{PacketConn: c, remoteAddr: remote}
|
||||
}
|
||||
|
||||
func (c *conn) Write(p []byte) error {
|
||||
func (c *sconn) Write(p []byte) error {
|
||||
_, err := c.PacketConn.WriteTo(p, c.remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
func (c *sconn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
|
|
@ -855,7 +855,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
|
|||
return false
|
||||
}
|
||||
|
||||
if err := s.handleUnpackedPacket(packet, p.rcvTime, p.Size()); err != nil {
|
||||
if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil {
|
||||
s.closeLocal(err)
|
||||
return false
|
||||
}
|
||||
|
@ -973,6 +973,7 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) {
|
|||
|
||||
func (s *session) handleUnpackedPacket(
|
||||
packet *unpackedPacket,
|
||||
ecn protocol.ECN,
|
||||
rcvTime time.Time,
|
||||
packetSize protocol.ByteCount, // only for logging
|
||||
) error {
|
||||
|
@ -1070,7 +1071,7 @@ func (s *session) handleUnpackedPacket(
|
|||
}
|
||||
}
|
||||
|
||||
return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, protocol.ECNNon, packet.encryptionLevel, rcvTime, isAckEliciting)
|
||||
return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
|
||||
}
|
||||
|
||||
func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error {
|
||||
|
|
|
@ -756,6 +756,7 @@ var _ = Describe("Session", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
packet := getPacket(hdr, nil)
|
||||
packet.ecn = protocol.ECNCE
|
||||
rcvTime := time.Now().Add(-10 * time.Second)
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
|
||||
packetNumber: 0x1337,
|
||||
|
@ -766,7 +767,7 @@ var _ = Describe("Session", func() {
|
|||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial),
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNNon, protocol.EncryptionInitial, rcvTime, false),
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false),
|
||||
)
|
||||
sess.receivedPacketHandler = rph
|
||||
packet.rcvTime = rcvTime
|
||||
|
@ -785,6 +786,7 @@ var _ = Describe("Session", func() {
|
|||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed())
|
||||
packet := getPacket(hdr, nil)
|
||||
packet.ecn = protocol.ECT1
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
|
||||
packetNumber: 0x1337,
|
||||
encryptionLevel: protocol.Encryption1RTT,
|
||||
|
@ -794,7 +796,7 @@ var _ = Describe("Session", func() {
|
|||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNNon, protocol.Encryption1RTT, rcvTime, true),
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, true),
|
||||
)
|
||||
sess.receivedPacketHandler = rph
|
||||
packet.rcvTime = rcvTime
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue