Use the correct source IP when binding multiple IPs

When the server is listening on multiple interfaces or interfaces with
multiple IPs, the outgoing datagrams are sometime delivered with the
wrong source IP address.

In order to fix that, each quic connection needs to extract the
destination IP (and optionally interface id) of the received datagrams,
and set it as source IP (and interface) on the sent datagrams.

On most platforms, this can be done using ancillary data with recvmsg()
and sendmsg(). Some of the machinery for this is already there for ECN,
this change extends it to read the destination IP info and write it to
the outgoing packets.

Fix #1736
This commit is contained in:
Olivier Poitrey 2021-03-16 00:43:41 +01:00
parent 3bce408c8d
commit eb6bdfdfc1
15 changed files with 468 additions and 181 deletions

View file

@ -116,13 +116,14 @@ func dialAddrContext(
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn. If
// If the PacketConn satisfies the ECNCapablePacketConn interface (as a net.UDPConn does), ECN support will be enabled. // the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// In this case, ReadMsgUDP will be used instead of ReadFrom to read packets. // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// The same PacketConn can be used for multiple calls to Dial and Listen, // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// QUIC connection IDs are used for demultiplexing the different connections. // packets. The same PacketConn can be used for multiple calls to Dial and
// The host parameter is used for SNI. // Listen, QUIC connection IDs are used for demultiplexing the different
// The tls.Config must define an application protocol (using NextProtos). // connections. The host parameter is used for SNI. The tls.Config must define
// an application protocol (using NextProtos).
func Dial( func Dial(
pconn net.PacketConn, pconn net.PacketConn,
remoteAddr net.Addr, remoteAddr net.Addr,
@ -255,7 +256,7 @@ func newClient(
c := &client{ c := &client{
srcConnID: srcConnID, srcConnID: srcConnID,
destConnID: destConnID, destConnID: destConnID,
conn: newSendConn(pconn, remoteAddr), conn: newSendPconn(pconn, remoteAddr),
createdPacketConn: createdPacketConn, createdPacketConn: createdPacketConn,
use0RTT: use0RTT, use0RTT: use0RTT,
tlsConf: tlsConf, tlsConf: tlsConf,

View file

@ -64,7 +64,7 @@ var _ = Describe("Client", func() {
srcConnID: connID, srcConnID: connID,
destConnID: connID, destConnID: connID,
version: protocol.VersionTLS, version: protocol.VersionTLS,
conn: newSendConn(packetConn, addr), conn: newSendPconn(packetConn, addr),
tracer: tracer, tracer: tracer,
logger: utils.DefaultLogger, logger: utils.DefaultLogger,
} }
@ -548,7 +548,7 @@ var _ = Describe("Client", func() {
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed()) Eventually(c).Should(BeClosed())
Expect(cconn.(*sconn).PacketConn).To(Equal(packetConn)) Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
Expect(version).To(Equal(config.Versions[0])) Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions)) Expect(conf.Versions).To(Equal(config.Versions))
}) })

13
conn.go
View file

@ -12,23 +12,24 @@ import (
type connection interface { type connection interface {
ReadPacket() (*receivedPacket, error) ReadPacket() (*receivedPacket, error)
WriteTo([]byte, net.Addr) (int, error) WritePacket(b []byte, addr net.Addr, info *packetInfo) (int, error)
LocalAddr() net.Addr LocalAddr() net.Addr
io.Closer io.Closer
} }
// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will read the ECN bits from the IP header. // If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will read the ECN bits from the IP header.
// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. // In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets.
type ECNCapablePacketConn interface { type OOBCapablePacketConn interface {
net.PacketConn net.PacketConn
SyscallConn() (syscall.RawConn, error) SyscallConn() (syscall.RawConn, error)
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
} }
var _ ECNCapablePacketConn = &net.UDPConn{} var _ OOBCapablePacketConn = &net.UDPConn{}
func wrapConn(pc net.PacketConn) (connection, error) { func wrapConn(pc net.PacketConn) (connection, error) {
c, ok := pc.(ECNCapablePacketConn) c, ok := pc.(OOBCapablePacketConn)
if !ok { if !ok {
utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.")
return &basicConn{PacketConn: pc}, nil return &basicConn{PacketConn: pc}, nil
@ -58,3 +59,7 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) {
buffer: buffer, buffer: buffer,
}, nil }, nil
} }
func (c *basicConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}

View file

@ -1,115 +0,0 @@
// +build darwin linux
package quic
import (
"errors"
"fmt"
"net"
"syscall"
"time"
"golang.org/x/sys/unix"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const ecnMask uint8 = 0x3
func inspectReadBuffer(c net.PacketConn) (int, error) {
conn, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, errors.New("doesn't have a SyscallConn")
}
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
}
var size int
var serr error
if err := rawConn.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}); err != nil {
return 0, err
}
return size, serr
}
type ecnConn struct {
ECNCapablePacketConn
oobBuffer []byte
}
var _ connection = &ecnConn{}
func newConn(c ECNCapablePacketConn) (*ecnConn, error) {
rawConn, err := c.SyscallConn()
if err != nil {
return nil, err
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN for both IP versions.
// We expect at least one of those syscalls to succeed.
var errIPv4, errIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
}); err != nil {
return nil, err
}
if err := rawConn.Control(func(fd uintptr) {
errIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
}); err != nil {
return nil, err
}
switch {
case errIPv4 == nil && errIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
case errIPv4 == nil && errIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
case errIPv4 != nil && errIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
case errIPv4 != nil && errIPv6 != nil:
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
}
return &ecnConn{
ECNCapablePacketConn: c,
oobBuffer: make([]byte, 128),
}, nil
}
func (c *ecnConn) ReadPacket() (*receivedPacket, error) {
buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)]
n, oobn, _, addr, err := c.ECNCapablePacketConn.ReadMsgUDP(buffer.Data, c.oobBuffer)
if err != nil {
return nil, err
}
ctrlMsgs, err := unix.ParseSocketControlMessage(c.oobBuffer[:oobn])
if err != nil {
return nil, err
}
var ecn protocol.ECN
for _, ctrlMsg := range ctrlMsgs {
if ctrlMsg.Header.Level == unix.IPPROTO_IP && ctrlMsg.Header.Type == msgTypeIPTOS {
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
break
}
if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 && ctrlMsg.Header.Type == unix.IPV6_TCLASS {
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
break
}
}
return &receivedPacket{
remoteAddr: addr,
rcvTime: time.Now(),
data: buffer.Data[:n],
ecn: ecn,
buffer: buffer,
}, nil
}

View file

@ -15,31 +15,31 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Basic Conn Test", func() { var _ = Describe("OOB Conn Test", func() {
Context("ECN conn", func() { runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) {
runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) { addr, err := net.ResolveUDPAddr(network, address)
addr, err := net.ResolveUDPAddr(network, address) Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred()) udpConn, err := net.ListenUDP(network, addr)
udpConn, err := net.ListenUDP(network, addr) Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred()) ecnConn, err := newConn(udpConn)
ecnConn, err := newConn(udpConn) Expect(err).ToNot(HaveOccurred())
Expect(err).ToNot(HaveOccurred())
packetChan := make(chan *receivedPacket) packetChan := make(chan *receivedPacket)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
for { for {
p, err := ecnConn.ReadPacket() p, err := ecnConn.ReadPacket()
if err != nil { if err != nil {
return return
}
packetChan <- p
} }
}() packetChan <- p
}
}()
return udpConn, packetChan return udpConn, packetChan
} }
Context("ECN conn", func() {
sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr { sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr {
conn, err := net.DialUDP(network, nil, addr) conn, err := net.DialUDP(network, nil, addr)
ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, err).ToNot(HaveOccurred())
@ -126,4 +126,75 @@ var _ = Describe("Basic Conn Test", func() {
Expect(p.ecn).To(Equal(protocol.ECT1)) Expect(p.ecn).To(Equal(protocol.ECT1))
}) })
}) })
Context("Packet Info conn", func() {
sendPacket := func(network string, addr *net.UDPAddr) net.Addr {
conn, err := net.DialUDP(network, nil, addr)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
_, err = conn.Write([]byte("foobar"))
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return conn.LocalAddr()
}
It("reads packet info on IPv4", func() {
conn, packetChan := runServer("udp4", ":0")
defer conn.Close()
addr := conn.LocalAddr().(*net.UDPAddr)
ip := net.ParseIP("127.0.0.1").To4()
addr.IP = ip
sentFrom := sendPacket("udp4", addr)
var p *receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.remoteAddr).To(Equal(sentFrom))
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr.To4()).To(Equal(ip))
})
It("reads packet info on IPv6", func() {
conn, packetChan := runServer("udp6", ":0")
defer conn.Close()
addr := conn.LocalAddr().(*net.UDPAddr)
ip := net.ParseIP("::1")
addr.IP = ip
sentFrom := sendPacket("udp6", addr)
var p *receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.remoteAddr).To(Equal(sentFrom))
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr).To(Equal(ip))
})
It("reads packet info on a connection that supports both IPv4 and IPv6", func() {
conn, packetChan := runServer("udp", ":0")
defer conn.Close()
port := conn.LocalAddr().(*net.UDPAddr).Port
// IPv4
ip4 := net.ParseIP("127.0.0.1").To4()
sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port})
var p *receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr.To4()).To(Equal(ip4))
// IPv6
ip6 := net.ParseIP("::1")
sendPacket("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: port})
Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr).To(Equal(ip6))
})
})
}) })

View file

@ -1,4 +1,4 @@
// +build !darwin,!linux,!windows // +build !darwin,!linux,!freebsd,!windows
package quic package quic
@ -8,6 +8,6 @@ func newConn(c net.PacketConn) (connection, error) {
return &basicConn{PacketConn: c}, nil return &basicConn{PacketConn: c}, nil
} }
func inspectReadBuffer(net.PacketConn) (int, error) { func inspectReadBuffer(interface{}) (int, error) {
return 0, nil return 0, nil
} }

View file

@ -5,3 +5,13 @@ package quic
import "golang.org/x/sys/unix" import "golang.org/x/sys/unix"
const msgTypeIPTOS = unix.IP_RECVTOS const msgTypeIPTOS = unix.IP_RECVTOS
const (
ipv4RECVPKTINFO = 0x1a
ipv6RECVPKTINFO = 0x3d
)
const (
msgTypeIPv4PKTINFO = 0x1a
msgTypeIPv6PKTINFO = 0x2e
)

17
conn_helper_freebsd.go Normal file
View file

@ -0,0 +1,17 @@
// +build freebsd
package quic
import "golang.org/x/sys/unix"
const msgTypeIPTOS = unix.IP_RECVTOS
const (
ipv4RECVPKTINFO = 0x7
ipv6RECVPKTINFO = 0x24
)
const (
msgTypeIPv4PKTINFO = 0x7
msgTypeIPv6PKTINFO = 0x2e
)

View file

@ -5,3 +5,13 @@ package quic
import "golang.org/x/sys/unix" import "golang.org/x/sys/unix"
const msgTypeIPTOS = unix.IP_TOS const msgTypeIPTOS = unix.IP_TOS
const (
ipv4RECVPKTINFO = 0x8
ipv6RECVPKTINFO = 0x31
)
const (
msgTypeIPv4PKTINFO = 0x8
msgTypeIPv6PKTINFO = 0x32
)

239
conn_oob.go Normal file
View file

@ -0,0 +1,239 @@
// +build darwin linux freebsd
package quic
import (
"encoding/binary"
"errors"
"fmt"
"net"
"runtime"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const ecnMask uint8 = 0x3
func inspectReadBuffer(c interface{}) (int, error) {
conn, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, errors.New("doesn't have a SyscallConn")
}
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
}
var size int
var serr error
if err := rawConn.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}); err != nil {
return 0, err
}
return size, serr
}
type oobConn struct {
OOBCapablePacketConn
oobBuffer []byte
}
var _ connection = &oobConn{}
func newConn(c OOBCapablePacketConn) (*oobConn, error) {
rawConn, err := c.SyscallConn()
if err != nil {
return nil, err
}
needsPacketInfo := false
if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
needsPacketInfo = true
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN and packet info for both IP versions.
// We expect at least one of those syscalls to succeed.
var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
if needsPacketInfo {
errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1)
errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1)
}
}); err != nil {
return nil, err
}
switch {
case errECNIPv4 == nil && errECNIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
case errECNIPv4 == nil && errECNIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
case errECNIPv4 != nil && errECNIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
case errECNIPv4 != nil && errECNIPv6 != nil:
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
}
if needsPacketInfo {
switch {
case errPIIPv4 == nil && errPIIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
case errPIIPv4 == nil && errPIIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
case errPIIPv4 != nil && errPIIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
case errPIIPv4 != nil && errPIIPv6 != nil:
return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
}
}
return &oobConn{
OOBCapablePacketConn: c,
oobBuffer: make([]byte, 128),
}, nil
}
func (c *oobConn) ReadPacket() (*receivedPacket, error) {
buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)]
n, oobn, _, addr, err := c.OOBCapablePacketConn.ReadMsgUDP(buffer.Data, c.oobBuffer)
if err != nil {
return nil, err
}
ctrlMsgs, err := unix.ParseSocketControlMessage(c.oobBuffer[:oobn])
if err != nil {
return nil, err
}
var ecn protocol.ECN
var destIP net.IP
var ifIndex uint32
for _, ctrlMsg := range ctrlMsgs {
if ctrlMsg.Header.Level == unix.IPPROTO_IP {
switch ctrlMsg.Header.Type {
case msgTypeIPTOS:
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
case msgTypeIPv4PKTINFO:
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination
// address */
// };
if len(ctrlMsg.Data) == 12 {
ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data)
destIP = net.IP(ctrlMsg.Data[8:12])
} else if len(ctrlMsg.Data) == 4 {
// FreeBSD
destIP = net.IP(ctrlMsg.Data)
}
}
}
if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 {
switch ctrlMsg.Header.Type {
case unix.IPV6_TCLASS:
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
case msgTypeIPv6PKTINFO:
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
if len(ctrlMsg.Data) == 20 {
destIP = net.IP(ctrlMsg.Data[:16])
ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:])
}
}
}
}
var info *packetInfo
if destIP != nil {
info = &packetInfo{
addr: destIP,
ifIndex: ifIndex,
}
}
return &receivedPacket{
remoteAddr: addr,
rcvTime: time.Now(),
data: buffer.Data[:n],
ecn: ecn,
info: info,
buffer: buffer,
}, nil
}
func (c *oobConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) {
n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, info.OOB(), addr.(*net.UDPAddr))
return n, err
}
func (info *packetInfo) OOB() []byte {
if info == nil {
return nil
}
info.once.Do(info.computeOOB)
return info.oob
}
func (info *packetInfo) computeOOB() {
if ip4 := info.addr.To4(); ip4 != nil {
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */
// };
msgLen := 12
if runtime.GOOS == "freebsd" {
msgLen = 4
}
cmsglen := cmsgLen(msgLen)
oob := make([]byte, cmsglen)
cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
cmsg.Level = syscall.IPPROTO_TCP
cmsg.Type = msgTypeIPv4PKTINFO
cmsg.SetLen(cmsglen)
off := cmsgLen(0)
if runtime.GOOS != "freebsd" {
// FreeBSD does not support in_pktinfo, just an in_addr is sent
binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
off += 4
}
copy(oob[off:], ip4)
info.oob = oob
} else if len(info.addr) == 16 {
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
const msgLen = 20
cmsglen := cmsgLen(msgLen)
oob := make([]byte, cmsglen)
cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
cmsg.Level = syscall.IPPROTO_IPV6
cmsg.Type = msgTypeIPv6PKTINFO
cmsg.SetLen(cmsglen)
off := cmsgLen(0)
off += copy(oob[off:], info.addr)
binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
info.oob = oob
}
}
func cmsgLen(datalen int) int {
return cmsgAlign(syscall.SizeofCmsghdr) + datalen
}
func cmsgAlign(salen int) int {
const sizeOfPtr = 0x8
salign := sizeOfPtr
return (salen + salign - 1) & ^(salign - 1)
}

View file

@ -469,7 +469,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro
rand.Read(data) rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40 data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...) data = append(data, token[:]...)
if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil { if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info); err != nil {
h.logger.Debugf("Error sending Stateless Reset: %s", err) h.logger.Debugf("Error sending Stateless Reset: %s", err)
} }
} }

View file

@ -13,22 +13,56 @@ type sendConn interface {
} }
type sconn struct { type sconn struct {
net.PacketConn connection
remoteAddr net.Addr remoteAddr net.Addr
info *packetInfo
} }
var _ sendConn = &sconn{} var _ sendConn = &sconn{}
func newSendConn(c net.PacketConn, remote net.Addr) sendConn { func newSendConn(c connection, remote net.Addr, info *packetInfo) sendConn {
return &sconn{PacketConn: c, remoteAddr: remote} return &sconn{connection: c, remoteAddr: remote, info: info}
} }
func (c *sconn) Write(p []byte) error { func (c *sconn) Write(p []byte) error {
_, err := c.PacketConn.WriteTo(p, c.remoteAddr) _, err := c.WritePacket(p, c.remoteAddr, c.info)
return err return err
} }
func (c *sconn) RemoteAddr() net.Addr { func (c *sconn) RemoteAddr() net.Addr {
return c.remoteAddr return c.remoteAddr
} }
func (c *sconn) LocalAddr() net.Addr {
addr := c.connection.LocalAddr()
if c.info != nil {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
addrCopy := *udpAddr
addrCopy.IP = c.info.addr
addr = &addrCopy
}
}
return addr
}
type spconn struct {
net.PacketConn
remoteAddr net.Addr
}
var _ sendConn = &spconn{}
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
return &spconn{PacketConn: c, remoteAddr: remote}
}
func (c *spconn) Write(p []byte) error {
_, err := c.WriteTo(p, c.remoteAddr)
return err
}
func (c *spconn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View file

@ -17,7 +17,7 @@ var _ = Describe("Connection (for sending packets)", func() {
BeforeEach(func() { BeforeEach(func() {
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl) packetConn = NewMockPacketConn(mockCtrl)
c = newSendConn(packetConn, addr) c = newSendPconn(packetConn, addr)
}) })
It("writes", func() { It("writes", func() {

View file

@ -61,7 +61,7 @@ type baseServer struct {
tlsConf *tls.Config tlsConf *tls.Config
config *Config config *Config
conn net.PacketConn conn connection
// If the server is started with ListenAddr, we create a packet conn. // If the server is started with ListenAddr, we create a packet conn.
// If it is started with Listen, we take a packet conn as a parameter. // If it is started with Listen, we take a packet conn as a parameter.
createdPacketConn bool createdPacketConn bool
@ -148,16 +148,17 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
return serv, nil return serv, nil
} }
// Listen listens for QUIC connections on a given net.PacketConn. // Listen listens for QUIC connections on a given net.PacketConn. If the
// If the PacketConn satisfies the ECNCapablePacketConn interface (as a net.UDPConn does), ECN support will be enabled. // PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// In this case, ReadMsgUDP will be used instead of ReadFrom to read packets. // does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// A single net.PacketConn only be used for a single call to Listen. // and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// The PacketConn can be used for simultaneous calls to Dial. // packets. A single net.PacketConn only be used for a single call to Listen.
// QUIC connection IDs are used for demultiplexing the different connections. // The PacketConn can be used for simultaneous calls to Dial. QUIC connection
// The tls.Config must not be nil and must contain a certificate configuration. // IDs are used for demultiplexing the different connections. The tls.Config
// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. // must not be nil and must contain a certificate configuration. The
// Furthermore, it must define an application control (using NextProtos). // tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. Furthermore,
// The quic.Config may be nil, in that case the default values will be used. // it must define an application control (using NextProtos). The quic.Config may
// be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config, false) return listen(conn, tlsConf, config, false)
} }
@ -193,8 +194,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, err := wrapConn(conn)
if err != nil {
return nil, err
}
s := &baseServer{ s := &baseServer{
conn: conn, conn: c,
tlsConf: tlsConf, tlsConf: tlsConf,
config: config, config: config,
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
@ -421,7 +426,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
} }
return return
} }
if err := s.sendRetry(p.remoteAddr, hdr); err != nil { if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error sending Retry: %s", err) s.logger.Debugf("Error sending Retry: %s", err)
} }
}() }()
@ -432,7 +437,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
go func() { go func() {
defer p.buffer.Release() defer p.buffer.Release()
if err := s.sendConnectionRefused(p.remoteAddr, hdr); err != nil { if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error rejecting connection: %s", err) s.logger.Debugf("Error rejecting connection: %s", err)
} }
}() }()
@ -456,7 +461,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID) tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID)
} }
sess = s.newSession( sess = s.newSession(
newSendConn(s.conn, p.remoteAddr), newSendConn(s.conn, p.remoteAddr, p.info),
s.sessionHandler, s.sessionHandler,
origDestConnID, origDestConnID,
retrySrcConnID, retrySrcConnID,
@ -514,7 +519,7 @@ func (s *baseServer) handleNewSession(sess quicSession) {
} }
} }
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error {
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session. // If no Retry is sent, the packet will be logged by the session.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
@ -551,7 +556,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
if s.config.Tracer != nil { if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil)
} }
_, err = s.conn.WriteTo(buf.Bytes(), remoteAddr) _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info)
return err return err
} }
@ -579,16 +584,16 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
if s.logger.Debug() { if s.logger.Debug() {
s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr)
} }
return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken) return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info)
} }
func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header) error { func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error {
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused) return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info)
} }
// sendError sends the error as a response to the packet received with header hdr // sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode) error { func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode, info *packetInfo) error {
packetBuffer := getPacketBuffer() packetBuffer := getPacketBuffer()
defer packetBuffer.Release() defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Data) buf := bytes.NewBuffer(packetBuffer.Data)
@ -628,7 +633,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
if s.config.Tracer != nil { if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf})
} }
_, err := s.conn.WriteTo(raw, remoteAddr) _, err := s.conn.WritePacket(raw, remoteAddr, info)
return err return err
} }
@ -651,7 +656,7 @@ func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.H
nil, nil,
) )
} }
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil { if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err) s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
} }

View file

@ -59,6 +59,13 @@ type cryptoStreamHandler interface {
ConnectionState() handshake.ConnectionState ConnectionState() handshake.ConnectionState
} }
type packetInfo struct {
addr net.IP
ifIndex uint32
once sync.Once
oob []byte
}
type receivedPacket struct { type receivedPacket struct {
buffer *packetBuffer buffer *packetBuffer
@ -67,6 +74,8 @@ type receivedPacket struct {
data []byte data []byte
ecn protocol.ECN ecn protocol.ECN
info *packetInfo
} }
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }
@ -78,6 +87,7 @@ func (p *receivedPacket) Clone() *receivedPacket {
data: p.data, data: p.data,
buffer: p.buffer, buffer: p.buffer,
ecn: p.ecn, ecn: p.ecn,
info: p.info,
} }
} }