mirror of
https://github.com/SagerNet/sing-tun.git
synced 2025-04-03 03:47:39 +03:00
Add handshake interface support for gVisor UDP
This commit is contained in:
parent
0a68b9f1d8
commit
aa8760b454
11 changed files with 156 additions and 61 deletions
243
stack_gvisor_udp.go
Normal file
243
stack_gvisor_udp.go
Normal file
|
@ -0,0 +1,243 @@
|
|||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net/netip"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
)
|
||||
|
||||
type UDPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
udpNat *udpnat.Service[netip.AddrPort]
|
||||
|
||||
// cache
|
||||
cacheProto tcpip.NetworkProtocolNumber
|
||||
cacheID stack.TransportEndpointID
|
||||
cachePacket stack.PacketBufferPtr
|
||||
}
|
||||
|
||||
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
|
||||
return &UDPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
var upstreamMetadata M.Metadata
|
||||
upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
|
||||
upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
|
||||
if upstreamMetadata.Source.IsIPv4() {
|
||||
f.cacheProto = header.IPv4ProtocolNumber
|
||||
} else {
|
||||
f.cacheProto = header.IPv6ProtocolNumber
|
||||
}
|
||||
gBuffer := pkt.Data().ToBuffer()
|
||||
sBuffer := buf.NewSize(int(gBuffer.Size()))
|
||||
gBuffer.Apply(func(view *buffer.View) {
|
||||
sBuffer.Write(view.AsSlice())
|
||||
})
|
||||
f.cacheID = id
|
||||
f.cachePacket = pkt
|
||||
f.udpNat.NewPacket(
|
||||
f.ctx,
|
||||
upstreamMetadata.Source.AddrPort(),
|
||||
sBuffer,
|
||||
upstreamMetadata,
|
||||
f.newUDPConn,
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
|
||||
return &UDPBackWriter{
|
||||
stack: f.stack,
|
||||
source: f.cacheID.RemoteAddress,
|
||||
sourcePort: f.cacheID.RemotePort,
|
||||
sourceNetwork: f.cacheProto,
|
||||
}
|
||||
}
|
||||
|
||||
type UDPBackWriter struct {
|
||||
stack *stack.Stack
|
||||
source tcpip.Address
|
||||
sourcePort uint16
|
||||
sourceNetwork tcpip.NetworkProtocolNumber
|
||||
packet stack.PacketBufferPtr
|
||||
}
|
||||
|
||||
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
|
||||
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
|
||||
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) {
|
||||
return E.New("send IPv6 packet to IPv4 connection")
|
||||
}
|
||||
|
||||
defer packetBuffer.Release()
|
||||
|
||||
route, err := w.stack.FindRoute(
|
||||
defaultNIC,
|
||||
AddressFromAddr(destination.Addr),
|
||||
w.source,
|
||||
w.sourceNetwork,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return wrapStackError(err)
|
||||
}
|
||||
defer route.Release()
|
||||
|
||||
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
|
||||
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
|
||||
})
|
||||
defer packet.DecRef()
|
||||
|
||||
packet.TransportProtocolNumber = header.UDPProtocolNumber
|
||||
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
|
||||
pLen := uint16(packet.Size())
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: destination.Port,
|
||||
DstPort: w.sourcePort,
|
||||
Length: pLen,
|
||||
})
|
||||
|
||||
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
|
||||
xsum := udpHdr.CalculateChecksum(checksum.Combine(
|
||||
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
|
||||
packet.Data().Checksum(),
|
||||
))
|
||||
if xsum != math.MaxUint16 {
|
||||
xsum = ^xsum
|
||||
}
|
||||
udpHdr.SetChecksum(xsum)
|
||||
}
|
||||
|
||||
err = route.WritePacket(stack.NetworkHeaderParams{
|
||||
Protocol: header.UDPProtocolNumber,
|
||||
TTL: route.DefaultTTL(),
|
||||
TOS: 0,
|
||||
}, packet)
|
||||
|
||||
if err != nil {
|
||||
route.Stats().UDP.PacketSendErrors.Increment()
|
||||
return wrapStackError(err)
|
||||
}
|
||||
|
||||
route.Stats().UDP.PacketsSent.Increment()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *UDPBackWriter) HandshakeFailure(err error) error {
|
||||
if w.packet == nil {
|
||||
return os.ErrClosed
|
||||
}
|
||||
err = gWriteUnreachable(w.stack, w.packet, err)
|
||||
w.packet.DecRef()
|
||||
w.packet = nil
|
||||
return err
|
||||
}
|
||||
|
||||
type gRequest struct {
|
||||
stack *stack.Stack
|
||||
id stack.TransportEndpointID
|
||||
pkt stack.PacketBufferPtr
|
||||
}
|
||||
|
||||
type gUDPConn struct {
|
||||
*gonet.UDPConn
|
||||
stack *stack.Stack
|
||||
packet stack.PacketBufferPtr
|
||||
}
|
||||
|
||||
func (c *gUDPConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.UDPConn.Read(b)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *gUDPConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.UDPConn.Write(b)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *gUDPConn) Close() error {
|
||||
c.packet.DecRef()
|
||||
c.packet = nil
|
||||
return c.UDPConn.Close()
|
||||
}
|
||||
|
||||
func (c *gUDPConn) HandshakeFailure(err error) error {
|
||||
if c.packet == nil {
|
||||
return os.ErrClosed
|
||||
}
|
||||
err = gWriteUnreachable(c.stack, c.packet, err)
|
||||
c.packet.DecRef()
|
||||
c.packet = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) error {
|
||||
if errors.Is(err, syscall.ENETUNREACH) {
|
||||
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
|
||||
} else {
|
||||
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
|
||||
}
|
||||
} else if errors.Is(err, syscall.EHOSTUNREACH) {
|
||||
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
|
||||
} else {
|
||||
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
|
||||
}
|
||||
} else if errors.Is(err, syscall.ECONNREFUSED) {
|
||||
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
|
||||
} else {
|
||||
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func gWriteUnreachable4(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv4WithICMPType) error {
|
||||
err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
|
||||
if err != nil {
|
||||
return wrapStackError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func gWriteUnreachable6(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv6WithICMPType) error {
|
||||
err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
|
||||
if err != nil {
|
||||
return wrapStackError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue