Merge pull request #3881 from quic-go/packet-info-allocs

embed packetInfo in receivedPacket struct, use netip.Addr
This commit is contained in:
Marten Seemann 2023-06-03 10:54:52 +03:00 committed by GitHub
commit f777c452ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 70 additions and 61 deletions

View file

@ -16,13 +16,13 @@ type closedLocalConn struct {
perspective protocol.Perspective perspective protocol.Perspective
logger utils.Logger logger utils.Logger
sendPacket func(net.Addr, *packetInfo) sendPacket func(net.Addr, packetInfo)
} }
var _ packetHandler = &closedLocalConn{} var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it. // newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{ return &closedLocalConn{
sendPacket: sendPacket, sendPacket: sendPacket,
perspective: pers, perspective: pers,

View file

@ -21,7 +21,7 @@ var _ = Describe("Closed local connection", func() {
It("repeats the packet containing the CONNECTION_CLOSE frame", func() { It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
written := make(chan net.Addr, 1) written := make(chan net.Addr, 1)
conn := newClosedLocalConn( conn := newClosedLocalConn(
func(addr net.Addr, _ *packetInfo) { written <- addr }, func(addr net.Addr, _ packetInfo) { written <- addr },
protocol.PerspectiveClient, protocol.PerspectiveClient,
utils.DefaultLogger, utils.DefaultLogger,
) )

View file

@ -70,7 +70,7 @@ type receivedPacket struct {
ecn protocol.ECN ecn protocol.ECN
info *packetInfo info packetInfo // only valid if the contained IP address is valid
} }
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }

View file

@ -39,7 +39,7 @@ type rawConn interface {
type closePacket struct { type closePacket struct {
payload []byte payload []byte
addr net.Addr addr net.Addr
info *packetInfo info packetInfo
} }
type unknownPacketHandler interface { type unknownPacketHandler interface {
@ -177,7 +177,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
var handler packetHandler var handler packetHandler
if connClosePacket != nil { if connClosePacket != nil {
handler = newClosedLocalConn( handler = newClosedLocalConn(
func(addr net.Addr, info *packetInfo) { func(addr net.Addr, info packetInfo) {
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
}, },
pers, pers,

View file

@ -21,13 +21,25 @@ type sconn struct {
rawConn rawConn
remoteAddr net.Addr remoteAddr net.Addr
info *packetInfo info packetInfo
oob []byte oob []byte
} }
var _ sendConn = &sconn{} var _ sendConn = &sconn{}
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { func newSendConn(c rawConn, remote net.Addr) *sconn {
sc := &sconn{
rawConn: c,
remoteAddr: remote,
}
if c.capabilities().GSO {
// add 32 bytes, so we can add the UDP_SEGMENT msg
sc.oob = make([]byte, 0, 32)
}
return sc
}
func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn {
oob := info.OOB() oob := info.OOB()
if c.capabilities().GSO { if c.capabilities().GSO {
// add 32 bytes, so we can add the UDP_SEGMENT msg // add 32 bytes, so we can add the UDP_SEGMENT msg
@ -57,10 +69,10 @@ func (c *sconn) RemoteAddr() net.Addr {
func (c *sconn) LocalAddr() net.Addr { func (c *sconn) LocalAddr() net.Addr {
addr := c.rawConn.LocalAddr() addr := c.rawConn.LocalAddr()
if c.info != nil { if c.info.addr.IsValid() {
if udpAddr, ok := addr.(*net.UDPAddr); ok { if udpAddr, ok := addr.(*net.UDPAddr); ok {
addrCopy := *udpAddr addrCopy := *udpAddr
addrCopy.IP = c.info.addr addrCopy.IP = c.info.addr.AsSlice()
addr = &addrCopy addr = &addrCopy
} }
} }

View file

@ -19,7 +19,7 @@ var _ = Describe("Connection (for sending packets)", func() {
packetConn = NewMockPacketConn(mockCtrl) packetConn = NewMockPacketConn(mockCtrl)
rawConn, err := wrapConn(packetConn) rawConn, err := wrapConn(packetConn)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
c = newSendConn(rawConn, addr, nil) c = newSendConnWithPacketInfo(rawConn, addr, packetInfo{})
}) })
It("writes", func() { It("writes", func() {

View file

@ -632,7 +632,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
} }
conn = s.newConn( conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info), newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info),
s.connHandler, s.connHandler,
origDestConnID, origDestConnID,
retrySrcConnID, retrySrcConnID,
@ -706,7 +706,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
} }
} }
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error {
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection. // If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
@ -795,13 +795,13 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
} }
} }
func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error {
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info)
} }
// sendError sends the error as a response to the packet received with header hdr // sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error {
b := getPacketBuffer() b := getPacketBuffer()
defer b.Release() defer b.Release()

View file

@ -2,7 +2,10 @@
package quic package quic
import "net" import (
"net"
"net/netip"
)
func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) {
return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil
@ -12,7 +15,7 @@ func inspectReadBuffer(any) (int, error) { return 0, nil }
func inspectWriteBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil }
type packetInfo struct { type packetInfo struct {
addr net.IP addr netip.Addr
} }
func (i *packetInfo) OOB() []byte { return nil } func (i *packetInfo) OOB() []byte { return nil }

View file

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@ -172,9 +173,12 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
c.readPos++ c.readPos++
data := msg.OOB[:msg.NN] data := msg.OOB[:msg.NN]
var ecn protocol.ECN p := receivedPacket{
var destIP net.IP remoteAddr: msg.Addr,
var ifIndex uint32 rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
buffer: buffer,
}
for len(data) > 0 { for len(data) > 0 {
hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
if err != nil { if err != nil {
@ -183,7 +187,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IP { if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type { switch hdr.Type {
case msgTypeIPTOS: case msgTypeIPTOS:
ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv4PKTINFO: case msgTypeIPv4PKTINFO:
// struct in_pktinfo { // struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */ // unsigned int ipi_ifindex; /* Interface index */
@ -191,51 +195,37 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
// struct in_addr ipi_addr; /* Header Destination // struct in_addr ipi_addr; /* Header Destination
// address */ // address */
// }; // };
ip := make([]byte, 4) var ip [4]byte
if len(body) == 12 { if len(body) == 12 {
ifIndex = binary.LittleEndian.Uint32(body) copy(ip[:], body[8:12])
copy(ip, body[8:12]) p.info.ifIndex = binary.LittleEndian.Uint32(body)
} else if len(body) == 4 { } else if len(body) == 4 {
// FreeBSD // FreeBSD
copy(ip, body) copy(ip[:], body)
} }
destIP = net.IP(ip) p.info.addr = netip.AddrFrom4(ip)
} }
} }
if hdr.Level == unix.IPPROTO_IPV6 { if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type { switch hdr.Type {
case unix.IPV6_TCLASS: case unix.IPV6_TCLASS:
ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv6PKTINFO: case msgTypeIPv6PKTINFO:
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
if len(body) == 20 { if len(body) == 20 {
ip := make([]byte, 16) var ip [16]byte
copy(ip, body[:16]) copy(ip[:], body[:16])
destIP = net.IP(ip) p.info.addr = netip.AddrFrom16(ip)
ifIndex = binary.LittleEndian.Uint32(body[16:]) p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
} }
} }
} }
data = remainder data = remainder
} }
var info *packetInfo return p, nil
if destIP != nil {
info = &packetInfo{
addr: destIP,
ifIndex: ifIndex,
}
}
return receivedPacket{
remoteAddr: msg.Addr,
rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
ecn: ecn,
info: info,
buffer: buffer,
}, nil
} }
// WriteTo (re)implements the net.PacketConn method. // WriteTo (re)implements the net.PacketConn method.
@ -265,7 +255,7 @@ func (c *oobConn) capabilities() connCapabilities {
} }
type packetInfo struct { type packetInfo struct {
addr net.IP addr netip.Addr
ifIndex uint32 ifIndex uint32
} }
@ -273,24 +263,26 @@ func (info *packetInfo) OOB() []byte {
if info == nil { if info == nil {
return nil return nil
} }
if ip4 := info.addr.To4(); ip4 != nil { if info.addr.Is4() {
ip := info.addr.As4()
// struct in_pktinfo { // struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */ // unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */ // struct in_addr ipi_addr; /* Header Destination address */
// }; // };
cm := ipv4.ControlMessage{ cm := ipv4.ControlMessage{
Src: ip4, Src: ip[:],
IfIndex: int(info.ifIndex), IfIndex: int(info.ifIndex),
} }
return cm.Marshal() return cm.Marshal()
} else if len(info.addr) == 16 { } else if info.addr.Is6() {
ip := info.addr.As16()
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
cm := ipv6.ControlMessage{ cm := ipv6.ControlMessage{
Src: info.addr, Src: ip[:],
IfIndex: int(info.ifIndex), IfIndex: int(info.ifIndex),
} }
return cm.Marshal() return cm.Marshal()

View file

@ -154,8 +154,8 @@ var _ = Describe("OOB Conn Test", func() {
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.remoteAddr).To(Equal(sentFrom)) Expect(p.remoteAddr).To(Equal(sentFrom))
Expect(p.info).To(Not(BeNil())) Expect(p.info.addr.IsValid()).To(BeTrue())
Expect(p.info.addr.To4()).To(Equal(ip)) Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip))
}) })
It("reads packet info on IPv6", func() { It("reads packet info on IPv6", func() {
@ -173,7 +173,7 @@ var _ = Describe("OOB Conn Test", func() {
Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.data).To(Equal([]byte("foobar")))
Expect(p.remoteAddr).To(Equal(sentFrom)) Expect(p.remoteAddr).To(Equal(sentFrom))
Expect(p.info).To(Not(BeNil())) Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr).To(Equal(ip)) Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip))
}) })
It("reads packet info on a connection that supports both IPv4 and IPv6", func() { It("reads packet info on a connection that supports both IPv4 and IPv6", func() {
@ -182,14 +182,16 @@ var _ = Describe("OOB Conn Test", func() {
port := conn.LocalAddr().(*net.UDPAddr).Port port := conn.LocalAddr().(*net.UDPAddr).Port
// IPv4 // IPv4
ip4 := net.ParseIP("127.0.0.1").To4() ip4 := net.ParseIP("127.0.0.1")
sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port}) sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port})
var p receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
Expect(p.info).To(Not(BeNil())) Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr.To4()).To(Equal(ip4)) Expect(p.info.addr.Is4In6() || p.info.addr.Is4()).To(BeTrue())
ip := p.info.addr.As4()
Expect(net.IP(ip[:])).To(Equal(ip4.To4()))
// IPv6 // IPv6
ip6 := net.ParseIP("::1") ip6 := net.ParseIP("::1")
@ -198,7 +200,7 @@ var _ = Describe("OOB Conn Test", func() {
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse()) Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
Expect(p.info).To(Not(BeNil())) Expect(p.info).To(Not(BeNil()))
Expect(p.info.addr).To(Equal(ip6)) Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip6))
}) })
}) })

View file

@ -3,7 +3,7 @@
package quic package quic
import ( import (
"net" "net/netip"
"syscall" "syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -36,7 +36,7 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) {
} }
type packetInfo struct { type packetInfo struct {
addr net.IP addr netip.Addr
} }
func (i *packetInfo) OOB() []byte { return nil } func (i *packetInfo) OOB() []byte { return nil }

View file

@ -155,7 +155,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
if t.isSingleUse { if t.isSingleUse {
onClose = func() { t.Close() } onClose = func() { t.Close() }
} }
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
} }
// DialEarly dials a new connection, attempting to use 0-RTT if possible. // DialEarly dials a new connection, attempting to use 0-RTT if possible.
@ -171,7 +171,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
if t.isSingleUse { if t.isSingleUse {
onClose = func() { t.Close() } onClose = func() { t.Close() }
} }
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
} }
func (t *Transport) init(isServer bool) error { func (t *Transport) init(isServer bool) error {