diff --git a/monitor_windows.go b/monitor_windows.go index 1b9333e..cc9d8af 100644 --- a/monitor_windows.go +++ b/monitor_windows.go @@ -4,7 +4,6 @@ import ( "sync" "github.com/sagernet/sing-tun/internal/winipcfg" - E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/x/list" @@ -14,7 +13,6 @@ import ( type networkUpdateMonitor struct { routeListener *winipcfg.RouteChangeCallback interfaceListener *winipcfg.InterfaceChangeCallback - errorHandler E.Handler access sync.Mutex callbacks list.List[NetworkUpdateCallback] diff --git a/redirect.go b/redirect.go index 189f09e..0569eb3 100644 --- a/redirect.go +++ b/redirect.go @@ -5,6 +5,7 @@ import ( "github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/logger" + N "github.com/sagernet/sing/common/network" "go4.org/netipx" ) @@ -23,7 +24,7 @@ type AutoRedirect interface { type AutoRedirectOptions struct { TunOptions *Options Context context.Context - Handler Handler + Handler N.TCPConnectionHandlerEx Logger logger.Logger NetworkMonitor NetworkUpdateMonitor InterfaceFinder control.InterfaceFinder diff --git a/redirect_linux.go b/redirect_linux.go index 6c3706e..1645b85 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -13,6 +13,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" "go4.org/netipx" @@ -21,7 +22,7 @@ import ( type autoRedirect struct { tunOptions *Options ctx context.Context - handler Handler + handler N.TCPConnectionHandlerEx logger logger.Logger tableName string networkMonitor NetworkUpdateMonitor diff --git a/redirect_server.go b/redirect_server.go index 7727cba..86abfd8 100644 --- a/redirect_server.go +++ b/redirect_server.go @@ -14,20 +14,19 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) -const ProtocolRedirect = "redirect" - type redirectServer struct { ctx context.Context - handler Handler + handler N.TCPConnectionHandlerEx logger logger.Logger listenAddr netip.Addr listener *net.TCPListener inShutdown atomic.Bool } -func newRedirectServer(ctx context.Context, handler Handler, logger logger.Logger, listenAddr netip.Addr) *redirectServer { +func newRedirectServer(ctx context.Context, handler N.TCPConnectionHandlerEx, logger logger.Logger, listenAddr netip.Addr) *redirectServer { return &redirectServer{ ctx: ctx, handler: handler, @@ -59,7 +58,6 @@ func (s *redirectServer) loopIn() { conn, err := s.listener.AcceptTCP() if err != nil { var netError net.Error - //goland:noinspection GoDeprecation //nolint:staticcheck if errors.As(err, &netError) && netError.Temporary() { s.logger.Error(err) @@ -72,17 +70,14 @@ func (s *redirectServer) loopIn() { s.logger.Error("serve error: ", err) continue } - var metadata M.Metadata - metadata.Protocol = ProtocolRedirect - metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap() + source := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap() destination, err := control.GetOriginalDestination(conn) if err != nil { _ = conn.SetLinger(0) _ = conn.Close() - s.logger.Error("process connection from ", metadata.Source, ": invalid connection: ", err) + s.logger.Error("process redirect connection from ", source, ": invalid connection: ", err) continue } - metadata.Destination = M.SocksaddrFromNetIP(destination).Unwrap() - go s.handler.NewConnection(s.ctx, conn, metadata) + go s.handler.NewConnectionEx(s.ctx, conn, source, M.SocksaddrFromNetIP(destination).Unwrap(), nil) } } diff --git a/stack.go b/stack.go index 4e61f8b..63727b6 100644 --- a/stack.go +++ b/stack.go @@ -11,6 +11,8 @@ import ( "github.com/sagernet/sing/common/logger" ) +var ErrDrop = E.New("drop connections by rule") + type Stack interface { Start() error Close() error diff --git a/stack_gvisor.go b/stack_gvisor.go index fcdeeba..5044729 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -76,17 +76,16 @@ func (t *GVisor) Start() error { return err } tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { - var metadata M.Metadata - metadata.Source = M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) - metadata.Destination = M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) conn := &gLazyConn{ parentCtx: t.ctx, stack: t.stack, request: r, - localAddr: metadata.Source.TCPAddr(), - remoteAddr: metadata.Destination.TCPAddr(), + localAddr: source.TCPAddr(), + remoteAddr: destination.TCPAddr(), } - _ = t.handler.NewConnection(t.ctx, conn, metadata) + go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) if !t.endpointIndependentNat { @@ -104,14 +103,10 @@ func (t *GVisor) Start() error { return } go func() { - var metadata M.Metadata - metadata.Source = M.SocksaddrFromNet(lAddr) - metadata.Destination = M.SocksaddrFromNet(rAddr) - ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second) - hErr := t.handler.NewPacketConnection(ctx, conn, metadata) - if hErr != nil { - endpoint.Abort() - } + source := M.SocksaddrFromNet(lAddr) + destination := M.SocksaddrFromNet(rAddr) + ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second) + t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) }() }) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go index 18312a9..26f9244 100644 --- a/stack_gvisor_lazy.go +++ b/stack_gvisor_lazy.go @@ -71,8 +71,7 @@ func (c *gLazyConn) HandshakeFailure(err error) error { if c.handshakeDone { return nil } - wErr := gWriteUnreachable(c.stack, c.request.Packet(), err) - c.request.Complete(wErr == os.ErrInvalid) + c.request.Complete(gWriteUnreachable(c.stack, c.request.Packet(), err) == os.ErrInvalid) c.handshakeDone = true c.handshakeErr = err return nil @@ -196,9 +195,11 @@ func (c *gLazyConn) Upstream() any { } func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error { - if errors.Is(err, syscall.ENETUNREACH) { + if errors.Is(err, ErrDrop) { + return nil + } else if errors.Is(err, syscall.ENETUNREACH) { if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetProhibited) } else { return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) } diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index a97eff4..99662c4 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -36,15 +36,14 @@ func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, u return &UDPForwarder{ ctx: ctx, stack: stack, - udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), + udpNat: udpnat.NewEx[netip.AddrPort](udpTimeout, handler), } } func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - var upstreamMetadata M.Metadata - upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) - upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) - if upstreamMetadata.Source.IsIPv4() { + source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) + if source.IsIPv4() { f.cacheProto = header.IPv4ProtocolNumber } else { f.cacheProto = header.IPv6ProtocolNumber @@ -55,11 +54,12 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac sBuffer.Write(view.AsSlice()) }) f.cacheID = id - f.udpNat.NewPacket( + f.udpNat.NewPacketEx( f.ctx, - upstreamMetadata.Source.AddrPort(), + source.AddrPort(), sBuffer, - upstreamMetadata, + source, + destination, f.newUDPConn, ) return true diff --git a/stack_mixed.go b/stack_mixed.go index 4872c1f..8e1ab8a 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -64,14 +64,10 @@ func (m *Mixed) Start() error { return } go func() { - var metadata M.Metadata - metadata.Source = M.SocksaddrFromNet(lAddr) - metadata.Destination = M.SocksaddrFromNet(rAddr) - ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second) - hErr := m.handler.NewPacketConnection(ctx, conn, metadata) - if hErr != nil { - endpoint.Abort() - } + source := M.SocksaddrFromNet(lAddr) + destination := M.SocksaddrFromNet(rAddr) + ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(m.udpTimeout)*time.Second) + m.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) }() }) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) diff --git a/stack_system.go b/stack_system.go index 873b08a..855a4af 100644 --- a/stack_system.go +++ b/stack_system.go @@ -152,7 +152,7 @@ func (s *System) start() error { go s.acceptLoop(tcpListener) } s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) - s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) + s.udpNat = udpnat.NewEx[netip.AddrPort](s.udpTimeout, s.handler) return nil } @@ -300,16 +300,7 @@ func (s *System) acceptLoop(listener net.Listener) { } } } - go func() { - _ = s.handler.NewConnection(s.ctx, conn, M.Metadata{ - Source: M.SocksaddrFromNetIP(session.Source), - Destination: destination, - }) - if tcpConn, isTCPConn := conn.(*net.TCPConn); isTCPConn { - _ = tcpConn.SetLinger(0) - } - _ = conn.Close() - }() + go s.handler.NewConnectionEx(s.ctx, conn, M.SocksaddrFromNet(conn.RemoteAddr()), destination, nil) } } @@ -427,11 +418,7 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. if data.Len() == 0 { return nil } - metadata := M.Metadata{ - Source: M.SocksaddrFromNetIP(source), - Destination: M.SocksaddrFromNetIP(destination), - } - s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { + s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter { headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) @@ -459,11 +446,7 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip. if data.Len() == 0 { return nil } - metadata := M.Metadata{ - Source: M.SocksaddrFromNetIP(source), - Destination: M.SocksaddrFromNetIP(destination), - } - s.udpNat.NewPacket(s.ctx, source, data.ToOwned(), metadata, func(natConn N.PacketConn) N.PacketWriter { + s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter { headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) diff --git a/tun.go b/tun.go index 5c0663a..683955a 100644 --- a/tun.go +++ b/tun.go @@ -8,7 +8,6 @@ import ( "strconv" "strings" - E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" N "github.com/sagernet/sing/common/network" @@ -16,9 +15,8 @@ import ( ) type Handler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler - E.Handler + N.TCPConnectionHandlerEx + N.UDPConnectionHandlerEx } type Tun interface { diff --git a/tun_rules.go b/tun_rules.go index cf9346b..93b0430 100644 --- a/tun_rules.go +++ b/tun_rules.go @@ -1,7 +1,6 @@ package tun import ( - "context" "net/netip" "os" "runtime" @@ -20,7 +19,7 @@ const ( userEnd uint32 = 0xFFFFFFFF - 1 ) -func (o *Options) BuildAndroidRules(packageManager PackageManager, errorHandler E.Handler) { +func (o *Options) BuildAndroidRules(packageManager PackageManager) { var includeUser []uint32 if len(o.IncludeAndroidUser) > 0 { o.IncludeAndroidUser = common.Uniq(o.IncludeAndroidUser) @@ -64,7 +63,9 @@ func (o *Options) BuildAndroidRules(packageManager PackageManager, errorHandler } continue } - errorHandler.NewError(context.Background(), E.New("package to include not found: ", packageName)) + if o.Logger != nil { + o.Logger.Debug("package to include not found: ", packageName) + } } } if len(o.ExcludePackage) > 0 { @@ -81,7 +82,9 @@ func (o *Options) BuildAndroidRules(packageManager PackageManager, errorHandler } continue } - errorHandler.NewError(context.Background(), E.New("package to exclude not found: ", packageName)) + if o.Logger != nil { + o.Logger.Debug("package to exclude not found: ", packageName) + } } } }