diff --git a/gvisor.go b/gvisor.go index e91810b..eac1a18 100644 --- a/gvisor.go +++ b/gvisor.go @@ -41,25 +41,20 @@ type GVisorTun interface { } func NewGVisor( - ctx context.Context, - tun Tun, - tunMtu uint32, - endpointIndependentNat bool, - udpTimeout int64, - handler Handler, + options StackOptions, ) (Stack, error) { - gTun, isGTun := tun.(GVisorTun) + gTun, isGTun := options.Tun.(GVisorTun) if !isGTun { return nil, ErrGVisorUnsupported } return &GVisor{ - ctx: ctx, + ctx: options.Context, tun: gTun, - tunMtu: tunMtu, - endpointIndependentNat: endpointIndependentNat, - udpTimeout: udpTimeout, - handler: handler, + tunMtu: options.MTU, + endpointIndependentNat: options.EndpointIndependentNat, + udpTimeout: options.UDPTimeout, + handler: options.Handler, }, nil } diff --git a/gvisor_other.go b/gvisor_other.go index 29647ab..e30e891 100644 --- a/gvisor_other.go +++ b/gvisor_other.go @@ -2,15 +2,8 @@ package tun -import "context" - func NewGVisor( - ctx context.Context, - tun Tun, - tunMtu uint32, - endpointIndependentNat bool, - endpointIndependentNatTimeout int64, - handler Handler, + options StackOptions, ) (Stack, error) { return nil, ErrGVisorUnsupported } diff --git a/gvisor_stub.go b/gvisor_stub.go index c2fe71d..2a022c3 100644 --- a/gvisor_stub.go +++ b/gvisor_stub.go @@ -2,15 +2,8 @@ package tun -import "context" - func NewGVisor( - ctx context.Context, - tun Tun, - tunMtu uint32, - endpointIndependentNat bool, - endpointIndependentNatTimeout int64, - handler Handler, + options StackOptions, ) (Stack, error) { return nil, ErrGVisorNotIncluded } diff --git a/internal/clashtcpip/icmp.go b/internal/clashtcpip/icmp.go new file mode 100644 index 0000000..0050fd5 --- /dev/null +++ b/internal/clashtcpip/icmp.go @@ -0,0 +1,40 @@ +package clashtcpip + +import ( + "encoding/binary" +) + +type ICMPType = byte + +const ( + ICMPTypePingRequest byte = 0x8 + ICMPTypePingResponse byte = 0x0 +) + +type ICMPPacket []byte + +func (p ICMPPacket) Type() ICMPType { + return p[0] +} + +func (p ICMPPacket) SetType(v ICMPType) { + p[0] = v +} + +func (p ICMPPacket) Code() byte { + return p[1] +} + +func (p ICMPPacket) Checksum() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p ICMPPacket) SetChecksum(sum [2]byte) { + p[2] = sum[0] + p[3] = sum[1] +} + +func (p ICMPPacket) ResetChecksum() { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(0, p)) +} diff --git a/internal/clashtcpip/icmpv6.go b/internal/clashtcpip/icmpv6.go new file mode 100644 index 0000000..67f92d1 --- /dev/null +++ b/internal/clashtcpip/icmpv6.go @@ -0,0 +1,172 @@ +package clashtcpip + +import ( + "encoding/binary" +) + +type ICMPv6Packet []byte + +const ( + ICMPv6HeaderSize = 4 + + ICMPv6MinimumSize = 8 + + ICMPv6PayloadOffset = 8 + + ICMPv6EchoMinimumSize = 8 + + ICMPv6ErrorHeaderSize = 8 + + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + + ICMPv6ChecksumOffset = 2 + + icmpv6PointerOffset = 4 + + icmpv6MTUOffset = 4 + + icmpv6IdentOffset = 4 + + icmpv6SequenceOffset = 6 + + NDPHopLimit = 255 +) + +type ICMPv6Type byte + +const ( + ICMPv6DstUnreachable ICMPv6Type = 1 + ICMPv6PacketTooBig ICMPv6Type = 2 + ICMPv6TimeExceeded ICMPv6Type = 3 + ICMPv6ParamProblem ICMPv6Type = 4 + ICMPv6EchoRequest ICMPv6Type = 128 + ICMPv6EchoReply ICMPv6Type = 129 + + ICMPv6RouterSolicit ICMPv6Type = 133 + ICMPv6RouterAdvert ICMPv6Type = 134 + ICMPv6NeighborSolicit ICMPv6Type = 135 + ICMPv6NeighborAdvert ICMPv6Type = 136 + ICMPv6RedirectMsg ICMPv6Type = 137 + + ICMPv6MulticastListenerQuery ICMPv6Type = 130 + ICMPv6MulticastListenerReport ICMPv6Type = 131 + ICMPv6MulticastListenerDone ICMPv6Type = 132 +) + +func (typ ICMPv6Type) IsErrorType() bool { + return typ&0x80 == 0 +} + +type ICMPv6Code byte + +const ( + ICMPv6NetworkUnreachable ICMPv6Code = 0 + ICMPv6Prohibited ICMPv6Code = 1 + ICMPv6BeyondScope ICMPv6Code = 2 + ICMPv6AddressUnreachable ICMPv6Code = 3 + ICMPv6PortUnreachable ICMPv6Code = 4 + ICMPv6Policy ICMPv6Code = 5 + ICMPv6RejectRoute ICMPv6Code = 6 +) + +const ( + ICMPv6HopLimitExceeded ICMPv6Code = 0 + ICMPv6ReassemblyTimeout ICMPv6Code = 1 +) + +const ( + ICMPv6ErroneousHeader ICMPv6Code = 0 + + ICMPv6UnknownHeader ICMPv6Code = 1 + + ICMPv6UnknownOption ICMPv6Code = 2 +) + +const ICMPv6UnusedCode ICMPv6Code = 0 + +func (b ICMPv6Packet) Type() ICMPv6Type { + return ICMPv6Type(b[0]) +} + +func (b ICMPv6Packet) SetType(t ICMPv6Type) { + b[0] = byte(t) +} + +func (b ICMPv6Packet) Code() ICMPv6Code { + return ICMPv6Code(b[1]) +} + +func (b ICMPv6Packet) SetCode(c ICMPv6Code) { + b[1] = byte(c) +} + +func (b ICMPv6Packet) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +func (b ICMPv6Packet) SetTypeSpecific(val uint32) { + binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) +} + +func (b ICMPv6Packet) Checksum() uint16 { + return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:]) +} + +func (b ICMPv6Packet) SetChecksum(sum [2]byte) { + _ = b[ICMPv6ChecksumOffset+1] + b[ICMPv6ChecksumOffset] = sum[0] + b[ICMPv6ChecksumOffset+1] = sum[1] +} + +func (ICMPv6Packet) SourcePort() uint16 { + return 0 +} + +func (ICMPv6Packet) DestinationPort() uint16 { + return 0 +} + +func (ICMPv6Packet) SetSourcePort(uint16) { +} + +func (ICMPv6Packet) SetDestinationPort(uint16) { +} + +func (b ICMPv6Packet) MTU() uint32 { + return binary.BigEndian.Uint32(b[icmpv6MTUOffset:]) +} + +func (b ICMPv6Packet) SetMTU(mtu uint32) { + binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu) +} + +func (b ICMPv6Packet) Ident() uint16 { + return binary.BigEndian.Uint16(b[icmpv6IdentOffset:]) +} + +func (b ICMPv6Packet) SetIdent(ident uint16) { + binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident) +} + +func (b ICMPv6Packet) Sequence() uint16 { + return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:]) +} + +func (b ICMPv6Packet) SetSequence(sequence uint16) { + binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) +} + +func (b ICMPv6Packet) MessageBody() []byte { + return b[ICMPv6HeaderSize:] +} + +func (b ICMPv6Packet) Payload() []byte { + return b[ICMPv6PayloadOffset:] +} + +func (b ICMPv6Packet) ResetChecksum(psum uint32) { + b.SetChecksum(zeroChecksum) + b.SetChecksum(Checksum(psum, b)) +} diff --git a/internal/clashtcpip/ip.go b/internal/clashtcpip/ip.go new file mode 100644 index 0000000..ad65679 --- /dev/null +++ b/internal/clashtcpip/ip.go @@ -0,0 +1,209 @@ +package clashtcpip + +import ( + "encoding/binary" + "errors" + "net/netip" +) + +type IPProtocol = byte + +type IP interface { + Payload() []byte + SourceIP() netip.Addr + DestinationIP() netip.Addr + SetSourceIP(ip netip.Addr) + SetDestinationIP(ip netip.Addr) + Protocol() IPProtocol + DecTimeToLive() + ResetChecksum() + PseudoSum() uint32 +} + +// IPProtocol type +const ( + ICMP IPProtocol = 0x01 + TCP IPProtocol = 0x06 + UDP IPProtocol = 0x11 + ICMPv6 IPProtocol = 0x3a +) + +const ( + FlagDontFragment = 1 << 1 + FlagMoreFragment = 1 << 2 +) + +const ( + IPv4HeaderSize = 20 + + IPv4Version = 4 + + IPv4OptionsOffset = 20 + IPv4PacketMinLength = IPv4OptionsOffset +) + +var ( + ErrInvalidLength = errors.New("invalid packet length") + ErrInvalidIPVersion = errors.New("invalid ip version") + ErrInvalidChecksum = errors.New("invalid checksum") +) + +type IPv4Packet []byte + +func (p IPv4Packet) TotalLen() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p IPv4Packet) SetTotalLength(length uint16) { + binary.BigEndian.PutUint16(p[2:], length) +} + +func (p IPv4Packet) HeaderLen() uint16 { + return uint16(p[0]&0xf) * 4 +} + +func (p IPv4Packet) SetHeaderLen(length uint16) { + p[0] &= 0xF0 + p[0] |= byte(length / 4) +} + +func (p IPv4Packet) TypeOfService() byte { + return p[1] +} + +func (p IPv4Packet) SetTypeOfService(tos byte) { + p[1] = tos +} + +func (p IPv4Packet) Identification() uint16 { + return binary.BigEndian.Uint16(p[4:]) +} + +func (p IPv4Packet) SetIdentification(id uint16) { + binary.BigEndian.PutUint16(p[4:], id) +} + +func (p IPv4Packet) FragmentOffset() uint16 { + return binary.BigEndian.Uint16([]byte{p[6] & 0x7, p[7]}) * 8 +} + +func (p IPv4Packet) SetFragmentOffset(offset uint32) { + flags := p.Flags() + binary.BigEndian.PutUint16(p[6:], uint16(offset/8)) + p.SetFlags(flags) +} + +func (p IPv4Packet) DataLen() uint16 { + return p.TotalLen() - p.HeaderLen() +} + +func (p IPv4Packet) Payload() []byte { + return p[p.HeaderLen():p.TotalLen()] +} + +func (p IPv4Packet) Protocol() IPProtocol { + return p[9] +} + +func (p IPv4Packet) SetProtocol(protocol IPProtocol) { + p[9] = protocol +} + +func (p IPv4Packet) Flags() byte { + return p[6] >> 5 +} + +func (p IPv4Packet) SetFlags(flags byte) { + p[6] &= 0x1F + p[6] |= flags << 5 +} + +func (p IPv4Packet) SourceIP() netip.Addr { + return netip.AddrFrom4([4]byte{p[12], p[13], p[14], p[15]}) +} + +func (p IPv4Packet) SetSourceIP(ip netip.Addr) { + if ip.Is4() { + copy(p[12:16], ip.AsSlice()) + } +} + +func (p IPv4Packet) DestinationIP() netip.Addr { + return netip.AddrFrom4([4]byte{p[16], p[17], p[18], p[19]}) +} + +func (p IPv4Packet) SetDestinationIP(ip netip.Addr) { + if ip.Is4() { + copy(p[16:20], ip.AsSlice()) + } +} + +func (p IPv4Packet) Checksum() uint16 { + return binary.BigEndian.Uint16(p[10:]) +} + +func (p IPv4Packet) SetChecksum(sum [2]byte) { + p[10] = sum[0] + p[11] = sum[1] +} + +func (p IPv4Packet) TimeToLive() uint8 { + return p[8] +} + +func (p IPv4Packet) SetTimeToLive(ttl uint8) { + p[8] = ttl +} + +func (p IPv4Packet) DecTimeToLive() { + p[8] = p[8] - uint8(1) +} + +func (p IPv4Packet) ResetChecksum() { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(0, p[:p.HeaderLen()])) +} + +// PseudoSum for tcp checksum +func (p IPv4Packet) PseudoSum() uint32 { + sum := Sum(p[12:20]) + sum += uint32(p.Protocol()) + sum += uint32(p.DataLen()) + return sum +} + +func (p IPv4Packet) Valid() bool { + return len(p) >= IPv4HeaderSize && uint16(len(p)) >= p.TotalLen() +} + +func (p IPv4Packet) Verify() error { + if len(p) < IPv4PacketMinLength { + return ErrInvalidLength + } + + checksum := []byte{p[10], p[11]} + headerLength := uint16(p[0]&0xF) * 4 + packetLength := binary.BigEndian.Uint16(p[2:]) + + if p[0]>>4 != 4 { + return ErrInvalidIPVersion + } + + if uint16(len(p)) < packetLength || packetLength < headerLength { + return ErrInvalidLength + } + + p[10] = 0 + p[11] = 0 + defer copy(p[10:12], checksum) + + answer := Checksum(0, p[:headerLength]) + + if answer[0] != checksum[0] || answer[1] != checksum[1] { + return ErrInvalidChecksum + } + + return nil +} + +var _ IP = (*IPv4Packet)(nil) diff --git a/internal/clashtcpip/ipv6.go b/internal/clashtcpip/ipv6.go new file mode 100644 index 0000000..20147e5 --- /dev/null +++ b/internal/clashtcpip/ipv6.go @@ -0,0 +1,141 @@ +package clashtcpip + +import ( + "encoding/binary" + "net/netip" +) + +const ( + versTCFL = 0 + + IPv6PayloadLenOffset = 4 + + IPv6NextHeaderOffset = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = v6SrcAddr + IPv6AddressSize + + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize +) + +const ( + versIHL = 0 + tos = 1 + ipVersionShift = 4 + ipIHLMask = 0x0f + IPv4IHLStride = 4 +) + +type IPv6Packet []byte + +const ( + IPv6MinimumSize = IPv6FixedHeaderSize + + IPv6AddressSize = 16 + + IPv6Version = 6 + + IPv6MinimumMTU = 1280 +) + +func (b IPv6Packet) PayloadLength() uint16 { + return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) +} + +func (b IPv6Packet) HopLimit() uint8 { + return b[hopLimit] +} + +func (b IPv6Packet) NextHeader() byte { + return b[IPv6NextHeaderOffset] +} + +func (b IPv6Packet) Protocol() IPProtocol { + return b.NextHeader() +} + +func (b IPv6Packet) Payload() []byte { + return b[IPv6MinimumSize:][:b.PayloadLength()] +} + +func (b IPv6Packet) SourceIP() netip.Addr { + addr, _ := netip.AddrFromSlice(b[v6SrcAddr:][:IPv6AddressSize]) + return addr +} + +func (b IPv6Packet) DestinationIP() netip.Addr { + addr, _ := netip.AddrFromSlice(b[v6DstAddr:][:IPv6AddressSize]) + return addr +} + +func (IPv6Packet) Checksum() uint16 { + return 0 +} + +func (b IPv6Packet) TOS() (uint8, uint32) { + v := binary.BigEndian.Uint32(b[versTCFL:]) + return uint8(v >> 20), v & 0xfffff +} + +func (b IPv6Packet) SetTOS(t uint8, l uint32) { + vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff) + binary.BigEndian.PutUint32(b[versTCFL:], vtf) +} + +func (b IPv6Packet) SetPayloadLength(payloadLength uint16) { + binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength) +} + +func (b IPv6Packet) SetSourceIP(addr netip.Addr) { + if addr.Is6() { + copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice()) + } +} + +func (b IPv6Packet) SetDestinationIP(addr netip.Addr) { + if addr.Is6() { + copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice()) + } +} + +func (b IPv6Packet) SetHopLimit(v uint8) { + b[hopLimit] = v +} + +func (b IPv6Packet) SetNextHeader(v byte) { + b[IPv6NextHeaderOffset] = v +} + +func (b IPv6Packet) SetProtocol(p IPProtocol) { + b.SetNextHeader(p) +} + +func (b IPv6Packet) DecTimeToLive() { + b[hopLimit] = b[hopLimit] - uint8(1) +} + +func (IPv6Packet) SetChecksum(uint16) { +} + +func (IPv6Packet) ResetChecksum() { +} + +func (b IPv6Packet) PseudoSum() uint32 { + sum := Sum(b[v6SrcAddr:IPv6FixedHeaderSize]) + sum += uint32(b.Protocol()) + sum += uint32(b.PayloadLength()) + return sum +} + +func (b IPv6Packet) Valid() bool { + return len(b) >= IPv6MinimumSize && len(b) >= int(b.PayloadLength())+IPv6MinimumSize +} + +func IPVersion(b []byte) int { + if len(b) < versIHL+1 { + return -1 + } + return int(b[versIHL] >> ipVersionShift) +} + +var _ IP = (*IPv6Packet)(nil) diff --git a/internal/clashtcpip/tcp.go b/internal/clashtcpip/tcp.go new file mode 100644 index 0000000..3e0ee73 --- /dev/null +++ b/internal/clashtcpip/tcp.go @@ -0,0 +1,90 @@ +package clashtcpip + +import ( + "encoding/binary" + "net" +) + +const ( + TCPFin uint16 = 1 << 0 + TCPSyn uint16 = 1 << 1 + TCPRst uint16 = 1 << 2 + TCPPuh uint16 = 1 << 3 + TCPAck uint16 = 1 << 4 + TCPUrg uint16 = 1 << 5 + TCPEce uint16 = 1 << 6 + TCPEwr uint16 = 1 << 7 + TCPNs uint16 = 1 << 8 +) + +const TCPHeaderSize = 20 + +type TCPPacket []byte + +func (p TCPPacket) SourcePort() uint16 { + return binary.BigEndian.Uint16(p) +} + +func (p TCPPacket) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(p, port) +} + +func (p TCPPacket) DestinationPort() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p TCPPacket) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(p[2:], port) +} + +func (p TCPPacket) Flags() uint16 { + return uint16(p[13] | (p[12] & 0x1)) +} + +func (p TCPPacket) Checksum() uint16 { + return binary.BigEndian.Uint16(p[16:]) +} + +func (p TCPPacket) SetChecksum(sum [2]byte) { + p[16] = sum[0] + p[17] = sum[1] +} + +func (p TCPPacket) ResetChecksum(psum uint32) { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(psum, p)) +} + +func (p TCPPacket) Valid() bool { + return len(p) >= TCPHeaderSize +} + +func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error { + var checksum [2]byte + checksum[0] = p[16] + checksum[1] = p[17] + + // reset checksum + p[16] = 0 + p[17] = 0 + + // restore checksum + defer func() { + p[16] = checksum[0] + p[17] = checksum[1] + }() + + // check checksum + s := uint32(0) + s += Sum(sourceAddress) + s += Sum(targetAddress) + s += uint32(TCP) + s += uint32(len(p)) + + check := Checksum(s, p) + if checksum[0] != check[0] || checksum[1] != check[1] { + return ErrInvalidChecksum + } + + return nil +} diff --git a/internal/clashtcpip/tcpip.go b/internal/clashtcpip/tcpip.go new file mode 100644 index 0000000..2994637 --- /dev/null +++ b/internal/clashtcpip/tcpip.go @@ -0,0 +1,24 @@ +package clashtcpip + +var zeroChecksum = [2]byte{0x00, 0x00} + +var SumFnc = SumCompat + +func Sum(b []byte) uint32 { + return SumFnc(b) +} + +// Checksum for Internet Protocol family headers +func Checksum(sum uint32, b []byte) (answer [2]byte) { + sum += Sum(b) + sum = (sum >> 16) + (sum & 0xffff) + sum += sum >> 16 + sum = ^sum + answer[0] = byte(sum >> 8) + answer[1] = byte(sum) + return +} + +func SetIPv4(packet []byte) { + packet[0] = (packet[0] & 0x0f) | (4 << 4) +} diff --git a/internal/clashtcpip/tcpip_amd64.go b/internal/clashtcpip/tcpip_amd64.go new file mode 100644 index 0000000..711a85c --- /dev/null +++ b/internal/clashtcpip/tcpip_amd64.go @@ -0,0 +1,26 @@ +//go:build !noasm + +package clashtcpip + +import ( + "unsafe" + + "golang.org/x/sys/cpu" +) + +//go:noescape +func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr + +func SumAVX2(data []byte) uint32 { + if len(data) == 0 { + return 0 + } + + return uint32(sumAsmAvx2(unsafe.Pointer(&data[0]), uintptr(len(data)))) +} + +func init() { + if cpu.X86.HasAVX2 { + SumFnc = SumAVX2 + } +} diff --git a/internal/clashtcpip/tcpip_amd64.s b/internal/clashtcpip/tcpip_amd64.s new file mode 100644 index 0000000..100820b --- /dev/null +++ b/internal/clashtcpip/tcpip_amd64.s @@ -0,0 +1,140 @@ +#include "textflag.h" + +DATA endian_swap_mask<>+0(SB)/8, $0x607040502030001 +DATA endian_swap_mask<>+8(SB)/8, $0xE0F0C0D0A0B0809 +DATA endian_swap_mask<>+16(SB)/8, $0x607040502030001 +DATA endian_swap_mask<>+24(SB)/8, $0xE0F0C0D0A0B0809 +GLOBL endian_swap_mask<>(SB), RODATA, $32 + +// func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr +// +// args (8 bytes aligned): +// data unsafe.Pointer - 8 bytes - 0 offset +// length uintptr - 8 bytes - 8 offset +// result uintptr - 8 bytes - 16 offset +#define PDATA AX +#define LENGTH CX +#define RESULT BX +TEXT ·sumAsmAvx2(SB),NOSPLIT,$0-24 + MOVQ data+0(FP), PDATA + MOVQ length+8(FP), LENGTH + XORQ RESULT, RESULT + +#define VSUM Y0 +#define ENDIAN_SWAP_MASK Y1 +BEGIN: + VMOVDQU endian_swap_mask<>(SB), ENDIAN_SWAP_MASK + VPXOR VSUM, VSUM, VSUM + +#define LOADED_0 Y2 +#define LOADED_1 Y3 +#define LOADED_2 Y4 +#define LOADED_3 Y5 +BATCH_64: + CMPQ LENGTH, $64 + JB BATCH_32 + VPMOVZXWD (PDATA), LOADED_0 + VPMOVZXWD 16(PDATA), LOADED_1 + VPMOVZXWD 32(PDATA), LOADED_2 + VPMOVZXWD 48(PDATA), LOADED_3 + VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0 + VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1 + VPSHUFB ENDIAN_SWAP_MASK, LOADED_2, LOADED_2 + VPSHUFB ENDIAN_SWAP_MASK, LOADED_3, LOADED_3 + VPADDD LOADED_0, VSUM, VSUM + VPADDD LOADED_1, VSUM, VSUM + VPADDD LOADED_2, VSUM, VSUM + VPADDD LOADED_3, VSUM, VSUM + ADDQ $-64, LENGTH + ADDQ $64, PDATA + JMP BATCH_64 +#undef LOADED_0 +#undef LOADED_1 +#undef LOADED_2 +#undef LOADED_3 + +#define LOADED_0 Y2 +#define LOADED_1 Y3 +BATCH_32: + CMPQ LENGTH, $32 + JB BATCH_16 + VPMOVZXWD (PDATA), LOADED_0 + VPMOVZXWD 16(PDATA), LOADED_1 + VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0 + VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1 + VPADDD LOADED_0, VSUM, VSUM + VPADDD LOADED_1, VSUM, VSUM + ADDQ $-32, LENGTH + ADDQ $32, PDATA + JMP BATCH_32 +#undef LOADED_0 +#undef LOADED_1 + +#define LOADED Y2 +BATCH_16: + CMPQ LENGTH, $16 + JB COLLECT + VPMOVZXWD (PDATA), LOADED + VPSHUFB ENDIAN_SWAP_MASK, LOADED, LOADED + VPADDD LOADED, VSUM, VSUM + ADDQ $-16, LENGTH + ADDQ $16, PDATA + JMP BATCH_16 +#undef LOADED + +#define EXTRACTED Y2 +#define EXTRACTED_128 X2 +#define TEMP_64 DX +COLLECT: + VEXTRACTI128 $0, VSUM, EXTRACTED_128 + VPEXTRD $0, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VPEXTRD $1, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VPEXTRD $2, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VPEXTRD $3, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VEXTRACTI128 $1, VSUM, EXTRACTED_128 + VPEXTRD $0, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VPEXTRD $1, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VPEXTRD $2, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT + VPEXTRD $3, EXTRACTED_128, TEMP_64 + ADDL TEMP_64, RESULT +#undef EXTRACTED +#undef EXTRACTED_128 +#undef TEMP_64 + +#define TEMP DX +#define TEMP2 SI +BATCH_2: + CMPQ LENGTH, $2 + JB BATCH_1 + XORQ TEMP, TEMP + MOVW (PDATA), TEMP + MOVQ TEMP, TEMP2 + SHRW $8, TEMP2 + SHLW $8, TEMP + ORW TEMP2, TEMP + ADDL TEMP, RESULT + ADDQ $-2, LENGTH + ADDQ $2, PDATA + JMP BATCH_2 +#undef TEMP + +#define TEMP DX +BATCH_1: + CMPQ LENGTH, $0 + JZ RETURN + XORQ TEMP, TEMP + MOVB (PDATA), TEMP + SHLW $8, TEMP + ADDL TEMP, RESULT +#undef TEMP + +RETURN: + MOVQ RESULT, result+16(FP) + RET diff --git a/internal/clashtcpip/tcpip_amd64_test.go b/internal/clashtcpip/tcpip_amd64_test.go new file mode 100644 index 0000000..e02a1b9 --- /dev/null +++ b/internal/clashtcpip/tcpip_amd64_test.go @@ -0,0 +1,51 @@ +package clashtcpip + +import ( + "crypto/rand" + "testing" + + "golang.org/x/sys/cpu" +) + +func Test_SumAVX2(t *testing.T) { + if !cpu.X86.HasAVX2 { + t.Skipf("AVX2 unavailable") + } + + bytes := make([]byte, chunkSize) + + for size := 0; size <= chunkSize; size++ { + for count := 0; count < chunkCount; count++ { + _, err := rand.Reader.Read(bytes[:size]) + if err != nil { + t.Skipf("Rand read failed: %v", err) + } + + compat := SumCompat(bytes[:size]) + avx := SumAVX2(bytes[:size]) + + if compat != avx { + t.Errorf("Sum of length=%d mismatched", size) + } + } + } +} + +func Benchmark_SumAVX2(b *testing.B) { + if !cpu.X86.HasAVX2 { + b.Skipf("AVX2 unavailable") + } + + bytes := make([]byte, chunkSize) + + _, err := rand.Reader.Read(bytes) + if err != nil { + b.Skipf("Rand read failed: %v", err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + SumAVX2(bytes) + } +} diff --git a/internal/clashtcpip/tcpip_arm64.go b/internal/clashtcpip/tcpip_arm64.go new file mode 100644 index 0000000..543803f --- /dev/null +++ b/internal/clashtcpip/tcpip_arm64.go @@ -0,0 +1,24 @@ +package clashtcpip + +import ( + "unsafe" + + "golang.org/x/sys/cpu" +) + +//go:noescape +func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr + +func SumNeon(data []byte) uint32 { + if len(data) == 0 { + return 0 + } + + return uint32(sumAsmNeon(unsafe.Pointer(&data[0]), uintptr(len(data)))) +} + +func init() { + if cpu.ARM64.HasASIMD { + SumFnc = SumNeon + } +} diff --git a/internal/clashtcpip/tcpip_arm64.s b/internal/clashtcpip/tcpip_arm64.s new file mode 100644 index 0000000..f6d57cf --- /dev/null +++ b/internal/clashtcpip/tcpip_arm64.s @@ -0,0 +1,118 @@ +#include "textflag.h" + +// func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr +// +// args (8 bytes aligned): +// data unsafe.Pointer - 8 bytes - 0 offset +// length uintptr - 8 bytes - 8 offset +// result uintptr - 8 bytes - 16 offset +#define PDATA R0 +#define LENGTH R1 +#define RESULT R2 +#define VSUM V0 +TEXT ·sumAsmNeon(SB),NOSPLIT,$0-24 + MOVD data+0(FP), PDATA + MOVD length+8(FP), LENGTH + MOVD $0, RESULT + VMOVQ $0, $0, VSUM + +#define LOADED_0 V1 +#define LOADED_1 V2 +#define LOADED_2 V3 +#define LOADED_3 V4 +BATCH_32: + CMP $32, LENGTH + BLO BATCH_16 + VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8, LOADED_2.B8, LOADED_3.B8] + VREV16 LOADED_0.B8, LOADED_0.B8 + VREV16 LOADED_1.B8, LOADED_1.B8 + VREV16 LOADED_2.B8, LOADED_2.B8 + VREV16 LOADED_3.B8, LOADED_3.B8 + VUSHLL $0, LOADED_0.H4, LOADED_0.S4 + VUSHLL $0, LOADED_1.H4, LOADED_1.S4 + VUSHLL $0, LOADED_2.H4, LOADED_2.S4 + VUSHLL $0, LOADED_3.H4, LOADED_3.S4 + VADD LOADED_0.S4, VSUM.S4, VSUM.S4 + VADD LOADED_1.S4, VSUM.S4, VSUM.S4 + VADD LOADED_2.S4, VSUM.S4, VSUM.S4 + VADD LOADED_3.S4, VSUM.S4, VSUM.S4 + ADD $-32, LENGTH + ADD $32, PDATA + B BATCH_32 +#undef LOADED_0 +#undef LOADED_1 +#undef LOADED_2 +#undef LOADED_3 + +#define LOADED_0 V1 +#define LOADED_1 V2 +BATCH_16: + CMP $16, LENGTH + BLO BATCH_8 + VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8] + VREV16 LOADED_0.B8, LOADED_0.B8 + VREV16 LOADED_1.B8, LOADED_1.B8 + VUSHLL $0, LOADED_0.H4, LOADED_0.S4 + VUSHLL $0, LOADED_1.H4, LOADED_1.S4 + VADD LOADED_0.S4, VSUM.S4, VSUM.S4 + VADD LOADED_1.S4, VSUM.S4, VSUM.S4 + ADD $-16, LENGTH + ADD $16, PDATA + B BATCH_16 +#undef LOADED_0 +#undef LOADED_1 + +#define LOADED_0 V1 +BATCH_8: + CMP $8, LENGTH + BLO BATCH_2 + VLD1 (PDATA), [LOADED_0.B8] + VREV16 LOADED_0.B8, LOADED_0.B8 + VUSHLL $0, LOADED_0.H4, LOADED_0.S4 + VADD LOADED_0.S4, VSUM.S4, VSUM.S4 + ADD $-8, LENGTH + ADD $8, PDATA + B BATCH_8 +#undef LOADED_0 + +#define LOADED_L R3 +#define LOADED_H R4 +BATCH_2: + CMP $2, LENGTH + BLO BATCH_1 + MOVBU (PDATA), LOADED_H + MOVBU 1(PDATA), LOADED_L + LSL $8, LOADED_H + ORR LOADED_H, LOADED_L, LOADED_L + ADD LOADED_L, RESULT, RESULT + ADD $2, PDATA + ADD $-2, LENGTH + B BATCH_2 +#undef LOADED_H +#undef LOADED_L + +#define LOADED R3 +BATCH_1: + CMP $1, LENGTH + BLO COLLECT + MOVBU (PDATA), LOADED + LSL $8, LOADED + ADD LOADED, RESULT, RESULT + +#define EXTRACTED R3 +COLLECT: + VMOV VSUM.S[0], EXTRACTED + ADD EXTRACTED, RESULT + VMOV VSUM.S[1], EXTRACTED + ADD EXTRACTED, RESULT + VMOV VSUM.S[2], EXTRACTED + ADD EXTRACTED, RESULT + VMOV VSUM.S[3], EXTRACTED + ADD EXTRACTED, RESULT +#undef VSUM +#undef PDATA +#undef LENGTH + +RETURN: + MOVD RESULT, result+16(FP) + RET diff --git a/internal/clashtcpip/tcpip_arm64_test.go b/internal/clashtcpip/tcpip_arm64_test.go new file mode 100644 index 0000000..bfe07a6 --- /dev/null +++ b/internal/clashtcpip/tcpip_arm64_test.go @@ -0,0 +1,51 @@ +package clashtcpip + +import ( + "crypto/rand" + "testing" + + "golang.org/x/sys/cpu" +) + +func Test_SumNeon(t *testing.T) { + if !cpu.ARM64.HasASIMD { + t.Skipf("Neon unavailable") + } + + bytes := make([]byte, chunkSize) + + for size := 0; size <= chunkSize; size++ { + for count := 0; count < chunkCount; count++ { + _, err := rand.Reader.Read(bytes[:size]) + if err != nil { + t.Skipf("Rand read failed: %v", err) + } + + compat := SumCompat(bytes[:size]) + neon := SumNeon(bytes[:size]) + + if compat != neon { + t.Errorf("Sum of length=%d mismatched", size) + } + } + } +} + +func Benchmark_SumNeon(b *testing.B) { + if !cpu.ARM64.HasASIMD { + b.Skipf("Neon unavailable") + } + + bytes := make([]byte, chunkSize) + + _, err := rand.Reader.Read(bytes) + if err != nil { + b.Skipf("Rand read failed: %v", err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + SumNeon(bytes) + } +} diff --git a/internal/clashtcpip/tcpip_compat.go b/internal/clashtcpip/tcpip_compat.go new file mode 100644 index 0000000..a72a489 --- /dev/null +++ b/internal/clashtcpip/tcpip_compat.go @@ -0,0 +1,14 @@ +package clashtcpip + +func SumCompat(b []byte) (sum uint32) { + n := len(b) + if n&1 != 0 { + n-- + sum += uint32(b[n]) << 8 + } + + for i := 0; i < n; i += 2 { + sum += (uint32(b[i]) << 8) | uint32(b[i+1]) + } + return +} diff --git a/internal/clashtcpip/tcpip_compat_test.go b/internal/clashtcpip/tcpip_compat_test.go new file mode 100644 index 0000000..828d886 --- /dev/null +++ b/internal/clashtcpip/tcpip_compat_test.go @@ -0,0 +1,26 @@ +package clashtcpip + +import ( + "crypto/rand" + "testing" +) + +const ( + chunkSize = 9000 + chunkCount = 10 +) + +func Benchmark_SumCompat(b *testing.B) { + bytes := make([]byte, chunkSize) + + _, err := rand.Reader.Read(bytes) + if err != nil { + b.Skipf("Rand read failed: %v", err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + SumCompat(bytes) + } +} diff --git a/internal/clashtcpip/udp.go b/internal/clashtcpip/udp.go new file mode 100644 index 0000000..f5773a1 --- /dev/null +++ b/internal/clashtcpip/udp.go @@ -0,0 +1,55 @@ +package clashtcpip + +import ( + "encoding/binary" +) + +const UDPHeaderSize = 8 + +type UDPPacket []byte + +func (p UDPPacket) Length() uint16 { + return binary.BigEndian.Uint16(p[4:]) +} + +func (p UDPPacket) SetLength(length uint16) { + binary.BigEndian.PutUint16(p[4:], length) +} + +func (p UDPPacket) SourcePort() uint16 { + return binary.BigEndian.Uint16(p) +} + +func (p UDPPacket) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(p, port) +} + +func (p UDPPacket) DestinationPort() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p UDPPacket) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(p[2:], port) +} + +func (p UDPPacket) Payload() []byte { + return p[UDPHeaderSize:p.Length()] +} + +func (p UDPPacket) Checksum() uint16 { + return binary.BigEndian.Uint16(p[6:]) +} + +func (p UDPPacket) SetChecksum(sum [2]byte) { + p[6] = sum[0] + p[7] = sum[1] +} + +func (p UDPPacket) ResetChecksum(psum uint32) { + p.SetChecksum(zeroChecksum) + p.SetChecksum(Checksum(psum, p)) +} + +func (p UDPPacket) Valid() bool { + return len(p) >= UDPHeaderSize && uint16(len(p)) >= p.Length() +} diff --git a/lwip.go b/lwip.go index 21de9d0..7a37c3a 100644 --- a/lwip.go +++ b/lwip.go @@ -7,7 +7,6 @@ import ( "net" "net/netip" "os" - "runtime" lwip "github.com/sagernet/go-tun2socks/core" "github.com/sagernet/sing/common" @@ -28,19 +27,15 @@ type LWIP struct { } func NewLWIP( - ctx context.Context, - tun Tun, - tunMtu uint32, - udpTimeout int64, - handler Handler, + options StackOptions, ) (Stack, error) { return &LWIP{ - ctx: ctx, - tun: tun, - tunMtu: tunMtu, - handler: handler, + ctx: options.Context, + tun: options.Tun, + tunMtu: options.MTU, + handler: options.Handler, stack: lwip.NewLWIPStack(), - udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), + udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler), }, nil } @@ -57,10 +52,7 @@ func (l *LWIP) loopIn() { l.loopInWintun(winTun) return } - mtu := int(l.tunMtu) - if runtime.GOOS == "darwin" { - mtu += 4 - } + mtu := int(l.tunMtu) + PacketOffset _buffer := buf.StackNewSize(mtu) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) @@ -71,13 +63,7 @@ func (l *LWIP) loopIn() { if err != nil { return } - var packet []byte - if runtime.GOOS == "darwin" { - packet = data[4:n] - } else { - packet = data[:n] - } - _, err = l.stack.Write(packet) + _, err = l.stack.Write(data[PacketOffset:n]) if err != nil { if err.Error() == "stack closed" { return diff --git a/lwip_stub.go b/lwip_stub.go index 2259057..1810329 100644 --- a/lwip_stub.go +++ b/lwip_stub.go @@ -2,14 +2,8 @@ package tun -import "context" - func NewLWIP( - ctx context.Context, - tun Tun, - tunMtu uint32, - udpTimeout int64, - handler Handler, + options StackOptions, ) (Stack, error) { return nil, ErrLWIPNotIncluded } diff --git a/stack.go b/stack.go index d4012e7..996d010 100644 --- a/stack.go +++ b/stack.go @@ -2,6 +2,7 @@ package tun import ( "context" + "net/netip" E "github.com/sagernet/sing/common/exceptions" ) @@ -17,20 +18,29 @@ type Stack interface { Close() error } +type StackOptions struct { + Context context.Context + Tun Tun + Name string + MTU uint32 + Inet4Address []netip.Prefix + Inet6Address []netip.Prefix + EndpointIndependentNat bool + UDPTimeout int64 + Handler Handler +} + func NewStack( - ctx context.Context, stack string, - tun Tun, - tunMtu uint32, - endpointIndependentNat bool, - udpTimeout int64, - handler Handler, + options StackOptions, ) (Stack, error) { switch stack { case "gvisor", "": - return NewGVisor(ctx, tun, tunMtu, endpointIndependentNat, udpTimeout, handler) + return NewGVisor(options) + case "system": + return NewSystem(options) case "lwip": - return NewLWIP(ctx, tun, tunMtu, udpTimeout, handler) + return NewLWIP(options) default: return nil, E.New("unknown stack: ", stack) } diff --git a/system.go b/system.go new file mode 100644 index 0000000..daa7ae2 --- /dev/null +++ b/system.go @@ -0,0 +1,409 @@ +package tun + +import ( + "context" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/udpnat" +) + +type System struct { + ctx context.Context + tun Tun + mtu uint32 + handler Handler + inet4Prefixes []netip.Prefix + inet6Prefixes []netip.Prefix + inet4ServerAddress netip.Addr + inet4Address netip.Addr + inet6ServerAddress netip.Addr + inet6Address netip.Addr + udpTimeout int64 + tcpListener net.Listener + tcpListener6 net.Listener + tcpPort uint16 + tcpPort6 uint16 + tcpNat *TCPNat + udpNat *udpnat.Service[netip.AddrPort] +} + +type Session struct { + SourceAddress netip.Addr + DestinationAddress netip.Addr + SourcePort uint16 + DestinationPort uint16 +} + +func NewSystem(options StackOptions) (Stack, error) { + stack := &System{ + ctx: options.Context, + tun: options.Tun, + mtu: options.MTU, + udpTimeout: options.UDPTimeout, + handler: options.Handler, + inet4Prefixes: options.Inet4Address, + inet6Prefixes: options.Inet6Address, + } + if len(options.Inet4Address) > 0 { + if options.Inet4Address[0].Bits() == 32 { + return nil, E.New("need one more IPv4 address in first prefix for system stack") + } + stack.inet4ServerAddress = options.Inet4Address[0].Addr() + stack.inet4Address = stack.inet4ServerAddress.Next() + } + if len(options.Inet6Address) > 0 { + if options.Inet6Address[0].Bits() == 128 { + return nil, E.New("need one more IPv6 address in first prefix for system stack") + } + stack.inet6ServerAddress = options.Inet6Address[0].Addr() + stack.inet6Address = stack.inet6ServerAddress.Next() + } + if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() { + return nil, E.New("missing interface address") + } + return stack, nil +} + +func (s *System) Close() error { + return common.Close( + s.tcpListener, + s.tcpListener6, + ) +} + +func (s *System) Start() error { + if s.inet4Address.IsValid() { + tcpListener, err := net.Listen("tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0")) + if err != nil { + return err + } + s.tcpListener = tcpListener + s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port + go s.acceptLoop(tcpListener) + } + if s.inet6Address.IsValid() { + tcpListener, err := net.Listen("tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0")) + if err != nil { + return err + } + s.tcpListener6 = tcpListener + s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port + go s.acceptLoop(tcpListener) + } + s.tcpNat = NewNat() + s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler) + go s.tunLoop() + return nil +} + +func (s *System) tunLoop() { + if winTun, isWinTun := s.tun.(WinTun); isWinTun { + s.wintunLoop(winTun) + return + } + _packetBuffer := buf.StackNewSize(int(s.mtu)) + defer common.KeepAlive(_packetBuffer) + packetBuffer := common.Dup(_packetBuffer) + defer packetBuffer.Release() + packetSlice := packetBuffer.Slice() + for { + n, err := s.tun.Read(packetSlice) + if err != nil { + return + } + if n < clashtcpip.IPv4PacketMinLength { + continue + } + packet := packetSlice[PacketOffset:n] + switch packet[0] >> 4 { + case 4: + s.processIPv4(packet) + case 6: + s.processIPv6(packet) + } + } +} + +func (s *System) wintunLoop(winTun WinTun) { + for { + packet, release, err := winTun.ReadPacket() + if err != nil { + return + } + if len(packet) < clashtcpip.IPv4PacketMinLength { + release() + continue + } + switch packet[0] >> 4 { + case 4: + s.processIPv4(packet) + case 6: + s.processIPv6(packet) + } + release() + } +} + +func (s *System) acceptLoop(listener net.Listener) { + for { + conn, err := listener.Accept() + if err != nil { + return + } + connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port + session := s.tcpNat.LookupBack(connPort) + if session == nil { + s.handler.NewError(context.Background(), E.New("unknown session with port ", connPort)) + continue + } + destination := M.SocksaddrFromNetIP(session.Destination) + if destination.Addr.Is4() { + for _, prefix := range s.inet4Prefixes { + if prefix.Contains(destination.Addr) { + destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + break + } + } + } else { + for _, prefix := range s.inet6Prefixes { + if prefix.Contains(destination.Addr) { + destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + break + } + } + } + go func() { + s.handler.NewConnection(context.Background(), conn, M.Metadata{ + Source: M.SocksaddrFromNetIP(session.Source), + Destination: destination, + }) + conn.Close() + time.Sleep(time.Second) + s.tcpNat.Revoke(connPort, session) + }() + } +} + +func (s *System) NewError(ctx context.Context, err error) { + s.handler.NewError(ctx, err) +} + +func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error { + if !packet.Valid() { + return E.New("ipv4: invalid packet") + } + if packet.TimeToLive() == 0x00 { + return E.New("ipv4: TTL exceeded") + } + switch packet.Protocol() { + case clashtcpip.TCP: + return s.processIPv4TCP(packet, packet.Payload()) + case clashtcpip.UDP: + return s.processIPv4UDP(packet, packet.Payload()) + case clashtcpip.ICMP: + return s.processIPv4ICMP(packet, packet.Payload()) + default: + return nil + } +} + +func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error { + if !packet.Valid() { + return E.New("ipv6: invalid packet") + } + if packet.HopLimit() == 0x00 { + return E.New("ipv6: TTL exceeded") + } + switch packet.Protocol() { + case clashtcpip.TCP: + return s.processIPv6TCP(packet, packet.Payload()) + case clashtcpip.UDP: + return s.processIPv6UDP(packet, packet.Payload()) + case clashtcpip.ICMPv6: + return s.processIPv6ICMP(packet, packet.Payload()) + default: + return nil + } +} + +func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { + source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) + destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) + if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { + session := s.tcpNat.LookupBack(destination.Port()) + if session == nil { + return E.New("session not found: ", destination.Port()) + } + packet.SetSourceIP(session.Destination.Addr()) + header.SetSourcePort(session.Destination.Port()) + packet.SetDestinationIP(session.Source.Addr()) + header.SetDestinationPort(session.Source.Port()) + } else { + natPort := s.tcpNat.Lookup(source, destination) + packet.SetSourceIP(s.inet4Address) + header.SetSourcePort(natPort) + packet.SetDestinationIP(s.inet4ServerAddress) + header.SetDestinationPort(s.tcpPort) + } + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + return common.Error(s.tun.Write(packet)) +} + +func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { + source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) + destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) + if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { + session := s.tcpNat.LookupBack(destination.Port()) + if session == nil { + return E.New("session not found: ", destination.Port()) + } + packet.SetSourceIP(session.Destination.Addr()) + header.SetSourcePort(session.Destination.Port()) + packet.SetDestinationIP(session.Source.Addr()) + header.SetDestinationPort(session.Source.Port()) + } else { + natPort := s.tcpNat.Lookup(source, destination) + packet.SetSourceIP(s.inet6Address) + header.SetSourcePort(natPort) + packet.SetDestinationIP(s.inet6ServerAddress) + header.SetDestinationPort(s.tcpPort6) + } + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + return common.Error(s.tun.Write(packet)) +} + +func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { + if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { + return E.New("ipv4: fragment dropped") + } + if packet.FragmentOffset() != 0 { + return E.New("ipv4: fragment dropped") + } + source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) + destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) + if !destination.Addr().IsGlobalUnicast() || destination.Addr().IsMulticast() { + return nil + } + data := buf.As(header.Payload()).ToOwned() + metadata := M.Metadata{ + Source: M.SocksaddrFromNetIP(source), + Destination: M.SocksaddrFromNetIP(destination), + } + s.udpNat.NewPacket(s.ctx, source, data, metadata, func(natConn N.PacketConn) N.PacketWriter { + hdr := buf.As(packet[:packet.HeaderLen()+clashtcpip.UDPHeaderSize]).ToOwned() + return &systemPacketWriter4{s.tun, hdr, source} + }) + return nil +} + +func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { + source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) + destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) + if !destination.Addr().IsGlobalUnicast() || destination.Addr().IsMulticast() { + return nil + } + data := buf.As(header.Payload()).ToOwned() + metadata := M.Metadata{ + Source: M.SocksaddrFromNetIP(source), + Destination: M.SocksaddrFromNetIP(destination), + } + s.udpNat.NewPacket(s.ctx, source, data, metadata, func(natConn N.PacketConn) N.PacketWriter { + hdr := buf.As(packet[:len(packet)-len(header.Payload())]).ToOwned() + return &systemPacketWriter6{s.tun, hdr, source} + }) + return nil +} + +func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { + if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { + return nil + } + header.SetType(clashtcpip.ICMPTypePingResponse) + sourceAddress := packet.SourceIP() + packet.SetSourceIP(packet.DestinationIP()) + packet.SetDestinationIP(sourceAddress) + header.ResetChecksum() + packet.ResetChecksum() + return common.Error(s.tun.Write(packet)) +} + +func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { + if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { + return nil + } + header.SetType(clashtcpip.ICMPv6EchoReply) + sourceAddress := packet.SourceIP() + packet.SetSourceIP(packet.DestinationIP()) + packet.SetDestinationIP(sourceAddress) + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + return common.Error(s.tun.Write(packet)) +} + +type systemPacketWriter4 struct { + tun Tun + header *buf.Buffer + source netip.AddrPort +} + +func (w *systemPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + newPacket := buf.StackNewSize(w.header.Len() + buffer.Len()) + defer newPacket.Release() + newPacket.Write(w.header.Bytes()) + newPacket.Write(buffer.Bytes()) + ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) + ipHdr.SetTotalLength(uint16(newPacket.Len())) + ipHdr.SetDestinationIP(ipHdr.SourceIP()) + ipHdr.SetSourceIP(destination.Unwrap().Addr) + udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) + udpHdr.SetDestinationPort(udpHdr.SourcePort()) + udpHdr.SetSourcePort(destination.Port) + udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) + udpHdr.ResetChecksum(ipHdr.PseudoSum()) + ipHdr.ResetChecksum() + return common.Error(w.tun.Write(newPacket.Bytes())) +} + +func (w *systemPacketWriter4) Close() error { + w.header.Release() + return nil +} + +type systemPacketWriter6 struct { + tun Tun + header *buf.Buffer + source netip.AddrPort +} + +func (w *systemPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + newPacket := buf.StackNewSize(w.header.Len() + buffer.Len()) + defer newPacket.Release() + newPacket.Write(w.header.Bytes()) + newPacket.Write(buffer.Bytes()) + ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) + udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len()) + ipHdr.SetPayloadLength(udpLen) + ipHdr.SetDestinationIP(ipHdr.SourceIP()) + ipHdr.SetSourceIP(destination.Addr) + udpHdr := clashtcpip.UDPPacket(ipHdr.Payload()) + udpHdr.SetDestinationPort(udpHdr.SourcePort()) + udpHdr.SetSourcePort(destination.Port) + udpHdr.SetLength(udpLen) + udpHdr.ResetChecksum(ipHdr.PseudoSum()) + return common.Error(w.tun.Write(newPacket.Bytes())) +} + +func (w *systemPacketWriter6) Close() error { + w.header.Release() + return nil +} diff --git a/system_nat.go b/system_nat.go new file mode 100644 index 0000000..adac1a6 --- /dev/null +++ b/system_nat.go @@ -0,0 +1,68 @@ +package tun + +import ( + "net/netip" + "sync" +) + +type TCPNat struct { + portIndex uint16 + portAccess sync.RWMutex + addrAccess sync.RWMutex + addrMap map[netip.AddrPort]uint16 + portMap map[uint16]*TCPSession +} + +type TCPSession struct { + Source netip.AddrPort + Destination netip.AddrPort +} + +func NewNat() *TCPNat { + return &TCPNat{ + portIndex: 10000, + addrMap: make(map[netip.AddrPort]uint16), + portMap: make(map[uint16]*TCPSession), + } +} + +func (n *TCPNat) LookupBack(port uint16) *TCPSession { + n.portAccess.RLock() + defer n.portAccess.RUnlock() + return n.portMap[port] +} + +func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 { + n.addrAccess.RLock() + port, loaded := n.addrMap[source] + n.addrAccess.RUnlock() + if loaded { + return port + } + n.addrAccess.Lock() + nextPort := n.portIndex + if nextPort == 0 { + nextPort = 10000 + n.portIndex = 10001 + } else { + n.portIndex++ + } + n.addrMap[source] = nextPort + n.addrAccess.Unlock() + n.portAccess.Lock() + n.portMap[nextPort] = &TCPSession{ + Source: source, + Destination: destination, + } + n.portAccess.Unlock() + return nextPort +} + +func (n *TCPNat) Revoke(natPort uint16, session *TCPSession) { + n.addrAccess.Lock() + delete(n.addrMap, session.Source) + n.addrAccess.Unlock() + n.portAccess.Lock() + delete(n.portMap, natPort) + n.portAccess.Unlock() +} diff --git a/tun_darwin.go b/tun_darwin.go index 2e0b6a8..7a0582d 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -19,6 +19,8 @@ import ( "golang.org/x/sys/unix" ) +const PacketOffset = 4 + type NativeTun struct { tunFile *os.File tunWriter N.VectorisedWriter diff --git a/tun_nondarwin.go b/tun_nondarwin.go new file mode 100644 index 0000000..0faa2c9 --- /dev/null +++ b/tun_nondarwin.go @@ -0,0 +1,5 @@ +//go:build !darwin + +package tun + +const PacketOffset = 0 diff --git a/tun_windows.go b/tun_windows.go index 34d73b7..3ecb706 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -62,38 +62,37 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv4 address") } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil) + if err != nil { + return E.Cause(err, "set ipv4 dns") + } } if len(t.options.Inet6Address) > 0 { err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address) if err != nil { return E.Cause(err, "set ipv6 address") } - } - err := luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil) - if err != nil { - return E.Cause(err, "set ipv4 dns") - } - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil) - if err != nil { - return E.Cause(err, "set ipv6 dns") + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil) + if err != nil { + return E.Cause(err, "set ipv6 dns") + } } if t.options.AutoRoute { if len(t.options.Inet4Address) > 0 { - err = luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0) + err := luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0) if err != nil { return E.Cause(err, "set ipv4 route") } } if len(t.options.Inet6Address) > 0 { - err = luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0) + err := luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0) if err != nil { return E.Cause(err, "set ipv6 route") } } } if len(t.options.Inet4Address) > 0 { - var inetIf *winipcfg.MibIPInterfaceRow - inetIf, err = luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) + inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) if err != nil { return err } @@ -113,8 +112,7 @@ func (t *NativeTun) configure() error { } } if len(t.options.Inet6Address) > 0 { - var inet6If *winipcfg.MibIPInterfaceRow - inet6If, err = luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) + inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) if err != nil { return err }