From b6d323004eddb292e885027b70ba8882133b1a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 22 Sep 2023 11:28:50 +0800 Subject: [PATCH] Remove use of Write Unreachable as SendRejectionError panics when passing invalid packet --- stack_gvisor.go | 3 +-- stack_gvisor_udp.go | 50 +++------------------------------------------ stack_mixed.go | 3 +-- 3 files changed, 5 insertions(+), 51 deletions(-) diff --git a/stack_gvisor.go b/stack_gvisor.go index 8e219e7..51c9179 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -6,7 +6,6 @@ import ( "context" "net/netip" "time" - "unsafe" "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" @@ -129,7 +128,7 @@ func (t *GVisor) Start() error { endpoint.Abort() return } - gConn := &gUDPConn{UDPConn: udpConn, stack: ipStack, packet: (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()} + gConn := &gUDPConn{UDPConn: udpConn} go func() { var metadata M.Metadata metadata.Source = M.SocksaddrFromNet(lAddr) diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index d29fa46..ce4648b 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -30,9 +30,8 @@ type UDPForwarder struct { udpNat *udpnat.Service[netip.AddrPort] // cache - cacheProto tcpip.NetworkProtocolNumber - cacheID stack.TransportEndpointID - cachePacket stack.PacketBufferPtr + cacheProto tcpip.NetworkProtocolNumber + cacheID stack.TransportEndpointID } func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { @@ -58,7 +57,6 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.Pack sBuffer.Write(view.AsSlice()) }) f.cacheID = id - f.cachePacket = pkt f.udpNat.NewPacket( f.ctx, upstreamMetadata.Source.AddrPort(), @@ -75,7 +73,6 @@ func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { source: f.cacheID.RemoteAddress, sourcePort: f.cacheID.RemotePort, sourceNetwork: f.cacheProto, - packet: f.cachePacket.IncRef(), } } @@ -88,17 +85,6 @@ type UDPBackWriter struct { packet stack.PacketBufferPtr } -func (w *UDPBackWriter) Close() error { - w.access.Lock() - defer w.access.Unlock() - if w.packet == nil { - return os.ErrClosed - } - w.packet.DecRef() - w.packet = nil - return nil -} - func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { if !destination.IsIP() { return E.Cause(os.ErrInvalid, "invalid destination") @@ -163,16 +149,6 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock return nil } -func (w *UDPBackWriter) HandshakeFailure(err error) error { - if w.packet == nil { - return os.ErrClosed - } - err = gWriteUnreachable(w.stack, w.packet, err) - w.packet.DecRef() - w.packet = nil - return err -} - type gRequest struct { stack *stack.Stack id stack.TransportEndpointID @@ -181,9 +157,6 @@ type gRequest struct { type gUDPConn struct { *gonet.UDPConn - access sync.Mutex - stack *stack.Stack - packet stack.PacketBufferPtr } func (c *gUDPConn) Read(b []byte) (n int, err error) { @@ -205,27 +178,10 @@ func (c *gUDPConn) Write(b []byte) (n int, err error) { } func (c *gUDPConn) Close() error { - c.access.Lock() - defer c.access.Unlock() - if c.packet == nil { - return os.ErrClosed - } - c.packet.DecRef() - c.packet = nil return c.UDPConn.Close() } -func (c *gUDPConn) HandshakeFailure(err error) error { - if c.packet == nil { - return os.ErrClosed - } - err = gWriteUnreachable(c.stack, c.packet, err) - c.packet.DecRef() - c.packet = nil - return err -} - -func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) error { +func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) (retErr error) { if errors.Is(err, syscall.ENETUNREACH) { if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) diff --git a/stack_mixed.go b/stack_mixed.go index f38c632..6d186e7 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -4,7 +4,6 @@ package tun import ( "time" - "unsafe" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" @@ -68,7 +67,7 @@ func (m *Mixed) Start() error { endpoint.Abort() return } - gConn := &gUDPConn{UDPConn: udpConn, stack: ipStack, packet: (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()} + gConn := &gUDPConn{UDPConn: udpConn} go func() { var metadata M.Metadata metadata.Source = M.SocksaddrFromNet(lAddr)