diff --git a/adapter/inbound.go b/adapter/inbound.go index ab5fb771..1b4affbb 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -2,11 +2,13 @@ package adapter import ( "context" + "net" "net/netip" "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-dns" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type Inbound interface { @@ -15,6 +17,13 @@ type Inbound interface { Tag() string } +type InjectableInbound interface { + Inbound + Network() []string + NewConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error + NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error +} + type InboundContext struct { Inbound string InboundType string @@ -29,6 +38,8 @@ type InboundContext struct { // cache + InboundDetour string + LastInbound string OriginDestination M.Socksaddr DomainStrategy dns.DomainStrategy SniffEnabled bool diff --git a/box.go b/box.go index d8b711de..53607ede 100644 --- a/box.go +++ b/box.go @@ -138,7 +138,7 @@ func New(ctx context.Context, options option.Options) (*Box, error) { } outbounds = append(outbounds, out) } - err = router.Initialize(outbounds, func() adapter.Outbound { + err = router.Initialize(inbounds, outbounds, func() adapter.Outbound { out, oErr := outbound.New(ctx, router, logFactory.NewLogger("outbound/direct"), option.Outbound{Type: "direct", Tag: "default"}) common.Must(oErr) outbounds = append(outbounds, out) diff --git a/inbound/default.go b/inbound/default.go index 87477a4e..76e695cf 100644 --- a/inbound/default.go +++ b/inbound/default.go @@ -3,25 +3,18 @@ package inbound import ( "context" "net" - "net/netip" - "os" "sync" - "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/proxyproto" "github.com/sagernet/sing-box/common/settings" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - - "github.com/database64128/tfo-go" ) var _ adapter.Inbound = (*myInboundAdapter)(nil) @@ -62,6 +55,10 @@ func (a *myInboundAdapter) Tag() string { return a.tag } +func (a *myInboundAdapter) Network() []string { + return a.network +} + func (a *myInboundAdapter) Start() error { var err error if common.Contains(a.network, N.NetworkTCP) { @@ -102,38 +99,6 @@ func (a *myInboundAdapter) Start() error { return nil } -func (a *myInboundAdapter) ListenTCP() (net.Listener, error) { - var err error - bindAddr := M.SocksaddrFrom(netip.Addr(a.listenOptions.Listen), a.listenOptions.ListenPort) - var tcpListener net.Listener - if !a.listenOptions.TCPFastOpen { - tcpListener, err = net.ListenTCP(M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.TCPAddr()) - } else { - tcpListener, err = tfo.ListenTCP(M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.TCPAddr()) - } - if err == nil { - a.logger.Info("tcp server started at ", tcpListener.Addr()) - } - if a.listenOptions.ProxyProtocol { - a.logger.Debug("proxy protocol enabled") - tcpListener = &proxyproto.Listener{Listener: tcpListener} - } - a.tcpListener = tcpListener - return tcpListener, err -} - -func (a *myInboundAdapter) ListenUDP() (net.PacketConn, error) { - bindAddr := M.SocksaddrFrom(netip.Addr(a.listenOptions.Listen), a.listenOptions.ListenPort) - udpConn, err := net.ListenUDP(M.NetworkFromNetAddr(N.NetworkUDP, bindAddr.Addr), bindAddr.UDPAddr()) - if err != nil { - return nil, err - } - a.udpConn = udpConn - a.udpAddr = bindAddr - a.logger.Info("udp server started at ", udpConn.LocalAddr()) - return udpConn, err -} - func (a *myInboundAdapter) Close() error { var err error if a.clearSystemProxy != nil { @@ -170,20 +135,10 @@ func (a *myInboundAdapter) newPacketConnection(ctx context.Context, conn N.Packe return a.router.RoutePacketConnection(ctx, conn, metadata) } -func (a *myInboundAdapter) loopTCPIn() { - tcpListener := a.tcpListener - for { - conn, err := tcpListener.Accept() - if err != nil { - return - } - go a.injectTCP(conn, adapter.InboundContext{}) - } -} - func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.InboundContext) adapter.InboundContext { metadata.Inbound = a.tag metadata.InboundType = a.protocol + metadata.InboundDetour = a.listenOptions.Detour metadata.SniffEnabled = a.listenOptions.SniffEnabled metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) @@ -199,166 +154,6 @@ func (a *myInboundAdapter) createMetadata(conn net.Conn, metadata adapter.Inboun return metadata } -func (a *myInboundAdapter) injectTCP(conn net.Conn, metadata adapter.InboundContext) { - ctx := log.ContextWithNewID(a.ctx) - metadata = a.createMetadata(conn, metadata) - a.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) - hErr := a.connHandler.NewConnection(ctx, conn, metadata) - if hErr != nil { - conn.Close() - a.NewError(ctx, E.Cause(hErr, "process connection from ", metadata.Source)) - } -} - -func (a *myInboundAdapter) routeTCP(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) { - a.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) - hErr := a.newConnection(ctx, conn, metadata) - if hErr != nil { - conn.Close() - a.NewError(ctx, E.Cause(hErr, "process connection from ", metadata.Source)) - } -} - -func (a *myInboundAdapter) loopUDPIn() { - defer close(a.packetOutboundClosed) - _buffer := buf.StackNewPacket() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - packetService := (*myInboundPacketAdapter)(a) - for { - buffer.Reset() - n, addr, err := a.udpConn.ReadFromUDPAddrPort(buffer.FreeBytes()) - if err != nil { - return - } - buffer.Truncate(n) - var metadata adapter.InboundContext - metadata.Inbound = a.tag - metadata.InboundType = a.protocol - metadata.SniffEnabled = a.listenOptions.SniffEnabled - metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination - metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) - metadata.Source = M.SocksaddrFromNetIP(addr) - metadata.OriginDestination = a.udpAddr - err = a.packetHandler.NewPacket(a.ctx, packetService, buffer, metadata) - if err != nil { - a.newError(E.Cause(err, "process packet from ", metadata.Source)) - } - } -} - -func (a *myInboundAdapter) loopUDPOOBIn() { - defer close(a.packetOutboundClosed) - _buffer := buf.StackNewPacket() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - packetService := (*myInboundPacketAdapter)(a) - oob := make([]byte, 1024) - for { - buffer.Reset() - n, oobN, _, addr, err := a.udpConn.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob) - if err != nil { - return - } - buffer.Truncate(n) - var metadata adapter.InboundContext - metadata.Inbound = a.tag - metadata.InboundType = a.protocol - metadata.SniffEnabled = a.listenOptions.SniffEnabled - metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination - metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) - metadata.Source = M.SocksaddrFromNetIP(addr) - metadata.OriginDestination = a.udpAddr - err = a.oobPacketHandler.NewPacket(a.ctx, packetService, buffer, oob[:oobN], metadata) - if err != nil { - a.newError(E.Cause(err, "process packet from ", metadata.Source)) - } - } -} - -func (a *myInboundAdapter) loopUDPInThreadSafe() { - defer close(a.packetOutboundClosed) - packetService := (*myInboundPacketAdapter)(a) - for { - buffer := buf.NewPacket() - n, addr, err := a.udpConn.ReadFromUDPAddrPort(buffer.FreeBytes()) - if err != nil { - buffer.Release() - return - } - buffer.Truncate(n) - var metadata adapter.InboundContext - metadata.Inbound = a.tag - metadata.InboundType = a.protocol - metadata.SniffEnabled = a.listenOptions.SniffEnabled - metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination - metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) - metadata.Source = M.SocksaddrFromNetIP(addr) - metadata.OriginDestination = a.udpAddr - err = a.packetHandler.NewPacket(a.ctx, packetService, buffer, metadata) - if err != nil { - buffer.Release() - a.newError(E.Cause(err, "process packet from ", metadata.Source)) - } - } -} - -func (a *myInboundAdapter) loopUDPOOBInThreadSafe() { - defer close(a.packetOutboundClosed) - packetService := (*myInboundPacketAdapter)(a) - oob := make([]byte, 1024) - for { - buffer := buf.NewPacket() - n, oobN, _, addr, err := a.udpConn.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob) - if err != nil { - buffer.Release() - return - } - buffer.Truncate(n) - var metadata adapter.InboundContext - metadata.Inbound = a.tag - metadata.InboundType = a.protocol - metadata.SniffEnabled = a.listenOptions.SniffEnabled - metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination - metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) - metadata.Source = M.SocksaddrFromNetIP(addr) - metadata.OriginDestination = a.udpAddr - err = a.oobPacketHandler.NewPacket(a.ctx, packetService, buffer, oob[:oobN], metadata) - if err != nil { - buffer.Release() - a.newError(E.Cause(err, "process packet from ", metadata.Source)) - } - } -} - -func (a *myInboundAdapter) loopUDPOut() { - for { - select { - case packet := <-a.packetOutbound: - err := a.writePacket(packet.buffer, packet.destination) - if err != nil && !E.IsClosed(err) { - a.newError(E.New("write back udp: ", err)) - } - continue - case <-a.packetOutboundClosed: - } - for { - select { - case packet := <-a.packetOutbound: - packet.buffer.Release() - default: - return - } - } - } -} - func (a *myInboundAdapter) newError(err error) { a.logger.Error(err) } @@ -375,72 +170,3 @@ func NewError(logger log.ContextLogger, ctx context.Context, err error) { } logger.ErrorContext(ctx, err) } - -func (a *myInboundAdapter) writePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - if destination.IsFqdn() { - udpAddr, err := net.ResolveUDPAddr(N.NetworkUDP, destination.String()) - if err != nil { - return err - } - return common.Error(a.udpConn.WriteTo(buffer.Bytes(), udpAddr)) - } - return common.Error(a.udpConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort())) -} - -type myInboundPacketAdapter myInboundAdapter - -func (s *myInboundPacketAdapter) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { - n, addr, err := s.udpConn.ReadFromUDPAddrPort(buffer.FreeBytes()) - if err != nil { - return M.Socksaddr{}, err - } - buffer.Truncate(n) - return M.SocksaddrFromNetIP(addr), nil -} - -func (s *myInboundPacketAdapter) WriteIsThreadUnsafe() { -} - -type myInboundPacket struct { - buffer *buf.Buffer - destination M.Socksaddr -} - -func (s *myInboundPacketAdapter) Upstream() any { - return s.udpConn -} - -func (s *myInboundPacketAdapter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - s.packetAccess.RLock() - defer s.packetAccess.RUnlock() - - select { - case <-s.packetOutboundClosed: - return os.ErrClosed - default: - } - - s.packetOutbound <- &myInboundPacket{buffer, destination} - return nil -} - -func (s *myInboundPacketAdapter) Close() error { - return s.udpConn.Close() -} - -func (s *myInboundPacketAdapter) LocalAddr() net.Addr { - return s.udpConn.LocalAddr() -} - -func (s *myInboundPacketAdapter) SetDeadline(t time.Time) error { - return s.udpConn.SetDeadline(t) -} - -func (s *myInboundPacketAdapter) SetReadDeadline(t time.Time) error { - return s.udpConn.SetReadDeadline(t) -} - -func (s *myInboundPacketAdapter) SetWriteDeadline(t time.Time) error { - return s.udpConn.SetWriteDeadline(t) -} diff --git a/inbound/default_tcp.go b/inbound/default_tcp.go new file mode 100644 index 00000000..0d32a1c0 --- /dev/null +++ b/inbound/default_tcp.go @@ -0,0 +1,67 @@ +package inbound + +import ( + "context" + "net" + "net/netip" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/proxyproto" + "github.com/sagernet/sing-box/log" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/database64128/tfo-go" +) + +func (a *myInboundAdapter) ListenTCP() (net.Listener, error) { + var err error + bindAddr := M.SocksaddrFrom(netip.Addr(a.listenOptions.Listen), a.listenOptions.ListenPort) + var tcpListener net.Listener + if !a.listenOptions.TCPFastOpen { + tcpListener, err = net.ListenTCP(M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.TCPAddr()) + } else { + tcpListener, err = tfo.ListenTCP(M.NetworkFromNetAddr(N.NetworkTCP, bindAddr.Addr), bindAddr.TCPAddr()) + } + if err == nil { + a.logger.Info("tcp server started at ", tcpListener.Addr()) + } + if a.listenOptions.ProxyProtocol { + a.logger.Debug("proxy protocol enabled") + tcpListener = &proxyproto.Listener{Listener: tcpListener} + } + a.tcpListener = tcpListener + return tcpListener, err +} + +func (a *myInboundAdapter) loopTCPIn() { + tcpListener := a.tcpListener + for { + conn, err := tcpListener.Accept() + if err != nil { + return + } + go a.injectTCP(conn, adapter.InboundContext{}) + } +} + +func (a *myInboundAdapter) injectTCP(conn net.Conn, metadata adapter.InboundContext) { + ctx := log.ContextWithNewID(a.ctx) + metadata = a.createMetadata(conn, metadata) + a.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) + hErr := a.connHandler.NewConnection(ctx, conn, metadata) + if hErr != nil { + conn.Close() + a.NewError(ctx, E.Cause(hErr, "process connection from ", metadata.Source)) + } +} + +func (a *myInboundAdapter) routeTCP(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) { + a.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) + hErr := a.newConnection(ctx, conn, metadata) + if hErr != nil { + conn.Close() + a.NewError(ctx, E.Cause(hErr, "process connection from ", metadata.Source)) + } +} diff --git a/inbound/default_udp.go b/inbound/default_udp.go new file mode 100644 index 00000000..35850333 --- /dev/null +++ b/inbound/default_udp.go @@ -0,0 +1,237 @@ +package inbound + +import ( + "net" + "net/netip" + "os" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-dns" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (a *myInboundAdapter) ListenUDP() (net.PacketConn, error) { + bindAddr := M.SocksaddrFrom(netip.Addr(a.listenOptions.Listen), a.listenOptions.ListenPort) + udpConn, err := net.ListenUDP(M.NetworkFromNetAddr(N.NetworkUDP, bindAddr.Addr), bindAddr.UDPAddr()) + if err != nil { + return nil, err + } + a.udpConn = udpConn + a.udpAddr = bindAddr + a.logger.Info("udp server started at ", udpConn.LocalAddr()) + return udpConn, err +} + +func (a *myInboundAdapter) loopUDPIn() { + defer close(a.packetOutboundClosed) + _buffer := buf.StackNewPacket() + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + buffer.IncRef() + defer buffer.DecRef() + packetService := (*myInboundPacketAdapter)(a) + for { + buffer.Reset() + n, addr, err := a.udpConn.ReadFromUDPAddrPort(buffer.FreeBytes()) + if err != nil { + return + } + buffer.Truncate(n) + var metadata adapter.InboundContext + metadata.Inbound = a.tag + metadata.InboundType = a.protocol + metadata.SniffEnabled = a.listenOptions.SniffEnabled + metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) + metadata.Source = M.SocksaddrFromNetIP(addr) + metadata.OriginDestination = a.udpAddr + err = a.packetHandler.NewPacket(a.ctx, packetService, buffer, metadata) + if err != nil { + a.newError(E.Cause(err, "process packet from ", metadata.Source)) + } + } +} + +func (a *myInboundAdapter) loopUDPOOBIn() { + defer close(a.packetOutboundClosed) + _buffer := buf.StackNewPacket() + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + buffer.IncRef() + defer buffer.DecRef() + packetService := (*myInboundPacketAdapter)(a) + oob := make([]byte, 1024) + for { + buffer.Reset() + n, oobN, _, addr, err := a.udpConn.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob) + if err != nil { + return + } + buffer.Truncate(n) + var metadata adapter.InboundContext + metadata.Inbound = a.tag + metadata.InboundType = a.protocol + metadata.SniffEnabled = a.listenOptions.SniffEnabled + metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) + metadata.Source = M.SocksaddrFromNetIP(addr) + metadata.OriginDestination = a.udpAddr + err = a.oobPacketHandler.NewPacket(a.ctx, packetService, buffer, oob[:oobN], metadata) + if err != nil { + a.newError(E.Cause(err, "process packet from ", metadata.Source)) + } + } +} + +func (a *myInboundAdapter) loopUDPInThreadSafe() { + defer close(a.packetOutboundClosed) + packetService := (*myInboundPacketAdapter)(a) + for { + buffer := buf.NewPacket() + n, addr, err := a.udpConn.ReadFromUDPAddrPort(buffer.FreeBytes()) + if err != nil { + buffer.Release() + return + } + buffer.Truncate(n) + var metadata adapter.InboundContext + metadata.Inbound = a.tag + metadata.InboundType = a.protocol + metadata.SniffEnabled = a.listenOptions.SniffEnabled + metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) + metadata.Source = M.SocksaddrFromNetIP(addr) + metadata.OriginDestination = a.udpAddr + err = a.packetHandler.NewPacket(a.ctx, packetService, buffer, metadata) + if err != nil { + buffer.Release() + a.newError(E.Cause(err, "process packet from ", metadata.Source)) + } + } +} + +func (a *myInboundAdapter) loopUDPOOBInThreadSafe() { + defer close(a.packetOutboundClosed) + packetService := (*myInboundPacketAdapter)(a) + oob := make([]byte, 1024) + for { + buffer := buf.NewPacket() + n, oobN, _, addr, err := a.udpConn.ReadMsgUDPAddrPort(buffer.FreeBytes(), oob) + if err != nil { + buffer.Release() + return + } + buffer.Truncate(n) + var metadata adapter.InboundContext + metadata.Inbound = a.tag + metadata.InboundType = a.protocol + metadata.SniffEnabled = a.listenOptions.SniffEnabled + metadata.SniffOverrideDestination = a.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = dns.DomainStrategy(a.listenOptions.DomainStrategy) + metadata.Source = M.SocksaddrFromNetIP(addr) + metadata.OriginDestination = a.udpAddr + err = a.oobPacketHandler.NewPacket(a.ctx, packetService, buffer, oob[:oobN], metadata) + if err != nil { + buffer.Release() + a.newError(E.Cause(err, "process packet from ", metadata.Source)) + } + } +} + +func (a *myInboundAdapter) loopUDPOut() { + for { + select { + case packet := <-a.packetOutbound: + err := a.writePacket(packet.buffer, packet.destination) + if err != nil && !E.IsClosed(err) { + a.newError(E.New("write back udp: ", err)) + } + continue + case <-a.packetOutboundClosed: + } + for { + select { + case packet := <-a.packetOutbound: + packet.buffer.Release() + default: + return + } + } + } +} + +func (a *myInboundAdapter) writePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + if destination.IsFqdn() { + udpAddr, err := net.ResolveUDPAddr(N.NetworkUDP, destination.String()) + if err != nil { + return err + } + return common.Error(a.udpConn.WriteTo(buffer.Bytes(), udpAddr)) + } + return common.Error(a.udpConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort())) +} + +type myInboundPacketAdapter myInboundAdapter + +func (s *myInboundPacketAdapter) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + n, addr, err := s.udpConn.ReadFromUDPAddrPort(buffer.FreeBytes()) + if err != nil { + return M.Socksaddr{}, err + } + buffer.Truncate(n) + return M.SocksaddrFromNetIP(addr), nil +} + +func (s *myInboundPacketAdapter) WriteIsThreadUnsafe() { +} + +type myInboundPacket struct { + buffer *buf.Buffer + destination M.Socksaddr +} + +func (s *myInboundPacketAdapter) Upstream() any { + return s.udpConn +} + +func (s *myInboundPacketAdapter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + s.packetAccess.RLock() + defer s.packetAccess.RUnlock() + + select { + case <-s.packetOutboundClosed: + return os.ErrClosed + default: + } + + s.packetOutbound <- &myInboundPacket{buffer, destination} + return nil +} + +func (s *myInboundPacketAdapter) Close() error { + return s.udpConn.Close() +} + +func (s *myInboundPacketAdapter) LocalAddr() net.Addr { + return s.udpConn.LocalAddr() +} + +func (s *myInboundPacketAdapter) SetDeadline(t time.Time) error { + return s.udpConn.SetDeadline(t) +} + +func (s *myInboundPacketAdapter) SetReadDeadline(t time.Time) error { + return s.udpConn.SetReadDeadline(t) +} + +func (s *myInboundPacketAdapter) SetWriteDeadline(t time.Time) error { + return s.udpConn.SetWriteDeadline(t) +} diff --git a/inbound/http.go b/inbound/http.go index c6e77b15..4a54fccd 100644 --- a/inbound/http.go +++ b/inbound/http.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "net" + "os" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -17,7 +18,10 @@ import ( "github.com/sagernet/sing/protocol/http" ) -var _ adapter.Inbound = (*HTTP)(nil) +var ( + _ adapter.Inbound = (*HTTP)(nil) + _ adapter.InjectableInbound = (*HTTP)(nil) +) type HTTP struct { myInboundAdapter @@ -74,6 +78,10 @@ func (h *HTTP) NewConnection(ctx context.Context, conn net.Conn, metadata adapte return http.HandleConnection(ctx, conn, std_bufio.NewReader(conn), h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) } +func (h *HTTP) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} + func (a *myInboundAdapter) upstreamUserHandler(metadata adapter.InboundContext) adapter.UpstreamHandlerAdapter { return adapter.NewUpstreamHandler(metadata, a.newUserConnection, a.streamUserPacketConnection, a) } diff --git a/inbound/mixed.go b/inbound/mixed.go index b7f5b3dc..fceefe4d 100644 --- a/inbound/mixed.go +++ b/inbound/mixed.go @@ -4,6 +4,7 @@ import ( std_bufio "bufio" "context" "net" + "os" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -20,7 +21,10 @@ import ( "github.com/sagernet/sing/protocol/socks/socks5" ) -var _ adapter.Inbound = (*Mixed)(nil) +var ( + _ adapter.Inbound = (*Mixed)(nil) + _ adapter.InjectableInbound = (*Mixed)(nil) +) type Mixed struct { myInboundAdapter @@ -57,3 +61,7 @@ func (h *Mixed) NewConnection(ctx context.Context, conn net.Conn, metadata adapt reader := std_bufio.NewReader(bufio.NewCachedReader(conn, buf.As([]byte{headerType}))) return http.HandleConnection(ctx, conn, reader, h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) } + +func (h *Mixed) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} diff --git a/inbound/shadowsocks.go b/inbound/shadowsocks.go index 0c26392c..f36f40cc 100644 --- a/inbound/shadowsocks.go +++ b/inbound/shadowsocks.go @@ -3,6 +3,7 @@ package inbound import ( "context" "net" + "os" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -30,7 +31,10 @@ func NewShadowsocks(ctx context.Context, router adapter.Router, logger log.Conte } } -var _ adapter.Inbound = (*Shadowsocks)(nil) +var ( + _ adapter.Inbound = (*Shadowsocks)(nil) + _ adapter.InjectableInbound = (*Shadowsocks)(nil) +) type Shadowsocks struct { myInboundAdapter @@ -80,6 +84,10 @@ func (h *Shadowsocks) NewPacket(ctx context.Context, conn N.PacketConn, buffer * return h.service.NewPacket(adapter.WithContext(ctx, &metadata), conn, buffer, adapter.UpstreamMetadata(metadata)) } +func (h *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} + func (h *Shadowsocks) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) return h.router.RouteConnection(ctx, conn, metadata) diff --git a/inbound/shadowsocks_multi.go b/inbound/shadowsocks_multi.go index 85743f8e..242a8ccf 100644 --- a/inbound/shadowsocks_multi.go +++ b/inbound/shadowsocks_multi.go @@ -21,7 +21,10 @@ import ( N "github.com/sagernet/sing/common/network" ) -var _ adapter.Inbound = (*ShadowsocksMulti)(nil) +var ( + _ adapter.Inbound = (*ShadowsocksMulti)(nil) + _ adapter.InjectableInbound = (*ShadowsocksMulti)(nil) +) type ShadowsocksMulti struct { myInboundAdapter @@ -114,6 +117,10 @@ func (h *ShadowsocksMulti) NewPacket(ctx context.Context, conn N.PacketConn, buf return h.service.NewPacket(adapter.WithContext(ctx, &metadata), conn, buffer, adapter.UpstreamMetadata(metadata)) } +func (h *ShadowsocksMulti) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} + func (h *ShadowsocksMulti) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { userIndex, loaded := auth.UserFromContext[int](ctx) if !loaded { diff --git a/inbound/shadowsocks_relay.go b/inbound/shadowsocks_relay.go index 22c67cf0..2f624447 100644 --- a/inbound/shadowsocks_relay.go +++ b/inbound/shadowsocks_relay.go @@ -17,7 +17,10 @@ import ( N "github.com/sagernet/sing/common/network" ) -var _ adapter.Inbound = (*ShadowsocksMulti)(nil) +var ( + _ adapter.Inbound = (*ShadowsocksRelay)(nil) + _ adapter.InjectableInbound = (*ShadowsocksRelay)(nil) +) type ShadowsocksRelay struct { myInboundAdapter @@ -76,6 +79,10 @@ func (h *ShadowsocksRelay) NewPacket(ctx context.Context, conn N.PacketConn, buf return h.service.NewPacket(adapter.WithContext(ctx, &metadata), conn, buffer, adapter.UpstreamMetadata(metadata)) } +func (h *ShadowsocksRelay) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} + func (h *ShadowsocksRelay) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { destinationIndex, loaded := auth.UserFromContext[int](ctx) if !loaded { diff --git a/inbound/socks.go b/inbound/socks.go index d23ce5b9..7471341f 100644 --- a/inbound/socks.go +++ b/inbound/socks.go @@ -3,6 +3,7 @@ package inbound import ( "context" "net" + "os" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" @@ -13,7 +14,10 @@ import ( "github.com/sagernet/sing/protocol/socks" ) -var _ adapter.Inbound = (*Socks)(nil) +var ( + _ adapter.Inbound = (*Socks)(nil) + _ adapter.InjectableInbound = (*Socks)(nil) +) type Socks struct { myInboundAdapter @@ -40,3 +44,7 @@ func NewSocks(ctx context.Context, router adapter.Router, logger log.ContextLogg func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { return socks.HandleConnection(ctx, conn, h.authenticator, h.upstreamUserHandler(metadata), adapter.UpstreamMetadata(metadata)) } + +func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} diff --git a/inbound/trojan.go b/inbound/trojan.go index 30a3a484..2160bc89 100644 --- a/inbound/trojan.go +++ b/inbound/trojan.go @@ -20,7 +20,10 @@ import ( "github.com/sagernet/sing/protocol/trojan" ) -var _ adapter.Inbound = (*Trojan)(nil) +var ( + _ adapter.Inbound = (*Trojan)(nil) + _ adapter.InjectableInbound = (*Trojan)(nil) +) type Trojan struct { myInboundAdapter @@ -157,6 +160,10 @@ func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adap return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata)) } +func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} + func (h *Trojan) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { userIndex, loaded := auth.UserFromContext[int](ctx) if !loaded { diff --git a/inbound/vmess.go b/inbound/vmess.go index 4d5ce743..255d21d6 100644 --- a/inbound/vmess.go +++ b/inbound/vmess.go @@ -21,7 +21,10 @@ import ( N "github.com/sagernet/sing/common/network" ) -var _ adapter.Inbound = (*VMess)(nil) +var ( + _ adapter.Inbound = (*VMess)(nil) + _ adapter.InjectableInbound = (*VMess)(nil) +) type VMess struct { myInboundAdapter @@ -137,6 +140,10 @@ func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapt return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata)) } +func (h *VMess) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return os.ErrInvalid +} + func (h *VMess) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { userIndex, loaded := auth.UserFromContext[int](ctx) if !loaded { diff --git a/option/inbound.go b/option/inbound.go index 5a5ae9e1..80092d75 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -111,5 +111,6 @@ type ListenOptions struct { TCPFastOpen bool `json:"tcp_fast_open,omitempty"` UDPTimeout int64 `json:"udp_timeout,omitempty"` ProxyProtocol bool `json:"proxy_protocol,omitempty"` + Detour string `json:"detour,omitempty"` InboundOptions } diff --git a/route/router.go b/route/router.go index 4cebb51a..42c3fadf 100644 --- a/route/router.go +++ b/route/router.go @@ -65,6 +65,7 @@ type Router struct { ctx context.Context logger log.ContextLogger dnsLogger log.ContextLogger + inboundByTag map[string]adapter.Inbound outbounds []adapter.Outbound outboundByTag map[string]adapter.Outbound rules []adapter.Rule @@ -295,7 +296,11 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont return router, nil } -func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error { +func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error { + inboundByTag := make(map[string]adapter.Inbound) + for _, inbound := range inbounds { + inboundByTag[inbound.Tag()] = inbound + } outboundByTag := make(map[string]adapter.Outbound) for _, detour := range outbounds { outboundByTag[detour.Tag()] = detour @@ -360,6 +365,7 @@ func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() r.logger.Info("using ", defaultOutboundForConnection.Type(), "[", description, "] as default outbound for connection") r.logger.Info("using ", defaultOutboundForPacketConnection.Type(), "[", packetDescription, "] as default outbound for packet connection") } + r.inboundByTag = inboundByTag r.outbounds = outbounds r.defaultOutboundForConnection = defaultOutboundForConnection r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection @@ -498,6 +504,29 @@ func (r *Router) DefaultOutbound(network string) adapter.Outbound { } func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + if metadata.InboundDetour != "" { + if metadata.LastInbound == metadata.InboundDetour { + return E.New("routing loop on detour: ", metadata.InboundDetour) + } + detour := r.inboundByTag[metadata.InboundDetour] + if detour == nil { + return E.New("inbound detour not found: ", metadata.InboundDetour) + } + injectable, isInjectable := detour.(adapter.InjectableInbound) + if !isInjectable { + return E.New("inbound detour is not injectable: ", metadata.InboundDetour) + } + if !common.Contains(injectable.Network(), N.NetworkTCP) { + return E.New("inject: TCP unsupported") + } + metadata.InboundDetour = "" + metadata.LastInbound = metadata.Inbound + err := injectable.NewConnection(ctx, conn, metadata) + if err != nil { + return E.Cause(err, "inject ", detour.Tag()) + } + return nil + } metadata.Network = N.NetworkTCP switch metadata.Destination.Fqdn { case mux.Destination.Fqdn: @@ -555,6 +584,29 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + if metadata.InboundDetour != "" { + if metadata.LastInbound == metadata.InboundDetour { + return E.New("routing loop on detour: ", metadata.InboundDetour) + } + detour := r.inboundByTag[metadata.InboundDetour] + if detour == nil { + return E.New("inbound detour not found: ", metadata.InboundDetour) + } + injectable, isInjectable := detour.(adapter.InjectableInbound) + if !isInjectable { + return E.New("inbound detour is not injectable: ", metadata.InboundDetour) + } + if !common.Contains(injectable.Network(), N.NetworkUDP) { + return E.New("inject: UDP unsupported") + } + metadata.InboundDetour = "" + metadata.LastInbound = metadata.Inbound + err := injectable.NewPacketConnection(ctx, conn, metadata) + if err != nil { + return E.Cause(err, "inject ", detour.Tag()) + } + return nil + } metadata.Network = N.NetworkUDP if metadata.SniffEnabled { buffer := buf.NewPacket() diff --git a/test/inbound_detour_test.go b/test/inbound_detour_test.go new file mode 100644 index 00000000..e96b77bb --- /dev/null +++ b/test/inbound_detour_test.go @@ -0,0 +1,101 @@ +package main + +import ( + "net/netip" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-shadowsocks/shadowaead_2022" +) + +func TestChainedInbound(t *testing.T) { + method := shadowaead_2022.List[0] + password := mkBase64(t, 16) + startInstance(t, option.Options{ + Log: &option.LogOptions{ + Level: "error", + }, + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + Tag: "mixed-in", + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + { + Type: C.TypeShadowsocks, + ShadowsocksOptions: option.ShadowsocksInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + Detour: "detour", + }, + Method: method, + Password: password, + }, + }, + { + Type: C.TypeShadowsocks, + Tag: "detour", + ShadowsocksOptions: option.ShadowsocksInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: otherPort, + }, + Method: method, + Password: password, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeDirect, + }, + { + Type: C.TypeShadowsocks, + Tag: "ss-out", + ShadowsocksOptions: option.ShadowsocksOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + Method: method, + Password: password, + OutboundDialerOptions: option.OutboundDialerOptions{ + DialerOptions: option.DialerOptions{ + Detour: "detour-out", + }, + }, + }, + }, + { + Type: C.TypeShadowsocks, + Tag: "detour-out", + ShadowsocksOptions: option.ShadowsocksOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + Method: method, + Password: password, + }, + }, + }, + Route: &option.RouteOptions{ + Rules: []option.Rule{ + { + DefaultOptions: option.DefaultRule{ + Inbound: []string{"mixed-in"}, + Outbound: "ss-out", + }, + }, + }, + }, + }) + testSuit(t, clientPort, testPort) +}