Compare commits

...

55 commits
v0.5.2 ... dev

Author SHA1 Message Date
世界
d39c2c2fdd
socks: Add custom udp listener 2025-03-26 13:18:24 +08:00
世界
ea82ac275f
Add freelru.GetWithLifetimeNoExpire 2025-03-26 13:18:18 +08:00
世界
ea0ac932ae
Add winiphlpapi 2025-03-26 13:18:17 +08:00
世界
2b41455f5a
Fix udpnat2 handler again 2025-03-26 12:46:15 +08:00
世界
23b0180a1b
Fix crash on udpnat2 handler 2025-03-24 18:11:10 +08:00
世界
ce1b4851a4
Fix socks5 UDP 2025-03-16 10:23:29 +08:00
世界
2238a05966
Fix merge objects 2025-03-16 10:23:29 +08:00
世界
b55d1c78b3
bufio: Add destination NAT packet conn 2025-03-09 15:20:32 +08:00
世界
d54716612c
Fix syscall packet read waiter for Windows 2025-02-28 12:07:45 +08:00
世界
9eafc7fc62
udpnat2: Fix crash 2025-02-10 15:08:18 +08:00
世界
d8153df67f
Add ENOTCONN to IsClosed 2025-02-06 08:41:32 +08:00
世界
d9f6eb136d
Fix set windows system time 2025-01-09 23:30:25 +08:00
世界
4dabb9be97
freelru: Fix GetAndRefreshOrAdd 2025-01-09 15:59:26 +08:00
世界
be9840c70f
listable: Fix incorrect unmarshaling of null to []T{null} 2025-01-09 15:57:12 +08:00
世界
aa7d2543a3
Fix errors usage 2024-12-16 09:20:34 +08:00
世界
33beacc053
Fix socks5 UDP handshake 2024-12-14 18:16:15 +08:00
世界
442cceb9fa
Fix disable UDP fragment 2024-12-12 20:43:56 +08:00
世界
3374a45475
Fix socks5 UDP implementation 2024-12-10 19:53:57 +08:00
世界
73776cf797
Fix lru test 2024-12-10 19:42:55 +08:00
世界
957166799e
Fix CloseOnHandshakeFailure 2024-12-04 17:14:58 +08:00
世界
809d8eca13
freelru: fix PurgeExpired 2024-12-04 11:36:20 +08:00
世界
9f69e7f9f7
E: IsClosedOrCanceled check IsTimeout 2024-12-01 20:19:37 +08:00
世界
478265cd45
badoption: Finish netip options 2024-12-01 14:33:23 +08:00
世界
3f30aaf25e
freelru: purge all expired items 2024-11-30 16:06:59 +08:00
世界
39040e06dc
udpnat2: Fix concurrency 2024-11-28 13:51:17 +08:00
世界
6edd2ce0ea
freelru: Update source and add GetAndRefreshOrAdd 2024-11-28 13:51:17 +08:00
世界
0a2e2a3eaf
udpnat2: Fix timeout 2024-11-27 18:02:22 +08:00
世界
4ba1eb123c
Fix set timeout 2024-11-27 17:28:18 +08:00
世界
c44912a861
freelru: Fix purge 2024-11-27 13:51:08 +08:00
世界
a8f5bf4eb0
udpnat2: Add timeout check 2024-11-26 19:08:35 +08:00
世界
30e9d91b57
Fix AppendClose 2024-11-26 12:21:37 +08:00
世界
7fd3517e4d
udpnat2: Add purge expire ticker 2024-11-26 12:21:37 +08:00
世界
a8285e06a5
udpnat2: Implement set timeout for nat conn 2024-11-26 12:21:37 +08:00
世界
3613ead480
freelru: Add PeekWithLifetime and UpdateLifetime 2024-11-26 11:29:14 +08:00
世界
c8f251c668
Fix copy count 2024-11-24 19:02:21 +08:00
世界
fa5355e99e
bufio: more copy funcs 2024-11-20 11:27:20 +08:00
世界
30fbafd954
udpnat2: Add cache funcs 2024-11-18 12:14:35 +08:00
世界
fdca9b3f8e
badjson: Fix Listable 2024-11-16 16:03:00 +08:00
世界
e52e04f721
Fix HandshakeFailure usages 2024-11-15 16:27:03 +08:00
世界
7f621fdd78
Add freelru.SetUpdateLifetimeOnGet/GetWithLifetime 2024-11-14 17:49:49 +08:00
世界
ae139d9ee1
Update N.PayloadDialer 2024-11-14 17:49:49 +08:00
世界
c432befd02
http: Fix proxying websocket 2024-11-13 19:02:07 +08:00
世界
cc7e630923
control: Refactor interface finder 2024-11-12 20:15:50 +08:00
世界
0998999911
udpnat2: Fix missing shared impl 2024-11-09 11:40:27 +08:00
世界
72ff654ee0
shared: Add SetHealthCheck to interface 2024-11-09 11:40:27 +08:00
世界
11ffb962ae
freelru: Fix impl 2024-11-09 11:40:27 +08:00
世界
fcb19641e6
freelru: Copy shared source 2024-11-09 11:40:27 +08:00
世界
524a6bd0d1
udpnat2: Set upstream to writer 2024-11-09 11:40:27 +08:00
世界
b5f9e70ffd
badjson: Fix Listable 2024-11-09 11:40:27 +08:00
世界
c80c8f907c
badjson: Add context marshaler/unmarshaler 2024-11-05 18:43:05 +08:00
世界
a4eb7fa900
udpnat2: Add SetHandler 2024-11-05 18:43:05 +08:00
世界
7ec09d6045
udpnat2: New synced udp nat service 2024-11-05 18:43:04 +08:00
世界
0641c71805
maphash: copy source from v0.1.0 2024-11-05 18:43:04 +08:00
世界
e7ec021b81
freelru: copy source from v0.14.0 2024-11-05 18:43:04 +08:00
世界
0f2447a95b
Crazy sekai overturns the small pond 2024-11-05 18:43:04 +08:00
76 changed files with 4490 additions and 492 deletions

View file

@ -2,11 +2,10 @@ package baderror
import (
"context"
"errors"
"io"
"net"
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
func Contains(err error, msgList ...string) bool {
@ -22,8 +21,7 @@ func WrapH2(err error) error {
if err == nil {
return nil
}
err = E.Unwrap(err)
if err == io.ErrUnexpectedEOF {
if errors.Is(err, io.ErrUnexpectedEOF) {
return io.EOF
}
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {

View file

@ -9,19 +9,20 @@ import (
type AddrConn struct {
net.Conn
M.Metadata
Source M.Socksaddr
Destination M.Socksaddr
}
func (c *AddrConn) LocalAddr() net.Addr {
if c.Metadata.Destination.IsValid() {
return c.Metadata.Destination.TCPAddr()
if c.Destination.IsValid() {
return c.Destination.TCPAddr()
}
return c.Conn.LocalAddr()
}
func (c *AddrConn) RemoteAddr() net.Addr {
if c.Metadata.Source.IsValid() {
return c.Metadata.Source.TCPAddr()
if c.Source.IsValid() {
return c.Source.TCPAddr()
}
return c.Conn.RemoteAddr()
}

View file

@ -184,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
if buffer != nil {
buffer.DecRef()
}
return &N.PacketBuffer{
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer,
Destination: c.destination,
}
return packet
}
func (c *CachedPacketConn) Upstream() any {

View file

@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
if destination.IsFqdn() {
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
if err != nil {
return err
}
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
}
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
}
func (w *ExtendedUDPConn) Upstream() any {

View file

@ -29,28 +29,36 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
if !cachedBuffer.IsEmpty() {
_, err = destination.Write(cachedBuffer.Bytes())
if err != nil {
cachedBuffer.Release()
return
}
}
dataLen := cachedBuffer.Len()
_, err = destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
continue
}
}
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break
}
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
}
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}
@ -75,6 +83,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
}
// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
@ -113,19 +122,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
}
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
options := N.NewReadWaitOptions(source, destination)
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
buffer := options.NewBuffer()
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
@ -136,7 +136,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
options.PostReturn(buffer)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
@ -196,18 +196,6 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
return group.Run(ctx)
}
// Deprecated: not used
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
switch len(contextList) {
case 0:
return CopyConn(context.Background(), source, destination)
case 1:
return CopyConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
}
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
@ -225,24 +213,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil {
return
}
}
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
n += copeN
return
}
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var (
handled bool
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
@ -256,28 +244,19 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
var destinationAddress M.Socksaddr
for {
buffer := buf.NewSize(bufferSize)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
buffer := options.NewPacketBuffer()
destinationAddress, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination)
options.PostReturn(buffer)
err = destination.WritePacket(buffer, destinationAddress)
if err != nil {
buffer.Leak()
if !notFirstTime {
@ -285,34 +264,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
}
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
options := N.NewReadWaitOptions(nil, destination)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket()
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
if err != nil {
buffer.Release()
continue
}
buffer := options.Copy(packetBuffer.Buffer)
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
err = destination.WritePacket(buffer, packetBuffer.Destination)
N.PutPacketBuffer(packetBuffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
@ -320,7 +290,14 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
return
}
@ -339,15 +316,3 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack
group.FastFail()
return group.Run(ctx)
}
// Deprecated: not used
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
switch len(contextList) {
case 0:
return CopyPacketConn(context.Background(), source, destination)
case 1:
return CopyPacketConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
}

View file

@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
//goland:noinspection GoDirectComparisonOfErrors
if w.readErr != nil {
buffer.Release()
return w.readErr != windows.WSAEWOULDBLOCK
}
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
w.options.PostReturn(buffer)
w.buffer = buffer
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:

View file

@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So
}
}
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
return &destinationNATPacketConn{
NetPacketConn: conn,
origin: origin,
destination: destination,
}
}
type unidirectionalNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
type destinationNATPacketConn struct {
N.NetPacketConn
origin M.Socksaddr
destination M.Socksaddr
}
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
if M.SocksaddrFromNet(addr) == c.origin {
addr = c.destination.UDPAddr()
}
return
}
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if M.SocksaddrFromNet(addr) == c.destination {
addr = c.origin.UDPAddr()
}
return c.NetPacketConn.WriteTo(p, addr)
}
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return
}
if destination == c.origin {
destination = c.destination
}
return
}
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if destination == c.destination {
destination = c.origin
}
return c.NetPacketConn.WritePacket(buffer, destination)
}
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}
func (c *destinationNATPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination

View file

@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})

View file

@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
return i.timeout
}
func (i *Instance) SetTimeout(timeout time.Duration) {
func (i *Instance) SetTimeout(timeout time.Duration) bool {
i.timeout = timeout
i.Update()
return i.Update()
}
func (i *Instance) wait() {

View file

@ -13,7 +13,7 @@ import (
type PacketConn interface {
N.PacketConn
Timeout() time.Duration
SetTimeout(timeout time.Duration)
SetTimeout(timeout time.Duration) bool
}
type TimerPacketConn struct {
@ -24,10 +24,12 @@ type TimerPacketConn struct {
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
oldTimeout := timeoutConn.Timeout()
if timeout < oldTimeout {
timeoutConn.SetTimeout(timeout)
if oldTimeout > 0 && timeout >= oldTimeout {
return ctx, conn
}
if timeoutConn.SetTimeout(timeout) {
return ctx, conn
}
return ctx, conn
}
err := conn.SetReadDeadline(time.Time{})
if err == nil {
@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
return c.instance.Timeout()
}
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
c.instance.SetTimeout(timeout)
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
return c.instance.SetTimeout(timeout)
}
func (c *TimerPacketConn) Close() error {

View file

@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
return c.timeout
}
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
c.timeout = timeout
c.PacketConn.SetReadDeadline(time.Now())
return c.PacketConn.SetReadDeadline(time.Now()) == nil
}
func (c *TimeoutPacketConn) Close() error {

View file

@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
return -1
}
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}
//go:norace
func Dup[T any](obj T) T {
pointer := uintptr(unsafe.Pointer(&obj))
@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
return arr
}
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
ret := make(map[V]K, len(m))
for k, v := range m {
ret[v] = k
}
return ret
}
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
@ -362,24 +382,3 @@ func Close(closers ...any) error {
}
return retErr
}
// Deprecated: wtf is this?
type Starter interface {
Start() error
}
// Deprecated: wtf is this?
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
continue
}
if starter, isStarter := rawStarter.(Starter); isStarter {
err := starter.Start()
if err != nil {
return err
}
}
}
return nil
}

View file

@ -9,15 +9,15 @@ import (
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
switch network {
case "tcp6", "udp6":

View file

@ -3,21 +3,57 @@ package control
import (
"net"
"net/netip"
"unsafe"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
)
type InterfaceFinder interface {
Update() error
Interfaces() []Interface
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
InterfaceByAddr(addr netip.Addr) (*Interface, error)
ByName(name string) (*Interface, error)
ByIndex(index int) (*Interface, error)
ByAddr(addr netip.Addr) (*Interface, error)
}
type Interface struct {
Index int
MTU int
Name string
Addresses []netip.Prefix
HardwareAddr net.HardwareAddr
Flags net.Flags
Addresses []netip.Prefix
}
func (i Interface) Equals(other Interface) bool {
return i.Index == other.Index &&
i.MTU == other.MTU &&
i.Name == other.Name &&
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
i.Flags == other.Flags &&
common.Equal(i.Addresses, other.Addresses)
}
func (i Interface) NetInterface() net.Interface {
return *(*net.Interface)(unsafe.Pointer(&i))
}
func InterfaceFromNet(iif net.Interface) (Interface, error) {
ifAddrs, err := iif.Addrs()
if err != nil {
return Interface{}, err
}
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
}
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
return Interface{
Index: iif.Index,
MTU: iif.MTU,
Name: iif.Name,
HardwareAddr: iif.HardwareAddr,
Flags: iif.Flags,
Addresses: addresses,
}
}

View file

@ -3,11 +3,8 @@ package control
import (
"net"
"net/netip"
_ "unsafe"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
@ -27,18 +24,12 @@ func (f *DefaultInterfaceFinder) Update() error {
}
interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs {
ifAddrs, err := netIf.Addrs()
var iif Interface
iif, err = InterfaceFromNet(netIf)
if err != nil {
return err
}
interfaces = append(interfaces, Interface{
Index: netIf.Index,
MTU: netIf.MTU,
Name: netIf.Name,
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
HardwareAddr: netIf.HardwareAddr,
Flags: netIf.Flags,
})
interfaces = append(interfaces, iif)
}
f.interfaces = interfaces
return nil
@ -52,46 +43,41 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces
}
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return netInterface.Index, nil
return &netInterface, nil
}
}
netInterface, err := net.InterfaceByName(name)
if err != nil {
return 0, err
_, err := net.InterfaceByName(name)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByName(name)
}
f.Update()
return netInterface.Index, nil
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return netInterface.Name, nil
return &netInterface, nil
}
}
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
_, err := net.InterfaceByIndex(index)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByIndex(index)
}
f.Update()
return netInterface.Name, nil
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
err := f.Update()
if err != nil {
return nil, err
}
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {

View file

@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
if interfaceName == "" {
return os.ErrInvalid
}
var err error
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil {

View file

@ -11,19 +11,19 @@ import (
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" {
err = bind4(handle, interfaceIndex)
err := bind4(handle, interfaceIndex)
if err != nil {
return err
}

View file

@ -4,19 +4,26 @@ import (
"os"
"syscall"
N "github.com/sagernet/sing/common/network"
"golang.org/x/sys/unix"
)
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
switch network {
case "udp4":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
}
case "udp6":
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
}
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
}
}

View file

@ -11,17 +11,19 @@ import (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
if network == "udp" || network == "udp4" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}
if network == "udp6" {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
if network == "udp" || network == "udp6" {
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -25,17 +25,19 @@ const (
func DisableUDPFragment() Func {
return func(network, address string, conn syscall.RawConn) error {
switch N.NetworkName(network) {
case N.NetworkUDP:
default:
if N.NetworkName(network) != N.NetworkUDP {
return nil
}
return Raw(conn, func(fd uintptr) error {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
if network == "udp" || network == "udp4" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}
if network == "udp6" {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
if network == "udp" || network == "udp6" {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
}
}

View file

@ -12,6 +12,7 @@ import (
F "github.com/sagernet/sing/common/format"
)
// Deprecated: wtf is this?
type Handler interface {
NewError(ctx context.Context, err error)
}
@ -39,11 +40,11 @@ func Extend(cause error, message ...any) error {
}
func IsClosedOrCanceled(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
}
func IsClosed(err error) bool {
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
}
func IsCanceled(err error) bool {

View file

@ -1,24 +1,14 @@
package exceptions
import "github.com/sagernet/sing/common"
import (
"errors"
type HasInnerError interface {
Unwrap() error
}
"github.com/sagernet/sing/common"
)
// Deprecated: Use errors.Unwrap instead.
func Unwrap(err error) error {
for {
inner, ok := err.(HasInnerError)
if !ok {
break
}
innerErr := inner.Unwrap()
if innerErr == nil {
break
}
err = innerErr
}
return err
return errors.Unwrap(err)
}
func Cast[T any](err error) (T, bool) {

View file

@ -63,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool {
return true
}
}
err = Unwrap(err)
multiErr, isMulti := err.(MultiError)
if !isMulti {
return false
}
return common.All(multiErr.Unwrap(), func(it error) bool {
return IsMulti(it, targetList...)
})
return false
}

View file

@ -12,7 +12,6 @@ type TimeoutError interface {
func IsTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
//goland:noinspection GoDeprecation
//nolint:staticcheck
return netErr.Temporary() && netErr.Timeout()
}

View file

@ -2,13 +2,14 @@ package badjson
import (
"bytes"
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
func Decode(content []byte) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(content))
func Decode(ctx context.Context, content []byte) (any, error) {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
return decodeJSON(decoder)
}

View file

@ -1,6 +1,7 @@
package badjson
import (
"context"
"os"
"reflect"
@ -9,75 +10,75 @@ import (
"github.com/sagernet/sing/common/json"
)
func Omitempty[T any](value T) (T, error) {
objectContent, err := json.Marshal(value)
func Omitempty[T any](ctx context.Context, value T) (T, error) {
objectContent, err := json.MarshalContext(ctx, value)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object")
}
rawNewObject, err := Decode(objectContent)
rawNewObject, err := Decode(ctx, objectContent)
if err != nil {
return common.DefaultValue[T](), err
}
newObjectContent, err := json.Marshal(rawNewObject)
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
}
var newObject T
err = json.Unmarshal(newObjectContent, &newObject)
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
}
return newObject, nil
}
func Merge[T any](source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.Marshal(source)
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
if rawSource == nil {
return destination, nil
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
if rawDestination == nil {
return source, nil
}
rawSource, err := json.Marshal(source)
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend)
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options")
}
var merged T
err = json.Unmarshal(rawMerged, &merged)
err = json.UnmarshalContext(ctx, rawMerged, &merged)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
}
return merged, nil
}
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid
} else if rawSource == nil {
@ -85,16 +86,16 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
} else if rawDestination == nil {
return rawSource, nil
}
source, err := Decode(rawSource)
source, err := Decode(ctx, rawSource)
if err != nil {
return nil, E.Cause(err, "decode source")
}
destination, err := Decode(rawDestination)
destination, err := Decode(ctx, rawDestination)
if err != nil {
return nil, E.Cause(err, "decode destination")
}
if source == nil {
return json.Marshal(destination)
return json.MarshalContext(ctx, destination)
} else if destination == nil {
return json.Marshal(source)
}
@ -102,7 +103,7 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl
if err != nil {
return nil, err
}
return json.Marshal(merged)
return json.MarshalContext(ctx, merged)
}
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {

View file

@ -1,36 +1,44 @@
package badjson
import (
"context"
"reflect"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
)
func MarshallObjects(objects ...any) ([]byte, error) {
return MarshallObjectsContext(context.Background(), objects...)
}
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
if len(objects) == 1 {
return json.Marshal(objects[0])
}
var content JSONObject
for _, object := range objects {
objectMap, err := newJSONObject(object)
objectMap, err := newJSONObject(ctx, object)
if err != nil {
return nil, err
}
content.PutAll(objectMap)
}
return content.MarshalJSON()
return content.MarshalJSONContext(ctx)
}
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(parentObject)
if err != nil {
return err
}
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
}
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
var content JSONObject
err = content.UnmarshalJSON(inputContent)
err := content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
for _, key := range parentContent.Keys() {
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
content.Remove(key)
}
if object == nil {
@ -39,20 +47,20 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error
}
return E.New("unexpected key: ", content.Keys()[0])
}
inputContent, err = content.MarshalJSON()
inputContent, err = content.MarshalJSONContext(ctx)
if err != nil {
return err
}
return json.UnmarshalDisallowUnknownFields(inputContent, object)
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
}
func newJSONObject(object any) (*JSONObject, error) {
inputContent, err := json.Marshal(object)
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
inputContent, err := json.MarshalContext(ctx, object)
if err != nil {
return nil, err
}
var content JSONObject
err = content.UnmarshalJSON(inputContent)
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return nil, err
}

View file

@ -2,6 +2,7 @@ package badjson
import (
"bytes"
"context"
"strings"
"github.com/sagernet/sing/common"
@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool {
}
func (m *JSONObject) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
})
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
}
func (m *JSONObject) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {

View file

@ -2,6 +2,7 @@ package badjson
import (
"bytes"
"context"
"strings"
E "github.com/sagernet/sing/common/exceptions"
@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
}
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := m.Entries()
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
}
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
err = m.decodeJSON(ctx, decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
return nil
}
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
for decoder.More() {
keyToken, err := decoder.Token()
if err != nil {
return err
}
keyContent, err := json.Marshal(keyToken)
keyContent, err := json.MarshalContext(ctx, keyToken)
if err != nil {
return err
}
var entryKey K
err = json.Unmarshal(keyContent, &entryKey)
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
if err != nil {
return err
}

View file

@ -1,30 +1,35 @@
package badoption
import (
"context"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
type Listable[T any] []T
func (l Listable[T]) MarshalJSON() ([]byte, error) {
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
arrayList := []T(l)
if len(arrayList) == 1 {
return json.Marshal(arrayList[0])
}
return json.Marshal(arrayList)
return json.MarshalContext(ctx, arrayList)
}
func (l *Listable[T]) UnmarshalJSON(content []byte) error {
err := json.UnmarshalDisallowUnknownFields(content, (*[]T)(l))
if err == nil {
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
if string(content) == "null" {
return nil
}
var singleItem T
newError := json.UnmarshalDisallowUnknownFields(content, &singleItem)
if newError != nil {
return E.Errors(err, newError)
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
if err == nil {
*l = []T{singleItem}
return nil
}
*l = []T{singleItem}
return nil
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
if newErr == nil {
return nil
}
return E.Errors(err, newErr)
}

View file

@ -35,6 +35,13 @@ func (a *Addr) UnmarshalJSON(content []byte) error {
type Prefix netip.Prefix
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefix) MarshalJSON() ([]byte, error) {
return json.Marshal(netip.Prefix(*p).String())
}
@ -55,6 +62,13 @@ func (p *Prefix) UnmarshalJSON(content []byte) error {
type Prefixable netip.Prefix
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
if p == nil {
return defaultPrefix
}
return netip.Prefix(*p)
}
func (p *Prefixable) MarshalJSON() ([]byte, error) {
prefix := netip.Prefix(*p)
if prefix.Bits() == prefix.Addr().BitLen() {

View file

@ -0,0 +1,23 @@
package json
import (
"context"
"github.com/sagernet/sing/common/json/internal/contextjson"
)
var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,11 @@
package json
import "context"
type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}
type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}

View file

@ -0,0 +1,43 @@
package json_test
import (
"context"
"testing"
"github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type myStruct struct {
value string
}
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return json.Marshal(ctx.Value("key").(string))
}
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
m.value = ctx.Value("key").(string)
return nil
}
//nolint:staticcheck
func TestMarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
b, err := json.MarshalContext(ctx, &s)
require.NoError(t, err)
require.Equal(t, []byte(`"value"`), b)
}
//nolint:staticcheck
func TestUnmarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
require.NoError(t, err)
require.Equal(t, "value", s.value)
}

View file

@ -8,6 +8,7 @@
package json
import (
"context"
"encoding"
"encoding/base64"
"fmt"
@ -95,10 +96,15 @@ import (
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
func Unmarshal(data []byte, v any) error {
return UnmarshalContext(context.Background(), data, v)
}
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
var d decodeState
d.ctx = ctx
err := checkValid(data, &d.scan)
if err != nil {
return err
@ -209,6 +215,7 @@ type errorContext struct {
// decodeState represents the state while decoding a JSON value.
type decodeState struct {
ctx context.Context
data []byte
off int // next read offset in data
opcode int // last read result
@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
}
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
return u, nil, nil, reflect.Value{}
}
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
return nil, cu, nil, reflect.Value{}
}
if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
return nil, nil, u, reflect.Value{}
}
}
}
@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
v = v.Elem()
}
}
return nil, nil, v
return nil, nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
@ -612,7 +631,7 @@ var (
// The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip()
@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return nil
}
isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull)
u, cu, ut, pv := indirect(v, isNull)
if u != nil {
err := u.UnmarshalJSON(item)
if err != nil {
@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
}
return nil
}
if cu != nil {
err := cu.UnmarshalJSONContext(d.ctx, item)
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {

View file

@ -12,6 +12,7 @@ package json
import (
"bytes"
"context"
"encoding"
"encoding/base64"
"fmt"
@ -156,7 +157,11 @@ import (
// handle them. Passing cyclic structures to Marshal will result in
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
return MarshalContext(context.Background(), v)
}
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
e := newEncodeState(ctx)
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true})
@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
type encodeState struct {
bytes.Buffer // accumulated output
ctx context.Context
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
func newEncodeState(ctx context.Context) *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
e.ptrLevel = 0
return e
}
return &encodeState{ptrSeen: make(map[any]struct{})}
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
}
// jsonError is an error wrapper type for internal use only.
@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
// newTypeEncoder constructs an encoderFunc for a type.
@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) {
return marshalerEncoder
}
if t.Implements(contextMarshalerType) {
return contextMarshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
}
}
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(ContextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(ContextMarshaler)
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice
}
}

View file

@ -0,0 +1,20 @@
package json
import (
"reflect"
"github.com/sagernet/sing/common"
)
func ObjectKeys(object reflect.Type) []string {
switch object.Kind() {
case reflect.Pointer:
return ObjectKeys(object.Elem())
case reflect.Struct:
default:
panic("invalid non-struct input")
}
return common.Map(cachedTypeFields(object).list, func(field field) string {
return field.name
})
}

View file

@ -0,0 +1,26 @@
package json_test
import (
"reflect"
"testing"
json "github.com/sagernet/sing/common/json/internal/contextjson"
"github.com/stretchr/testify/require"
)
type MyObject struct {
Hello string `json:"hello,omitempty"`
MyWorld
MyWorld2 string `json:"-"`
}
type MyWorld struct {
World string `json:"world,omitempty"`
}
func TestObjectKeys(t *testing.T) {
t.Parallel()
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
require.Equal(t, []string{"hello", "world"}, keys)
}

View file

@ -6,6 +6,7 @@ package json
import (
"bytes"
"context"
"errors"
"io"
)
@ -29,7 +30,11 @@ type Decoder struct {
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
return NewDecoderContext(context.Background(), r)
}
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
return &Decoder{r: r, d: decodeState{ctx: ctx}}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
@ -183,6 +188,7 @@ func nonSpace(b []byte) bool {
// An Encoder writes JSON values to an output stream.
type Encoder struct {
ctx context.Context
w io.Writer
err error
escapeHTML bool
@ -194,7 +200,11 @@ type Encoder struct {
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
return NewEncoderContext(context.Background(), w)
}
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error {
return enc.err
}
e := newEncodeState()
e := newEncodeState(enc.ctx)
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})

View file

@ -1,5 +1,7 @@
package json
import "context"
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
var d decodeState
d.disallowUnknownFields = true
@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error {
d.init(data)
return d.unmarshal(v)
}
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
var d decodeState
d.ctx = ctx
d.disallowUnknownFields = true
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}

View file

@ -2,6 +2,7 @@ package json
import (
"bytes"
"context"
"errors"
"strings"
@ -10,7 +11,11 @@ import (
)
func UnmarshalExtended[T any](content []byte) (T, error) {
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
return UnmarshalExtendedContext[T](context.Background(), content)
}
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
var value T
err := decoder.Decode(&value)
if err == nil {

View file

@ -1,5 +1,6 @@
package metadata
// Deprecated: wtf is this?
type Metadata struct {
Protocol string
Source Socksaddr

View file

@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"sync"
"time"
"github.com/sagernet/sing/common"
@ -70,8 +71,39 @@ type ExtendedConn interface {
net.Conn
}
type CloseHandlerFunc = func(it error)
func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
if onClose == nil {
panic("nil onClose")
}
if parent == nil {
return onClose
}
return func(it error) {
onClose(it)
parent(it)
}
}
func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc {
var once sync.Once
return func(it error) {
once.Do(func() {
onClose(it)
})
}
}
// Deprecated: Use TCPConnectionHandlerEx instead.
type TCPConnectionHandler interface {
NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error
NewConnection(ctx context.Context, conn net.Conn,
//nolint:staticcheck
metadata M.Metadata) error
}
type TCPConnectionHandlerEx interface {
NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
}
type NetPacketConn interface {
@ -85,12 +117,26 @@ type BindPacketConn interface {
net.Conn
}
// Deprecated: Use UDPHandlerEx instead.
type UDPHandler interface {
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer,
//nolint:staticcheck
metadata M.Metadata) error
}
type UDPHandlerEx interface {
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
}
// Deprecated: Use UDPConnectionHandlerEx instead.
type UDPConnectionHandler interface {
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error
NewPacketConnection(ctx context.Context, conn PacketConn,
//nolint:staticcheck
metadata M.Metadata) error
}
type UDPConnectionHandlerEx interface {
NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
}
type CachedReader interface {
@ -101,11 +147,6 @@ type CachedPacketReader interface {
ReadCachedPacket() *PacketBuffer
}
type PacketBuffer struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
type WithUpstreamReader interface {
UpstreamReader() any
}

View file

@ -13,10 +13,6 @@ type Dialer interface {
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
}
type PayloadDialer interface {
DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error)
}
type ParallelDialer interface {
Dialer
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)

View file

@ -15,19 +15,39 @@ type ReadWaitOptions struct {
MTU int
}
func NewReadWaitOptions(source any, destination any) ReadWaitOptions {
return ReadWaitOptions{
FrontHeadroom: CalculateFrontHeadroom(destination),
RearHeadroom: CalculateRearHeadroom(destination),
MTU: CalculateMTU(source, destination),
}
}
func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
}
func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer {
if o.FrontHeadroom > buffer.Start() ||
o.RearHeadroom > buffer.FreeLen() {
newBuffer := o.newBuffer(buf.UDPBufferSize, false)
newBuffer.Write(buffer.Bytes())
buffer.Release()
return newBuffer
} else {
return buffer
}
}
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
return o.newBuffer(buf.BufferSize)
return o.newBuffer(buf.BufferSize, true)
}
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
return o.newBuffer(buf.UDPBufferSize)
return o.newBuffer(buf.UDPBufferSize, true)
}
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer {
var bufferSize int
if o.MTU > 0 {
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
@ -38,7 +58,7 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
if o.FrontHeadroom > 0 {
buffer.Resize(o.FrontHeadroom, 0)
}
if o.RearHeadroom > 0 {
if o.RearHeadroom > 0 && reserve {
buffer.Reserve(o.RearHeadroom)
}
return buffer

View file

@ -1,6 +1,9 @@
package network
import (
"io"
"net"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
@ -13,17 +16,75 @@ type HandshakeSuccess interface {
HandshakeSuccess() error
}
func ReportHandshakeFailure(conn any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn {
type ConnHandshakeSuccess interface {
ConnHandshakeSuccess(conn net.Conn) error
}
type PacketConnHandshakeSuccess interface {
PacketConnHandshakeSuccess(conn net.PacketConn) error
}
func ReportHandshakeFailure(reporter any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
return E.Cause(err, "write handshake failure")
})
}
return nil
}
func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error {
if err != nil {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
hErr := handshakeConn.HandshakeFailure(err)
err = E.Append(err, hErr, func(err error) error {
if closer, isCloser := reporter.(io.Closer); isCloser {
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close")
})
}
return E.Cause(err, "write handshake failure")
})
} else {
if tcpConn, isTCPConn := common.Cast[interface {
SetLinger(sec int) error
}](reporter); isTCPConn {
tcpConn.SetLinger(0)
}
}
err = E.Append(err, reporter.Close(), func(err error) error {
return E.Cause(err, "close")
})
}
if onClose != nil {
onClose(err)
}
return err
}
func ReportHandshakeSuccess(conn any) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn {
// Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead
func ReportHandshakeSuccess(reporter any) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}
func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error {
if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.ConnHandshakeSuccess(conn)
}
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}
func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error {
if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.PacketConnHandshakeSuccess(conn)
}
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil

35
common/network/packet.go Normal file
View file

@ -0,0 +1,35 @@
package network
import (
"sync"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type PacketBuffer struct {
Buffer *buf.Buffer
Destination M.Socksaddr
}
var packetPool = sync.Pool{
New: func() any {
return new(PacketBuffer)
},
}
func NewPacketBuffer() *PacketBuffer {
return packetPool.Get().(*PacketBuffer)
}
func PutPacketBuffer(packet *PacketBuffer) {
*packet = PacketBuffer{}
packetPool.Put(packet)
}
func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) {
for _, packet := range packetBuffers {
packet.Buffer.Release()
PutPacketBuffer(packet)
}
}

View file

@ -11,6 +11,7 @@ type ThreadUnsafeWriter interface {
}
// Deprecated: Use ReadWaiter interface instead.
type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
@ -18,7 +19,6 @@ type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead.
type ThreadSafePacketReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
}

View file

@ -7,6 +7,7 @@ import (
)
func SetSystemTime(nowTime time.Time) error {
nowTime = nowTime.UTC()
var systemTime windows.Systemtime
systemTime.Year = uint16(nowTime.Year())
systemTime.Month = uint16(nowTime.Month())

View file

@ -20,6 +20,5 @@ func InitializeSeed() {
func initializeSeed() {
var seed int64
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
//goland:noinspection GoDeprecation
mRand.Seed(seed)
}

View file

@ -27,7 +27,6 @@ func ToByteReader(reader io.Reader) io.ByteReader {
// Deprecated: Use binary.ReadUvarint instead.
func ReadUVariant(reader io.Reader) (uint64, error) {
//goland:noinspection GoDeprecation
return binary.ReadUvarint(ToByteReader(reader))
}

View file

@ -16,18 +16,23 @@ import (
"github.com/sagernet/sing/common/pipe"
)
// Deprecated: Use N.UDPConnectionHandler instead.
//
//nolint:staticcheck
type Handler interface {
N.UDPConnectionHandler
E.Handler
}
type Service[K comparable] struct {
nat *cache.LruCache[K, *conn]
handler Handler
nat *cache.LruCache[K, *conn]
handler Handler
handlerEx N.UDPConnectionHandlerEx
}
// Deprecated: Use NewEx instead.
func New[K comparable](maxAge int64, handler Handler) *Service[K] {
return &Service[K]{
service := &Service[K]{
nat: cache.New(
cache.WithAge[K, *conn](maxAge),
cache.WithUpdateAgeOnGet[K, *conn](),
@ -37,11 +42,27 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] {
),
handler: handler,
}
return service
}
func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] {
service := &Service[K]{
nat: cache.New(
cache.WithAge[K, *conn](maxAge),
cache.WithUpdateAgeOnGet[K, *conn](),
cache.WithEvict[K, *conn](func(key K, conn *conn) {
conn.Close()
}),
),
handlerEx: handler,
}
return service
}
func (s *Service[T]) WriteIsThreadUnsafe() {
}
// Deprecated: don't use
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, &DirectBackWriter{conn, natConn}
@ -61,18 +82,30 @@ func (w *DirectBackWriter) Upstream() any {
return w.Source
}
// Deprecated: use NewPacketEx instead.
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, init(natConn)
})
}
func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) {
s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, init(natConn)
})
}
// Deprecated: Use NewPacketConnectionEx instead.
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init)
}
func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet, 64),
localAddr: metadata.Source,
remoteAddr: metadata.Destination,
localAddr: source,
remoteAddr: destination,
readDeadline: pipe.MakeDeadline(),
}
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
@ -81,26 +114,34 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
if !loaded {
ctx, c.source = init(c)
go func() {
err := s.handler.NewPacketConnection(ctx, c, metadata)
if err != nil {
s.handler.NewError(ctx, err)
if s.handlerEx != nil {
s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) {
s.nat.Delete(key)
})
} else {
//nolint:staticcheck
err := s.handler.NewPacketConnection(ctx, c, M.Metadata{
Source: source,
Destination: destination,
})
if err != nil {
s.handler.NewError(ctx, err)
}
c.Close()
s.nat.Delete(key)
}
c.Close()
s.nat.Delete(key)
}()
} else {
c.localAddr = metadata.Source
}
if common.Done(c.ctx) {
s.nat.Delete(key)
if !common.Done(ctx) {
s.NewContextPacket(ctx, key, buffer, metadata, init)
s.NewContextPacketEx(ctx, key, buffer, source, destination, init)
}
return
}
c.data <- packet{
data: buffer,
destination: metadata.Destination,
destination: destination,
}
}
@ -172,10 +213,6 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *conn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *conn) Upstream() any {
return c.source
}

138
common/udpnat2/conn.go Normal file
View file

@ -0,0 +1,138 @@
package udpnat
import (
"io"
"net"
"net/netip"
"os"
"sync"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/canceler"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
)
type Conn interface {
N.PacketConn
SetHandler(handler N.UDPHandlerEx)
canceler.PacketConn
}
var _ Conn = (*natConn)(nil)
type natConn struct {
cache freelru.Cache[netip.AddrPort, *natConn]
writer N.PacketWriter
localAddr M.Socksaddr
handlerAccess sync.RWMutex
handler N.UDPHandlerEx
packetChan chan *N.PacketBuffer
closeOnce sync.Once
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
}
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select {
case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer)
destination := p.Destination
p.Buffer.Release()
N.PutPacketBuffer(p)
return destination, err
case <-c.doneChan:
return M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination)
}
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
destination = packet.Destination
N.PutPacketBuffer(packet)
return
case <-c.doneChan:
return nil, M.Socksaddr{}, io.ErrClosedPipe
case <-c.readDeadline.Wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
c.handlerAccess.Lock()
c.handler = handler
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
c.handlerAccess.Unlock()
fetch:
for {
select {
case packet := <-c.packetChan:
c.handler.NewPacketEx(packet.Buffer, packet.Destination)
N.PutPacketBuffer(packet)
continue fetch
default:
break fetch
}
}
}
func (c *natConn) Timeout() time.Duration {
rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort())
if !loaded || rawConn != c {
return 0
}
return time.Until(lifetime)
}
func (c *natConn) SetTimeout(timeout time.Duration) bool {
return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
}
func (c *natConn) Close() error {
c.closeOnce.Do(func() {
close(c.doneChan)
})
return nil
}
func (c *natConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *natConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *natConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *natConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}
func (c *natConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *natConn) Upstream() any {
return c.writer
}

103
common/udpnat2/service.go Normal file
View file

@ -0,0 +1,103 @@
package udpnat
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type Service struct {
cache freelru.Cache[netip.AddrPort, *natConn]
handler N.UDPConnectionHandlerEx
prepare PrepareFunc
}
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service {
if timeout == 0 {
panic("invalid timeout")
}
var cache freelru.Cache[netip.AddrPort, *natConn]
if !shared {
cache = common.Must1(freelru.NewSynced[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
} else {
cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
}
cache.SetLifetime(timeout)
cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
select {
case <-conn.doneChan:
return false
default:
return true
}
})
cache.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
conn.Close()
})
return &Service{
cache: cache,
handler: handler,
prepare: prepare,
}
}
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
conn, _, ok := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) {
ok, ctx, writer, onClose := s.prepare(source, destination, userData)
if !ok {
return nil, false
}
newConn := &natConn{
cache: s.cache,
writer: writer,
localAddr: source,
packetChan: make(chan *N.PacketBuffer, 64),
doneChan: make(chan struct{}),
readDeadline: pipe.MakeDeadline(),
}
go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose)
return newConn, true
})
if !ok {
return
}
buffer := conn.readWaitOptions.NewPacketBuffer()
for _, bufferSlice := range bufferSlices {
buffer.Write(bufferSlice)
}
conn.handlerAccess.RLock()
handler := conn.handler
conn.handlerAccess.RUnlock()
if handler != nil {
handler.NewPacketEx(buffer, destination)
return
}
packet := N.NewPacketBuffer()
*packet = N.PacketBuffer{
Buffer: buffer,
Destination: destination,
}
select {
case conn.packetChan <- packet:
default:
packet.Buffer.Release()
N.PutPacketBuffer(packet)
}
}
func (s *Service) Purge() {
s.cache.Purge()
}
func (s *Service) PurgeExpired() {
s.cache.PurgeExpired()
}

View file

@ -1,16 +1,14 @@
//go:build windows
package windnsapi
import (
"runtime"
"testing"
"github.com/stretchr/testify/require"
)
func TestDNSAPI(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}
t.Parallel()
require.NoError(t, FlushResolverCache())
}

View file

@ -0,0 +1,217 @@
//go:build windows
package winiphlpapi
import (
"context"
"encoding/binary"
"net"
"net/netip"
"os"
"time"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func LoadEStats() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetTcpTable.Find()
if err != nil {
return err
}
err = procGetTcp6Table.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procGetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcpConnectionEStats.Find()
if err != nil {
return err
}
err = procSetPerTcp6ConnectionEStats.Find()
if err != nil {
return err
}
return nil
}
func LoadExtendedTable() error {
err := modiphlpapi.Load()
if err != nil {
return err
}
err = procGetExtendedTcpTable.Find()
if err != nil {
return err
}
err = procGetExtendedUdpTable.Find()
if err != nil {
return err
}
return nil
}
func FindPid(network string, source netip.AddrPort) (uint32, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
if source.Addr().Is4() {
tcpTable, err := GetExtendedTcpTable()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
tcpTable, err := GetExtendedTcp6Table()
if err != nil {
return 0, err
}
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
case N.NetworkUDP:
if source.Addr().Is4() {
udpTable, err := GetExtendedUdpTable()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
} else {
udpTable, err := GetExtendedUdp6Table()
if err != nil {
return 0, err
}
for _, row := range udpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
return row.DwOwningPid, nil
}
}
}
}
return 0, E.New("process not found for ", source)
}
func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error {
source := M.AddrPortFromNet(conn.LocalAddr())
destination := M.AddrPortFromNet(conn.RemoteAddr())
if source.Addr().Is4() {
tcpTable, err := GetTcpTable()
if err != nil {
return err
}
var tcpRow *MibTcpRow
for _, row := range tcpTable {
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
} else {
tcpTable, err := GetTcp6Table()
if err != nil {
return err
}
var tcpRow *MibTcp6Row
for _, row := range tcpTable {
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
tcpRow = &row
break
}
}
if tcpRow == nil {
return E.New("row not found for: ", source)
}
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: true,
})
if err != nil {
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
}
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
EnableCollection: false,
})
_, err = conn.Write(payload)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
if err != nil {
return err
}
if eStstsSendBuffer.CurRetxQueue == 0 {
return nil
}
time.Sleep(10 * time.Millisecond)
}
}
}
func DwordToAddr(addr uint32) netip.Addr {
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
}
func DwordToPort(dword uint32) uint16 {
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
}

View file

@ -0,0 +1,313 @@
//go:build windows
package winiphlpapi
import (
"errors"
"os"
"unsafe"
"golang.org/x/sys/windows"
)
const (
TcpTableBasicListener uint32 = iota
TcpTableBasicConnections
TcpTableBasicAll
TcpTableOwnerPidListener
TcpTableOwnerPidConnections
TcpTableOwnerPidAll
TcpTableOwnerModuleListener
TcpTableOwnerModuleConnections
TcpTableOwnerModuleAll
)
const (
UdpTableBasic uint32 = iota
UdpTableOwnerPid
UdpTableOwnerModule
)
const (
TcpConnectionEstatsSynOpts uint32 = iota
TcpConnectionEstatsData
TcpConnectionEstatsSndCong
TcpConnectionEstatsPath
TcpConnectionEstatsSendBuff
TcpConnectionEstatsRec
TcpConnectionEstatsObsRec
TcpConnectionEstatsBandwidth
TcpConnectionEstatsFineRtt
TcpConnectionEstatsMaximum
)
type MibTcpTable struct {
DwNumEntries uint32
Table [1]MibTcpRow
}
type MibTcpRow struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
}
type MibTcp6Table struct {
DwNumEntries uint32
Table [1]MibTcp6Row
}
type MibTcp6Row struct {
State uint32
LocalAddr [16]byte
LocalScopeId uint32
LocalPort uint32
RemoteAddr [16]byte
RemoteScopeId uint32
RemotePort uint32
}
type MibTcpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcpRowOwnerPid
}
type MibTcpRowOwnerPid struct {
DwState uint32
DwLocalAddr uint32
DwLocalPort uint32
DwRemoteAddr uint32
DwRemotePort uint32
DwOwningPid uint32
}
type MibTcp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibTcp6RowOwnerPid
}
type MibTcp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
UcRemoteAddr [16]byte
DwRemoteScopeId uint32
DwRemotePort uint32
DwState uint32
DwOwningPid uint32
}
type MibUdpTableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdpRowOwnerPid
}
type MibUdpRowOwnerPid struct {
DwLocalAddr uint32
DwLocalPort uint32
DwOwningPid uint32
}
type MibUdp6TableOwnerPid struct {
DwNumEntries uint32
Table [1]MibUdp6RowOwnerPid
}
type MibUdp6RowOwnerPid struct {
UcLocalAddr [16]byte
DwLocalScopeId uint32
DwLocalPort uint32
DwOwningPid uint32
}
type TcpEstatsSendBufferRodV0 struct {
CurRetxQueue uint64
MaxRetxQueue uint64
CurAppWQueue uint64
MaxAppWQueue uint64
}
type TcpEstatsSendBuffRwV0 struct {
EnableCollection bool
}
const (
offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table)
offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table)
offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table)
offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table)
sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{})
sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{})
)
func GetTcpTable() ([]MibTcpRow, error) {
var size uint32
err := getTcpTable(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcpTable(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil
}
}
func GetTcp6Table() ([]MibTcp6Row, error) {
var size uint32
err := getTcp6Table(nil, &size, false)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, err
}
for {
table := make([]byte, size)
err = getTcp6Table(&table[0], &size, false)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, err
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil
}
}
func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) {
var size uint32
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil
}
}
func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) {
var size uint32
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
for {
table := make([]byte, size)
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
if err != nil {
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
continue
}
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
}
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil
}
}
func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcpConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}
func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) {
var rod TcpEstatsSendBufferRodV0
err := getPerTcp6ConnectionEStats(row,
TcpConnectionEstatsSendBuff,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&rod)),
0,
uint64(sizeOfTcpEstatsSendBufferRodV0),
)
if err != nil {
return nil, err
}
return &rod, nil
}
func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}
func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error {
return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
}

View file

@ -0,0 +1,90 @@
//go:build windows
package winiphlpapi_test
import (
"context"
"net"
"syscall"
"testing"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/winiphlpapi"
"github.com/stretchr/testify/require"
)
func TestFindPidTcp4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidTcp6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidUdp4(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestFindPidUdp6(t *testing.T) {
t.Parallel()
conn, err := net.ListenPacket("udp", "[::1]:0")
require.NoError(t, err)
defer conn.Close()
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
require.NoError(t, err)
require.Equal(t, uint32(syscall.Getpid()), pid)
}
func TestWaitAck4(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}
func TestWaitAck6(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)
defer listener.Close()
go listener.Accept()
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
require.NoError(t, err)
}

View file

@ -0,0 +1,27 @@
package winiphlpapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable
//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table
//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats
//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats
//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats
//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats
//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable
//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable

View file

@ -0,0 +1,131 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winiphlpapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats")
procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats")
procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table")
procGetTcpTable = modiphlpapi.NewProc("GetTcpTable")
procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats")
procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats")
)
func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
var _p0 uint32
if bOrder {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
var _p0 uint32
if bOrder {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
var _p0 uint32
if order {
_p0 = 1
}
r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
var _p0 uint32
if order {
_p0 = 1
}
r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}
func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
if r0 != 0 {
errcode = syscall.Errno(r0)
}
return
}

201
contrab/freelru/LICENSE Normal file
View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

2
contrab/freelru/NOTICE Normal file
View file

@ -0,0 +1,2 @@
Go LRU Hashmap
Copyright 2022 Elasticsearch B.V.

View file

@ -0,0 +1,4 @@
# freelru
upstream: github.com/elastic/go-freelru@v0.16.0
source: github.com/sagernet/go-freelru@1b34934a560d528d1866415d440625ed2a2560fe

102
contrab/freelru/cache.go Normal file
View file

@ -0,0 +1,102 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package freelru
import "time"
type Cache[K comparable, V comparable] interface {
// SetLifetime sets the default lifetime of LRU elements.
// Lifetime 0 means "forever".
SetLifetime(lifetime time.Duration)
// SetOnEvict sets the OnEvict callback function.
// The onEvict function is called for each evicted lru entry.
SetOnEvict(onEvict OnEvictCallback[K, V])
SetHealthCheck(healthCheck HealthCheckCallback[K, V])
// Len returns the number of elements stored in the cache.
Len() int
// AddWithLifetime adds a key:value to the cache with a lifetime.
// Returns true, true if key was updated and eviction occurred.
AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool)
// Add adds a key:value to the cache.
// Returns true, true if key was updated and eviction occurred.
Add(key K, value V) (evicted bool)
// Get returns the value associated with the key, setting it as the most
// recently used item.
// If the found cache item is already expired, the evict function is called
// and the return value indicates that the key was not found.
Get(key K) (V, bool)
GetWithLifetime(key K) (V, time.Time, bool)
GetWithLifetimeNoExpire(key K) (V, time.Time, bool)
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.
GetAndRefresh(key K) (V, bool)
GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool)
// Peek looks up a key's value from the cache, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
Peek(key K) (V, bool)
PeekWithLifetime(key K) (V, time.Time, bool)
UpdateLifetime(key K, value V, lifetime time.Duration) bool
// Contains checks for the existence of a key, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
Contains(key K) bool
// Remove removes the key from the cache.
// The return value indicates whether the key existed or not.
// The evict function is called for the removed entry.
Remove(key K) bool
// RemoveOldest removes the oldest entry from the cache.
// Key, value and an indicator of whether the entry has been removed is returned.
// The evict function is called for the removed entry.
RemoveOldest() (key K, value V, removed bool)
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are not included.
// The evict function is called for each expired item.
Keys() []K
// Purge purges all data (key and value) from the LRU.
// The evict function is called for each expired item.
// The LRU metrics are reset.
Purge()
// PurgeExpired purges all expired items from the LRU.
// The evict function is called for each expired item.
PurgeExpired()
// Metrics returns the metrics of the cache.
Metrics() Metrics
// ResetMetrics resets the metrics of the cache and returns the previous state.
ResetMetrics() Metrics
}

767
contrab/freelru/lru.go Normal file
View file

@ -0,0 +1,767 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package freelru
import (
"errors"
"fmt"
"math"
"math/bits"
"time"
)
// OnEvictCallback is the type for the eviction function.
type OnEvictCallback[K comparable, V comparable] func(K, V)
// HashKeyCallback is the function that creates a hash from the passed key.
type HashKeyCallback[K comparable] func(K) uint32
type HealthCheckCallback[K comparable, V comparable] func(K, V) bool
type element[K comparable, V comparable] struct {
key K
value V
// bucketNext and bucketPrev are indexes in the space-dimension doubly-linked list of elements.
// That is to add/remove items to the collision bucket without re-allocations and with O(1)
// complexity.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.latest.prev is last element and
// &l.last.next is the latest element.
nextBucket, prevBucket uint32
// bucketPos is the bucket that an element belongs to.
bucketPos uint32
// next and prev are indexes in the time-dimension doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.latest.prev is last element and
// &l.last.next is the latest element.
next, prev uint32
// expire is the point in time when the element expires.
// Its value is Unix milliseconds since epoch.
expire int64
}
const emptyBucket = math.MaxUint32
// LRU implements a non-thread safe fixed size LRU cache.
type LRU[K comparable, V comparable] struct {
buckets []uint32 // contains positions of bucket lists or 'emptyBucket'
elements []element[K, V]
onEvict OnEvictCallback[K, V]
hash HashKeyCallback[K]
healthCheck HealthCheckCallback[K, V]
lifetime time.Duration
metrics Metrics
// used for element clearing after removal or expiration
emptyKey K
emptyValue V
head uint32 // index of the newest element in the cache
len uint32 // current number of elements in the cache
cap uint32 // max number of elements in the cache
size uint32 // size of the element array (X% larger than cap)
mask uint32 // bitmask to avoid the costly idiv in hashToPos() if size is a 2^n value
}
// Metrics contains metrics about the cache.
type Metrics struct {
Inserts uint64
Collisions uint64
Evictions uint64
Removals uint64
Hits uint64
Misses uint64
}
var _ Cache[int, int] = (*LRU[int, int])(nil)
// SetLifetime sets the default lifetime of LRU elements.
// Lifetime 0 means "forever".
func (lru *LRU[K, V]) SetLifetime(lifetime time.Duration) {
lru.lifetime = lifetime
}
// SetOnEvict sets the OnEvict callback function.
// The onEvict function is called for each evicted lru entry.
// Eviction happens
// - when the cache is full and a new entry is added (oldest entry is evicted)
// - when an entry is removed by Remove() or RemoveOldest()
// - when an entry is recognized as expired
// - when Purge() is called
func (lru *LRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
lru.onEvict = onEvict
}
func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
lru.healthCheck = healthCheck
}
// New constructs an LRU with the given capacity of elements.
// The hash function calculates a hash value from the keys.
func New[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) {
return NewWithSize[K, V](capacity, capacity, hash)
}
// NewWithSize constructs an LRU with the given capacity and size.
// The hash function calculates a hash value from the keys.
// A size greater than the capacity increases memory consumption and decreases the CPU consumption
// by reducing the chance of collisions.
// Size must not be lower than the capacity.
func NewWithSize[K comparable, V comparable](capacity, size uint32, hash HashKeyCallback[K]) (
*LRU[K, V], error,
) {
if capacity == 0 {
return nil, errors.New("capacity must be positive")
}
if size == emptyBucket {
return nil, fmt.Errorf("size must not be %#X", size)
}
if size < capacity {
return nil, fmt.Errorf("size (%d) is smaller than capacity (%d)", size, capacity)
}
if hash == nil {
return nil, errors.New("hash function must be set")
}
buckets := make([]uint32, size)
elements := make([]element[K, V], size)
var lru LRU[K, V]
initLRU(&lru, capacity, size, hash, buckets, elements)
return &lru, nil
}
func initLRU[K comparable, V comparable](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K],
buckets []uint32, elements []element[K, V],
) {
lru.cap = capacity
lru.size = size
lru.hash = hash
lru.buckets = buckets
lru.elements = elements
// If the size is 2^N, we can avoid costly divisions.
if bits.OnesCount32(lru.size) == 1 {
lru.mask = lru.size - 1
}
// Mark all slots as free.
for i := range lru.buckets {
lru.buckets[i] = emptyBucket
}
}
// hashToBucketPos converts a hash value into a position in the elements array.
func (lru *LRU[K, V]) hashToBucketPos(hash uint32) uint32 {
if lru.mask != 0 {
return hash & lru.mask
}
return fastModulo(hash, lru.size)
}
// fastModulo calculates x % n without using the modulo operator (~4x faster).
// Reference: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
func fastModulo(x, n uint32) uint32 {
return uint32((uint64(x) * uint64(n)) >> 32) //nolint:gosec
}
// hashToPos converts a key into a position in the elements array.
func (lru *LRU[K, V]) hashToPos(hash uint32) (bucketPos, elemPos uint32) {
bucketPos = lru.hashToBucketPos(hash)
elemPos = lru.buckets[bucketPos]
return
}
// setHead links the element as the head into the list.
func (lru *LRU[K, V]) setHead(pos uint32) {
// Both calls to setHead() check beforehand that pos != lru.head.
// So if you run into this situation, you likely use FreeLRU in a concurrent situation
// without proper locking. It requires a write lock, even around Get().
// But better use SyncedLRU or SharedLRU in such a case.
if pos == lru.head {
panic(pos)
}
lru.elements[pos].prev = lru.head
lru.elements[pos].next = lru.elements[lru.head].next
lru.elements[lru.elements[lru.head].next].prev = pos
lru.elements[lru.head].next = pos
lru.head = pos
}
// unlinkElement removes the element from the elements list.
func (lru *LRU[K, V]) unlinkElement(pos uint32) {
lru.elements[lru.elements[pos].prev].next = lru.elements[pos].next
lru.elements[lru.elements[pos].next].prev = lru.elements[pos].prev
}
// unlinkBucket removes the element from the buckets list.
func (lru *LRU[K, V]) unlinkBucket(pos uint32) {
prevBucket := lru.elements[pos].prevBucket
nextBucket := lru.elements[pos].nextBucket
if prevBucket == nextBucket && prevBucket == pos { //nolint:gocritic
// The element references itself, so it's the only bucket entry
lru.buckets[lru.elements[pos].bucketPos] = emptyBucket
return
}
lru.elements[prevBucket].nextBucket = nextBucket
lru.elements[nextBucket].prevBucket = prevBucket
lru.buckets[lru.elements[pos].bucketPos] = nextBucket
}
// evict evicts the element at the given position.
func (lru *LRU[K, V]) evict(pos uint32) {
if pos == lru.head {
lru.head = lru.elements[pos].prev
}
lru.unlinkElement(pos)
lru.unlinkBucket(pos)
lru.len--
if lru.onEvict != nil {
// Save k/v for the eviction function.
key := lru.elements[pos].key
value := lru.elements[pos].value
lru.onEvict(key, value)
}
}
// Move element from position 'from' to position 'to'.
// That avoids 'gaps' and new elements can always be simply appended.
func (lru *LRU[K, V]) move(to, from uint32) {
if to == from {
return
}
if from == lru.head {
lru.head = to
}
prev := lru.elements[from].prev
next := lru.elements[from].next
lru.elements[prev].next = to
lru.elements[next].prev = to
prev = lru.elements[from].prevBucket
next = lru.elements[from].nextBucket
lru.elements[prev].nextBucket = to
lru.elements[next].prevBucket = to
lru.elements[to] = lru.elements[from]
if lru.buckets[lru.elements[to].bucketPos] == from {
lru.buckets[lru.elements[to].bucketPos] = to
}
}
// insert stores the k/v at pos.
// It updates the head to point to this position.
func (lru *LRU[K, V]) insert(pos uint32, key K, value V, lifetime time.Duration) {
lru.elements[pos].key = key
lru.elements[pos].value = value
lru.elements[pos].expire = expire(lifetime)
if lru.len == 0 {
lru.elements[pos].prev = pos
lru.elements[pos].next = pos
lru.head = pos
} else if pos != lru.head {
lru.setHead(pos)
}
lru.len++
lru.metrics.Inserts++
}
func now() int64 {
return time.Now().UnixMilli()
}
func expire(lifetime time.Duration) int64 {
if lifetime == 0 {
return 0
}
return now() + lifetime.Milliseconds()
}
// clearKeyAndValue clears stale data to avoid memory leaks
func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) {
lru.elements[pos].key = lru.emptyKey
lru.elements[pos].value = lru.emptyValue
}
func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) {
_, startPos := lru.hashToPos(hash)
if startPos == emptyBucket {
return emptyBucket, false
}
pos := startPos
for {
if key == lru.elements[pos].key {
if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() || (lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value)) {
lru.removeAt(pos)
return emptyBucket, false
}
return pos, true
}
pos = lru.elements[pos].nextBucket
if pos == startPos {
// Key not found
return emptyBucket, false
}
}
}
func (lru *LRU[K, V]) findKeyNoExpire(hash uint32, key K) (uint32, bool) {
_, startPos := lru.hashToPos(hash)
if startPos == emptyBucket {
return emptyBucket, false
}
pos := startPos
for {
if key == lru.elements[pos].key {
if lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value) {
lru.removeAt(pos)
return emptyBucket, false
}
return pos, true
}
pos = lru.elements[pos].nextBucket
if pos == startPos {
// Key not found
return emptyBucket, false
}
}
}
// Len returns the number of elements stored in the cache.
func (lru *LRU[K, V]) Len() int {
return int(lru.len)
}
// AddWithLifetime adds a key:value to the cache with a lifetime.
// Returns true, true if key was updated and eviction occurred.
func (lru *LRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) {
return lru.addWithLifetime(lru.hash(key), key, value, lifetime)
}
func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V,
lifetime time.Duration,
) (evicted bool) {
bucketPos, startPos := lru.hashToPos(hash)
if startPos == emptyBucket {
pos := lru.len
if pos == lru.cap {
// Capacity reached, evict the oldest entry and
// store the new entry at evicted position.
pos = lru.elements[lru.head].next
lru.evict(pos)
lru.metrics.Evictions++
evicted = true
}
// insert new (first) entry into the bucket
lru.buckets[bucketPos] = pos
lru.elements[pos].bucketPos = bucketPos
lru.elements[pos].nextBucket = pos
lru.elements[pos].prevBucket = pos
lru.insert(pos, key, value, lifetime)
return evicted
}
// Walk through the bucket list to see whether key already exists.
pos := startPos
for {
if lru.elements[pos].key == key {
// Key exists, replace the value and update element to be the head element.
lru.elements[pos].value = value
lru.elements[pos].expire = expire(lifetime)
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
// count as insert, even if it's just an update
lru.metrics.Inserts++
return false
}
pos = lru.elements[pos].nextBucket
if pos == startPos {
// Key not found
break
}
}
pos = lru.len
if pos == lru.cap {
// Capacity reached, evict the oldest entry and
// store the new entry at evicted position.
pos = lru.elements[lru.head].next
lru.evict(pos)
lru.metrics.Evictions++
evicted = true
startPos = lru.buckets[bucketPos]
if startPos == emptyBucket {
startPos = pos
}
}
// insert new entry into the existing bucket before startPos
lru.buckets[bucketPos] = pos
lru.elements[pos].bucketPos = bucketPos
lru.elements[pos].nextBucket = startPos
lru.elements[pos].prevBucket = lru.elements[startPos].prevBucket
lru.elements[lru.elements[startPos].prevBucket].nextBucket = pos
lru.elements[startPos].prevBucket = pos
lru.insert(pos, key, value, lifetime)
if lru.elements[pos].prevBucket != pos {
// The bucket now contains more than 1 element.
// That means we have a collision.
lru.metrics.Collisions++
}
return evicted
}
// Add adds a key:value to the cache.
// Returns true, true if key was updated and eviction occurred.
func (lru *LRU[K, V]) Add(key K, value V) (evicted bool) {
return lru.addWithLifetime(lru.hash(key), key, value, lru.lifetime)
}
func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) {
return lru.addWithLifetime(hash, key, value, lru.lifetime)
}
// Get returns the value associated with the key, setting it as the most
// recently used item.
// If the found cache item is already expired, the evict function is called
// and the return value indicates that the key was not found.
func (lru *LRU[K, V]) Get(key K) (value V, ok bool) {
return lru.get(lru.hash(key), key)
}
func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) {
if pos, ok := lru.findKey(hash, key); ok {
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Hits++
return lru.elements[pos].value, ok
}
lru.metrics.Misses++
return
}
func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
return lru.getWithLifetime(lru.hash(key), key)
}
func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
if pos, ok := lru.findKey(hash, key); ok {
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Hits++
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
}
lru.metrics.Misses++
return
}
func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
return lru.getWithLifetimeNoExpire(lru.hash(key), key)
}
func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Hits++
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
}
lru.metrics.Misses++
return
}
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.
func (lru *LRU[K, V]) GetAndRefresh(key K) (V, bool) {
return lru.getAndRefresh(lru.hash(key), key)
}
func (lru *LRU[K, V]) getAndRefresh(hash uint32, key K) (value V, ok bool) {
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Hits++
lru.elements[pos].expire = expire(lru.lifetime)
return lru.elements[pos].value, ok
}
lru.metrics.Misses++
return
}
func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) {
value, updated, ok := lru.getAndRefreshOrAdd(lru.hash(key), key, constructor)
if !updated && ok {
lru.PurgeExpired()
}
return value, updated, ok
}
func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, updated bool, ok bool) {
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Hits++
lru.elements[pos].expire = expire(lru.lifetime)
return lru.elements[pos].value, true, true
}
lru.metrics.Misses++
value, ok = constructor()
if !ok {
return
}
lru.addWithLifetime(hash, key, value, lru.lifetime)
return value, false, true
}
// Peek looks up a key's value from the cache, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) {
return lru.peek(lru.hash(key), key)
}
func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) {
if pos, ok := lru.findKey(hash, key); ok {
return lru.elements[pos].value, ok
}
return
}
func (lru *LRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
return lru.peekWithLifetime(lru.hash(key), key)
}
func (lru *LRU[K, V]) peekWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
if pos, ok := lru.findKey(hash, key); ok {
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
}
return
}
func (lru *LRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) bool {
return lru.updateLifetime(lru.hash(key), key, value, lifetime)
}
func (lru *LRU[K, V]) updateLifetime(hash uint32, key K, value V, lifetime time.Duration) bool {
_, startPos := lru.hashToPos(hash)
if startPos == emptyBucket {
return false
}
pos := startPos
for {
if lru.elements[pos].key == key {
if lru.elements[pos].value != value {
return false
}
lru.elements[pos].expire = expire(lifetime)
if pos != lru.head {
lru.unlinkElement(pos)
lru.setHead(pos)
}
lru.metrics.Inserts++
return true
}
pos = lru.elements[pos].nextBucket
if pos == startPos {
return false
}
}
}
// Contains checks for the existence of a key, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
func (lru *LRU[K, V]) Contains(key K) (ok bool) {
_, ok = lru.peek(lru.hash(key), key)
return
}
func (lru *LRU[K, V]) contains(hash uint32, key K) (ok bool) {
_, ok = lru.peek(hash, key)
return
}
// Remove removes the key from the cache.
// The return value indicates whether the key existed or not.
// The evict function is called for the removed entry.
func (lru *LRU[K, V]) Remove(key K) (removed bool) {
return lru.remove(lru.hash(key), key)
}
func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) {
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
lru.removeAt(pos)
return ok
}
return
}
func (lru *LRU[K, V]) removeAt(pos uint32) {
lru.evict(pos)
lru.move(pos, lru.len)
lru.metrics.Removals++
// remove stale data to avoid memory leaks
lru.clearKeyAndValue(lru.len)
}
// RemoveOldest removes the oldest entry from the cache.
// Key, value and an indicator of whether the entry has been removed is returned.
// The evict function is called for the removed entry.
func (lru *LRU[K, V]) RemoveOldest() (key K, value V, removed bool) {
if lru.len == 0 {
return lru.emptyKey, lru.emptyValue, false
}
pos := lru.elements[lru.head].next
key = lru.elements[pos].key
value = lru.elements[pos].value
lru.removeAt(pos)
return key, value, true
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are not included.
// The evict function is called for each expired item.
func (lru *LRU[K, V]) Keys() []K {
lru.PurgeExpired()
keys := make([]K, 0, lru.len)
pos := lru.elements[lru.head].next
for i := uint32(0); i < lru.len; i++ {
keys = append(keys, lru.elements[pos].key)
pos = lru.elements[pos].next
}
return keys
}
// Purge purges all data (key and value) from the LRU.
// The evict function is called for each expired item.
// The LRU metrics are reset.
func (lru *LRU[K, V]) Purge() {
l := lru.len
for i := uint32(0); i < l; i++ {
_, _, _ = lru.RemoveOldest()
}
lru.metrics = Metrics{}
}
// PurgeExpired purges all expired items from the LRU.
// The evict function is called for each expired item.
func (lru *LRU[K, V]) PurgeExpired() {
n := now()
loop:
l := lru.len
if l == 0 {
return
}
pos := lru.elements[lru.head].next
for i := uint32(0); i < l; i++ {
if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= n {
lru.removeAt(pos)
goto loop
}
pos = lru.elements[pos].next
}
}
// Metrics returns the metrics of the cache.
func (lru *LRU[K, V]) Metrics() Metrics {
return lru.metrics
}
// ResetMetrics resets the metrics of the cache and returns the previous state.
func (lru *LRU[K, V]) ResetMetrics() Metrics {
metrics := lru.metrics
lru.metrics = Metrics{}
return metrics
}
// just used for debugging
func (lru *LRU[K, V]) dump() {
fmt.Printf("head %d len %d cap %d size %d mask 0x%X\n",
lru.head, lru.len, lru.cap, lru.size, lru.mask)
for i := range lru.buckets {
if lru.buckets[i] == emptyBucket {
continue
}
fmt.Printf(" bucket[%d] -> %d\n", i, lru.buckets[i])
pos := lru.buckets[i]
for {
e := &lru.elements[pos]
fmt.Printf(" pos %d bucketPos %d prevBucket %d nextBucket %d prev %d next %d k %v v %v\n",
pos, e.bucketPos, e.prevBucket, e.nextBucket, e.prev, e.next, e.key, e.value)
pos = e.nextBucket
if pos == lru.buckets[i] {
break
}
}
}
}
func (lru *LRU[K, V]) PrintStats() {
m := &lru.metrics
fmt.Printf("Inserts: %d Collisions: %d (%.2f%%) Evictions: %d Removals: %d Hits: %d (%.2f%%) Misses: %d\n",
m.Inserts, m.Collisions, float64(m.Collisions)/float64(m.Inserts)*100,
m.Evictions, m.Removals,
m.Hits, float64(m.Hits)/float64(m.Hits+m.Misses)*100, m.Misses)
}

100
contrab/freelru/lru_test.go Normal file
View file

@ -0,0 +1,100 @@
package freelru_test
import (
"math/rand"
"testing"
"time"
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
"github.com/stretchr/testify/require"
)
func TestUpdateLifetimeOnGet(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.AddWithLifetime("hello", "world", 2*time.Second)
time.Sleep(time.Second)
_, ok := lru.GetAndRefresh("hello")
require.True(t, ok)
time.Sleep(time.Second + time.Millisecond*100)
_, ok = lru.Get("hello")
require.True(t, ok)
}
func TestUpdateLifetimeOnGet1(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.AddWithLifetime("hello", "world", 2*time.Second)
time.Sleep(time.Second)
lru.Peek("hello")
time.Sleep(time.Second + time.Millisecond*100)
_, ok := lru.Get("hello")
require.False(t, ok)
}
func TestUpdateLifetime(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.Add("hello", "world")
require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second))
time.Sleep(time.Second)
_, ok := lru.Get("hello")
require.True(t, ok)
time.Sleep(time.Second + time.Millisecond*100)
_, ok = lru.Get("hello")
require.False(t, ok)
}
func TestUpdateLifetime1(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.Add("hello", "world")
require.False(t, lru.UpdateLifetime("hello", "not world", 2*time.Second))
time.Sleep(2*time.Second + time.Millisecond*100)
_, ok := lru.Get("hello")
require.True(t, ok)
}
func TestUpdateLifetime2(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.AddWithLifetime("hello", "world", 2*time.Second)
time.Sleep(time.Second)
require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second))
time.Sleep(time.Second + time.Millisecond*100)
_, ok := lru.Get("hello")
require.True(t, ok)
time.Sleep(time.Second + time.Millisecond*100)
_, ok = lru.Get("hello")
require.False(t, ok)
}
func TestPurgeExpired(t *testing.T) {
t.Parallel()
lru, err := freelru.New[string, *string](1024, maphash.NewHasher[string]().Hash32)
require.NoError(t, err)
lru.SetLifetime(time.Second)
lru.SetOnEvict(func(s string, s2 *string) {
if s2 == nil {
t.Fail()
}
})
for i := 0; i < 100; i++ {
lru.AddWithLifetime("hello_"+F.ToString(i), common.Ptr("world_"+F.ToString(i)), time.Duration(rand.Intn(3000))*time.Millisecond)
}
for i := 0; i < 5; i++ {
time.Sleep(time.Second)
lru.GetAndRefreshOrAdd("hellox"+F.ToString(i), func() (*string, bool) {
return common.Ptr("worldx"), true
})
}
}

View file

@ -0,0 +1,398 @@
package freelru
import (
"errors"
"fmt"
"math/bits"
"runtime"
"sync"
"time"
)
// ShardedLRU is a thread-safe, sharded, fixed size LRU cache.
// Sharding is used to reduce lock contention on high concurrency.
// The downside is that exact LRU behavior is not given (as for the LRU and SynchedLRU types).
type ShardedLRU[K comparable, V comparable] struct {
lrus []LRU[K, V]
mus []sync.RWMutex
hash HashKeyCallback[K]
shards uint32
mask uint32
}
var _ Cache[int, int] = (*ShardedLRU[int, int])(nil)
// SetLifetime sets the default lifetime of LRU elements.
// Lifetime 0 means "forever".
func (lru *ShardedLRU[K, V]) SetLifetime(lifetime time.Duration) {
for shard := range lru.lrus {
lru.mus[shard].Lock()
lru.lrus[shard].SetLifetime(lifetime)
lru.mus[shard].Unlock()
}
}
// SetOnEvict sets the OnEvict callback function.
// The onEvict function is called for each evicted lru entry.
func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
for shard := range lru.lrus {
lru.mus[shard].Lock()
lru.lrus[shard].SetOnEvict(onEvict)
lru.mus[shard].Unlock()
}
}
func (lru *ShardedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
for shard := range lru.lrus {
lru.mus[shard].Lock()
lru.lrus[shard].SetHealthCheck(healthCheck)
lru.mus[shard].Unlock()
}
}
func nextPowerOfTwo(val uint32) uint32 {
if bits.OnesCount32(val) != 1 {
return 1 << bits.Len32(val)
}
return val
}
// NewSharded creates a new thread-safe sharded LRU hashmap with the given capacity.
func NewSharded[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V],
error,
) {
size := uint32(float64(capacity) * 1.25) // 25% extra space for fewer collisions
return NewShardedWithSize[K, V](uint32(runtime.GOMAXPROCS(0)*16), capacity, size, hash)
}
func NewShardedWithSize[K comparable, V comparable](shards, capacity, size uint32,
hash HashKeyCallback[K]) (
*ShardedLRU[K, V], error,
) {
if capacity == 0 {
return nil, errors.New("capacity must be positive")
}
if size < capacity {
return nil, fmt.Errorf("size (%d) is smaller than capacity (%d)", size, capacity)
}
if size < 1<<31 {
size = nextPowerOfTwo(size) // next power of 2 so the LRUs can avoid costly divisions
} else {
size = 1 << 31 // the highest 2^N value that fits in a uint32
}
shards = nextPowerOfTwo(shards) // next power of 2 so we can avoid costly division for sharding
for shards > size/16 {
shards /= 16
}
if shards == 0 {
shards = 1
}
size /= shards // size per LRU
if size == 0 {
size = 1
}
capacity = (capacity + shards - 1) / shards // size per LRU
if capacity == 0 {
capacity = 1
}
lrus := make([]LRU[K, V], shards)
buckets := make([]uint32, size*shards)
elements := make([]element[K, V], size*shards)
from := 0
to := int(size)
for i := range lrus {
initLRU(&lrus[i], capacity, size, hash, buckets[from:to], elements[from:to])
from = to
to += int(size)
}
return &ShardedLRU[K, V]{
lrus: lrus,
mus: make([]sync.RWMutex, shards),
hash: hash,
shards: shards,
mask: shards - 1,
}, nil
}
// Len returns the number of elements stored in the cache.
func (lru *ShardedLRU[K, V]) Len() (length int) {
for shard := range lru.lrus {
lru.mus[shard].RLock()
length += lru.lrus[shard].Len()
lru.mus[shard].RUnlock()
}
return
}
// AddWithLifetime adds a key:value to the cache with a lifetime.
// Returns true, true if key was updated and eviction occurred.
func (lru *ShardedLRU[K, V]) AddWithLifetime(key K, value V,
lifetime time.Duration,
) (evicted bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
evicted = lru.lrus[shard].addWithLifetime(hash, key, value, lifetime)
lru.mus[shard].Unlock()
return
}
// Add adds a key:value to the cache.
// Returns true, true if key was updated and eviction occurred.
func (lru *ShardedLRU[K, V]) Add(key K, value V) (evicted bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
evicted = lru.lrus[shard].add(hash, key, value)
lru.mus[shard].Unlock()
return
}
// Get returns the value associated with the key, setting it as the most
// recently used item.
// If the found cache item is already expired, the evict function is called
// and the return value indicates that the key was not found.
func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
value, ok = lru.lrus[shard].get(hash, key)
lru.mus[shard].Unlock()
return
}
func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
value, lifetime, ok = lru.lrus[shard].getWithLifetime(hash, key)
lru.mus[shard].Unlock()
return
}
func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].RLock()
value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key)
lru.mus[shard].RUnlock()
return
}
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.
func (lru *ShardedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
value, ok = lru.lrus[shard].getAndRefresh(hash, key)
lru.mus[shard].Unlock()
return
}
func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
value, updated, ok = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor)
lru.mus[shard].Unlock()
if !updated && ok {
lru.PurgeExpired()
}
return
}
// Peek looks up a key's value from the cache, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
value, ok = lru.lrus[shard].peek(hash, key)
lru.mus[shard].Unlock()
return
}
func (lru *ShardedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
value, lifetime, ok = lru.lrus[shard].peekWithLifetime(hash, key)
lru.mus[shard].Unlock()
return
}
func (lru *ShardedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
ok = lru.lrus[shard].updateLifetime(hash, key, value, lifetime)
lru.mus[shard].Unlock()
return
}
// Contains checks for the existence of a key, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
func (lru *ShardedLRU[K, V]) Contains(key K) (ok bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
ok = lru.lrus[shard].contains(hash, key)
lru.mus[shard].Unlock()
return
}
// Remove removes the key from the cache.
// The return value indicates whether the key existed or not.
// The evict function is called for the removed entry.
func (lru *ShardedLRU[K, V]) Remove(key K) (removed bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
removed = lru.lrus[shard].remove(hash, key)
lru.mus[shard].Unlock()
return
}
// RemoveOldest removes the oldest entry from the cache.
// Key, value and an indicator of whether the entry has been removed is returned.
// The evict function is called for the removed entry.
func (lru *ShardedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) {
hash := lru.hash(key)
shard := (hash >> 16) & lru.mask
lru.mus[shard].Lock()
key, value, removed = lru.lrus[shard].RemoveOldest()
lru.mus[shard].Unlock()
return
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are not included.
// The evict function is called for each expired item.
func (lru *ShardedLRU[K, V]) Keys() []K {
keys := make([]K, 0, lru.shards*lru.lrus[0].cap)
for shard := range lru.lrus {
lru.mus[shard].Lock()
keys = append(keys, lru.lrus[shard].Keys()...)
lru.mus[shard].Unlock()
}
return keys
}
// Purge purges all data (key and value) from the LRU.
// The evict function is called for each expired item.
// The LRU metrics are reset.
func (lru *ShardedLRU[K, V]) Purge() {
for shard := range lru.lrus {
lru.mus[shard].Lock()
lru.lrus[shard].Purge()
lru.mus[shard].Unlock()
}
}
// PurgeExpired purges all expired items from the LRU.
// The evict function is called for each expired item.
func (lru *ShardedLRU[K, V]) PurgeExpired() {
for shard := range lru.lrus {
lru.mus[shard].Lock()
lru.lrus[shard].PurgeExpired()
lru.mus[shard].Unlock()
}
}
// Metrics returns the metrics of the cache.
func (lru *ShardedLRU[K, V]) Metrics() Metrics {
metrics := Metrics{}
for shard := range lru.lrus {
lru.mus[shard].Lock()
m := lru.lrus[shard].Metrics()
lru.mus[shard].Unlock()
addMetrics(&metrics, m)
}
return metrics
}
// ResetMetrics resets the metrics of the cache and returns the previous state.
func (lru *ShardedLRU[K, V]) ResetMetrics() Metrics {
metrics := Metrics{}
for shard := range lru.lrus {
lru.mus[shard].Lock()
m := lru.lrus[shard].ResetMetrics()
lru.mus[shard].Unlock()
addMetrics(&metrics, m)
}
return metrics
}
func addMetrics(dst *Metrics, src Metrics) {
dst.Inserts += src.Inserts
dst.Collisions += src.Collisions
dst.Evictions += src.Evictions
dst.Removals += src.Removals
dst.Hits += src.Hits
dst.Misses += src.Misses
}
// just used for debugging
func (lru *ShardedLRU[K, V]) dump() {
for shard := range lru.lrus {
fmt.Printf("Shard %d:\n", shard)
lru.mus[shard].RLock()
lru.lrus[shard].dump()
lru.mus[shard].RUnlock()
fmt.Println("")
}
}
func (lru *ShardedLRU[K, V]) PrintStats() {
for shard := range lru.lrus {
fmt.Printf("Shard %d:\n", shard)
lru.mus[shard].RLock()
lru.lrus[shard].PrintStats()
lru.mus[shard].RUnlock()
fmt.Println("")
}
}

View file

@ -0,0 +1,270 @@
package freelru
import (
"sync"
"time"
)
type SyncedLRU[K comparable, V comparable] struct {
mu sync.RWMutex
lru *LRU[K, V]
}
var _ Cache[int, int] = (*SyncedLRU[int, int])(nil)
// SetLifetime sets the default lifetime of LRU elements.
// Lifetime 0 means "forever".
func (lru *SyncedLRU[K, V]) SetLifetime(lifetime time.Duration) {
lru.mu.Lock()
lru.lru.SetLifetime(lifetime)
lru.mu.Unlock()
}
// SetOnEvict sets the OnEvict callback function.
// The onEvict function is called for each evicted lru entry.
func (lru *SyncedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
lru.mu.Lock()
lru.lru.SetOnEvict(onEvict)
lru.mu.Unlock()
}
func (lru *SyncedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
lru.mu.Lock()
lru.lru.SetHealthCheck(healthCheck)
lru.mu.Unlock()
}
// NewSynced creates a new thread-safe LRU hashmap with the given capacity.
func NewSynced[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*SyncedLRU[K, V],
error,
) {
return NewSyncedWithSize[K, V](capacity, capacity, hash)
}
func NewSyncedWithSize[K comparable, V comparable](capacity, size uint32,
hash HashKeyCallback[K],
) (*SyncedLRU[K, V], error) {
lru, err := NewWithSize[K, V](capacity, size, hash)
if err != nil {
return nil, err
}
return &SyncedLRU[K, V]{lru: lru}, nil
}
// Len returns the number of elements stored in the cache.
func (lru *SyncedLRU[K, V]) Len() (length int) {
lru.mu.RLock()
length = lru.lru.Len()
lru.mu.RUnlock()
return
}
// AddWithLifetime adds a key:value to the cache with a lifetime.
// Returns true, true if key was updated and eviction occurred.
func (lru *SyncedLRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
evicted = lru.lru.addWithLifetime(hash, key, value, lifetime)
lru.mu.Unlock()
return
}
// Add adds a key:value to the cache.
// Returns true, true if key was updated and eviction occurred.
func (lru *SyncedLRU[K, V]) Add(key K, value V) (evicted bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
evicted = lru.lru.add(hash, key, value)
lru.mu.Unlock()
return
}
// Get returns the value associated with the key, setting it as the most
// recently used item.
// If the found cache item is already expired, the evict function is called
// and the return value indicates that the key was not found.
func (lru *SyncedLRU[K, V]) Get(key K) (value V, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, ok = lru.lru.get(hash, key)
lru.mu.Unlock()
return
}
func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, lifetime, ok = lru.lru.getWithLifetime(hash, key)
lru.mu.Unlock()
return
}
func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key)
lru.mu.Unlock()
return
}
// GetAndRefresh returns the value associated with the key, setting it as the most
// recently used item.
// The lifetime of the found cache item is refreshed, even if it was already expired.
func (lru *SyncedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, ok = lru.lru.getAndRefresh(hash, key)
lru.mu.Unlock()
return
}
func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, updated, ok = lru.lru.getAndRefreshOrAdd(hash, key, constructor)
if !updated && ok {
lru.lru.PurgeExpired()
}
lru.mu.Unlock()
return
}
// Peek looks up a key's value from the cache, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
func (lru *SyncedLRU[K, V]) Peek(key K) (value V, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, ok = lru.lru.peek(hash, key)
lru.mu.Unlock()
return
}
func (lru *SyncedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
value, lifetime, ok = lru.lru.peekWithLifetime(hash, key)
lru.mu.Unlock()
return
}
func (lru *SyncedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
ok = lru.lru.updateLifetime(hash, key, value, lifetime)
lru.mu.Unlock()
return
}
// Contains checks for the existence of a key, without changing its recent-ness.
// If the found entry is already expired, the evict function is called.
func (lru *SyncedLRU[K, V]) Contains(key K) (ok bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
ok = lru.lru.contains(hash, key)
lru.mu.Unlock()
return
}
// Remove removes the key from the cache.
// The return value indicates whether the key existed or not.
// The evict function is being called if the key existed.
func (lru *SyncedLRU[K, V]) Remove(key K) (removed bool) {
hash := lru.lru.hash(key)
lru.mu.Lock()
removed = lru.lru.remove(hash, key)
lru.mu.Unlock()
return
}
// RemoveOldest removes the oldest entry from the cache.
// Key, value and an indicator of whether the entry has been removed is returned.
// The evict function is called for the removed entry.
func (lru *SyncedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) {
lru.mu.Lock()
key, value, removed = lru.lru.RemoveOldest()
lru.mu.Unlock()
return
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are not included.
// The evict function is called for each expired item.
func (lru *SyncedLRU[K, V]) Keys() (keys []K) {
lru.mu.Lock()
keys = lru.lru.Keys()
lru.mu.Unlock()
return
}
// Purge purges all data (key and value) from the LRU.
// The evict function is called for each expired item.
// The LRU metrics are reset.
func (lru *SyncedLRU[K, V]) Purge() {
lru.mu.Lock()
lru.lru.Purge()
lru.mu.Unlock()
}
// PurgeExpired purges all expired items from the LRU.
// The evict function is called for each expired item.
func (lru *SyncedLRU[K, V]) PurgeExpired() {
lru.mu.Lock()
lru.lru.PurgeExpired()
lru.mu.Unlock()
}
// Metrics returns the metrics of the cache.
func (lru *SyncedLRU[K, V]) Metrics() Metrics {
lru.mu.Lock()
metrics := lru.lru.Metrics()
lru.mu.Unlock()
return metrics
}
// ResetMetrics resets the metrics of the cache and returns the previous state.
func (lru *SyncedLRU[K, V]) ResetMetrics() Metrics {
lru.mu.Lock()
metrics := lru.lru.ResetMetrics()
lru.mu.Unlock()
return metrics
}
// just used for debugging
func (lru *SyncedLRU[K, V]) dump() {
lru.mu.RLock()
lru.lru.dump()
lru.mu.RUnlock()
}
func (lru *SyncedLRU[K, V]) PrintStats() {
lru.mu.RLock()
lru.lru.PrintStats()
lru.mu.RUnlock()
}

View file

@ -0,0 +1,3 @@
# maphash
kanged from github.com/dolthub/maphash@v0.1.0

53
contrab/maphash/hasher.go Normal file
View file

@ -0,0 +1,53 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package maphash
import "unsafe"
// Hasher hashes values of type K.
// Uses runtime AES-based hashing.
type Hasher[K comparable] struct {
hash hashfn
seed uintptr
}
// NewHasher creates a new Hasher[K] with a random seed.
func NewHasher[K comparable]() Hasher[K] {
return Hasher[K]{
hash: getRuntimeHasher[K](),
seed: newHashSeed(),
}
}
// NewSeed returns a copy of |h| with a new hash seed.
func NewSeed[K comparable](h Hasher[K]) Hasher[K] {
return Hasher[K]{
hash: h.hash,
seed: newHashSeed(),
}
}
// Hash hashes |key|.
func (h Hasher[K]) Hash(key K) uint64 {
// promise to the compiler that pointer
// |p| does not escape the stack.
p := noescape(unsafe.Pointer(&key))
return uint64(h.hash(p, h.seed))
}
func (h Hasher[K]) Hash32(key K) uint32 {
p := noescape(unsafe.Pointer(&key))
return uint32(h.hash(p, h.seed))
}

114
contrab/maphash/runtime.go Normal file
View file

@ -0,0 +1,114 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file incorporates work covered by the following copyright and
// permission notice:
//
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.18 || go1.19
// +build go1.18 go1.19
package maphash
import (
"math/rand"
"unsafe"
)
type hashfn func(unsafe.Pointer, uintptr) uintptr
func getRuntimeHasher[K comparable]() (h hashfn) {
a := any(make(map[K]struct{}))
i := (*mapiface)(unsafe.Pointer(&a))
h = i.typ.hasher
return
}
func newHashSeed() uintptr {
return uintptr(rand.Int())
}
// noescape hides a pointer from escape analysis. It is the identity function
// but escape analysis doesn't think the output depends on the input.
// noescape is inlined and currently compiles down to zero instructions.
// USE CAREFULLY!
// This was copied from the runtime (via pkg "strings"); see issues 23382 and 7921.
//
//go:nosplit
//go:nocheckptr
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
//nolint:staticcheck
return unsafe.Pointer(x ^ 0)
}
type mapiface struct {
typ *maptype
val *hmap
}
// go/src/runtime/type.go
type maptype struct {
typ _type
key *_type
elem *_type
bucket *_type
// function for hashing keys (ptr to key, seed) -> hash
hasher func(unsafe.Pointer, uintptr) uintptr
keysize uint8
elemsize uint8
bucketsize uint16
flags uint32
}
// go/src/runtime/map.go
type hmap struct {
count int
flags uint8
B uint8
noverflow uint16
// hash seed
hash0 uint32
buckets unsafe.Pointer
oldbuckets unsafe.Pointer
nevacuate uintptr
// true type is *mapextra
// but we don't need this data
extra unsafe.Pointer
}
// go/src/runtime/type.go
type (
tflag uint8
nameOff int32
typeOff int32
)
// go/src/runtime/type.go
type _type struct {
size uintptr
ptrdata uintptr
hash uint32
tflag tflag
align uint8
fieldAlign uint8
kind uint8
equal func(unsafe.Pointer, unsafe.Pointer) bool
gcdata *byte
str nameOff
ptrToThis typeOff
}

View file

@ -4,6 +4,7 @@ import (
std_bufio "bufio"
"context"
"encoding/base64"
"io"
"net"
"net/http"
"strings"
@ -20,15 +21,20 @@ import (
"github.com/sagernet/sing/common/pipe"
)
type Handler = N.TCPConnectionHandler
func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
func HandleConnectionEx(
ctx context.Context,
conn net.Conn,
reader *std_bufio.Reader,
authenticator *auth.Authenticator,
handler N.TCPConnectionHandlerEx,
source M.Socksaddr,
onClose N.CloseHandlerFunc,
) error {
for {
request, err := ReadRequest(reader)
if err != nil {
return E.Cause(err, "read http request")
}
if authenticator != nil {
var (
username string
@ -68,22 +74,23 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
}
if sourceAddress := SourceAddress(request); sourceAddress.IsValid() {
metadata.Source = sourceAddress
source = sourceAddress
}
if request.Method == "CONNECT" {
portStr := request.URL.Port()
if portStr == "" {
portStr = "80"
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
if destination.Port == 0 {
switch request.URL.Scheme {
case "https", "wss":
destination.Port = 443
default:
destination.Port = 80
}
}
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
if err != nil {
return E.Cause(err, "write http response")
}
metadata.Protocol = "http"
metadata.Destination = destination
var requestConn net.Conn
if reader.Buffered() > 0 {
buffer := buf.NewSize(reader.Buffered())
@ -95,75 +102,115 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read
} else {
requestConn = conn
}
return handler.NewConnection(ctx, requestConn, metadata)
}
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
request.RequestURI = ""
removeHopByHopHeaders(request.Header)
removeExtraHTTPHostPort(request)
if hostStr := request.Header.Get("Host"); hostStr != "" {
if hostStr != request.URL.Host {
request.Host = hostStr
handler.NewConnectionEx(ctx, requestConn, source, destination, onClose)
return nil
} else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" {
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
if destination.Port == 0 {
switch request.URL.Scheme {
case "https", "wss":
destination.Port = 443
default:
destination.Port = 80
}
}
serverConn, clientConn := pipe.Pipe()
go func() {
handler.NewConnectionEx(ctx, clientConn, source, destination, func(it error) {
if it != nil {
common.Close(serverConn, clientConn)
}
})
}()
err = request.Write(serverConn)
if err != nil {
return E.Cause(err, "http: write upgrade request")
}
if reader.Buffered() > 0 {
_, err = io.CopyN(serverConn, reader, int64(reader.Buffered()))
if err != nil {
return err
}
}
return bufio.CopyConn(ctx, conn, serverConn)
} else {
err = handleHTTPConnection(ctx, handler, conn, request, source)
if err != nil {
return err
}
}
}
}
if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
}
func handleHTTPConnection(
ctx context.Context,
handler N.TCPConnectionHandlerEx,
conn net.Conn,
request *http.Request, source M.Socksaddr,
) error {
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
request.RequestURI = ""
var innerErr atomic.TypedValue[error]
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
metadata.Destination = M.ParseSocksaddr(address)
metadata.Protocol = "http"
input, output := pipe.Pipe()
go func() {
hErr := handler.NewConnection(ctx, output, metadata)
if hErr != nil {
innerErr.Store(hErr)
common.Close(input, output)
}
}()
return input, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
requestCtx, cancel := context.WithCancel(ctx)
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
}
removeHopByHopHeaders(request.Header)
removeExtraHTTPHostPort(request)
removeHopByHopHeaders(response.Header)
if keepAlive {
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
cancel()
if !keepAlive {
return conn.Close()
if hostStr := request.Header.Get("Host"); hostStr != "" {
if hostStr != request.URL.Host {
request.Host = hostStr
}
}
if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
}
var innerErr atomic.TypedValue[error]
httpClient := &http.Client{
Transport: &http.Transport{
DisableCompression: true,
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
input, output := pipe.Pipe()
go handler.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) {
innerErr.Store(it)
common.Close(input, output)
})
return input, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
defer httpClient.CloseIdleConnections()
requestCtx, cancel := context.WithCancel(ctx)
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
}
removeHopByHopHeaders(response.Header)
if keepAlive {
response.Header.Set("Proxy-Connection", "keep-alive")
response.Header.Set("Connection", "keep-alive")
response.Header.Set("Keep-Alive", "timeout=4")
}
response.Close = !keepAlive
err = response.Write(conn)
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err)
}
cancel()
if !keepAlive {
return conn.Close()
}
return nil
}
func removeHopByHopHeaders(header http.Header) {

View file

@ -10,6 +10,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@ -19,9 +20,13 @@ import (
"github.com/sagernet/sing/protocol/socks/socks5"
)
type Handler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
type HandlerEx interface {
N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx
}
type PacketListener interface {
ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error)
}
func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
@ -79,6 +84,26 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
} else if authResponse.Method != socks5.AuthTypeNotRequired {
return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method)
}
if command == socks5.CommandUDPAssociate {
if destination.Addr.IsPrivate() {
if destination.Addr.Is6() {
destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
} else {
destination.Addr = netip.IPv6Loopback()
}
} else if destination.Addr.IsGlobalUnicast() {
if destination.Addr.Is6() {
destination.Addr = netip.IPv6Unspecified()
} else {
destination.Addr = netip.IPv4Unspecified()
}
} else {
destination.Addr = netip.IPv6Unspecified()
}
destination.Port = 0
}
err = socks5.WriteRequest(conn, socks5.Request{
Command: command,
Destination: destination,
@ -96,18 +121,23 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
return response, err
}
func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata)
}
func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
func HandleConnectionEx(
ctx context.Context, conn net.Conn, reader *std_bufio.Reader,
authenticator *auth.Authenticator,
handler HandlerEx,
packetListener PacketListener,
// resolver TorResolver,
source M.Socksaddr,
onClose N.CloseHandlerFunc,
) error {
version, err := reader.ReadByte()
if err != nil {
return err
}
switch version {
case socks4.Version:
request, err := socks4.ReadRequest0(reader)
var request socks4.Request
request, err = socks4.ReadRequest0(reader)
if err != nil {
return err
}
@ -115,28 +145,23 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea
case socks4.CommandConnect:
if authenticator != nil && !authenticator.Verify(request.Username, "") {
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
Destination: request.Destination,
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
if err != nil {
return err
}
return E.New("socks4: authentication failed, username=", request.Username)
}
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.SocksaddrFromNet(conn.LocalAddr()),
})
if err != nil {
return err
}
metadata.Protocol = "socks4"
metadata.Destination = request.Destination
return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, metadata)
handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose)
return nil
/*case CommandTorResolve, CommandTorResolvePTR:
if resolver == nil {
return E.New("socks4: torsocks: commands not implemented")
}
return handleTorSocks4(ctx, conn, request, resolver)*/
default:
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
Destination: request.Destination,
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
if err != nil {
return err
@ -144,7 +169,8 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea
return E.New("socks4: unsupported command ", request.Command)
}
case socks5.Version:
authRequest, err := socks5.ReadAuthRequest0(reader)
var authRequest socks5.AuthRequest
authRequest, err = socks5.ReadAuthRequest0(reader)
if err != nil {
return err
}
@ -169,7 +195,8 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea
return err
}
if authMethod == socks5.AuthTypeUsernamePassword {
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader)
var usernamePasswordAuthRequest socks5.UsernamePasswordAuthRequest
usernamePasswordAuthRequest, err = socks5.ReadUsernamePasswordAuthRequest(reader)
if err != nil {
return err
}
@ -188,49 +215,50 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea
return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password)
}
}
request, err := socks5.ReadRequest(reader)
var request socks5.Request
request, err = socks5.ReadRequest(reader)
if err != nil {
return err
}
switch request.Command {
case socks5.CommandConnect:
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
})
if err != nil {
return err
}
metadata.Protocol = "socks5"
metadata.Destination = request.Destination
return handler.NewConnection(ctx, conn, metadata)
handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose)
return nil
case socks5.CommandUDPAssociate:
var udpConn *net.UDPConn
udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0)))
if err != nil {
return err
var (
listenConfig net.ListenConfig
udpConn net.PacketConn
)
if packetListener != nil {
udpConn, err = packetListener.ListenPacket(listenConfig, ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String())
} else {
udpConn, err = listenConfig.ListenPacket(ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String())
}
if err != nil {
return E.Cause(err, "socks5: listen udp")
}
defer udpConn.Close()
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
})
if err != nil {
return err
return E.Cause(err, "socks5: write response")
}
metadata.Protocol = "socks5"
metadata.Destination = request.Destination
var innerError error
done := make(chan struct{})
associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn)
go func() {
innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata)
close(done)
}()
err = common.Error(io.Copy(io.Discard, conn))
associatePacketConn.Close()
<-done
return E.Errors(innerError, err)
var socksPacketConn N.PacketConn = NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), M.Socksaddr{}, conn)
firstPacket := buf.NewPacket()
var destination M.Socksaddr
destination, err = socksPacketConn.ReadPacket(firstPacket)
if err != nil {
return E.Cause(err, "socks5: read first packet")
}
socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination)
handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose)
return nil
/*case CommandTorResolve, CommandTorResolvePTR:
if resolver == nil {
return E.New("socks4: torsocks: commands not implemented")
}
return handleTorSocks5(ctx, conn, request, resolver)*/
default:
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeUnsupported,

View file

@ -0,0 +1,146 @@
package socks
import (
"context"
"net"
"net/netip"
"os"
"strings"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
const (
CommandTorResolve byte = 0xF0
CommandTorResolvePTR byte = 0xF1
)
type TorResolver interface {
LookupIP(ctx context.Context, host string) (netip.Addr, error)
LookupPTR(ctx context.Context, addr netip.Addr) (string, error)
}
func handleTorSocks4(ctx context.Context, conn net.Conn, request socks4.Request, resolver TorResolver) error {
switch request.Command {
case CommandTorResolve:
if !request.Destination.IsFqdn() {
return E.New("socks4: torsocks: invalid destination")
}
ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn)
if err != nil {
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
if err != nil {
return err
}
return E.Cause(err, "socks4: torsocks: lookup failed for domain: ", request.Destination.Fqdn)
}
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.SocksaddrFrom(ipAddr, 0),
})
if err != nil {
return E.Cause(err, "socks4: torsocks: write response")
}
return nil
case CommandTorResolvePTR:
var ipAddr netip.Addr
if request.Destination.IsIP() {
ipAddr = request.Destination.Addr
} else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") {
ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")])
} else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") {
ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":"))
}
if !ipAddr.IsValid() {
return E.New("socks4: torsocks: invalid destination")
}
host, err := resolver.LookupPTR(ctx, ipAddr)
if err != nil {
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
if err != nil {
return err
}
return E.Cause(err, "socks4: torsocks: lookup PTR failed for ip: ", ipAddr)
}
err = socks4.WriteResponse(conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.Socksaddr{
Fqdn: host,
},
})
if err != nil {
return E.Cause(err, "socks4: torsocks: write response")
}
return nil
default:
return os.ErrInvalid
}
}
func handleTorSocks5(ctx context.Context, conn net.Conn, request socks5.Request, resolver TorResolver) error {
switch request.Command {
case CommandTorResolve:
if !request.Destination.IsFqdn() {
return E.New("socks5: torsocks: invalid destination")
}
ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn)
if err != nil {
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeFailure,
})
if err != nil {
return err
}
return E.Cause(err, "socks5: torsocks: lookup failed for domain: ", request.Destination.Fqdn)
}
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFrom(ipAddr, 0),
})
if err != nil {
return E.Cause(err, "socks5: torsocks: write response")
}
return nil
case CommandTorResolvePTR:
var ipAddr netip.Addr
if request.Destination.IsIP() {
ipAddr = request.Destination.Addr
} else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") {
ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")])
} else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") {
ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":"))
}
if !ipAddr.IsValid() {
return E.New("socks5: torsocks: invalid destination")
}
host, err := resolver.LookupPTR(ctx, ipAddr)
if err != nil {
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeFailure,
})
if err != nil {
return err
}
return E.Cause(err, "socks5: torsocks: lookup PTR failed for ip: ", ipAddr)
}
err = socks5.WriteResponse(conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.Socksaddr{
Fqdn: host,
},
})
if err != nil {
return E.Cause(err, "socks5: torsocks: write response")
}
return nil
default:
return os.ErrInvalid
}
}

215
protocol/socks/lazy.go Normal file
View file

@ -0,0 +1,215 @@
package socks
import (
"net"
"os"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/protocol/socks/socks4"
"github.com/sagernet/sing/protocol/socks/socks5"
)
type LazyConn struct {
net.Conn
socksVersion byte
responseWritten bool
}
func NewLazyConn(conn net.Conn, socksVersion byte) *LazyConn {
return &LazyConn{
Conn: conn,
socksVersion: socksVersion,
}
}
func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error {
if c.responseWritten {
return nil
}
defer func() {
c.responseWritten = true
}()
switch c.socksVersion {
case socks4.Version:
return socks4.WriteResponse(c.Conn, socks4.Response{
ReplyCode: socks4.ReplyCodeGranted,
Destination: M.SocksaddrFromNet(conn.LocalAddr()),
})
case socks5.Version:
return socks5.WriteResponse(c.Conn, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
})
default:
panic("unknown socks version")
}
}
func (c *LazyConn) HandshakeFailure(err error) error {
if c.responseWritten {
return os.ErrInvalid
}
defer func() {
c.responseWritten = true
}()
switch c.socksVersion {
case socks4.Version:
return socks4.WriteResponse(c.Conn, socks4.Response{
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
})
case socks5.Version:
return socks5.WriteResponse(c.Conn, socks5.Response{
ReplyCode: socks5.ReplyCodeForError(err),
})
default:
panic("unknown socks version")
}
}
func (c *LazyConn) Read(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.ConnHandshakeSuccess(c.Conn)
if err != nil {
return
}
}
return c.Conn.Read(p)
}
func (c *LazyConn) Write(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.ConnHandshakeSuccess(c.Conn)
if err != nil {
return
}
}
return c.Conn.Write(p)
}
func (c *LazyConn) ReaderReplaceable() bool {
return c.responseWritten
}
func (c *LazyConn) WriterReplaceable() bool {
return c.responseWritten
}
func (c *LazyConn) Upstream() any {
return c.Conn
}
type LazyAssociatePacketConn struct {
AssociatePacketConn
responseWritten bool
}
func NewLazyAssociatePacketConn(conn net.Conn, underlying net.Conn) *LazyAssociatePacketConn {
return &LazyAssociatePacketConn{
AssociatePacketConn: AssociatePacketConn{
AbstractConn: conn,
conn: bufio.NewExtendedConn(conn),
underlying: underlying,
},
}
}
func (c *LazyAssociatePacketConn) HandshakeSuccess() error {
if c.responseWritten {
return nil
}
defer func() {
c.responseWritten = true
}()
return socks5.WriteResponse(c.underlying, socks5.Response{
ReplyCode: socks5.ReplyCodeSuccess,
Bind: M.SocksaddrFromNet(c.conn.LocalAddr()),
})
}
func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error {
if c.responseWritten {
return os.ErrInvalid
}
defer func() {
c.responseWritten = true
c.conn.Close()
c.underlying.Close()
}()
return socks5.WriteResponse(c.underlying, socks5.Response{
ReplyCode: socks5.ReplyCodeForError(err),
})
}
func (c *LazyAssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.ReadFrom(p)
}
func (c *LazyAssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.WriteTo(p, addr)
}
func (c *LazyAssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.ReadPacket(buffer)
}
func (c *LazyAssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.responseWritten {
err := c.HandshakeSuccess()
if err != nil {
return err
}
}
return c.AssociatePacketConn.WritePacket(buffer, destination)
}
func (c *LazyAssociatePacketConn) Read(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.Read(p)
}
func (c *LazyAssociatePacketConn) Write(p []byte) (n int, err error) {
if !c.responseWritten {
err = c.HandshakeSuccess()
if err != nil {
return
}
}
return c.AssociatePacketConn.Write(p)
}
func (c *LazyAssociatePacketConn) ReaderReplaceable() bool {
return c.responseWritten
}
func (c *LazyAssociatePacketConn) WriterReplaceable() bool {
return c.responseWritten
}
func (c *LazyAssociatePacketConn) Upstream() any {
return &c.AssociatePacketConn
}

View file

@ -1,8 +1,10 @@
package socks5
import (
"errors"
"io"
"net/netip"
"syscall"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@ -37,6 +39,20 @@ const (
ReplyCodeAddressTypeUnsupported byte = 8
)
func ReplyCodeForError(err error) byte {
if errors.Is(err, syscall.ENETUNREACH) {
return ReplyCodeNetworkUnreachable
} else if errors.Is(err, syscall.EHOSTUNREACH) {
return ReplyCodeHostUnreachable
} else if errors.Is(err, syscall.ECONNREFUSED) {
return ReplyCodeConnectionRefused
} else if errors.Is(err, syscall.EPERM) {
return ReplyCodeNotAllowed
} else {
return ReplyCodeFailure
}
}
// +----+----------+----------+
// |VER | NMETHODS | METHODS |
// +----+----------+----------+