Use the correct source IP when binding multiple IPs

When the server is listening on multiple interfaces or interfaces with
multiple IPs, the outgoing datagrams are sometime delivered with the
wrong source IP address.

In order to fix that, each quic connection needs to extract the
destination IP (and optionally interface id) of the received datagrams,
and set it as source IP (and interface) on the sent datagrams.

On most platforms, this can be done using ancillary data with recvmsg()
and sendmsg(). Some of the machinery for this is already there for ECN,
this change extends it to read the destination IP info and write it to
the outgoing packets.

Fix #1736
This commit is contained in:
Olivier Poitrey 2021-03-16 00:43:41 +01:00
parent 3bce408c8d
commit eb6bdfdfc1
15 changed files with 468 additions and 181 deletions

View file

@ -116,13 +116,14 @@ func dialAddrContext(
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the ECNCapablePacketConn interface (as a net.UDPConn does), ECN support will be enabled.
// In this case, ReadMsgUDP will be used instead of ReadFrom to read packets.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The host parameter is used for SNI.
// The tls.Config must define an application protocol (using NextProtos).
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets. The same PacketConn can be used for multiple calls to Dial and
// Listen, QUIC connection IDs are used for demultiplexing the different
// connections. The host parameter is used for SNI. The tls.Config must define
// an application protocol (using NextProtos).
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
@ -255,7 +256,7 @@ func newClient(
c := &client{
srcConnID: srcConnID,
destConnID: destConnID,
conn: newSendConn(pconn, remoteAddr),
conn: newSendPconn(pconn, remoteAddr),
createdPacketConn: createdPacketConn,
use0RTT: use0RTT,
tlsConf: tlsConf,

View file

@ -64,7 +64,7 @@ var _ = Describe("Client", func() {
srcConnID: connID,
destConnID: connID,
version: protocol.VersionTLS,
conn: newSendConn(packetConn, addr),
conn: newSendPconn(packetConn, addr),
tracer: tracer,
logger: utils.DefaultLogger,
}
@ -548,7 +548,7 @@ var _ = Describe("Client", func() {
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*sconn).PacketConn).To(Equal(packetConn))
Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})

13
conn.go
View file

@ -12,23 +12,24 @@ import (
type connection interface {
ReadPacket() (*receivedPacket, error)
WriteTo([]byte, net.Addr) (int, error)
WritePacket(b []byte, addr net.Addr, info *packetInfo) (int, error)
LocalAddr() net.Addr
io.Closer
}
// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will read the ECN bits from the IP header.
// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets.
type ECNCapablePacketConn interface {
type OOBCapablePacketConn interface {
net.PacketConn
SyscallConn() (syscall.RawConn, error)
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
}
var _ ECNCapablePacketConn = &net.UDPConn{}
var _ OOBCapablePacketConn = &net.UDPConn{}
func wrapConn(pc net.PacketConn) (connection, error) {
c, ok := pc.(ECNCapablePacketConn)
c, ok := pc.(OOBCapablePacketConn)
if !ok {
utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.")
return &basicConn{PacketConn: pc}, nil
@ -58,3 +59,7 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) {
buffer: buffer,
}, nil
}
func (c *basicConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}

View file

@ -1,115 +0,0 @@
// +build darwin linux
package quic
import (
"errors"
"fmt"
"net"
"syscall"
"time"
"golang.org/x/sys/unix"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const ecnMask uint8 = 0x3
func inspectReadBuffer(c net.PacketConn) (int, error) {
conn, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, errors.New("doesn't have a SyscallConn")
}
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
}
var size int
var serr error
if err := rawConn.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}); err != nil {
return 0, err
}
return size, serr
}
type ecnConn struct {
ECNCapablePacketConn
oobBuffer []byte
}
var _ connection = &ecnConn{}
func newConn(c ECNCapablePacketConn) (*ecnConn, error) {
rawConn, err := c.SyscallConn()
if err != nil {
return nil, err
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN for both IP versions.
// We expect at least one of those syscalls to succeed.
var errIPv4, errIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
}); err != nil {
return nil, err
}
if err := rawConn.Control(func(fd uintptr) {
errIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
}); err != nil {
return nil, err
}
switch {
case errIPv4 == nil && errIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
case errIPv4 == nil && errIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
case errIPv4 != nil && errIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
case errIPv4 != nil && errIPv6 != nil:
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
}
return &ecnConn{
ECNCapablePacketConn: c,
oobBuffer: make([]byte, 128),
}, nil
}
func (c *ecnConn) ReadPacket() (*receivedPacket, error) {
buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)]
n, oobn, _, addr, err := c.ECNCapablePacketConn.ReadMsgUDP(buffer.Data, c.oobBuffer)
if err != nil {
return nil, err
}
ctrlMsgs, err := unix.ParseSocketControlMessage(c.oobBuffer[:oobn])
if err != nil {
return nil, err
}
var ecn protocol.ECN
for _, ctrlMsg := range ctrlMsgs {
if ctrlMsg.Header.Level == unix.IPPROTO_IP && ctrlMsg.Header.Type == msgTypeIPTOS {
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
break
}
if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 && ctrlMsg.Header.Type == unix.IPV6_TCLASS {
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
break
}
}
return &receivedPacket{
remoteAddr: addr,
rcvTime: time.Now(),
data: buffer.Data[:n],
ecn: ecn,
buffer: buffer,
}, nil
}

View file

@ -15,31 +15,31 @@ import (
. "github.com/onsi/gomega"
)
var _ = Describe("Basic Conn Test", func() {
Context("ECN conn", func() {
runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) {
addr, err := net.ResolveUDPAddr(network, address)
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP(network, addr)
Expect(err).ToNot(HaveOccurred())
ecnConn, err := newConn(udpConn)
Expect(err).ToNot(HaveOccurred())
var _ = Describe("OOB Conn Test", func() {
runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) {
addr, err := net.ResolveUDPAddr(network, address)
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP(network, addr)
Expect(err).ToNot(HaveOccurred())
ecnConn, err := newConn(udpConn)
Expect(err).ToNot(HaveOccurred())
packetChan := make(chan *receivedPacket)
go func() {
defer GinkgoRecover()
for {
p, err := ecnConn.ReadPacket()
if err != nil {
return
}
packetChan <- p
packetChan := make(chan *receivedPacket)
go func() {
defer GinkgoRecover()
for {
p, err := ecnConn.ReadPacket()
if err != nil {
return
}
}()
packetChan <- p
}
}()
return udpConn, packetChan
}
return udpConn, packetChan
}
Context("ECN conn", func() {
sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr {
conn, err := net.DialUDP(network, nil, addr)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
@ -126,4 +126,75 @@ var _ = Describe("Basic Conn Test", func() {
Expect(p.ecn).To(Equal(protocol.ECT1))
})
})
Context("Packet Info conn", func() {
sendPacket := func(network string, addr *net.UDPAddr) net.Addr {
conn, err := net.DialUDP(network, nil, addr)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
_, err = conn.Write([]byte("foobar"))
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return conn.LocalAddr()
}
It("reads packet info on IPv4", func() {
conn, packetChan := runServer("udp4", ":0")
defer conn.Close()
addr := conn.LocalAddr().(*net.UDPAddr)
ip := net.ParseIP("127.0.0.1").To4()
addr.IP = ip
sentFrom := sendPacket("udp4", addr)
var p *receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.remoteAddr).To(Equal(sentFrom))
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr.To4()).To(Equal(ip))
})
It("reads packet info on IPv6", func() {
conn, packetChan := runServer("udp6", ":0")
defer conn.Close()
addr := conn.LocalAddr().(*net.UDPAddr)
ip := net.ParseIP("::1")
addr.IP = ip
sentFrom := sendPacket("udp6", addr)
var p *receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.remoteAddr).To(Equal(sentFrom))
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr).To(Equal(ip))
})
It("reads packet info on a connection that supports both IPv4 and IPv6", func() {
conn, packetChan := runServer("udp", ":0")
defer conn.Close()
port := conn.LocalAddr().(*net.UDPAddr).Port
// IPv4
ip4 := net.ParseIP("127.0.0.1").To4()
sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port})
var p *receivedPacket
Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr.To4()).To(Equal(ip4))
// IPv6
ip6 := net.ParseIP("::1")
sendPacket("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: port})
Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr).To(Equal(ip6))
})
})
})

View file

@ -1,4 +1,4 @@
// +build !darwin,!linux,!windows
// +build !darwin,!linux,!freebsd,!windows
package quic
@ -8,6 +8,6 @@ func newConn(c net.PacketConn) (connection, error) {
return &basicConn{PacketConn: c}, nil
}
func inspectReadBuffer(net.PacketConn) (int, error) {
func inspectReadBuffer(interface{}) (int, error) {
return 0, nil
}

View file

@ -5,3 +5,13 @@ package quic
import "golang.org/x/sys/unix"
const msgTypeIPTOS = unix.IP_RECVTOS
const (
ipv4RECVPKTINFO = 0x1a
ipv6RECVPKTINFO = 0x3d
)
const (
msgTypeIPv4PKTINFO = 0x1a
msgTypeIPv6PKTINFO = 0x2e
)

17
conn_helper_freebsd.go Normal file
View file

@ -0,0 +1,17 @@
// +build freebsd
package quic
import "golang.org/x/sys/unix"
const msgTypeIPTOS = unix.IP_RECVTOS
const (
ipv4RECVPKTINFO = 0x7
ipv6RECVPKTINFO = 0x24
)
const (
msgTypeIPv4PKTINFO = 0x7
msgTypeIPv6PKTINFO = 0x2e
)

View file

@ -5,3 +5,13 @@ package quic
import "golang.org/x/sys/unix"
const msgTypeIPTOS = unix.IP_TOS
const (
ipv4RECVPKTINFO = 0x8
ipv6RECVPKTINFO = 0x31
)
const (
msgTypeIPv4PKTINFO = 0x8
msgTypeIPv6PKTINFO = 0x32
)

239
conn_oob.go Normal file
View file

@ -0,0 +1,239 @@
// +build darwin linux freebsd
package quic
import (
"encoding/binary"
"errors"
"fmt"
"net"
"runtime"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const ecnMask uint8 = 0x3
func inspectReadBuffer(c interface{}) (int, error) {
conn, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, errors.New("doesn't have a SyscallConn")
}
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
}
var size int
var serr error
if err := rawConn.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}); err != nil {
return 0, err
}
return size, serr
}
type oobConn struct {
OOBCapablePacketConn
oobBuffer []byte
}
var _ connection = &oobConn{}
func newConn(c OOBCapablePacketConn) (*oobConn, error) {
rawConn, err := c.SyscallConn()
if err != nil {
return nil, err
}
needsPacketInfo := false
if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
needsPacketInfo = true
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN and packet info for both IP versions.
// We expect at least one of those syscalls to succeed.
var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
if needsPacketInfo {
errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1)
errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1)
}
}); err != nil {
return nil, err
}
switch {
case errECNIPv4 == nil && errECNIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
case errECNIPv4 == nil && errECNIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
case errECNIPv4 != nil && errECNIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
case errECNIPv4 != nil && errECNIPv6 != nil:
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
}
if needsPacketInfo {
switch {
case errPIIPv4 == nil && errPIIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
case errPIIPv4 == nil && errPIIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
case errPIIPv4 != nil && errPIIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
case errPIIPv4 != nil && errPIIPv6 != nil:
return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
}
}
return &oobConn{
OOBCapablePacketConn: c,
oobBuffer: make([]byte, 128),
}, nil
}
func (c *oobConn) ReadPacket() (*receivedPacket, error) {
buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)]
n, oobn, _, addr, err := c.OOBCapablePacketConn.ReadMsgUDP(buffer.Data, c.oobBuffer)
if err != nil {
return nil, err
}
ctrlMsgs, err := unix.ParseSocketControlMessage(c.oobBuffer[:oobn])
if err != nil {
return nil, err
}
var ecn protocol.ECN
var destIP net.IP
var ifIndex uint32
for _, ctrlMsg := range ctrlMsgs {
if ctrlMsg.Header.Level == unix.IPPROTO_IP {
switch ctrlMsg.Header.Type {
case msgTypeIPTOS:
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
case msgTypeIPv4PKTINFO:
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination
// address */
// };
if len(ctrlMsg.Data) == 12 {
ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data)
destIP = net.IP(ctrlMsg.Data[8:12])
} else if len(ctrlMsg.Data) == 4 {
// FreeBSD
destIP = net.IP(ctrlMsg.Data)
}
}
}
if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 {
switch ctrlMsg.Header.Type {
case unix.IPV6_TCLASS:
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
case msgTypeIPv6PKTINFO:
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
if len(ctrlMsg.Data) == 20 {
destIP = net.IP(ctrlMsg.Data[:16])
ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:])
}
}
}
}
var info *packetInfo
if destIP != nil {
info = &packetInfo{
addr: destIP,
ifIndex: ifIndex,
}
}
return &receivedPacket{
remoteAddr: addr,
rcvTime: time.Now(),
data: buffer.Data[:n],
ecn: ecn,
info: info,
buffer: buffer,
}, nil
}
func (c *oobConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) {
n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, info.OOB(), addr.(*net.UDPAddr))
return n, err
}
func (info *packetInfo) OOB() []byte {
if info == nil {
return nil
}
info.once.Do(info.computeOOB)
return info.oob
}
func (info *packetInfo) computeOOB() {
if ip4 := info.addr.To4(); ip4 != nil {
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */
// };
msgLen := 12
if runtime.GOOS == "freebsd" {
msgLen = 4
}
cmsglen := cmsgLen(msgLen)
oob := make([]byte, cmsglen)
cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
cmsg.Level = syscall.IPPROTO_TCP
cmsg.Type = msgTypeIPv4PKTINFO
cmsg.SetLen(cmsglen)
off := cmsgLen(0)
if runtime.GOOS != "freebsd" {
// FreeBSD does not support in_pktinfo, just an in_addr is sent
binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
off += 4
}
copy(oob[off:], ip4)
info.oob = oob
} else if len(info.addr) == 16 {
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
const msgLen = 20
cmsglen := cmsgLen(msgLen)
oob := make([]byte, cmsglen)
cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
cmsg.Level = syscall.IPPROTO_IPV6
cmsg.Type = msgTypeIPv6PKTINFO
cmsg.SetLen(cmsglen)
off := cmsgLen(0)
off += copy(oob[off:], info.addr)
binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
info.oob = oob
}
}
func cmsgLen(datalen int) int {
return cmsgAlign(syscall.SizeofCmsghdr) + datalen
}
func cmsgAlign(salen int) int {
const sizeOfPtr = 0x8
salign := sizeOfPtr
return (salen + salign - 1) & ^(salign - 1)
}

View file

@ -469,7 +469,7 @@ func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID pro
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil {
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info); err != nil {
h.logger.Debugf("Error sending Stateless Reset: %s", err)
}
}

View file

@ -13,22 +13,56 @@ type sendConn interface {
}
type sconn struct {
net.PacketConn
connection
remoteAddr net.Addr
info *packetInfo
}
var _ sendConn = &sconn{}
func newSendConn(c net.PacketConn, remote net.Addr) sendConn {
return &sconn{PacketConn: c, remoteAddr: remote}
func newSendConn(c connection, remote net.Addr, info *packetInfo) sendConn {
return &sconn{connection: c, remoteAddr: remote, info: info}
}
func (c *sconn) Write(p []byte) error {
_, err := c.PacketConn.WriteTo(p, c.remoteAddr)
_, err := c.WritePacket(p, c.remoteAddr, c.info)
return err
}
func (c *sconn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *sconn) LocalAddr() net.Addr {
addr := c.connection.LocalAddr()
if c.info != nil {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
addrCopy := *udpAddr
addrCopy.IP = c.info.addr
addr = &addrCopy
}
}
return addr
}
type spconn struct {
net.PacketConn
remoteAddr net.Addr
}
var _ sendConn = &spconn{}
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
return &spconn{PacketConn: c, remoteAddr: remote}
}
func (c *spconn) Write(p []byte) error {
_, err := c.WriteTo(p, c.remoteAddr)
return err
}
func (c *spconn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View file

@ -17,7 +17,7 @@ var _ = Describe("Connection (for sending packets)", func() {
BeforeEach(func() {
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl)
c = newSendConn(packetConn, addr)
c = newSendPconn(packetConn, addr)
})
It("writes", func() {

View file

@ -61,7 +61,7 @@ type baseServer struct {
tlsConf *tls.Config
config *Config
conn net.PacketConn
conn connection
// If the server is started with ListenAddr, we create a packet conn.
// If it is started with Listen, we take a packet conn as a parameter.
createdPacketConn bool
@ -148,16 +148,17 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
return serv, nil
}
// Listen listens for QUIC connections on a given net.PacketConn.
// If the PacketConn satisfies the ECNCapablePacketConn interface (as a net.UDPConn does), ECN support will be enabled.
// In this case, ReadMsgUDP will be used instead of ReadFrom to read packets.
// A single net.PacketConn only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must not be nil and must contain a certificate configuration.
// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
// Furthermore, it must define an application control (using NextProtos).
// The quic.Config may be nil, in that case the default values will be used.
// Listen listens for QUIC connections on a given net.PacketConn. If the
// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets. A single net.PacketConn only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial. QUIC connection
// IDs are used for demultiplexing the different connections. The tls.Config
// must not be nil and must contain a certificate configuration. The
// tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. Furthermore,
// it must define an application control (using NextProtos). The quic.Config may
// be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config, false)
}
@ -193,8 +194,12 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
if err != nil {
return nil, err
}
c, err := wrapConn(conn)
if err != nil {
return nil, err
}
s := &baseServer{
conn: conn,
conn: c,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
@ -421,7 +426,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
return
}
if err := s.sendRetry(p.remoteAddr, hdr); err != nil {
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error sending Retry: %s", err)
}
}()
@ -432,7 +437,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
go func() {
defer p.buffer.Release()
if err := s.sendConnectionRefused(p.remoteAddr, hdr); err != nil {
if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error rejecting connection: %s", err)
}
}()
@ -456,7 +461,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID)
}
sess = s.newSession(
newSendConn(s.conn, p.remoteAddr),
newSendConn(s.conn, p.remoteAddr, p.info),
s.sessionHandler,
origDestConnID,
retrySrcConnID,
@ -514,7 +519,7 @@ func (s *baseServer) handleNewSession(sess quicSession) {
}
}
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error {
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
@ -551,7 +556,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil)
}
_, err = s.conn.WriteTo(buf.Bytes(), remoteAddr)
_, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info)
return err
}
@ -579,16 +584,16 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
if s.logger.Debug() {
s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr)
}
return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken)
return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info)
}
func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header) error {
func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error {
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused)
return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info)
}
// sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode) error {
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode, info *packetInfo) error {
packetBuffer := getPacketBuffer()
defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Data)
@ -628,7 +633,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf})
}
_, err := s.conn.WriteTo(raw, remoteAddr)
_, err := s.conn.WritePacket(raw, remoteAddr, info)
return err
}
@ -651,7 +656,7 @@ func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.H
nil,
)
}
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
}

View file

@ -59,6 +59,13 @@ type cryptoStreamHandler interface {
ConnectionState() handshake.ConnectionState
}
type packetInfo struct {
addr net.IP
ifIndex uint32
once sync.Once
oob []byte
}
type receivedPacket struct {
buffer *packetBuffer
@ -67,6 +74,8 @@ type receivedPacket struct {
data []byte
ecn protocol.ECN
info *packetInfo
}
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }
@ -78,6 +87,7 @@ func (p *receivedPacket) Clone() *receivedPacket {
data: p.data,
buffer: p.buffer,
ecn: p.ecn,
info: p.info,
}
}