Add system stack

This commit is contained in:
世界 2022-09-06 19:24:47 +08:00
parent 0efafc9963
commit 2f15b0cd3f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
26 changed files with 1713 additions and 79 deletions

View file

@ -41,25 +41,20 @@ type GVisorTun interface {
}
func NewGVisor(
ctx context.Context,
tun Tun,
tunMtu uint32,
endpointIndependentNat bool,
udpTimeout int64,
handler Handler,
options StackOptions,
) (Stack, error) {
gTun, isGTun := tun.(GVisorTun)
gTun, isGTun := options.Tun.(GVisorTun)
if !isGTun {
return nil, ErrGVisorUnsupported
}
return &GVisor{
ctx: ctx,
ctx: options.Context,
tun: gTun,
tunMtu: tunMtu,
endpointIndependentNat: endpointIndependentNat,
udpTimeout: udpTimeout,
handler: handler,
tunMtu: options.MTU,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
handler: options.Handler,
}, nil
}

View file

@ -2,15 +2,8 @@
package tun
import "context"
func NewGVisor(
ctx context.Context,
tun Tun,
tunMtu uint32,
endpointIndependentNat bool,
endpointIndependentNatTimeout int64,
handler Handler,
options StackOptions,
) (Stack, error) {
return nil, ErrGVisorUnsupported
}

View file

@ -2,15 +2,8 @@
package tun
import "context"
func NewGVisor(
ctx context.Context,
tun Tun,
tunMtu uint32,
endpointIndependentNat bool,
endpointIndependentNatTimeout int64,
handler Handler,
options StackOptions,
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}

View file

@ -0,0 +1,40 @@
package clashtcpip
import (
"encoding/binary"
)
type ICMPType = byte
const (
ICMPTypePingRequest byte = 0x8
ICMPTypePingResponse byte = 0x0
)
type ICMPPacket []byte
func (p ICMPPacket) Type() ICMPType {
return p[0]
}
func (p ICMPPacket) SetType(v ICMPType) {
p[0] = v
}
func (p ICMPPacket) Code() byte {
return p[1]
}
func (p ICMPPacket) Checksum() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p ICMPPacket) SetChecksum(sum [2]byte) {
p[2] = sum[0]
p[3] = sum[1]
}
func (p ICMPPacket) ResetChecksum() {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(0, p))
}

View file

@ -0,0 +1,172 @@
package clashtcpip
import (
"encoding/binary"
)
type ICMPv6Packet []byte
const (
ICMPv6HeaderSize = 4
ICMPv6MinimumSize = 8
ICMPv6PayloadOffset = 8
ICMPv6EchoMinimumSize = 8
ICMPv6ErrorHeaderSize = 8
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
ICMPv6ChecksumOffset = 2
icmpv6PointerOffset = 4
icmpv6MTUOffset = 4
icmpv6IdentOffset = 4
icmpv6SequenceOffset = 6
NDPHopLimit = 255
)
type ICMPv6Type byte
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
ICMPv6TimeExceeded ICMPv6Type = 3
ICMPv6ParamProblem ICMPv6Type = 4
ICMPv6EchoRequest ICMPv6Type = 128
ICMPv6EchoReply ICMPv6Type = 129
ICMPv6RouterSolicit ICMPv6Type = 133
ICMPv6RouterAdvert ICMPv6Type = 134
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
ICMPv6MulticastListenerQuery ICMPv6Type = 130
ICMPv6MulticastListenerReport ICMPv6Type = 131
ICMPv6MulticastListenerDone ICMPv6Type = 132
)
func (typ ICMPv6Type) IsErrorType() bool {
return typ&0x80 == 0
}
type ICMPv6Code byte
const (
ICMPv6NetworkUnreachable ICMPv6Code = 0
ICMPv6Prohibited ICMPv6Code = 1
ICMPv6BeyondScope ICMPv6Code = 2
ICMPv6AddressUnreachable ICMPv6Code = 3
ICMPv6PortUnreachable ICMPv6Code = 4
ICMPv6Policy ICMPv6Code = 5
ICMPv6RejectRoute ICMPv6Code = 6
)
const (
ICMPv6HopLimitExceeded ICMPv6Code = 0
ICMPv6ReassemblyTimeout ICMPv6Code = 1
)
const (
ICMPv6ErroneousHeader ICMPv6Code = 0
ICMPv6UnknownHeader ICMPv6Code = 1
ICMPv6UnknownOption ICMPv6Code = 2
)
const ICMPv6UnusedCode ICMPv6Code = 0
func (b ICMPv6Packet) Type() ICMPv6Type {
return ICMPv6Type(b[0])
}
func (b ICMPv6Packet) SetType(t ICMPv6Type) {
b[0] = byte(t)
}
func (b ICMPv6Packet) Code() ICMPv6Code {
return ICMPv6Code(b[1])
}
func (b ICMPv6Packet) SetCode(c ICMPv6Code) {
b[1] = byte(c)
}
func (b ICMPv6Packet) TypeSpecific() uint32 {
return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
}
func (b ICMPv6Packet) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
func (b ICMPv6Packet) Checksum() uint16 {
return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:])
}
func (b ICMPv6Packet) SetChecksum(sum [2]byte) {
_ = b[ICMPv6ChecksumOffset+1]
b[ICMPv6ChecksumOffset] = sum[0]
b[ICMPv6ChecksumOffset+1] = sum[1]
}
func (ICMPv6Packet) SourcePort() uint16 {
return 0
}
func (ICMPv6Packet) DestinationPort() uint16 {
return 0
}
func (ICMPv6Packet) SetSourcePort(uint16) {
}
func (ICMPv6Packet) SetDestinationPort(uint16) {
}
func (b ICMPv6Packet) MTU() uint32 {
return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
}
func (b ICMPv6Packet) SetMTU(mtu uint32) {
binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
}
func (b ICMPv6Packet) Ident() uint16 {
return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
}
func (b ICMPv6Packet) SetIdent(ident uint16) {
binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
}
func (b ICMPv6Packet) Sequence() uint16 {
return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
}
func (b ICMPv6Packet) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
func (b ICMPv6Packet) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
func (b ICMPv6Packet) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
func (b ICMPv6Packet) ResetChecksum(psum uint32) {
b.SetChecksum(zeroChecksum)
b.SetChecksum(Checksum(psum, b))
}

209
internal/clashtcpip/ip.go Normal file
View file

@ -0,0 +1,209 @@
package clashtcpip
import (
"encoding/binary"
"errors"
"net/netip"
)
type IPProtocol = byte
type IP interface {
Payload() []byte
SourceIP() netip.Addr
DestinationIP() netip.Addr
SetSourceIP(ip netip.Addr)
SetDestinationIP(ip netip.Addr)
Protocol() IPProtocol
DecTimeToLive()
ResetChecksum()
PseudoSum() uint32
}
// IPProtocol type
const (
ICMP IPProtocol = 0x01
TCP IPProtocol = 0x06
UDP IPProtocol = 0x11
ICMPv6 IPProtocol = 0x3a
)
const (
FlagDontFragment = 1 << 1
FlagMoreFragment = 1 << 2
)
const (
IPv4HeaderSize = 20
IPv4Version = 4
IPv4OptionsOffset = 20
IPv4PacketMinLength = IPv4OptionsOffset
)
var (
ErrInvalidLength = errors.New("invalid packet length")
ErrInvalidIPVersion = errors.New("invalid ip version")
ErrInvalidChecksum = errors.New("invalid checksum")
)
type IPv4Packet []byte
func (p IPv4Packet) TotalLen() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p IPv4Packet) SetTotalLength(length uint16) {
binary.BigEndian.PutUint16(p[2:], length)
}
func (p IPv4Packet) HeaderLen() uint16 {
return uint16(p[0]&0xf) * 4
}
func (p IPv4Packet) SetHeaderLen(length uint16) {
p[0] &= 0xF0
p[0] |= byte(length / 4)
}
func (p IPv4Packet) TypeOfService() byte {
return p[1]
}
func (p IPv4Packet) SetTypeOfService(tos byte) {
p[1] = tos
}
func (p IPv4Packet) Identification() uint16 {
return binary.BigEndian.Uint16(p[4:])
}
func (p IPv4Packet) SetIdentification(id uint16) {
binary.BigEndian.PutUint16(p[4:], id)
}
func (p IPv4Packet) FragmentOffset() uint16 {
return binary.BigEndian.Uint16([]byte{p[6] & 0x7, p[7]}) * 8
}
func (p IPv4Packet) SetFragmentOffset(offset uint32) {
flags := p.Flags()
binary.BigEndian.PutUint16(p[6:], uint16(offset/8))
p.SetFlags(flags)
}
func (p IPv4Packet) DataLen() uint16 {
return p.TotalLen() - p.HeaderLen()
}
func (p IPv4Packet) Payload() []byte {
return p[p.HeaderLen():p.TotalLen()]
}
func (p IPv4Packet) Protocol() IPProtocol {
return p[9]
}
func (p IPv4Packet) SetProtocol(protocol IPProtocol) {
p[9] = protocol
}
func (p IPv4Packet) Flags() byte {
return p[6] >> 5
}
func (p IPv4Packet) SetFlags(flags byte) {
p[6] &= 0x1F
p[6] |= flags << 5
}
func (p IPv4Packet) SourceIP() netip.Addr {
return netip.AddrFrom4([4]byte{p[12], p[13], p[14], p[15]})
}
func (p IPv4Packet) SetSourceIP(ip netip.Addr) {
if ip.Is4() {
copy(p[12:16], ip.AsSlice())
}
}
func (p IPv4Packet) DestinationIP() netip.Addr {
return netip.AddrFrom4([4]byte{p[16], p[17], p[18], p[19]})
}
func (p IPv4Packet) SetDestinationIP(ip netip.Addr) {
if ip.Is4() {
copy(p[16:20], ip.AsSlice())
}
}
func (p IPv4Packet) Checksum() uint16 {
return binary.BigEndian.Uint16(p[10:])
}
func (p IPv4Packet) SetChecksum(sum [2]byte) {
p[10] = sum[0]
p[11] = sum[1]
}
func (p IPv4Packet) TimeToLive() uint8 {
return p[8]
}
func (p IPv4Packet) SetTimeToLive(ttl uint8) {
p[8] = ttl
}
func (p IPv4Packet) DecTimeToLive() {
p[8] = p[8] - uint8(1)
}
func (p IPv4Packet) ResetChecksum() {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(0, p[:p.HeaderLen()]))
}
// PseudoSum for tcp checksum
func (p IPv4Packet) PseudoSum() uint32 {
sum := Sum(p[12:20])
sum += uint32(p.Protocol())
sum += uint32(p.DataLen())
return sum
}
func (p IPv4Packet) Valid() bool {
return len(p) >= IPv4HeaderSize && uint16(len(p)) >= p.TotalLen()
}
func (p IPv4Packet) Verify() error {
if len(p) < IPv4PacketMinLength {
return ErrInvalidLength
}
checksum := []byte{p[10], p[11]}
headerLength := uint16(p[0]&0xF) * 4
packetLength := binary.BigEndian.Uint16(p[2:])
if p[0]>>4 != 4 {
return ErrInvalidIPVersion
}
if uint16(len(p)) < packetLength || packetLength < headerLength {
return ErrInvalidLength
}
p[10] = 0
p[11] = 0
defer copy(p[10:12], checksum)
answer := Checksum(0, p[:headerLength])
if answer[0] != checksum[0] || answer[1] != checksum[1] {
return ErrInvalidChecksum
}
return nil
}
var _ IP = (*IPv4Packet)(nil)

141
internal/clashtcpip/ipv6.go Normal file
View file

@ -0,0 +1,141 @@
package clashtcpip
import (
"encoding/binary"
"net/netip"
)
const (
versTCFL = 0
IPv6PayloadLenOffset = 4
IPv6NextHeaderOffset = 6
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
const (
versIHL = 0
tos = 1
ipVersionShift = 4
ipIHLMask = 0x0f
IPv4IHLStride = 4
)
type IPv6Packet []byte
const (
IPv6MinimumSize = IPv6FixedHeaderSize
IPv6AddressSize = 16
IPv6Version = 6
IPv6MinimumMTU = 1280
)
func (b IPv6Packet) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
func (b IPv6Packet) HopLimit() uint8 {
return b[hopLimit]
}
func (b IPv6Packet) NextHeader() byte {
return b[IPv6NextHeaderOffset]
}
func (b IPv6Packet) Protocol() IPProtocol {
return b.NextHeader()
}
func (b IPv6Packet) Payload() []byte {
return b[IPv6MinimumSize:][:b.PayloadLength()]
}
func (b IPv6Packet) SourceIP() netip.Addr {
addr, _ := netip.AddrFromSlice(b[v6SrcAddr:][:IPv6AddressSize])
return addr
}
func (b IPv6Packet) DestinationIP() netip.Addr {
addr, _ := netip.AddrFromSlice(b[v6DstAddr:][:IPv6AddressSize])
return addr
}
func (IPv6Packet) Checksum() uint16 {
return 0
}
func (b IPv6Packet) TOS() (uint8, uint32) {
v := binary.BigEndian.Uint32(b[versTCFL:])
return uint8(v >> 20), v & 0xfffff
}
func (b IPv6Packet) SetTOS(t uint8, l uint32) {
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
}
func (b IPv6Packet) SetPayloadLength(payloadLength uint16) {
binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
func (b IPv6Packet) SetSourceIP(addr netip.Addr) {
if addr.Is6() {
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
}
}
func (b IPv6Packet) SetDestinationIP(addr netip.Addr) {
if addr.Is6() {
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
}
}
func (b IPv6Packet) SetHopLimit(v uint8) {
b[hopLimit] = v
}
func (b IPv6Packet) SetNextHeader(v byte) {
b[IPv6NextHeaderOffset] = v
}
func (b IPv6Packet) SetProtocol(p IPProtocol) {
b.SetNextHeader(p)
}
func (b IPv6Packet) DecTimeToLive() {
b[hopLimit] = b[hopLimit] - uint8(1)
}
func (IPv6Packet) SetChecksum(uint16) {
}
func (IPv6Packet) ResetChecksum() {
}
func (b IPv6Packet) PseudoSum() uint32 {
sum := Sum(b[v6SrcAddr:IPv6FixedHeaderSize])
sum += uint32(b.Protocol())
sum += uint32(b.PayloadLength())
return sum
}
func (b IPv6Packet) Valid() bool {
return len(b) >= IPv6MinimumSize && len(b) >= int(b.PayloadLength())+IPv6MinimumSize
}
func IPVersion(b []byte) int {
if len(b) < versIHL+1 {
return -1
}
return int(b[versIHL] >> ipVersionShift)
}
var _ IP = (*IPv6Packet)(nil)

View file

@ -0,0 +1,90 @@
package clashtcpip
import (
"encoding/binary"
"net"
)
const (
TCPFin uint16 = 1 << 0
TCPSyn uint16 = 1 << 1
TCPRst uint16 = 1 << 2
TCPPuh uint16 = 1 << 3
TCPAck uint16 = 1 << 4
TCPUrg uint16 = 1 << 5
TCPEce uint16 = 1 << 6
TCPEwr uint16 = 1 << 7
TCPNs uint16 = 1 << 8
)
const TCPHeaderSize = 20
type TCPPacket []byte
func (p TCPPacket) SourcePort() uint16 {
return binary.BigEndian.Uint16(p)
}
func (p TCPPacket) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(p, port)
}
func (p TCPPacket) DestinationPort() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p TCPPacket) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(p[2:], port)
}
func (p TCPPacket) Flags() uint16 {
return uint16(p[13] | (p[12] & 0x1))
}
func (p TCPPacket) Checksum() uint16 {
return binary.BigEndian.Uint16(p[16:])
}
func (p TCPPacket) SetChecksum(sum [2]byte) {
p[16] = sum[0]
p[17] = sum[1]
}
func (p TCPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))
}
func (p TCPPacket) Valid() bool {
return len(p) >= TCPHeaderSize
}
func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error {
var checksum [2]byte
checksum[0] = p[16]
checksum[1] = p[17]
// reset checksum
p[16] = 0
p[17] = 0
// restore checksum
defer func() {
p[16] = checksum[0]
p[17] = checksum[1]
}()
// check checksum
s := uint32(0)
s += Sum(sourceAddress)
s += Sum(targetAddress)
s += uint32(TCP)
s += uint32(len(p))
check := Checksum(s, p)
if checksum[0] != check[0] || checksum[1] != check[1] {
return ErrInvalidChecksum
}
return nil
}

View file

@ -0,0 +1,24 @@
package clashtcpip
var zeroChecksum = [2]byte{0x00, 0x00}
var SumFnc = SumCompat
func Sum(b []byte) uint32 {
return SumFnc(b)
}
// Checksum for Internet Protocol family headers
func Checksum(sum uint32, b []byte) (answer [2]byte) {
sum += Sum(b)
sum = (sum >> 16) + (sum & 0xffff)
sum += sum >> 16
sum = ^sum
answer[0] = byte(sum >> 8)
answer[1] = byte(sum)
return
}
func SetIPv4(packet []byte) {
packet[0] = (packet[0] & 0x0f) | (4 << 4)
}

View file

@ -0,0 +1,26 @@
//go:build !noasm
package clashtcpip
import (
"unsafe"
"golang.org/x/sys/cpu"
)
//go:noescape
func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr
func SumAVX2(data []byte) uint32 {
if len(data) == 0 {
return 0
}
return uint32(sumAsmAvx2(unsafe.Pointer(&data[0]), uintptr(len(data))))
}
func init() {
if cpu.X86.HasAVX2 {
SumFnc = SumAVX2
}
}

View file

@ -0,0 +1,140 @@
#include "textflag.h"
DATA endian_swap_mask<>+0(SB)/8, $0x607040502030001
DATA endian_swap_mask<>+8(SB)/8, $0xE0F0C0D0A0B0809
DATA endian_swap_mask<>+16(SB)/8, $0x607040502030001
DATA endian_swap_mask<>+24(SB)/8, $0xE0F0C0D0A0B0809
GLOBL endian_swap_mask<>(SB), RODATA, $32
// func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr
//
// args (8 bytes aligned):
// data unsafe.Pointer - 8 bytes - 0 offset
// length uintptr - 8 bytes - 8 offset
// result uintptr - 8 bytes - 16 offset
#define PDATA AX
#define LENGTH CX
#define RESULT BX
TEXT ·sumAsmAvx2(SB),NOSPLIT,$0-24
MOVQ data+0(FP), PDATA
MOVQ length+8(FP), LENGTH
XORQ RESULT, RESULT
#define VSUM Y0
#define ENDIAN_SWAP_MASK Y1
BEGIN:
VMOVDQU endian_swap_mask<>(SB), ENDIAN_SWAP_MASK
VPXOR VSUM, VSUM, VSUM
#define LOADED_0 Y2
#define LOADED_1 Y3
#define LOADED_2 Y4
#define LOADED_3 Y5
BATCH_64:
CMPQ LENGTH, $64
JB BATCH_32
VPMOVZXWD (PDATA), LOADED_0
VPMOVZXWD 16(PDATA), LOADED_1
VPMOVZXWD 32(PDATA), LOADED_2
VPMOVZXWD 48(PDATA), LOADED_3
VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0
VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1
VPSHUFB ENDIAN_SWAP_MASK, LOADED_2, LOADED_2
VPSHUFB ENDIAN_SWAP_MASK, LOADED_3, LOADED_3
VPADDD LOADED_0, VSUM, VSUM
VPADDD LOADED_1, VSUM, VSUM
VPADDD LOADED_2, VSUM, VSUM
VPADDD LOADED_3, VSUM, VSUM
ADDQ $-64, LENGTH
ADDQ $64, PDATA
JMP BATCH_64
#undef LOADED_0
#undef LOADED_1
#undef LOADED_2
#undef LOADED_3
#define LOADED_0 Y2
#define LOADED_1 Y3
BATCH_32:
CMPQ LENGTH, $32
JB BATCH_16
VPMOVZXWD (PDATA), LOADED_0
VPMOVZXWD 16(PDATA), LOADED_1
VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0
VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1
VPADDD LOADED_0, VSUM, VSUM
VPADDD LOADED_1, VSUM, VSUM
ADDQ $-32, LENGTH
ADDQ $32, PDATA
JMP BATCH_32
#undef LOADED_0
#undef LOADED_1
#define LOADED Y2
BATCH_16:
CMPQ LENGTH, $16
JB COLLECT
VPMOVZXWD (PDATA), LOADED
VPSHUFB ENDIAN_SWAP_MASK, LOADED, LOADED
VPADDD LOADED, VSUM, VSUM
ADDQ $-16, LENGTH
ADDQ $16, PDATA
JMP BATCH_16
#undef LOADED
#define EXTRACTED Y2
#define EXTRACTED_128 X2
#define TEMP_64 DX
COLLECT:
VEXTRACTI128 $0, VSUM, EXTRACTED_128
VPEXTRD $0, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $1, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $2, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $3, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VEXTRACTI128 $1, VSUM, EXTRACTED_128
VPEXTRD $0, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $1, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $2, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $3, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
#undef EXTRACTED
#undef EXTRACTED_128
#undef TEMP_64
#define TEMP DX
#define TEMP2 SI
BATCH_2:
CMPQ LENGTH, $2
JB BATCH_1
XORQ TEMP, TEMP
MOVW (PDATA), TEMP
MOVQ TEMP, TEMP2
SHRW $8, TEMP2
SHLW $8, TEMP
ORW TEMP2, TEMP
ADDL TEMP, RESULT
ADDQ $-2, LENGTH
ADDQ $2, PDATA
JMP BATCH_2
#undef TEMP
#define TEMP DX
BATCH_1:
CMPQ LENGTH, $0
JZ RETURN
XORQ TEMP, TEMP
MOVB (PDATA), TEMP
SHLW $8, TEMP
ADDL TEMP, RESULT
#undef TEMP
RETURN:
MOVQ RESULT, result+16(FP)
RET

View file

@ -0,0 +1,51 @@
package clashtcpip
import (
"crypto/rand"
"testing"
"golang.org/x/sys/cpu"
)
func Test_SumAVX2(t *testing.T) {
if !cpu.X86.HasAVX2 {
t.Skipf("AVX2 unavailable")
}
bytes := make([]byte, chunkSize)
for size := 0; size <= chunkSize; size++ {
for count := 0; count < chunkCount; count++ {
_, err := rand.Reader.Read(bytes[:size])
if err != nil {
t.Skipf("Rand read failed: %v", err)
}
compat := SumCompat(bytes[:size])
avx := SumAVX2(bytes[:size])
if compat != avx {
t.Errorf("Sum of length=%d mismatched", size)
}
}
}
}
func Benchmark_SumAVX2(b *testing.B) {
if !cpu.X86.HasAVX2 {
b.Skipf("AVX2 unavailable")
}
bytes := make([]byte, chunkSize)
_, err := rand.Reader.Read(bytes)
if err != nil {
b.Skipf("Rand read failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
SumAVX2(bytes)
}
}

View file

@ -0,0 +1,24 @@
package clashtcpip
import (
"unsafe"
"golang.org/x/sys/cpu"
)
//go:noescape
func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr
func SumNeon(data []byte) uint32 {
if len(data) == 0 {
return 0
}
return uint32(sumAsmNeon(unsafe.Pointer(&data[0]), uintptr(len(data))))
}
func init() {
if cpu.ARM64.HasASIMD {
SumFnc = SumNeon
}
}

View file

@ -0,0 +1,118 @@
#include "textflag.h"
// func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr
//
// args (8 bytes aligned):
// data unsafe.Pointer - 8 bytes - 0 offset
// length uintptr - 8 bytes - 8 offset
// result uintptr - 8 bytes - 16 offset
#define PDATA R0
#define LENGTH R1
#define RESULT R2
#define VSUM V0
TEXT ·sumAsmNeon(SB),NOSPLIT,$0-24
MOVD data+0(FP), PDATA
MOVD length+8(FP), LENGTH
MOVD $0, RESULT
VMOVQ $0, $0, VSUM
#define LOADED_0 V1
#define LOADED_1 V2
#define LOADED_2 V3
#define LOADED_3 V4
BATCH_32:
CMP $32, LENGTH
BLO BATCH_16
VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8, LOADED_2.B8, LOADED_3.B8]
VREV16 LOADED_0.B8, LOADED_0.B8
VREV16 LOADED_1.B8, LOADED_1.B8
VREV16 LOADED_2.B8, LOADED_2.B8
VREV16 LOADED_3.B8, LOADED_3.B8
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
VUSHLL $0, LOADED_1.H4, LOADED_1.S4
VUSHLL $0, LOADED_2.H4, LOADED_2.S4
VUSHLL $0, LOADED_3.H4, LOADED_3.S4
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
VADD LOADED_1.S4, VSUM.S4, VSUM.S4
VADD LOADED_2.S4, VSUM.S4, VSUM.S4
VADD LOADED_3.S4, VSUM.S4, VSUM.S4
ADD $-32, LENGTH
ADD $32, PDATA
B BATCH_32
#undef LOADED_0
#undef LOADED_1
#undef LOADED_2
#undef LOADED_3
#define LOADED_0 V1
#define LOADED_1 V2
BATCH_16:
CMP $16, LENGTH
BLO BATCH_8
VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8]
VREV16 LOADED_0.B8, LOADED_0.B8
VREV16 LOADED_1.B8, LOADED_1.B8
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
VUSHLL $0, LOADED_1.H4, LOADED_1.S4
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
VADD LOADED_1.S4, VSUM.S4, VSUM.S4
ADD $-16, LENGTH
ADD $16, PDATA
B BATCH_16
#undef LOADED_0
#undef LOADED_1
#define LOADED_0 V1
BATCH_8:
CMP $8, LENGTH
BLO BATCH_2
VLD1 (PDATA), [LOADED_0.B8]
VREV16 LOADED_0.B8, LOADED_0.B8
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
ADD $-8, LENGTH
ADD $8, PDATA
B BATCH_8
#undef LOADED_0
#define LOADED_L R3
#define LOADED_H R4
BATCH_2:
CMP $2, LENGTH
BLO BATCH_1
MOVBU (PDATA), LOADED_H
MOVBU 1(PDATA), LOADED_L
LSL $8, LOADED_H
ORR LOADED_H, LOADED_L, LOADED_L
ADD LOADED_L, RESULT, RESULT
ADD $2, PDATA
ADD $-2, LENGTH
B BATCH_2
#undef LOADED_H
#undef LOADED_L
#define LOADED R3
BATCH_1:
CMP $1, LENGTH
BLO COLLECT
MOVBU (PDATA), LOADED
LSL $8, LOADED
ADD LOADED, RESULT, RESULT
#define EXTRACTED R3
COLLECT:
VMOV VSUM.S[0], EXTRACTED
ADD EXTRACTED, RESULT
VMOV VSUM.S[1], EXTRACTED
ADD EXTRACTED, RESULT
VMOV VSUM.S[2], EXTRACTED
ADD EXTRACTED, RESULT
VMOV VSUM.S[3], EXTRACTED
ADD EXTRACTED, RESULT
#undef VSUM
#undef PDATA
#undef LENGTH
RETURN:
MOVD RESULT, result+16(FP)
RET

View file

@ -0,0 +1,51 @@
package clashtcpip
import (
"crypto/rand"
"testing"
"golang.org/x/sys/cpu"
)
func Test_SumNeon(t *testing.T) {
if !cpu.ARM64.HasASIMD {
t.Skipf("Neon unavailable")
}
bytes := make([]byte, chunkSize)
for size := 0; size <= chunkSize; size++ {
for count := 0; count < chunkCount; count++ {
_, err := rand.Reader.Read(bytes[:size])
if err != nil {
t.Skipf("Rand read failed: %v", err)
}
compat := SumCompat(bytes[:size])
neon := SumNeon(bytes[:size])
if compat != neon {
t.Errorf("Sum of length=%d mismatched", size)
}
}
}
}
func Benchmark_SumNeon(b *testing.B) {
if !cpu.ARM64.HasASIMD {
b.Skipf("Neon unavailable")
}
bytes := make([]byte, chunkSize)
_, err := rand.Reader.Read(bytes)
if err != nil {
b.Skipf("Rand read failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
SumNeon(bytes)
}
}

View file

@ -0,0 +1,14 @@
package clashtcpip
func SumCompat(b []byte) (sum uint32) {
n := len(b)
if n&1 != 0 {
n--
sum += uint32(b[n]) << 8
}
for i := 0; i < n; i += 2 {
sum += (uint32(b[i]) << 8) | uint32(b[i+1])
}
return
}

View file

@ -0,0 +1,26 @@
package clashtcpip
import (
"crypto/rand"
"testing"
)
const (
chunkSize = 9000
chunkCount = 10
)
func Benchmark_SumCompat(b *testing.B) {
bytes := make([]byte, chunkSize)
_, err := rand.Reader.Read(bytes)
if err != nil {
b.Skipf("Rand read failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
SumCompat(bytes)
}
}

View file

@ -0,0 +1,55 @@
package clashtcpip
import (
"encoding/binary"
)
const UDPHeaderSize = 8
type UDPPacket []byte
func (p UDPPacket) Length() uint16 {
return binary.BigEndian.Uint16(p[4:])
}
func (p UDPPacket) SetLength(length uint16) {
binary.BigEndian.PutUint16(p[4:], length)
}
func (p UDPPacket) SourcePort() uint16 {
return binary.BigEndian.Uint16(p)
}
func (p UDPPacket) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(p, port)
}
func (p UDPPacket) DestinationPort() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p UDPPacket) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(p[2:], port)
}
func (p UDPPacket) Payload() []byte {
return p[UDPHeaderSize:p.Length()]
}
func (p UDPPacket) Checksum() uint16 {
return binary.BigEndian.Uint16(p[6:])
}
func (p UDPPacket) SetChecksum(sum [2]byte) {
p[6] = sum[0]
p[7] = sum[1]
}
func (p UDPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))
}
func (p UDPPacket) Valid() bool {
return len(p) >= UDPHeaderSize && uint16(len(p)) >= p.Length()
}

30
lwip.go
View file

@ -7,7 +7,6 @@ import (
"net"
"net/netip"
"os"
"runtime"
lwip "github.com/sagernet/go-tun2socks/core"
"github.com/sagernet/sing/common"
@ -28,19 +27,15 @@ type LWIP struct {
}
func NewLWIP(
ctx context.Context,
tun Tun,
tunMtu uint32,
udpTimeout int64,
handler Handler,
options StackOptions,
) (Stack, error) {
return &LWIP{
ctx: ctx,
tun: tun,
tunMtu: tunMtu,
handler: handler,
ctx: options.Context,
tun: options.Tun,
tunMtu: options.MTU,
handler: options.Handler,
stack: lwip.NewLWIPStack(),
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler),
}, nil
}
@ -57,10 +52,7 @@ func (l *LWIP) loopIn() {
l.loopInWintun(winTun)
return
}
mtu := int(l.tunMtu)
if runtime.GOOS == "darwin" {
mtu += 4
}
mtu := int(l.tunMtu) + PacketOffset
_buffer := buf.StackNewSize(mtu)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
@ -71,13 +63,7 @@ func (l *LWIP) loopIn() {
if err != nil {
return
}
var packet []byte
if runtime.GOOS == "darwin" {
packet = data[4:n]
} else {
packet = data[:n]
}
_, err = l.stack.Write(packet)
_, err = l.stack.Write(data[PacketOffset:n])
if err != nil {
if err.Error() == "stack closed" {
return

View file

@ -2,14 +2,8 @@
package tun
import "context"
func NewLWIP(
ctx context.Context,
tun Tun,
tunMtu uint32,
udpTimeout int64,
handler Handler,
options StackOptions,
) (Stack, error) {
return nil, ErrLWIPNotIncluded
}

View file

@ -2,6 +2,7 @@ package tun
import (
"context"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
@ -17,20 +18,29 @@ type Stack interface {
Close() error
}
type StackOptions struct {
Context context.Context
Tun Tun
Name string
MTU uint32
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
EndpointIndependentNat bool
UDPTimeout int64
Handler Handler
}
func NewStack(
ctx context.Context,
stack string,
tun Tun,
tunMtu uint32,
endpointIndependentNat bool,
udpTimeout int64,
handler Handler,
options StackOptions,
) (Stack, error) {
switch stack {
case "gvisor", "":
return NewGVisor(ctx, tun, tunMtu, endpointIndependentNat, udpTimeout, handler)
return NewGVisor(options)
case "system":
return NewSystem(options)
case "lwip":
return NewLWIP(ctx, tun, tunMtu, udpTimeout, handler)
return NewLWIP(options)
default:
return nil, E.New("unknown stack: ", stack)
}

409
system.go Normal file
View file

@ -0,0 +1,409 @@
package tun
import (
"context"
"net"
"net/netip"
"time"
"github.com/sagernet/sing-tun/internal/clashtcpip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type System struct {
ctx context.Context
tun Tun
mtu uint32
handler Handler
inet4Prefixes []netip.Prefix
inet6Prefixes []netip.Prefix
inet4ServerAddress netip.Addr
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
udpTimeout int64
tcpListener net.Listener
tcpListener6 net.Listener
tcpPort uint16
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service[netip.AddrPort]
}
type Session struct {
SourceAddress netip.Addr
DestinationAddress netip.Addr
SourcePort uint16
DestinationPort uint16
}
func NewSystem(options StackOptions) (Stack, error) {
stack := &System{
ctx: options.Context,
tun: options.Tun,
mtu: options.MTU,
udpTimeout: options.UDPTimeout,
handler: options.Handler,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
}
if len(options.Inet4Address) > 0 {
if options.Inet4Address[0].Bits() == 32 {
return nil, E.New("need one more IPv4 address in first prefix for system stack")
}
stack.inet4ServerAddress = options.Inet4Address[0].Addr()
stack.inet4Address = stack.inet4ServerAddress.Next()
}
if len(options.Inet6Address) > 0 {
if options.Inet6Address[0].Bits() == 128 {
return nil, E.New("need one more IPv6 address in first prefix for system stack")
}
stack.inet6ServerAddress = options.Inet6Address[0].Addr()
stack.inet6Address = stack.inet6ServerAddress.Next()
}
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
return nil, E.New("missing interface address")
}
return stack, nil
}
func (s *System) Close() error {
return common.Close(
s.tcpListener,
s.tcpListener6,
)
}
func (s *System) Start() error {
if s.inet4Address.IsValid() {
tcpListener, err := net.Listen("tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
if err != nil {
return err
}
s.tcpListener = tcpListener
s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener)
}
if s.inet6Address.IsValid() {
tcpListener, err := net.Listen("tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
if err != nil {
return err
}
s.tcpListener6 = tcpListener
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener)
}
s.tcpNat = NewNat()
s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler)
go s.tunLoop()
return nil
}
func (s *System) tunLoop() {
if winTun, isWinTun := s.tun.(WinTun); isWinTun {
s.wintunLoop(winTun)
return
}
_packetBuffer := buf.StackNewSize(int(s.mtu))
defer common.KeepAlive(_packetBuffer)
packetBuffer := common.Dup(_packetBuffer)
defer packetBuffer.Release()
packetSlice := packetBuffer.Slice()
for {
n, err := s.tun.Read(packetSlice)
if err != nil {
return
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
packet := packetSlice[PacketOffset:n]
switch packet[0] >> 4 {
case 4:
s.processIPv4(packet)
case 6:
s.processIPv6(packet)
}
}
}
func (s *System) wintunLoop(winTun WinTun) {
for {
packet, release, err := winTun.ReadPacket()
if err != nil {
return
}
if len(packet) < clashtcpip.IPv4PacketMinLength {
release()
continue
}
switch packet[0] >> 4 {
case 4:
s.processIPv4(packet)
case 6:
s.processIPv6(packet)
}
release()
}
}
func (s *System) acceptLoop(listener net.Listener) {
for {
conn, err := listener.Accept()
if err != nil {
return
}
connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port
session := s.tcpNat.LookupBack(connPort)
if session == nil {
s.handler.NewError(context.Background(), E.New("unknown session with port ", connPort))
continue
}
destination := M.SocksaddrFromNetIP(session.Destination)
if destination.Addr.Is4() {
for _, prefix := range s.inet4Prefixes {
if prefix.Contains(destination.Addr) {
destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
break
}
}
} else {
for _, prefix := range s.inet6Prefixes {
if prefix.Contains(destination.Addr) {
destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
break
}
}
}
go func() {
s.handler.NewConnection(context.Background(), conn, M.Metadata{
Source: M.SocksaddrFromNetIP(session.Source),
Destination: destination,
})
conn.Close()
time.Sleep(time.Second)
s.tcpNat.Revoke(connPort, session)
}()
}
}
func (s *System) NewError(ctx context.Context, err error) {
s.handler.NewError(ctx, err)
}
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
if !packet.Valid() {
return E.New("ipv4: invalid packet")
}
if packet.TimeToLive() == 0x00 {
return E.New("ipv4: TTL exceeded")
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
return s.processIPv4UDP(packet, packet.Payload())
case clashtcpip.ICMP:
return s.processIPv4ICMP(packet, packet.Payload())
default:
return nil
}
}
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error {
if !packet.Valid() {
return E.New("ipv6: invalid packet")
}
if packet.HopLimit() == 0x00 {
return E.New("ipv6: TTL exceeded")
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
return s.processIPv6UDP(packet, packet.Payload())
case clashtcpip.ICMPv6:
return s.processIPv6ICMP(packet, packet.Payload())
default:
return nil
}
}
func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return E.New("session not found: ", destination.Port())
}
packet.SetSourceIP(session.Destination.Addr())
header.SetSourcePort(session.Destination.Port())
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
natPort := s.tcpNat.Lookup(source, destination)
packet.SetSourceIP(s.inet4Address)
header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet4ServerAddress)
header.SetDestinationPort(s.tcpPort)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
}
func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return E.New("session not found: ", destination.Port())
}
packet.SetSourceIP(session.Destination.Addr())
header.SetSourcePort(session.Destination.Port())
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
natPort := s.tcpNat.Lookup(source, destination)
packet.SetSourceIP(s.inet6Address)
header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet6ServerAddress)
header.SetDestinationPort(s.tcpPort6)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
}
func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error {
if packet.Flags()&clashtcpip.FlagMoreFragment != 0 {
return E.New("ipv4: fragment dropped")
}
if packet.FragmentOffset() != 0 {
return E.New("ipv4: fragment dropped")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() || destination.Addr().IsMulticast() {
return nil
}
data := buf.As(header.Payload()).ToOwned()
metadata := M.Metadata{
Source: M.SocksaddrFromNetIP(source),
Destination: M.SocksaddrFromNetIP(destination),
}
s.udpNat.NewPacket(s.ctx, source, data, metadata, func(natConn N.PacketConn) N.PacketWriter {
hdr := buf.As(packet[:packet.HeaderLen()+clashtcpip.UDPHeaderSize]).ToOwned()
return &systemPacketWriter4{s.tun, hdr, source}
})
return nil
}
func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() || destination.Addr().IsMulticast() {
return nil
}
data := buf.As(header.Payload()).ToOwned()
metadata := M.Metadata{
Source: M.SocksaddrFromNetIP(source),
Destination: M.SocksaddrFromNetIP(destination),
}
s.udpNat.NewPacket(s.ctx, source, data, metadata, func(natConn N.PacketConn) N.PacketWriter {
hdr := buf.As(packet[:len(packet)-len(header.Payload())]).ToOwned()
return &systemPacketWriter6{s.tun, hdr, source}
})
return nil
}
func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 {
return nil
}
header.SetType(clashtcpip.ICMPTypePingResponse)
sourceAddress := packet.SourceIP()
packet.SetSourceIP(packet.DestinationIP())
packet.SetDestinationIP(sourceAddress)
header.ResetChecksum()
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
}
func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error {
if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 {
return nil
}
header.SetType(clashtcpip.ICMPv6EchoReply)
sourceAddress := packet.SourceIP()
packet.SetSourceIP(packet.DestinationIP())
packet.SetDestinationIP(sourceAddress)
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
}
type systemPacketWriter4 struct {
tun Tun
header *buf.Buffer
source netip.AddrPort
}
func (w *systemPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.StackNewSize(w.header.Len() + buffer.Len())
defer newPacket.Release()
newPacket.Write(w.header.Bytes())
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes())
ipHdr.SetTotalLength(uint16(newPacket.Len()))
ipHdr.SetDestinationIP(ipHdr.SourceIP())
ipHdr.SetSourceIP(destination.Unwrap().Addr)
udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
return common.Error(w.tun.Write(newPacket.Bytes()))
}
func (w *systemPacketWriter4) Close() error {
w.header.Release()
return nil
}
type systemPacketWriter6 struct {
tun Tun
header *buf.Buffer
source netip.AddrPort
}
func (w *systemPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.StackNewSize(w.header.Len() + buffer.Len())
defer newPacket.Release()
newPacket.Write(w.header.Bytes())
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes())
udpLen := uint16(clashtcpip.UDPHeaderSize + buffer.Len())
ipHdr.SetPayloadLength(udpLen)
ipHdr.SetDestinationIP(ipHdr.SourceIP())
ipHdr.SetSourceIP(destination.Addr)
udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
udpHdr.ResetChecksum(ipHdr.PseudoSum())
return common.Error(w.tun.Write(newPacket.Bytes()))
}
func (w *systemPacketWriter6) Close() error {
w.header.Release()
return nil
}

68
system_nat.go Normal file
View file

@ -0,0 +1,68 @@
package tun
import (
"net/netip"
"sync"
)
type TCPNat struct {
portIndex uint16
portAccess sync.RWMutex
addrAccess sync.RWMutex
addrMap map[netip.AddrPort]uint16
portMap map[uint16]*TCPSession
}
type TCPSession struct {
Source netip.AddrPort
Destination netip.AddrPort
}
func NewNat() *TCPNat {
return &TCPNat{
portIndex: 10000,
addrMap: make(map[netip.AddrPort]uint16),
portMap: make(map[uint16]*TCPSession),
}
}
func (n *TCPNat) LookupBack(port uint16) *TCPSession {
n.portAccess.RLock()
defer n.portAccess.RUnlock()
return n.portMap[port]
}
func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort) uint16 {
n.addrAccess.RLock()
port, loaded := n.addrMap[source]
n.addrAccess.RUnlock()
if loaded {
return port
}
n.addrAccess.Lock()
nextPort := n.portIndex
if nextPort == 0 {
nextPort = 10000
n.portIndex = 10001
} else {
n.portIndex++
}
n.addrMap[source] = nextPort
n.addrAccess.Unlock()
n.portAccess.Lock()
n.portMap[nextPort] = &TCPSession{
Source: source,
Destination: destination,
}
n.portAccess.Unlock()
return nextPort
}
func (n *TCPNat) Revoke(natPort uint16, session *TCPSession) {
n.addrAccess.Lock()
delete(n.addrMap, session.Source)
n.addrAccess.Unlock()
n.portAccess.Lock()
delete(n.portMap, natPort)
n.portAccess.Unlock()
}

View file

@ -19,6 +19,8 @@ import (
"golang.org/x/sys/unix"
)
const PacketOffset = 4
type NativeTun struct {
tunFile *os.File
tunWriter N.VectorisedWriter

5
tun_nondarwin.go Normal file
View file

@ -0,0 +1,5 @@
//go:build !darwin
package tun
const PacketOffset = 0

View file

@ -62,38 +62,37 @@ func (t *NativeTun) configure() error {
if err != nil {
return E.Cause(err, "set ipv4 address")
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv4 dns")
}
}
if len(t.options.Inet6Address) > 0 {
err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address)
if err != nil {
return E.Cause(err, "set ipv6 address")
}
}
err := luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv4 dns")
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv6 dns")
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv6 dns")
}
}
if t.options.AutoRoute {
if len(t.options.Inet4Address) > 0 {
err = luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0)
err := luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv4 route")
}
}
if len(t.options.Inet6Address) > 0 {
err = luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0)
err := luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv6 route")
}
}
}
if len(t.options.Inet4Address) > 0 {
var inetIf *winipcfg.MibIPInterfaceRow
inetIf, err = luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
if err != nil {
return err
}
@ -113,8 +112,7 @@ func (t *NativeTun) configure() error {
}
}
if len(t.options.Inet6Address) > 0 {
var inet6If *winipcfg.MibIPInterfaceRow
inet6If, err = luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6))
inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6))
if err != nil {
return err
}