diff --git a/closed_conn.go b/closed_conn.go index 901bb8ae..0c988b53 100644 --- a/closed_conn.go +++ b/closed_conn.go @@ -16,13 +16,13 @@ type closedLocalConn struct { perspective protocol.Perspective logger utils.Logger - sendPacket func(net.Addr, *packetInfo) + sendPacket func(net.Addr, packetInfo) } var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { +func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { return &closedLocalConn{ sendPacket: sendPacket, perspective: pers, diff --git a/closed_conn_test.go b/closed_conn_test.go index 21fddab4..ac9da31e 100644 --- a/closed_conn_test.go +++ b/closed_conn_test.go @@ -21,7 +21,7 @@ var _ = Describe("Closed local connection", func() { It("repeats the packet containing the CONNECTION_CLOSE frame", func() { written := make(chan net.Addr, 1) conn := newClosedLocalConn( - func(addr net.Addr, _ *packetInfo) { written <- addr }, + func(addr net.Addr, _ packetInfo) { written <- addr }, protocol.PerspectiveClient, utils.DefaultLogger, ) diff --git a/connection.go b/connection.go index e081e45f..c4668503 100644 --- a/connection.go +++ b/connection.go @@ -70,7 +70,7 @@ type receivedPacket struct { ecn protocol.ECN - info *packetInfo + info packetInfo // only valid if the contained IP address is valid } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } diff --git a/packet_handler_map.go b/packet_handler_map.go index 47f0dcc2..2a16773c 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -39,7 +39,7 @@ type rawConn interface { type closePacket struct { payload []byte addr net.Addr - info *packetInfo + info packetInfo } type unknownPacketHandler interface { @@ -177,7 +177,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p var handler packetHandler if connClosePacket != nil { handler = newClosedLocalConn( - func(addr net.Addr, info *packetInfo) { + func(addr net.Addr, info packetInfo) { h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) }, pers, diff --git a/send_conn.go b/send_conn.go index f2bc0bd8..4e7007fa 100644 --- a/send_conn.go +++ b/send_conn.go @@ -21,13 +21,25 @@ type sconn struct { rawConn remoteAddr net.Addr - info *packetInfo + info packetInfo oob []byte } var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { +func newSendConn(c rawConn, remote net.Addr) *sconn { + sc := &sconn{ + rawConn: c, + remoteAddr: remote, + } + if c.capabilities().GSO { + // add 32 bytes, so we can add the UDP_SEGMENT msg + sc.oob = make([]byte, 0, 32) + } + return sc +} + +func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn { oob := info.OOB() if c.capabilities().GSO { // add 32 bytes, so we can add the UDP_SEGMENT msg @@ -57,7 +69,7 @@ func (c *sconn) RemoteAddr() net.Addr { func (c *sconn) LocalAddr() net.Addr { addr := c.rawConn.LocalAddr() - if c.info != nil { + if c.info.addr.IsValid() { if udpAddr, ok := addr.(*net.UDPAddr); ok { addrCopy := *udpAddr addrCopy.IP = c.info.addr.AsSlice() diff --git a/send_conn_test.go b/send_conn_test.go index 0b5cc621..56fe9236 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -19,7 +19,7 @@ var _ = Describe("Connection (for sending packets)", func() { packetConn = NewMockPacketConn(mockCtrl) rawConn, err := wrapConn(packetConn) Expect(err).ToNot(HaveOccurred()) - c = newSendConn(rawConn, addr, nil) + c = newSendConnWithPacketInfo(rawConn, addr, packetInfo{}) }) It("writes", func() { diff --git a/server.go b/server.go index 8337a978..0f8219e3 100644 --- a/server.go +++ b/server.go @@ -632,7 +632,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) } conn = s.newConn( - newSendConn(s.conn, p.remoteAddr, p.info), + newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info), s.connHandler, origDestConnID, retrySrcConnID, @@ -706,7 +706,7 @@ func (s *baseServer) handleNewConn(conn quicConn) { } } -func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { +func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) @@ -795,13 +795,13 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { } } -func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { +func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error { sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) } // sendError sends the error as a response to the packet received with header hdr -func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error { b := getPacketBuffer() defer b.Release() diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 5528ef81..fc28f94b 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -173,8 +173,12 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { c.readPos++ data := msg.OOB[:msg.NN] - var ecn protocol.ECN - var info *packetInfo + p := receivedPacket{ + remoteAddr: msg.Addr, + rcvTime: time.Now(), + data: msg.Buffers[0][:msg.N], + buffer: buffer, + } for len(data) > 0 { hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) if err != nil { @@ -183,7 +187,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: - ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ECN(body[0] & ecnMask) case msgTypeIPv4PKTINFO: // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ @@ -191,22 +195,21 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { // struct in_addr ipi_addr; /* Header Destination // address */ // }; - info = &packetInfo{} var ip [4]byte if len(body) == 12 { copy(ip[:], body[8:12]) - info.ifIndex = binary.LittleEndian.Uint32(body) + p.info.ifIndex = binary.LittleEndian.Uint32(body) } else if len(body) == 4 { // FreeBSD copy(ip[:], body) } - info.addr = netip.AddrFrom4(ip) + p.info.addr = netip.AddrFrom4(ip) } } if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: - ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ECN(body[0] & ecnMask) case msgTypeIPv6PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ @@ -215,23 +218,14 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if len(body) == 20 { var ip [16]byte copy(ip[:], body[:16]) - info = &packetInfo{ - addr: netip.AddrFrom16(ip), - ifIndex: binary.LittleEndian.Uint32(body[16:]), - } + p.info.addr = netip.AddrFrom16(ip) + p.info.ifIndex = binary.LittleEndian.Uint32(body[16:]) } } } data = remainder } - return receivedPacket{ - remoteAddr: msg.Addr, - rcvTime: time.Now(), - data: msg.Buffers[0][:msg.N], - ecn: ecn, - info: info, - buffer: buffer, - }, nil + return p, nil } // WriteTo (re)implements the net.PacketConn method. diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index f6f864f7..30b333b9 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -154,7 +154,7 @@ var _ = Describe("OOB Conn Test", func() { Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.remoteAddr).To(Equal(sentFrom)) - Expect(p.info).To(Not(BeNil())) + Expect(p.info.addr.IsValid()).To(BeTrue()) Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip)) }) diff --git a/transport.go b/transport.go index cc3d294f..242eedc1 100644 --- a/transport.go +++ b/transport.go @@ -155,7 +155,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) + return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) } // DialEarly dials a new connection, attempting to use 0-RTT if possible. @@ -171,7 +171,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C if t.isSingleUse { onClose = func() { t.Close() } } - return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) + return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) } func (t *Transport) init(isServer bool) error {