From 297dd632e8fa0c4dfc813de9378c9d44c6737d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 26 Jan 2025 09:01:00 +0800 Subject: [PATCH] Add TLS fragment support --- adapter/inbound.go | 2 + common/process/searcher.go | 1 + common/process/searcher_windows.go | 199 ++--------------------------- common/tlsfragment/conn.go | 107 ++++++++++++++++ common/tlsfragment/index.go | 131 +++++++++++++++++++ common/tlsfragment/wait_darwin.go | 93 ++++++++++++++ common/tlsfragment/wait_linux.go | 40 ++++++ common/tlsfragment/wait_stub.go | 14 ++ common/tlsfragment/wait_windows.go | 28 ++++ constant/timeout.go | 1 + go.mod | 2 +- go.sum | 4 +- option/rule_action.go | 3 + route/conn.go | 16 +++ route/route.go | 4 + route/rule/rule_action.go | 12 ++ transport/simple-obfs/http.go | 4 + transport/simple-obfs/tls.go | 4 + 18 files changed, 476 insertions(+), 189 deletions(-) create mode 100644 common/tlsfragment/conn.go create mode 100644 common/tlsfragment/index.go create mode 100644 common/tlsfragment/wait_darwin.go create mode 100644 common/tlsfragment/wait_linux.go create mode 100644 common/tlsfragment/wait_stub.go create mode 100644 common/tlsfragment/wait_windows.go diff --git a/adapter/inbound.go b/adapter/inbound.go index 11085099..1218c049 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -72,6 +72,8 @@ type InboundContext struct { UDPDisableDomainUnmapping bool UDPConnect bool UDPTimeout time.Duration + TLSFragment bool + TLSFragmentFallbackDelay time.Duration NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType diff --git a/common/process/searcher.go b/common/process/searcher.go index cee81068..d525b3c1 100644 --- a/common/process/searcher.go +++ b/common/process/searcher.go @@ -23,6 +23,7 @@ type Config struct { } type Info struct { + ProcessID uint32 ProcessPath string PackageName string User string diff --git a/common/process/searcher_windows.go b/common/process/searcher_windows.go index 5b3d59b5..b7d89dda 100644 --- a/common/process/searcher_windows.go +++ b/common/process/searcher_windows.go @@ -2,14 +2,11 @@ package process import ( "context" - "fmt" "net/netip" - "os" "syscall" - "unsafe" E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/winiphlpapi" "golang.org/x/sys/windows" ) @@ -26,209 +23,39 @@ func NewSearcher(_ Config) (Searcher, error) { return &windowsSearcher{}, nil } -var ( - modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") - procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") - procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW") -) - func initWin32API() error { - err := modiphlpapi.Load() - if err != nil { - return E.Cause(err, "load iphlpapi.dll") - } - - err = procGetExtendedTcpTable.Find() - if err != nil { - return E.Cause(err, "load iphlpapi::GetExtendedTcpTable") - } - - err = procGetExtendedUdpTable.Find() - if err != nil { - return E.Cause(err, "load iphlpapi::GetExtendedUdpTable") - } - - err = modkernel32.Load() - if err != nil { - return E.Cause(err, "load kernel32.dll") - } - - err = procQueryFullProcessImageNameW.Find() - if err != nil { - return E.Cause(err, "load kernel32::QueryFullProcessImageNameW") - } - - return nil + return winiphlpapi.LoadExtendedTable() } func (s *windowsSearcher) FindProcessInfo(ctx context.Context, network string, source netip.AddrPort, destination netip.AddrPort) (*Info, error) { - processName, err := findProcessName(network, source.Addr(), int(source.Port())) + pid, err := winiphlpapi.FindPid(network, source) if err != nil { return nil, err } - return &Info{ProcessPath: processName, UserId: -1}, nil -} - -func findProcessName(network string, ip netip.Addr, srcPort int) (string, error) { - family := windows.AF_INET - if ip.Is6() { - family = windows.AF_INET6 - } - - const ( - tcpTablePidConn = 4 - udpTablePid = 1 - ) - - var class int - var fn uintptr - switch network { - case N.NetworkTCP: - fn = procGetExtendedTcpTable.Addr() - class = tcpTablePidConn - case N.NetworkUDP: - fn = procGetExtendedUdpTable.Addr() - class = udpTablePid - default: - return "", os.ErrInvalid - } - - buf, err := getTransportTable(fn, family, class) + path, err := getProcessPath(pid) if err != nil { - return "", err + return &Info{ProcessID: pid, UserId: -1}, err } - - s := newSearcher(family == windows.AF_INET, network == N.NetworkTCP) - - pid, err := s.Search(buf, ip, uint16(srcPort)) - if err != nil { - return "", err - } - return getExecPathFromPID(pid) + return &Info{ProcessID: pid, ProcessPath: path, UserId: -1}, nil } -type searcher struct { - itemSize int - port int - ip int - ipSize int - pid int - tcpState int -} - -func (s *searcher) Search(b []byte, ip netip.Addr, port uint16) (uint32, error) { - n := int(readNativeUint32(b[:4])) - itemSize := s.itemSize - for i := 0; i < n; i++ { - row := b[4+itemSize*i : 4+itemSize*(i+1)] - - if s.tcpState >= 0 { - tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4]) - // MIB_TCP_STATE_ESTAB, only check established connections for TCP - if tcpState != 5 { - continue - } - } - - // according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian. - // this field can be illustrated as follows depends on different machine endianess: - // little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB) - // big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB) - // so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32 - srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4]))) - if srcPort != port { - continue - } - - srcIP, _ := netip.AddrFromSlice(row[s.ip : s.ip+s.ipSize]) - // windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto - if ip != srcIP && (!srcIP.IsUnspecified() || s.tcpState != -1) { - continue - } - - pid := readNativeUint32(row[s.pid : s.pid+4]) - return pid, nil - } - return 0, ErrNotFound -} - -func newSearcher(isV4, isTCP bool) *searcher { - var itemSize, port, ip, ipSize, pid int - tcpState := -1 - switch { - case isV4 && isTCP: - // struct MIB_TCPROW_OWNER_PID - itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0 - case isV4 && !isTCP: - // struct MIB_UDPROW_OWNER_PID - itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8 - case !isV4 && isTCP: - // struct MIB_TCP6ROW_OWNER_PID - itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48 - case !isV4 && !isTCP: - // struct MIB_UDP6ROW_OWNER_PID - itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24 - } - - return &searcher{ - itemSize: itemSize, - port: port, - ip: ip, - ipSize: ipSize, - pid: pid, - tcpState: tcpState, - } -} - -func getTransportTable(fn uintptr, family int, class int) ([]byte, error) { - for size, buf := uint32(8), make([]byte, 8); ; { - ptr := unsafe.Pointer(&buf[0]) - err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0) - - switch err { - case 0: - return buf, nil - case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER): - buf = make([]byte, size) - default: - return nil, fmt.Errorf("syscall error: %d", err) - } - } -} - -func readNativeUint32(b []byte) uint32 { - return *(*uint32)(unsafe.Pointer(&b[0])) -} - -func getExecPathFromPID(pid uint32) (string, error) { - // kernel process starts with a colon in order to distinguish with normal processes +func getProcessPath(pid uint32) (string, error) { switch pid { case 0: - // reserved pid for system idle process return ":System Idle Process", nil case 4: - // reserved pid for windows kernel image return ":System", nil } - h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid) + handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid) if err != nil { return "", err } - defer windows.CloseHandle(h) - + defer windows.CloseHandle(handle) + size := uint32(syscall.MAX_LONG_PATH) buf := make([]uint16, syscall.MAX_LONG_PATH) - size := uint32(len(buf)) - r1, _, err := syscall.SyscallN( - procQueryFullProcessImageNameW.Addr(), - uintptr(h), - uintptr(0), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&size)), - ) - if r1 == 0 { + err = windows.QueryFullProcessImageName(handle, 0, &buf[0], &size) + if err != nil { return "", err } - return syscall.UTF16ToString(buf[:size]), nil + return windows.UTF16ToString(buf[:size]), nil } diff --git a/common/tlsfragment/conn.go b/common/tlsfragment/conn.go new file mode 100644 index 00000000..6f2a3dad --- /dev/null +++ b/common/tlsfragment/conn.go @@ -0,0 +1,107 @@ +package tf + +import ( + "context" + "math/rand" + "net" + "strings" + "time" + + N "github.com/sagernet/sing/common/network" + + "golang.org/x/net/publicsuffix" +) + +type Conn struct { + net.Conn + tcpConn *net.TCPConn + ctx context.Context + firstPacketWritten bool + fallbackDelay time.Duration +} + +func NewConn(conn net.Conn, ctx context.Context, fallbackDelay time.Duration) (*Conn, error) { + tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn) + return &Conn{ + Conn: conn, + tcpConn: tcpConn, + ctx: ctx, + fallbackDelay: fallbackDelay, + }, nil +} + +func (c *Conn) Write(b []byte) (n int, err error) { + if !c.firstPacketWritten { + defer func() { + c.firstPacketWritten = true + }() + serverName := indexTLSServerName(b) + if serverName != nil { + if c.tcpConn != nil { + err = c.tcpConn.SetNoDelay(true) + if err != nil { + return + } + } + splits := strings.Split(serverName.ServerName, ".") + currentIndex := serverName.Index + if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" { + splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")] + } + if len(splits) > 1 && splits[0] == "..." { + currentIndex += len(splits[0]) + 1 + splits = splits[1:] + } + var splitIndexes []int + for i, split := range splits { + splitAt := rand.Intn(len(split)) + splitIndexes = append(splitIndexes, currentIndex+splitAt) + currentIndex += len(split) + if i != len(splits)-1 { + currentIndex++ + } + } + for i := 0; i <= len(splitIndexes); i++ { + var payload []byte + if i == 0 { + payload = b[:splitIndexes[i]] + } else if i == len(splitIndexes) { + payload = b[splitIndexes[i-1]:] + } else { + payload = b[splitIndexes[i-1]:splitIndexes[i]] + } + if c.tcpConn != nil && i != len(splitIndexes) { + err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay) + if err != nil { + return + } + } else { + _, err = c.Conn.Write(payload) + if err != nil { + return + } + } + } + if c.tcpConn != nil { + err = c.tcpConn.SetNoDelay(false) + if err != nil { + return + } + } + return len(b), nil + } + } + return c.Conn.Write(b) +} + +func (c *Conn) ReaderReplaceable() bool { + return true +} + +func (c *Conn) WriterReplaceable() bool { + return c.firstPacketWritten +} + +func (c *Conn) Upstream() any { + return c.Conn +} diff --git a/common/tlsfragment/index.go b/common/tlsfragment/index.go new file mode 100644 index 00000000..59031cec --- /dev/null +++ b/common/tlsfragment/index.go @@ -0,0 +1,131 @@ +package tf + +import ( + "encoding/binary" +) + +const ( + recordLayerHeaderLen int = 5 + handshakeHeaderLen int = 6 + randomDataLen int = 32 + sessionIDHeaderLen int = 1 + cipherSuiteHeaderLen int = 2 + compressMethodHeaderLen int = 1 + extensionsHeaderLen int = 2 + extensionHeaderLen int = 4 + sniExtensionHeaderLen int = 5 + contentType uint8 = 22 + handshakeType uint8 = 1 + sniExtensionType uint16 = 0 + sniNameDNSHostnameType uint8 = 0 + tlsVersionBitmask uint16 = 0xFFFC + tls13 uint16 = 0x0304 +) + +type myServerName struct { + Index int + Length int + ServerName string +} + +func indexTLSServerName(payload []byte) *myServerName { + if len(payload) < recordLayerHeaderLen || payload[0] != contentType { + return nil + } + segmentLen := binary.BigEndian.Uint16(payload[3:5]) + if len(payload) < recordLayerHeaderLen+int(segmentLen) { + return nil + } + serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen : recordLayerHeaderLen+int(segmentLen)]) + if serverName == nil { + return nil + } + serverName.Length += recordLayerHeaderLen + return serverName +} + +func indexTLSServerNameFromHandshake(hs []byte) *myServerName { + if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen { + return nil + } + if hs[0] != handshakeType { + return nil + } + handshakeLen := uint32(hs[1])<<16 | uint32(hs[2])<<8 | uint32(hs[3]) + if len(hs[4:]) != int(handshakeLen) { + return nil + } + tlsVersion := uint16(hs[4])<<8 | uint16(hs[5]) + if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 { + return nil + } + sessionIDLen := hs[38] + if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen) { + return nil + } + cs := hs[handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen):] + if len(cs) < cipherSuiteHeaderLen { + return nil + } + csLen := uint16(cs[0])<<8 | uint16(cs[1]) + if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen { + return nil + } + compressMethodLen := uint16(cs[cipherSuiteHeaderLen+int(csLen)]) + if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen+int(compressMethodLen) { + return nil + } + currentIndex := cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen) + serverName := indexTLSServerNameFromExtensions(cs[currentIndex:]) + if serverName == nil { + return nil + } + serverName.Index += currentIndex + return serverName +} + +func indexTLSServerNameFromExtensions(exs []byte) *myServerName { + if len(exs) == 0 { + return nil + } + if len(exs) < extensionsHeaderLen { + return nil + } + exsLen := uint16(exs[0])<<8 | uint16(exs[1]) + exs = exs[extensionsHeaderLen:] + if len(exs) < int(exsLen) { + return nil + } + for currentIndex := extensionsHeaderLen; len(exs) > 0; { + if len(exs) < extensionHeaderLen { + return nil + } + exType := uint16(exs[0])<<8 | uint16(exs[1]) + exLen := uint16(exs[2])<<8 | uint16(exs[3]) + if len(exs) < extensionHeaderLen+int(exLen) { + return nil + } + sex := exs[extensionHeaderLen : extensionHeaderLen+int(exLen)] + + switch exType { + case sniExtensionType: + if len(sex) < sniExtensionHeaderLen { + return nil + } + sniType := sex[2] + if sniType != sniNameDNSHostnameType { + return nil + } + sniLen := uint16(sex[3])<<8 | uint16(sex[4]) + sex = sex[sniExtensionHeaderLen:] + return &myServerName{ + Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen, + Length: int(sniLen), + ServerName: string(sex), + } + } + exs = exs[4+exLen:] + currentIndex += 4 + int(exLen) + } + return nil +} diff --git a/common/tlsfragment/wait_darwin.go b/common/tlsfragment/wait_darwin.go new file mode 100644 index 00000000..90c65ba2 --- /dev/null +++ b/common/tlsfragment/wait_darwin.go @@ -0,0 +1,93 @@ +package tf + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing/common/control" + + "golang.org/x/sys/unix" +) + +/* +const tcpMaxNotifyAck = 10 + +type tcpNotifyAckID uint32 + + type tcpNotifyAckComplete struct { + NotifyPending uint32 + NotifyCompleteCount uint32 + NotifyCompleteID [tcpMaxNotifyAck]tcpNotifyAckID + } + +var sizeOfTCPNotifyAckComplete = int(unsafe.Sizeof(tcpNotifyAckComplete{})) + + func getsockoptTCPNotifyAckComplete(fd, level, opt int) (*tcpNotifyAckComplete, error) { + var value tcpNotifyAckComplete + vallen := uint32(sizeOfTCPNotifyAckComplete) + err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen) + return &value, err + } + +//go:linkname getsockopt golang.org/x/sys/unix.getsockopt +func getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *uint32) error + + func waitAck(ctx context.Context, conn *net.TCPConn, _ time.Duration) error { + const TCP_NOTIFY_ACKNOWLEDGEMENT = 0x212 + return control.Conn(conn, func(fd uintptr) error { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT, 1) + if err != nil { + if errors.Is(err, unix.EINVAL) { + return waitAckFallback(ctx, conn, 0) + } + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + var ackComplete *tcpNotifyAckComplete + ackComplete, err = getsockoptTCPNotifyAckComplete(int(fd), unix.IPPROTO_TCP, TCP_NOTIFY_ACKNOWLEDGEMENT) + if err != nil { + return err + } + if ackComplete.NotifyPending == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + }) + } +*/ + +func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error { + _, err := conn.Write(payload) + if err != nil { + return err + } + return control.Conn(conn, func(fd uintptr) error { + start := time.Now() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + unacked, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_NWRITE) + if err != nil { + return err + } + if unacked == 0 { + if time.Since(start) <= 20*time.Millisecond { + // under transparent proxy + time.Sleep(fallbackDelay) + } + return nil + } + time.Sleep(10 * time.Millisecond) + } + }) +} diff --git a/common/tlsfragment/wait_linux.go b/common/tlsfragment/wait_linux.go new file mode 100644 index 00000000..517d6ea5 --- /dev/null +++ b/common/tlsfragment/wait_linux.go @@ -0,0 +1,40 @@ +package tf + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing/common/control" + + "golang.org/x/sys/unix" +) + +func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error { + _, err := conn.Write(payload) + if err != nil { + return err + } + return control.Conn(conn, func(fd uintptr) error { + start := time.Now() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + tcpInfo, err := unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + if err != nil { + return err + } + if tcpInfo.Unacked == 0 { + if time.Since(start) <= 20*time.Millisecond { + // under transparent proxy + time.Sleep(fallbackDelay) + } + return nil + } + time.Sleep(10 * time.Millisecond) + } + }) +} diff --git a/common/tlsfragment/wait_stub.go b/common/tlsfragment/wait_stub.go new file mode 100644 index 00000000..7e451a04 --- /dev/null +++ b/common/tlsfragment/wait_stub.go @@ -0,0 +1,14 @@ +//go:build !(linux || darwin || windows) + +package tf + +import ( + "context" + "net" + "time" +) + +func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error { + time.Sleep(fallbackDelay) + return nil +} diff --git a/common/tlsfragment/wait_windows.go b/common/tlsfragment/wait_windows.go new file mode 100644 index 00000000..118a204d --- /dev/null +++ b/common/tlsfragment/wait_windows.go @@ -0,0 +1,28 @@ +package tf + +import ( + "context" + "errors" + "net" + "time" + + "github.com/sagernet/sing/common/winiphlpapi" + + "golang.org/x/sys/windows" +) + +func writeAndWaitAck(ctx context.Context, conn *net.TCPConn, payload []byte, fallbackDelay time.Duration) error { + start := time.Now() + err := winiphlpapi.WriteAndWaitAck(ctx, conn, payload) + if err != nil { + if errors.Is(err, windows.ERROR_ACCESS_DENIED) { + time.Sleep(fallbackDelay) + return nil + } + return err + } + if time.Since(start) <= 20*time.Millisecond { + time.Sleep(fallbackDelay) + } + return nil +} diff --git a/constant/timeout.go b/constant/timeout.go index 3b5a452b..eb0fd34c 100644 --- a/constant/timeout.go +++ b/constant/timeout.go @@ -16,6 +16,7 @@ const ( StopTimeout = 5 * time.Second FatalStopTimeout = 10 * time.Second FakeIPMetadataSaveInterval = 10 * time.Second + TLSFragmentFallbackDelay = 500 * time.Millisecond ) var PortProtocols = map[uint16]string{ diff --git a/go.mod b/go.mod index d3aa6559..8748fc08 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff github.com/sagernet/quic-go v0.49.0-beta.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.6.5 + github.com/sagernet/sing v0.6.6-0.20250326051824-d39c2c2fddfa github.com/sagernet/sing-mux v0.3.1 github.com/sagernet/sing-quic v0.4.0 github.com/sagernet/sing-shadowsocks v0.2.7 diff --git a/go.sum b/go.sum index f446adfa..85a76256 100644 --- a/go.sum +++ b/go.sum @@ -119,8 +119,8 @@ github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8W github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= -github.com/sagernet/sing v0.6.5 h1:TBKTK6Ms0/MNTZm+cTC2hhKunE42XrNIdsxcYtWqeUU= -github.com/sagernet/sing v0.6.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.6-0.20250326051824-d39c2c2fddfa h1:18mz8gmh0/EL3Bk+hB0Xf3tGOO1p/tP1sjjhSDeyUtU= +github.com/sagernet/sing v0.6.6-0.20250326051824-d39c2c2fddfa/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-mux v0.3.1 h1:kvCc8HyGAskDHDQ0yQvoTi/7J4cZPB/VJMsAM3MmdQI= github.com/sagernet/sing-mux v0.3.1/go.mod h1:Mkdz8LnDstthz0HWuA/5foncnDIdcNN5KZ6AdJX+x78= github.com/sagernet/sing-quic v0.4.0 h1:E4geazHk/UrJTXMlT+CBCKmn8V86RhtNeczWtfeoEFc= diff --git a/option/rule_action.go b/option/rule_action.go index a715d260..f07d7298 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -150,6 +150,9 @@ type RawRouteOptionsActionOptions struct { UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` UDPConnect bool `json:"udp_connect,omitempty"` UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` + + TLSFragment bool `json:"tls_fragment,omitempty"` + TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"` } type RouteOptionsActionOptions RawRouteOptionsActionOptions diff --git a/route/conn.go b/route/conn.go index 17218387..c2a2eab9 100644 --- a/route/conn.go +++ b/route/conn.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tlsfragment" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -78,6 +79,21 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co m.logger.ErrorContext(ctx, err) return } + if metadata.TLSFragment { + fallbackDelay := metadata.TLSFragmentFallbackDelay + if fallbackDelay == 0 { + fallbackDelay = C.TLSFragmentFallbackDelay + } + var newConn *tf.Conn + newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay) + if err != nil { + conn.Close() + remoteConn.Close() + m.logger.ErrorContext(ctx, err) + return + } + remoteConn = newConn + } m.access.Lock() element := m.connections.PushBack(conn) m.access.Unlock() diff --git a/route/route.go b/route/route.go index ac81420c..6f4c5c65 100644 --- a/route/route.go +++ b/route/route.go @@ -454,6 +454,10 @@ match: if routeOptions.UDPTimeout > 0 { metadata.UDPTimeout = routeOptions.UDPTimeout } + if routeOptions.TLSFragment { + metadata.TLSFragment = true + metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay + } } switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 88149aff..96515dad 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -36,6 +36,8 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptions.UDPConnect, + TLSFragment: action.RouteOptions.TLSFragment, + TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay), }, }, nil case C.RuleActionTypeRouteOptions: @@ -47,6 +49,8 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptionsOptions.UDPConnect, UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), + TLSFragment: action.RouteOptionsOptions.TLSFragment, + TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay), }, nil case C.RuleActionTypeDirect: directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false) @@ -142,6 +146,9 @@ func (r *RuleActionRoute) String() string { if r.UDPConnect { descriptions = append(descriptions, "udp-connect") } + if r.TLSFragment { + descriptions = append(descriptions, "tls-fragment") + } return F.ToString("route(", strings.Join(descriptions, ","), ")") } @@ -155,6 +162,8 @@ type RuleActionRouteOptions struct { UDPDisableDomainUnmapping bool UDPConnect bool UDPTimeout time.Duration + TLSFragment bool + TLSFragmentFallbackDelay time.Duration } func (r *RuleActionRouteOptions) Type() string { @@ -187,6 +196,9 @@ func (r *RuleActionRouteOptions) String() string { if r.UDPConnect { descriptions = append(descriptions, "udp-connect") } + if r.UDPTimeout > 0 { + descriptions = append(descriptions, "udp-timeout") + } return F.ToString("route-options(", strings.Join(descriptions, ","), ")") } diff --git a/transport/simple-obfs/http.go b/transport/simple-obfs/http.go index f77a63a8..df38768e 100644 --- a/transport/simple-obfs/http.go +++ b/transport/simple-obfs/http.go @@ -82,6 +82,10 @@ func (ho *HTTPObfs) Write(b []byte) (int, error) { return ho.Conn.Write(b) } +func (ho *HTTPObfs) Upstream() any { + return ho.Conn +} + // NewHTTPObfs return a HTTPObfs func NewHTTPObfs(conn net.Conn, host string, port string) net.Conn { return &HTTPObfs{ diff --git a/transport/simple-obfs/tls.go b/transport/simple-obfs/tls.go index 51756fdb..96564815 100644 --- a/transport/simple-obfs/tls.go +++ b/transport/simple-obfs/tls.go @@ -113,6 +113,10 @@ func (to *TLSObfs) write(b []byte) (int, error) { return len(b), err } +func (to *TLSObfs) Upstream() any { + return to.Conn +} + // NewTLSObfs return a SimpleObfs func NewTLSObfs(conn net.Conn, server string) net.Conn { return &TLSObfs{