Add handshake interface support for gVisor UDP

This commit is contained in:
世界 2023-07-23 14:18:36 +08:00
parent 0a68b9f1d8
commit aa8760b454
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
11 changed files with 156 additions and 61 deletions

View file

@ -6,6 +6,7 @@ import (
"context"
"net/netip"
"time"
"unsafe"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
@ -70,44 +71,10 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint)
if tErr != nil {
return E.New("create nic: ", wrapStackError(tErr))
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
handshakeCtx, cancel := context.WithCancel(context.Background())
@ -162,11 +129,12 @@ func (t *GVisor) Start() error {
endpoint.Abort()
return
}
gConn := &gUDPConn{udpConn, ipStack, (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(&gUDPConn{udpConn}), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
@ -207,3 +175,44 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
return netip.AddrFrom4(address.As4())
}
}
func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, ep)
if tErr != nil {
return nil, E.New("create nic: ", wrapStackError(tErr))
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
return ipStack, nil
}

View file

@ -27,28 +27,6 @@ func (c *gTCPConn) Write(b []byte) (n int, err error) {
return
}
type gUDPConn struct {
*gonet.UDPConn
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func wrapStackError(err tcpip.Error) error {
switch err.(type) {
case *tcpip.ErrClosedForSend,

View file

@ -4,11 +4,15 @@ package tun
import (
"context"
"errors"
"math"
"net/netip"
"os"
"syscall"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
@ -78,6 +82,7 @@ type UDPBackWriter struct {
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
packet stack.PacketBufferPtr
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
@ -141,3 +146,98 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
route.Stats().UDP.PacketsSent.Increment()
return nil
}
func (w *UDPBackWriter) HandshakeFailure(err error) error {
if w.packet == nil {
return os.ErrClosed
}
err = gWriteUnreachable(w.stack, w.packet, err)
w.packet.DecRef()
w.packet = nil
return err
}
type gRequest struct {
stack *stack.Stack
id stack.TransportEndpointID
pkt stack.PacketBufferPtr
}
type gUDPConn struct {
*gonet.UDPConn
stack *stack.Stack
packet stack.PacketBufferPtr
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Close() error {
c.packet.DecRef()
c.packet = nil
return c.UDPConn.Close()
}
func (c *gUDPConn) HandshakeFailure(err error) error {
if c.packet == nil {
return os.ErrClosed
}
err = gWriteUnreachable(c.stack, c.packet, err)
c.packet.DecRef()
c.packet = nil
return err
}
func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) error {
if errors.Is(err, syscall.ENETUNREACH) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
}
} else if errors.Is(err, syscall.EHOSTUNREACH) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
}
} else if errors.Is(err, syscall.ECONNREFUSED) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
}
}
return nil
}
func gWriteUnreachable4(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv4WithICMPType) error {
err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
if err != nil {
return wrapStackError(err)
}
return nil
}
func gWriteUnreachable6(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv6WithICMPType) error {
err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
if err != nil {
return wrapStackError(err)
}
return nil
}

View file

@ -91,6 +91,15 @@ func (s *System) Close() error {
}
func (s *System) Start() error {
err := s.start()
if err != nil {
return err
}
go s.tunLoop()
return nil
}
func (s *System) start() error {
err := fixWindowsFirewall()
if err != nil {
return E.Cause(err, "fix windows firewall for system stack")
@ -125,7 +134,6 @@ func (s *System) Start() error {
}
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler)
go s.tunLoop()
return nil
}