From aa8760b454ed954c52f7429c7a057df7636e21eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 23 Jul 2023 14:18:36 +0800 Subject: [PATCH] Add handshake interface support for gVisor UDP --- gvisor.go => stack_gvisor.go | 85 ++++++++------- gvisor_err.go => stack_gvisor_err.go | 22 ---- gvisor_log.go => stack_gvisor_log.go | 0 gvisor_stub.go => stack_gvisor_stub.go | 0 gvisor_udp.go => stack_gvisor_udp.go | 100 ++++++++++++++++++ lwip.go => stack_lwip.go | 0 lwip_stub.go => stack_lwip_stub.go | 0 system.go => stack_system.go | 10 +- system_nat.go => stack_system_nat.go | 0 ...onwindows.go => stack_system_nonwindows.go | 0 system_windows.go => stack_system_windows.go | 0 11 files changed, 156 insertions(+), 61 deletions(-) rename gvisor.go => stack_gvisor.go (91%) rename gvisor_err.go => stack_gvisor_err.go (73%) rename gvisor_log.go => stack_gvisor_log.go (100%) rename gvisor_stub.go => stack_gvisor_stub.go (100%) rename gvisor_udp.go => stack_gvisor_udp.go (58%) rename lwip.go => stack_lwip.go (100%) rename lwip_stub.go => stack_lwip_stub.go (100%) rename system.go => stack_system.go (99%) rename system_nat.go => stack_system_nat.go (100%) rename system_nonwindows.go => stack_system_nonwindows.go (100%) rename system_windows.go => stack_system_windows.go (100%) diff --git a/gvisor.go b/stack_gvisor.go similarity index 91% rename from gvisor.go rename to stack_gvisor.go index a486741..dd9b2e2 100644 --- a/gvisor.go +++ b/stack_gvisor.go @@ -6,6 +6,7 @@ import ( "context" "net/netip" "time" + "unsafe" "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" @@ -70,44 +71,10 @@ func (t *GVisor) Start() error { if err != nil { return err } - ipStack := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{ - tcp.NewProtocol, - udp.NewProtocol, - icmp.NewProtocol4, - icmp.NewProtocol6, - }, - }) - tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint) - if tErr != nil { - return E.New("create nic: ", wrapStackError(tErr)) + ipStack, err := newGVisorStack(linkEndpoint) + if err != nil { + return err } - ipStack.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: defaultNIC}, - {Destination: header.IPv6EmptySubnet, NIC: defaultNIC}, - }) - ipStack.SetSpoofing(defaultNIC, true) - ipStack.SetPromiscuousMode(defaultNIC, true) - bufSize := 20 * 1024 - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: bufSize, - Max: bufSize, - }) - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{ - Min: 1, - Default: bufSize, - Max: bufSize, - }) - sOpt := tcpip.TCPSACKEnabled(true) - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) - mOpt := tcpip.TCPModerateReceiveBufferOption(true) - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt) - tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { var wq waiter.Queue handshakeCtx, cancel := context.WithCancel(context.Background()) @@ -162,11 +129,12 @@ func (t *GVisor) Start() error { endpoint.Abort() return } + gConn := &gUDPConn{udpConn, ipStack, (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()} go func() { var metadata M.Metadata metadata.Source = M.SocksaddrFromNet(lAddr) metadata.Destination = M.SocksaddrFromNet(rAddr) - ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(&gUDPConn{udpConn}), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second) + ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second) hErr := t.handler.NewPacketConnection(ctx, conn, metadata) if hErr != nil { endpoint.Abort() @@ -207,3 +175,44 @@ func AddrFromAddress(address tcpip.Address) netip.Addr { return netip.AddrFrom4(address.As4()) } } + +func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { + ipStack := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + icmp.NewProtocol4, + icmp.NewProtocol6, + }, + }) + tErr := ipStack.CreateNIC(defaultNIC, ep) + if tErr != nil { + return nil, E.New("create nic: ", wrapStackError(tErr)) + } + ipStack.SetRouteTable([]tcpip.Route{ + {Destination: header.IPv4EmptySubnet, NIC: defaultNIC}, + {Destination: header.IPv6EmptySubnet, NIC: defaultNIC}, + }) + ipStack.SetSpoofing(defaultNIC, true) + ipStack.SetPromiscuousMode(defaultNIC, true) + bufSize := 20 * 1024 + ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: bufSize, + Max: bufSize, + }) + ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: bufSize, + Max: bufSize, + }) + sOpt := tcpip.TCPSACKEnabled(true) + ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) + mOpt := tcpip.TCPModerateReceiveBufferOption(true) + ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt) + return ipStack, nil +} diff --git a/gvisor_err.go b/stack_gvisor_err.go similarity index 73% rename from gvisor_err.go rename to stack_gvisor_err.go index 6fe49cc..51ce23c 100644 --- a/gvisor_err.go +++ b/stack_gvisor_err.go @@ -27,28 +27,6 @@ func (c *gTCPConn) Write(b []byte) (n int, err error) { return } -type gUDPConn struct { - *gonet.UDPConn -} - -func (c *gUDPConn) Read(b []byte) (n int, err error) { - n, err = c.UDPConn.Read(b) - if err == nil { - return - } - err = wrapError(err) - return -} - -func (c *gUDPConn) Write(b []byte) (n int, err error) { - n, err = c.UDPConn.Write(b) - if err == nil { - return - } - err = wrapError(err) - return -} - func wrapStackError(err tcpip.Error) error { switch err.(type) { case *tcpip.ErrClosedForSend, diff --git a/gvisor_log.go b/stack_gvisor_log.go similarity index 100% rename from gvisor_log.go rename to stack_gvisor_log.go diff --git a/gvisor_stub.go b/stack_gvisor_stub.go similarity index 100% rename from gvisor_stub.go rename to stack_gvisor_stub.go diff --git a/gvisor_udp.go b/stack_gvisor_udp.go similarity index 58% rename from gvisor_udp.go rename to stack_gvisor_udp.go index 077cf0b..0846e5d 100644 --- a/gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -4,11 +4,15 @@ package tun import ( "context" + "errors" "math" "net/netip" + "os" + "syscall" "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" "github.com/sagernet/gvisor/pkg/tcpip/checksum" "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" @@ -78,6 +82,7 @@ type UDPBackWriter struct { source tcpip.Address sourcePort uint16 sourceNetwork tcpip.NetworkProtocolNumber + packet stack.PacketBufferPtr } func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { @@ -141,3 +146,98 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock route.Stats().UDP.PacketsSent.Increment() return nil } + +func (w *UDPBackWriter) HandshakeFailure(err error) error { + if w.packet == nil { + return os.ErrClosed + } + err = gWriteUnreachable(w.stack, w.packet, err) + w.packet.DecRef() + w.packet = nil + return err +} + +type gRequest struct { + stack *stack.Stack + id stack.TransportEndpointID + pkt stack.PacketBufferPtr +} + +type gUDPConn struct { + *gonet.UDPConn + stack *stack.Stack + packet stack.PacketBufferPtr +} + +func (c *gUDPConn) Read(b []byte) (n int, err error) { + n, err = c.UDPConn.Read(b) + if err == nil { + return + } + err = wrapError(err) + return +} + +func (c *gUDPConn) Write(b []byte) (n int, err error) { + n, err = c.UDPConn.Write(b) + if err == nil { + return + } + err = wrapError(err) + return +} + +func (c *gUDPConn) Close() error { + c.packet.DecRef() + c.packet = nil + return c.UDPConn.Close() +} + +func (c *gUDPConn) HandshakeFailure(err error) error { + if c.packet == nil { + return os.ErrClosed + } + err = gWriteUnreachable(c.stack, c.packet, err) + c.packet.DecRef() + c.packet = nil + return err +} + +func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) error { + if errors.Is(err, syscall.ENETUNREACH) { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) + } else { + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) + } + } else if errors.Is(err, syscall.EHOSTUNREACH) { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) + } else { + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) + } + } else if errors.Is(err, syscall.ECONNREFUSED) { + if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) + } else { + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) + } + } + return nil +} + +func gWriteUnreachable4(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv4WithICMPType) error { + err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true) + if err != nil { + return wrapStackError(err) + } + return nil +} + +func gWriteUnreachable6(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv6WithICMPType) error { + err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true) + if err != nil { + return wrapStackError(err) + } + return nil +} diff --git a/lwip.go b/stack_lwip.go similarity index 100% rename from lwip.go rename to stack_lwip.go diff --git a/lwip_stub.go b/stack_lwip_stub.go similarity index 100% rename from lwip_stub.go rename to stack_lwip_stub.go diff --git a/system.go b/stack_system.go similarity index 99% rename from system.go rename to stack_system.go index a454dd5..e8a64ca 100644 --- a/system.go +++ b/stack_system.go @@ -91,6 +91,15 @@ func (s *System) Close() error { } func (s *System) Start() error { + err := s.start() + if err != nil { + return err + } + go s.tunLoop() + return nil +} + +func (s *System) start() error { err := fixWindowsFirewall() if err != nil { return E.Cause(err, "fix windows firewall for system stack") @@ -125,7 +134,6 @@ func (s *System) Start() error { } s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout)) s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) - go s.tunLoop() return nil } diff --git a/system_nat.go b/stack_system_nat.go similarity index 100% rename from system_nat.go rename to stack_system_nat.go diff --git a/system_nonwindows.go b/stack_system_nonwindows.go similarity index 100% rename from system_nonwindows.go rename to stack_system_nonwindows.go diff --git a/system_windows.go b/stack_system_windows.go similarity index 100% rename from system_windows.go rename to stack_system_windows.go