diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go index c5ab530..d90ddeb 100644 --- a/common/baderror/baderror.go +++ b/common/baderror/baderror.go @@ -2,11 +2,10 @@ package baderror import ( "context" + "errors" "io" "net" "strings" - - E "github.com/sagernet/sing/common/exceptions" ) func Contains(err error, msgList ...string) bool { @@ -22,8 +21,7 @@ func WrapH2(err error) error { if err == nil { return nil } - err = E.Unwrap(err) - if err == io.ErrUnexpectedEOF { + if errors.Is(err, io.ErrUnexpectedEOF) { return io.EOF } if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") { diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 506eab1..a21fbc4 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -29,21 +29,18 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { if cachedSrc, isCached := source.(N.CachedReader); isCached { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { - if !cachedBuffer.IsEmpty() { - dataLen := cachedBuffer.Len() - for _, counter := range readCounters { - counter(int64(dataLen)) - } - _, err = destination.Write(cachedBuffer.Bytes()) - if err != nil { - cachedBuffer.Release() - return - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - } + dataLen := cachedBuffer.Len() + _, err = destination.Write(cachedBuffer.Bytes()) cachedBuffer.Release() + if err != nil { + return + } + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } continue } } @@ -139,9 +136,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } dataLen := buffer.Len() - for _, counter := range readCounters { - counter(int64(dataLen)) - } options.PostReturn(buffer) err = destination.WriteBuffer(buffer) if err != nil { @@ -152,6 +146,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } for _, counter := range writeCounters { counter(int64(dataLen)) } @@ -258,9 +255,6 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, return } dataLen := buffer.Len() - for _, counter := range readCounters { - counter(int64(dataLen)) - } options.PostReturn(buffer) err = destination.WritePacket(buffer, destinationAddress) if err != nil { @@ -270,6 +264,9 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, } return } + for _, counter := range readCounters { + counter(int64(dataLen)) + } for _, counter := range writeCounters { counter(int64(dataLen)) } @@ -282,9 +279,6 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter options := N.NewReadWaitOptions(nil, destination) var notFirstTime bool for _, packetBuffer := range packetBuffers { - for _, counter := range readCounters { - counter(int64(packetBuffer.Buffer.Len())) - } buffer := options.Copy(packetBuffer.Buffer) dataLen := buffer.Len() err = destination.WritePacket(buffer, packetBuffer.Destination) @@ -296,6 +290,9 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter } return } + for _, counter := range readCounters { + counter(int64(dataLen)) + } for _, counter := range writeCounters { counter(int64(dataLen)) } diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 482b649..7da9213 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions var readN int var from windows.Sockaddr readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0) + //goland:noinspection GoDirectComparisonOfErrors + if w.readErr != nil { + buffer.Release() + return w.readErr != windows.WSAEWOULDBLOCK + } if readN > 0 { buffer.Truncate(readN) - w.options.PostReturn(buffer) - w.buffer = buffer - } else { - buffer.Release() - } - if w.readErr == windows.WSAEWOULDBLOCK { - return false } + w.options.PostReturn(buffer) + w.buffer = buffer if from != nil { switch fromAddr := from.(type) { case *windows.SockaddrInet4: diff --git a/common/bufio/nat.go b/common/bufio/nat.go index cafeb06..6e5ab64 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So } } +func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &destinationNATPacketConn{ + NetPacketConn: conn, + origin: origin, + destination: destination, + } +} + type unidirectionalNATPacketConn struct { N.NetPacketConn origin M.Socksaddr @@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr { return c.destination.UDPAddr() } +type destinationNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + if M.SocksaddrFromNet(addr) == c.origin { + addr = c.destination.UDPAddr() + } + return +} + +func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if M.SocksaddrFromNet(addr) == c.destination { + addr = c.origin.UDPAddr() + } + return c.NetPacketConn.WriteTo(p, addr) +} + +func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return + } + if destination == c.origin { + destination = c.destination + } + return +} + +func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination == c.destination { + destination = c.origin + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *destinationNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *destinationNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { destination.Port = 0 return destination diff --git a/common/canceler/instance.go b/common/canceler/instance.go index 05faa91..c47270d 100644 --- a/common/canceler/instance.go +++ b/common/canceler/instance.go @@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration { return i.timeout } -func (i *Instance) SetTimeout(timeout time.Duration) { +func (i *Instance) SetTimeout(timeout time.Duration) bool { i.timeout = timeout - i.Update() + return i.Update() } func (i *Instance) wait() { diff --git a/common/canceler/packet.go b/common/canceler/packet.go index fb4ad84..519f283 100644 --- a/common/canceler/packet.go +++ b/common/canceler/packet.go @@ -13,7 +13,7 @@ import ( type PacketConn interface { N.PacketConn Timeout() time.Duration - SetTimeout(timeout time.Duration) + SetTimeout(timeout time.Duration) bool } type TimerPacketConn struct { @@ -24,10 +24,12 @@ type TimerPacketConn struct { func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) { if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { oldTimeout := timeoutConn.Timeout() - if timeout < oldTimeout { - timeoutConn.SetTimeout(timeout) + if oldTimeout > 0 && timeout >= oldTimeout { + return ctx, conn + } + if timeoutConn.SetTimeout(timeout) { + return ctx, conn } - return ctx, conn } err := conn.SetReadDeadline(time.Time{}) if err == nil { @@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration { return c.instance.Timeout() } -func (c *TimerPacketConn) SetTimeout(timeout time.Duration) { - c.instance.SetTimeout(timeout) +func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool { + return c.instance.SetTimeout(timeout) } func (c *TimerPacketConn) Close() error { diff --git a/common/canceler/packet_timeout.go b/common/canceler/packet_timeout.go index ab5c760..a679567 100644 --- a/common/canceler/packet_timeout.go +++ b/common/canceler/packet_timeout.go @@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration { return c.timeout } -func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) { +func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool { c.timeout = timeout - c.PacketConn.SetReadDeadline(time.Now()) + return c.PacketConn.SetReadDeadline(time.Now()) == nil } func (c *TimeoutPacketConn) Close() error { diff --git a/common/control/frag_darwin.go b/common/control/frag_darwin.go index 567a821..f76b241 100644 --- a/common/control/frag_darwin.go +++ b/common/control/frag_darwin.go @@ -4,19 +4,26 @@ import ( "os" "syscall" + N "github.com/sagernet/sing/common/network" + "golang.org/x/sys/unix" ) func DisableUDPFragment() Func { return func(network, address string, conn syscall.RawConn) error { + if N.NetworkName(network) != N.NetworkUDP { + return nil + } return Raw(conn, func(fd uintptr) error { - switch network { - case "udp4": - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil { + if network == "udp" || network == "udp4" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1) + if err != nil { return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err) } - case "udp6": - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil { + } + if network == "udp" || network == "udp6" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1) + if err != nil { return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err) } } diff --git a/common/control/frag_linux.go b/common/control/frag_linux.go index 5cb5fca..3bf9d57 100644 --- a/common/control/frag_linux.go +++ b/common/control/frag_linux.go @@ -11,17 +11,19 @@ import ( func DisableUDPFragment() Func { return func(network, address string, conn syscall.RawConn) error { - switch N.NetworkName(network) { - case N.NetworkUDP: - default: + if N.NetworkName(network) != N.NetworkUDP { return nil } return Raw(conn, func(fd uintptr) error { - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil { - return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + if network == "udp" || network == "udp4" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + if err != nil { + return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + } } - if network == "udp6" { - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil { + if network == "udp" || network == "udp6" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + if err != nil { return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err) } } diff --git a/common/control/frag_windows.go b/common/control/frag_windows.go index bf02948..8e9bb83 100644 --- a/common/control/frag_windows.go +++ b/common/control/frag_windows.go @@ -25,17 +25,19 @@ const ( func DisableUDPFragment() Func { return func(network, address string, conn syscall.RawConn) error { - switch N.NetworkName(network) { - case N.NetworkUDP: - default: + if N.NetworkName(network) != N.NetworkUDP { return nil } return Raw(conn, func(fd uintptr) error { - if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil { - return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + if network == "udp" || network == "udp4" { + err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO) + if err != nil { + return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + } } - if network == "udp6" { - if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil { + if network == "udp" || network == "udp6" { + err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO) + if err != nil { return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err) } } diff --git a/common/exceptions/cause.go b/common/exceptions/cause.go index fe7adf3..b48b6e2 100644 --- a/common/exceptions/cause.go +++ b/common/exceptions/cause.go @@ -12,3 +12,16 @@ func (e *causeError) Error() string { func (e *causeError) Unwrap() error { return e.cause } + +type causeError1 struct { + error + cause error +} + +func (e *causeError1) Error() string { + return e.error.Error() + ": " + e.cause.Error() +} + +func (e *causeError1) Unwrap() []error { + return []error{e.error, e.cause} +} diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 16b075a..24f0c29 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -32,6 +32,13 @@ func Cause(cause error, message ...any) error { return &causeError{F.ToString(message...), cause} } +func Cause1(err error, cause error) error { + if cause == nil { + panic("cause on an nil error") + } + return &causeError1{err, cause} +} + func Extend(cause error, message ...any) error { if cause == nil { panic("extend on an nil error") @@ -40,11 +47,11 @@ func Extend(cause error, message ...any) error { } func IsClosedOrCanceled(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded) + return IsClosed(err) || IsCanceled(err) || IsTimeout(err) } func IsClosed(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET) + return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN) } func IsCanceled(err error) bool { diff --git a/common/exceptions/inner.go b/common/exceptions/inner.go index 3799af9..58bcc5a 100644 --- a/common/exceptions/inner.go +++ b/common/exceptions/inner.go @@ -1,24 +1,14 @@ package exceptions -import "github.com/sagernet/sing/common" +import ( + "errors" -type HasInnerError interface { - Unwrap() error -} + "github.com/sagernet/sing/common" +) +// Deprecated: Use errors.Unwrap instead. func Unwrap(err error) error { - for { - inner, ok := err.(HasInnerError) - if !ok { - break - } - innerErr := inner.Unwrap() - if innerErr == nil { - break - } - err = innerErr - } - return err + return errors.Unwrap(err) } func Cast[T any](err error) (T, bool) { diff --git a/common/exceptions/multi.go b/common/exceptions/multi.go index 2cdec05..78fb3b6 100644 --- a/common/exceptions/multi.go +++ b/common/exceptions/multi.go @@ -63,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool { return true } } - err = Unwrap(err) - multiErr, isMulti := err.(MultiError) - if !isMulti { - return false - } - return common.All(multiErr.Unwrap(), func(it error) bool { - return IsMulti(it, targetList...) - }) + return false } diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index fa6c2d4..5b23209 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -2,9 +2,11 @@ package badjson import ( "context" + "reflect" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + cJSON "github.com/sagernet/sing/common/json/internal/contextjson" ) func MarshallObjects(objects ...any) ([]byte, error) { @@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(ctx, parentObject) - if err != nil { - return err - } var content JSONObject - err = content.UnmarshalJSONContext(ctx, inputContent) + err := content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } - for _, key := range parentContent.Keys() { + for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) { content.Remove(key) } if object == nil { diff --git a/common/json/badoption/listable.go b/common/json/badoption/listable.go index df02217..c48df12 100644 --- a/common/json/badoption/listable.go +++ b/common/json/badoption/listable.go @@ -18,6 +18,9 @@ func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) { } func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error { + if string(content) == "null" { + return nil + } var singleItem T err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem) if err == nil { diff --git a/common/json/badoption/netip.go b/common/json/badoption/netip.go index 61aef82..f22df43 100644 --- a/common/json/badoption/netip.go +++ b/common/json/badoption/netip.go @@ -35,6 +35,13 @@ func (a *Addr) UnmarshalJSON(content []byte) error { type Prefix netip.Prefix +func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix { + if p == nil { + return defaultPrefix + } + return netip.Prefix(*p) +} + func (p *Prefix) MarshalJSON() ([]byte, error) { return json.Marshal(netip.Prefix(*p).String()) } @@ -55,6 +62,13 @@ func (p *Prefix) UnmarshalJSON(content []byte) error { type Prefixable netip.Prefix +func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix { + if p == nil { + return defaultPrefix + } + return netip.Prefix(*p) +} + func (p *Prefixable) MarshalJSON() ([]byte, error) { prefix := netip.Prefix(*p) if prefix.Bits() == prefix.Addr().BitLen() { diff --git a/common/json/internal/contextjson/keys.go b/common/json/internal/contextjson/keys.go new file mode 100644 index 0000000..589007f --- /dev/null +++ b/common/json/internal/contextjson/keys.go @@ -0,0 +1,20 @@ +package json + +import ( + "reflect" + + "github.com/sagernet/sing/common" +) + +func ObjectKeys(object reflect.Type) []string { + switch object.Kind() { + case reflect.Pointer: + return ObjectKeys(object.Elem()) + case reflect.Struct: + default: + panic("invalid non-struct input") + } + return common.Map(cachedTypeFields(object).list, func(field field) string { + return field.name + }) +} diff --git a/common/json/internal/contextjson/keys_test.go b/common/json/internal/contextjson/keys_test.go new file mode 100644 index 0000000..11c72cb --- /dev/null +++ b/common/json/internal/contextjson/keys_test.go @@ -0,0 +1,26 @@ +package json_test + +import ( + "reflect" + "testing" + + json "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type MyObject struct { + Hello string `json:"hello,omitempty"` + MyWorld + MyWorld2 string `json:"-"` +} + +type MyWorld struct { + World string `json:"world,omitempty"` +} + +func TestObjectKeys(t *testing.T) { + t.Parallel() + keys := json.ObjectKeys(reflect.TypeOf(&MyObject{})) + require.Equal(t, []string{"hello", "world"}, keys) +} diff --git a/common/network/conn.go b/common/network/conn.go index c289bf6..ab3961f 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -74,9 +74,10 @@ type ExtendedConn interface { type CloseHandlerFunc = func(it error) func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc { + if onClose == nil { + panic("nil onClose") + } if parent == nil { - return parent - } else if onClose == nil { return onClose } return func(it error) { diff --git a/common/network/handshake.go b/common/network/handshake.go index d2203e0..273b9e3 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -33,7 +33,7 @@ func ReportHandshakeFailure(reporter any, err error) error { return nil } -func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error { +func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error { if err != nil { if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { hErr := handshakeConn.HandshakeFailure(err) @@ -51,12 +51,10 @@ func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) }](reporter); isTCPConn { tcpConn.SetLinger(0) } - if closer, isCloser := reporter.(io.Closer); isCloser { - err = E.Append(err, closer.Close(), func(err error) error { - return E.Cause(err, "close") - }) - } } + err = E.Append(err, reporter.Close(), func(err error) error { + return E.Cause(err, "close") + }) } if onClose != nil { onClose(err) diff --git a/common/ntp/time_windows.go b/common/ntp/time_windows.go index 5aab23e..82d8798 100644 --- a/common/ntp/time_windows.go +++ b/common/ntp/time_windows.go @@ -7,6 +7,7 @@ import ( ) func SetSystemTime(nowTime time.Time) error { + nowTime = nowTime.UTC() var systemTime windows.Systemtime systemTime.Year = uint16(nowTime.Year()) systemTime.Month = uint16(nowTime.Month()) diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 9d5bfa9..e8987f4 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -3,27 +3,35 @@ package udpnat import ( "io" "net" + "net/netip" "os" + "sync" "time" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/canceler" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/pipe" + "github.com/sagernet/sing/contrab/freelru" ) type Conn interface { N.PacketConn SetHandler(handler N.UDPHandlerEx) + canceler.PacketConn } var _ Conn = (*natConn)(nil) type natConn struct { + cache freelru.Cache[netip.AddrPort, *natConn] writer N.PacketWriter localAddr M.Socksaddr + handlerAccess sync.RWMutex handler N.UDPHandlerEx packetChan chan *N.PacketBuffer + closeOnce sync.Once doneChan chan struct{} readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions @@ -48,26 +56,6 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error return c.writer.WritePacket(buffer, destination) } -func (c *natConn) SetHandler(handler N.UDPHandlerEx) { - select { - case <-c.doneChan: - default: - } - c.handler = handler - c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) -fetch: - for { - select { - case packet := <-c.packetChan: - c.handler.NewPacketEx(packet.Buffer, packet.Destination) - N.PutPacketBuffer(packet) - continue fetch - default: - break fetch - } - } -} - func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { c.readWaitOptions = options return false @@ -87,12 +75,40 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, } } -func (c *natConn) Close() error { - select { - case <-c.doneChan: - default: - close(c.doneChan) +func (c *natConn) SetHandler(handler N.UDPHandlerEx) { + c.handlerAccess.Lock() + c.handler = handler + c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) + c.handlerAccess.Unlock() +fetch: + for { + select { + case packet := <-c.packetChan: + c.handler.NewPacketEx(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + continue fetch + default: + break fetch + } } +} + +func (c *natConn) Timeout() time.Duration { + rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort()) + if !loaded || rawConn != c { + return 0 + } + return time.Until(lifetime) +} + +func (c *natConn) SetTimeout(timeout time.Duration) bool { + return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout) +} + +func (c *natConn) Close() error { + c.closeOnce.Do(func() { + close(c.doneChan) + }) return nil } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index ac5da1d..3367bd7 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -17,27 +17,21 @@ type Service struct { cache freelru.Cache[netip.AddrPort, *natConn] handler N.UDPConnectionHandlerEx prepare PrepareFunc - metrics Metrics } type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) -type Metrics struct { - Creates uint64 - Rejects uint64 - Inputs uint64 - Drops uint64 -} - func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { + if timeout == 0 { + panic("invalid timeout") + } var cache freelru.Cache[netip.AddrPort, *natConn] if !shared { - cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.NewSynced[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } else { cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } cache.SetLifetime(timeout) - cache.SetUpdateLifetimeOnGet(true) cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { select { case <-conn.doneChan: @@ -57,30 +51,34 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { - conn, loaded := s.cache.Get(source.AddrPort()) - if !loaded { + conn, _, ok := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { - s.metrics.Rejects++ - return + return nil, false } - conn = &natConn{ + newConn := &natConn{ + cache: s.cache, writer: writer, localAddr: source, packetChan: make(chan *N.PacketBuffer, 64), doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } - s.cache.Add(source.AddrPort(), conn) - go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) - s.metrics.Creates++ + go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) + return newConn, true + }) + if !ok { + return } buffer := conn.readWaitOptions.NewPacketBuffer() for _, bufferSlice := range bufferSlices { buffer.Write(bufferSlice) } - if conn.handler != nil { - conn.handler.NewPacketEx(buffer, destination) + conn.handlerAccess.RLock() + handler := conn.handler + conn.handlerAccess.RUnlock() + if handler != nil { + handler.NewPacketEx(buffer, destination) return } packet := N.NewPacketBuffer() @@ -90,11 +88,9 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati } select { case conn.packetChan <- packet: - s.metrics.Inputs++ default: packet.Buffer.Release() N.PutPacketBuffer(packet) - s.metrics.Drops++ } } @@ -105,13 +101,3 @@ func (s *Service) Purge() { func (s *Service) PurgeExpired() { s.cache.PurgeExpired() } - -func (s *Service) Metrics() Metrics { - return s.metrics -} - -func (s *Service) ResetMetrics() Metrics { - metrics := s.metrics - s.metrics = Metrics{} - return metrics -} diff --git a/common/windnsapi/dnsapi_test.go b/common/windnsapi/dnsapi_test.go index adf582d..c5ea831 100644 --- a/common/windnsapi/dnsapi_test.go +++ b/common/windnsapi/dnsapi_test.go @@ -1,16 +1,14 @@ +//go:build windows + package windnsapi import ( - "runtime" "testing" "github.com/stretchr/testify/require" ) func TestDNSAPI(t *testing.T) { - if runtime.GOOS != "windows" { - t.SkipNow() - } t.Parallel() require.NoError(t, FlushResolverCache()) } diff --git a/common/winiphlpapi/helper.go b/common/winiphlpapi/helper.go new file mode 100644 index 0000000..6bd4e8f --- /dev/null +++ b/common/winiphlpapi/helper.go @@ -0,0 +1,217 @@ +//go:build windows + +package winiphlpapi + +import ( + "context" + "encoding/binary" + "net" + "net/netip" + "os" + "time" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func LoadEStats() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetTcpTable.Find() + if err != nil { + return err + } + err = procGetTcp6Table.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcpConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + return nil +} + +func LoadExtendedTable() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetExtendedTcpTable.Find() + if err != nil { + return err + } + err = procGetExtendedUdpTable.Find() + if err != nil { + return err + } + return nil +} + +func FindPid(network string, source netip.AddrPort) (uint32, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + if source.Addr().Is4() { + tcpTable, err := GetExtendedTcpTable() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + tcpTable, err := GetExtendedTcp6Table() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + case N.NetworkUDP: + if source.Addr().Is4() { + udpTable, err := GetExtendedUdpTable() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + udpTable, err := GetExtendedUdp6Table() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + } + return 0, E.New("process not found for ", source) +} + +func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error { + source := M.AddrPortFromNet(conn.LocalAddr()) + destination := M.AddrPortFromNet(conn.RemoteAddr()) + if source.Addr().Is4() { + tcpTable, err := GetTcpTable() + if err != nil { + return err + } + var tcpRow *MibTcpRow + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) || + destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } else { + tcpTable, err := GetTcp6Table() + if err != nil { + return err + } + var tcpRow *MibTcp6Row + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) || + destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } +} + +func DwordToAddr(addr uint32) netip.Addr { + return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr))) +} + +func DwordToPort(dword uint32) uint16 { + return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:]) +} diff --git a/common/winiphlpapi/iphlpapi.go b/common/winiphlpapi/iphlpapi.go new file mode 100644 index 0000000..74e5b90 --- /dev/null +++ b/common/winiphlpapi/iphlpapi.go @@ -0,0 +1,313 @@ +//go:build windows + +package winiphlpapi + +import ( + "errors" + "os" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + TcpTableBasicListener uint32 = iota + TcpTableBasicConnections + TcpTableBasicAll + TcpTableOwnerPidListener + TcpTableOwnerPidConnections + TcpTableOwnerPidAll + TcpTableOwnerModuleListener + TcpTableOwnerModuleConnections + TcpTableOwnerModuleAll +) + +const ( + UdpTableBasic uint32 = iota + UdpTableOwnerPid + UdpTableOwnerModule +) + +const ( + TcpConnectionEstatsSynOpts uint32 = iota + TcpConnectionEstatsData + TcpConnectionEstatsSndCong + TcpConnectionEstatsPath + TcpConnectionEstatsSendBuff + TcpConnectionEstatsRec + TcpConnectionEstatsObsRec + TcpConnectionEstatsBandwidth + TcpConnectionEstatsFineRtt + TcpConnectionEstatsMaximum +) + +type MibTcpTable struct { + DwNumEntries uint32 + Table [1]MibTcpRow +} + +type MibTcpRow struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 +} + +type MibTcp6Table struct { + DwNumEntries uint32 + Table [1]MibTcp6Row +} + +type MibTcp6Row struct { + State uint32 + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeId uint32 + RemotePort uint32 +} + +type MibTcpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcpRowOwnerPid +} + +type MibTcpRowOwnerPid struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 + DwOwningPid uint32 +} + +type MibTcp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcp6RowOwnerPid +} + +type MibTcp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + UcRemoteAddr [16]byte + DwRemoteScopeId uint32 + DwRemotePort uint32 + DwState uint32 + DwOwningPid uint32 +} + +type MibUdpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdpRowOwnerPid +} + +type MibUdpRowOwnerPid struct { + DwLocalAddr uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type MibUdp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdp6RowOwnerPid +} + +type MibUdp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type TcpEstatsSendBufferRodV0 struct { + CurRetxQueue uint64 + MaxRetxQueue uint64 + CurAppWQueue uint64 + MaxAppWQueue uint64 +} + +type TcpEstatsSendBuffRwV0 struct { + EnableCollection bool +} + +const ( + offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table) + offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table) + offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table) + offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table) + sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{}) + sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{}) +) + +func GetTcpTable() ([]MibTcpRow, error) { + var size uint32 + err := getTcpTable(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcpTable(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil + } +} + +func GetTcp6Table() ([]MibTcp6Row, error) { + var size uint32 + err := getTcp6Table(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcp6Table(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil + } +} + +func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcpConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcp6ConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} + +func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} diff --git a/common/winiphlpapi/iphlpapi_test.go b/common/winiphlpapi/iphlpapi_test.go new file mode 100644 index 0000000..5fc3b74 --- /dev/null +++ b/common/winiphlpapi/iphlpapi_test.go @@ -0,0 +1,90 @@ +//go:build windows + +package winiphlpapi_test + +import ( + "context" + "net" + "syscall" + "testing" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/winiphlpapi" + + "github.com/stretchr/testify/require" +) + +func TestFindPidTcp4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidTcp6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp4(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp6(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestWaitAck4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} + +func TestWaitAck6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} diff --git a/common/winiphlpapi/syscall_windows.go b/common/winiphlpapi/syscall_windows.go new file mode 100644 index 0000000..f6aab14 --- /dev/null +++ b/common/winiphlpapi/syscall_windows.go @@ -0,0 +1,27 @@ +package winiphlpapi + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable +//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table +//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats +//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats +//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats +//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats +//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable +//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable +//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable diff --git a/common/winiphlpapi/zsyscall_windows.go b/common/winiphlpapi/zsyscall_windows.go new file mode 100644 index 0000000..e5e9308 --- /dev/null +++ b/common/winiphlpapi/zsyscall_windows.go @@ -0,0 +1,131 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winiphlpapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + + procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") + procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats") + procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats") + procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table") + procGetTcpTable = modiphlpapi.NewProc("GetTcpTable") + procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats") + procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats") +) + +func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} diff --git a/contrab/freelru/LICENSE b/contrab/freelru/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/contrab/freelru/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrab/freelru/NOTICE b/contrab/freelru/NOTICE new file mode 100644 index 0000000..9b1c877 --- /dev/null +++ b/contrab/freelru/NOTICE @@ -0,0 +1,2 @@ +Go LRU Hashmap +Copyright 2022 Elasticsearch B.V. diff --git a/contrab/freelru/README.md b/contrab/freelru/README.md index 206fbba..72bec81 100644 --- a/contrab/freelru/README.md +++ b/contrab/freelru/README.md @@ -1,3 +1,4 @@ # freelru -kanged from github.com/elastic/go-freelru@v0.14.0 \ No newline at end of file +upstream: github.com/elastic/go-freelru@v0.16.0 +source: github.com/sagernet/go-freelru@1b34934a560d528d1866415d440625ed2a2560fe diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 707e5bc..22488e0 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -19,19 +19,17 @@ package freelru import "time" -type Cache[K comparable, V any] interface { +type Cache[K comparable, V comparable] interface { // SetLifetime sets the default lifetime of LRU elements. // Lifetime 0 means "forever". SetLifetime(lifetime time.Duration) - SetUpdateLifetimeOnGet(update bool) - - SetHealthCheck(healthCheck HealthCheckCallback[K, V]) - // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. SetOnEvict(onEvict OnEvictCallback[K, V]) + SetHealthCheck(healthCheck HealthCheckCallback[K, V]) + // Len returns the number of elements stored in the cache. Len() int @@ -49,12 +47,25 @@ type Cache[K comparable, V any] interface { // and the return value indicates that the key was not found. Get(key K) (V, bool) - GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) + GetWithLifetime(key K) (V, time.Time, bool) + + GetWithLifetimeNoExpire(key K) (V, time.Time, bool) + + // GetAndRefresh returns the value associated with the key, setting it as the most + // recently used item. + // The lifetime of the found cache item is refreshed, even if it was already expired. + GetAndRefresh(key K) (V, bool) + + GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. Peek(key K) (V, bool) + PeekWithLifetime(key K) (V, time.Time, bool) + + UpdateLifetime(key K, value V, lifetime time.Duration) bool + // Contains checks for the existence of a key, without changing its recent-ness. // If the found entry is already expired, the evict function is called. Contains(key K) bool diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 7672460..055057c 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -26,14 +26,14 @@ import ( ) // OnEvictCallback is the type for the eviction function. -type OnEvictCallback[K comparable, V any] func(K, V) +type OnEvictCallback[K comparable, V comparable] func(K, V) // HashKeyCallback is the function that creates a hash from the passed key. type HashKeyCallback[K comparable] func(K) uint32 -type HealthCheckCallback[K comparable, V any] func(K, V) bool +type HealthCheckCallback[K comparable, V comparable] func(K, V) bool -type element[K comparable, V any] struct { +type element[K comparable, V comparable] struct { key K value V @@ -62,15 +62,14 @@ type element[K comparable, V any] struct { const emptyBucket = math.MaxUint32 // LRU implements a non-thread safe fixed size LRU cache. -type LRU[K comparable, V any] struct { - buckets []uint32 // contains positions of bucket lists or 'emptyBucket' - elements []element[K, V] - onEvict OnEvictCallback[K, V] - hash HashKeyCallback[K] - healthCheck HealthCheckCallback[K, V] - lifetime time.Duration - updateLifetimeOnGet bool - metrics Metrics +type LRU[K comparable, V comparable] struct { + buckets []uint32 // contains positions of bucket lists or 'emptyBucket' + elements []element[K, V] + onEvict OnEvictCallback[K, V] + hash HashKeyCallback[K] + healthCheck HealthCheckCallback[K, V] + lifetime time.Duration + metrics Metrics // used for element clearing after removal or expiration emptyKey K @@ -101,10 +100,6 @@ func (lru *LRU[K, V]) SetLifetime(lifetime time.Duration) { lru.lifetime = lifetime } -func (lru *LRU[K, V]) SetUpdateLifetimeOnGet(update bool) { - lru.updateLifetimeOnGet = update -} - // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. // Eviction happens @@ -122,7 +117,7 @@ func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { // New constructs an LRU with the given capacity of elements. // The hash function calculates a hash value from the keys. -func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { +func New[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { return NewWithSize[K, V](capacity, capacity, hash) } @@ -131,7 +126,7 @@ func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, // A size greater than the capacity increases memory consumption and decreases the CPU consumption // by reducing the chance of collisions. // Size must not be lower than the capacity. -func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) ( +func NewWithSize[K comparable, V comparable](capacity, size uint32, hash HashKeyCallback[K]) ( *LRU[K, V], error, ) { if capacity == 0 { @@ -156,7 +151,7 @@ func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallbac return &lru, nil } -func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], +func initLRU[K comparable, V comparable](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], buckets []uint32, elements []element[K, V], ) { lru.cap = capacity @@ -181,7 +176,13 @@ func (lru *LRU[K, V]) hashToBucketPos(hash uint32) uint32 { if lru.mask != 0 { return hash & lru.mask } - return hash % lru.size + return fastModulo(hash, lru.size) +} + +// fastModulo calculates x % n without using the modulo operator (~4x faster). +// Reference: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +func fastModulo(x, n uint32) uint32 { + return uint32((uint64(x) * uint64(n)) >> 32) //nolint:gosec } // hashToPos converts a key into a position in the elements array. @@ -308,30 +309,50 @@ func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) { lru.elements[pos].value = lru.emptyValue } -func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, int64, bool) { +func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { _, startPos := lru.hashToPos(hash) if startPos == emptyBucket { - return emptyBucket, 0, false + return emptyBucket, false } pos := startPos for { if key == lru.elements[pos].key { - elem := lru.elements[pos] - if (elem.expire != 0 && elem.expire <= now()) || (lru.healthCheck != nil && !lru.healthCheck(key, elem.value)) { + if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() || (lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value)) { lru.removeAt(pos) - return emptyBucket, elem.expire, false + return emptyBucket, false } - if updateLifetimeOnGet { - lru.elements[pos].expire = expire(lru.lifetime) - } - return pos, elem.expire, true + return pos, true } pos = lru.elements[pos].nextBucket if pos == startPos { // Key not found - return emptyBucket, 0, false + return emptyBucket, false + } + } +} + +func (lru *LRU[K, V]) findKeyNoExpire(hash uint32, key K) (uint32, bool) { + _, startPos := lru.hashToPos(hash) + if startPos == emptyBucket { + return emptyBucket, false + } + + pos := startPos + for { + if key == lru.elements[pos].key { + if lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value) { + lru.removeAt(pos) + return emptyBucket, false + } + return pos, true + } + + pos = lru.elements[pos].nextBucket + if pos == startPos { + // Key not found + return emptyBucket, false } } } @@ -444,30 +465,108 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { - value, _, ok = lru.get(lru.hash(key), key) - return + return lru.get(lru.hash(key), key) } -func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { - value, expireMills, ok := lru.get(lru.hash(key), key) - lifetime = time.UnixMilli(expireMills) - return -} - -func (lru *LRU[K, V]) get(hash uint32, key K) (value V, expire int64, ok bool) { - if pos, expire, ok := lru.findKey(hash, key, lru.updateLifetimeOnGet); ok { +func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { if pos != lru.head { lru.unlinkElement(pos) lru.setHead(pos) } lru.metrics.Hits++ - return lru.elements[pos].value, expire, ok + return lru.elements[pos].value, ok } lru.metrics.Misses++ return } +func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + return lru.getWithLifetime(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + lru.metrics.Misses++ + return +} + +func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + return lru.getWithLifetimeNoExpire(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + lru.metrics.Misses++ + return +} + +// GetAndRefresh returns the value associated with the key, setting it as the most +// recently used item. +// The lifetime of the found cache item is refreshed, even if it was already expired. +func (lru *LRU[K, V]) GetAndRefresh(key K) (V, bool) { + return lru.getAndRefresh(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getAndRefresh(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + lru.elements[pos].expire = expire(lru.lifetime) + return lru.elements[pos].value, ok + } + + lru.metrics.Misses++ + return +} + +func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) { + value, updated, ok := lru.getAndRefreshOrAdd(lru.hash(key), key, constructor) + if !updated && ok { + lru.PurgeExpired() + } + return value, updated, ok +} + +func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + lru.elements[pos].expire = expire(lru.lifetime) + return lru.elements[pos].value, true, true + } + lru.metrics.Misses++ + value, ok = constructor() + if !ok { + return + } + lru.addWithLifetime(hash, key, value, lru.lifetime) + return value, false, true +} + // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { @@ -475,13 +574,58 @@ func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { } func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { - if pos, _, ok := lru.findKey(hash, key, false); ok { + if pos, ok := lru.findKey(hash, key); ok { return lru.elements[pos].value, ok } return } +func (lru *LRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + return lru.peekWithLifetime(lru.hash(key), key) +} + +func (lru *LRU[K, V]) peekWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + return +} + +func (lru *LRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) bool { + return lru.updateLifetime(lru.hash(key), key, value, lifetime) +} + +func (lru *LRU[K, V]) updateLifetime(hash uint32, key K, value V, lifetime time.Duration) bool { + _, startPos := lru.hashToPos(hash) + if startPos == emptyBucket { + return false + } + pos := startPos + for { + if lru.elements[pos].key == key { + if lru.elements[pos].value != value { + return false + } + + lru.elements[pos].expire = expire(lifetime) + + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Inserts++ + return true + } + + pos = lru.elements[pos].nextBucket + if pos == startPos { + return false + } + } +} + // Contains checks for the existence of a key, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *LRU[K, V]) Contains(key K) (ok bool) { @@ -502,7 +646,7 @@ func (lru *LRU[K, V]) Remove(key K) (removed bool) { } func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { - if pos, _, ok := lru.findKey(hash, key, false); ok { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { lru.removeAt(pos) return ok } @@ -552,7 +696,8 @@ func (lru *LRU[K, V]) Keys() []K { // The evict function is called for each expired item. // The LRU metrics are reset. func (lru *LRU[K, V]) Purge() { - for i := uint32(0); i < lru.len; i++ { + l := lru.len + for i := uint32(0); i < l; i++ { _, _, _ = lru.RemoveOldest() } @@ -562,14 +707,19 @@ func (lru *LRU[K, V]) Purge() { // PurgeExpired purges all expired items from the LRU. // The evict function is called for each expired item. func (lru *LRU[K, V]) PurgeExpired() { - for i := uint32(0); i < lru.len; i++ { - pos := lru.elements[lru.head].next - if lru.elements[pos].expire != 0 { - if lru.elements[pos].expire > now() { - return // no more expired items - } + n := now() +loop: + l := lru.len + if l == 0 { + return + } + pos := lru.elements[lru.head].next + for i := uint32(0); i < l; i++ { + if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= n { lru.removeAt(pos) + goto loop } + pos = lru.elements[pos].next } } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index fa3c157..3d4a328 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -1,34 +1,35 @@ package freelru_test import ( + "math/rand" "testing" "time" + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" "github.com/stretchr/testify/require" ) -func TestMyChange0(t *testing.T) { +func TestUpdateLifetimeOnGet(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) - lru.SetUpdateLifetimeOnGet(true) lru.AddWithLifetime("hello", "world", 2*time.Second) time.Sleep(time.Second) - _, ok := lru.Get("hello") + _, ok := lru.GetAndRefresh("hello") require.True(t, ok) time.Sleep(time.Second + time.Millisecond*100) _, ok = lru.Get("hello") require.True(t, ok) } -func TestMyChange1(t *testing.T) { +func TestUpdateLifetimeOnGet1(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) - lru.SetUpdateLifetimeOnGet(true) lru.AddWithLifetime("hello", "world", 2*time.Second) time.Sleep(time.Second) lru.Peek("hello") @@ -36,3 +37,64 @@ func TestMyChange1(t *testing.T) { _, ok := lru.Get("hello") require.False(t, ok) } + +func TestUpdateLifetime(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.Add("hello", "world") + require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second)) + time.Sleep(time.Second) + _, ok := lru.Get("hello") + require.True(t, ok) + time.Sleep(time.Second + time.Millisecond*100) + _, ok = lru.Get("hello") + require.False(t, ok) +} + +func TestUpdateLifetime1(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.Add("hello", "world") + require.False(t, lru.UpdateLifetime("hello", "not world", 2*time.Second)) + time.Sleep(2*time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.True(t, ok) +} + +func TestUpdateLifetime2(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.AddWithLifetime("hello", "world", 2*time.Second) + time.Sleep(time.Second) + require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second)) + time.Sleep(time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.True(t, ok) + time.Sleep(time.Second + time.Millisecond*100) + _, ok = lru.Get("hello") + require.False(t, ok) +} + +func TestPurgeExpired(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, *string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.SetLifetime(time.Second) + lru.SetOnEvict(func(s string, s2 *string) { + if s2 == nil { + t.Fail() + } + }) + for i := 0; i < 100; i++ { + lru.AddWithLifetime("hello_"+F.ToString(i), common.Ptr("world_"+F.ToString(i)), time.Duration(rand.Intn(3000))*time.Millisecond) + } + for i := 0; i < 5; i++ { + time.Sleep(time.Second) + lru.GetAndRefreshOrAdd("hellox"+F.ToString(i), func() (*string, bool) { + return common.Ptr("worldx"), true + }) + } +} diff --git a/contrab/freelru/sharedlru.go b/contrab/freelru/shardedlru.go similarity index 80% rename from contrab/freelru/sharedlru.go rename to contrab/freelru/shardedlru.go index db1d8cd..fa325d7 100644 --- a/contrab/freelru/sharedlru.go +++ b/contrab/freelru/shardedlru.go @@ -12,7 +12,7 @@ import ( // ShardedLRU is a thread-safe, sharded, fixed size LRU cache. // Sharding is used to reduce lock contention on high concurrency. // The downside is that exact LRU behavior is not given (as for the LRU and SynchedLRU types). -type ShardedLRU[K comparable, V any] struct { +type ShardedLRU[K comparable, V comparable] struct { lrus []LRU[K, V] mus []sync.RWMutex hash HashKeyCallback[K] @@ -32,14 +32,6 @@ func (lru *ShardedLRU[K, V]) SetLifetime(lifetime time.Duration) { } } -func (lru *ShardedLRU[K, V]) SetUpdateLifetimeOnGet(update bool) { - for shard := range lru.lrus { - lru.mus[shard].Lock() - lru.lrus[shard].SetUpdateLifetimeOnGet(update) - lru.mus[shard].Unlock() - } -} - // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { @@ -66,7 +58,7 @@ func nextPowerOfTwo(val uint32) uint32 { } // NewSharded creates a new thread-safe sharded LRU hashmap with the given capacity. -func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V], +func NewSharded[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V], error, ) { size := uint32(float64(capacity) * 1.25) // 25% extra space for fewer collisions @@ -74,7 +66,7 @@ func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) ( return NewShardedWithSize[K, V](uint32(runtime.GOMAXPROCS(0)*16), capacity, size, hash) } -func NewShardedWithSize[K comparable, V any](shards, capacity, size uint32, +func NewShardedWithSize[K comparable, V comparable](shards, capacity, size uint32, hash HashKeyCallback[K]) ( *ShardedLRU[K, V], error, ) { @@ -178,7 +170,7 @@ func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) { shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, _, ok = lru.lrus[shard].get(hash, key) + value, ok = lru.lrus[shard].get(hash, key) lru.mus[shard].Unlock() return @@ -189,9 +181,48 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, expireMills, ok := lru.lrus[shard].get(hash, key) + value, lifetime, ok = lru.lrus[shard].getWithLifetime(hash, key) lru.mus[shard].Unlock() - lifetime = time.UnixMilli(expireMills) + + return +} + +func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].RLock() + value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key) + lru.mus[shard].RUnlock() + + return +} + +// GetAndRefresh returns the value associated with the key, setting it as the most +// recently used item. +// The lifetime of the found cache item is refreshed, even if it was already expired. +func (lru *ShardedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, ok = lru.lrus[shard].getAndRefresh(hash, key) + lru.mus[shard].Unlock() + + return +} + +func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, updated, ok = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor) + lru.mus[shard].Unlock() + + if !updated && ok { + lru.PurgeExpired() + } return } @@ -208,6 +239,28 @@ func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) { return } +func (lru *ShardedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, lifetime, ok = lru.lrus[shard].peekWithLifetime(hash, key) + lru.mus[shard].Unlock() + + return +} + +func (lru *ShardedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + ok = lru.lrus[shard].updateLifetime(hash, key, value, lifetime) + lru.mus[shard].Unlock() + + return +} + // Contains checks for the existence of a key, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *ShardedLRU[K, V]) Contains(key K) (ok bool) { diff --git a/contrab/freelru/syncedlru.go b/contrab/freelru/syncedlru.go new file mode 100644 index 0000000..4364907 --- /dev/null +++ b/contrab/freelru/syncedlru.go @@ -0,0 +1,270 @@ +package freelru + +import ( + "sync" + "time" +) + +type SyncedLRU[K comparable, V comparable] struct { + mu sync.RWMutex + lru *LRU[K, V] +} + +var _ Cache[int, int] = (*SyncedLRU[int, int])(nil) + +// SetLifetime sets the default lifetime of LRU elements. +// Lifetime 0 means "forever". +func (lru *SyncedLRU[K, V]) SetLifetime(lifetime time.Duration) { + lru.mu.Lock() + lru.lru.SetLifetime(lifetime) + lru.mu.Unlock() +} + +// SetOnEvict sets the OnEvict callback function. +// The onEvict function is called for each evicted lru entry. +func (lru *SyncedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { + lru.mu.Lock() + lru.lru.SetOnEvict(onEvict) + lru.mu.Unlock() +} + +func (lru *SyncedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { + lru.mu.Lock() + lru.lru.SetHealthCheck(healthCheck) + lru.mu.Unlock() +} + +// NewSynced creates a new thread-safe LRU hashmap with the given capacity. +func NewSynced[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*SyncedLRU[K, V], + error, +) { + return NewSyncedWithSize[K, V](capacity, capacity, hash) +} + +func NewSyncedWithSize[K comparable, V comparable](capacity, size uint32, + hash HashKeyCallback[K], +) (*SyncedLRU[K, V], error) { + lru, err := NewWithSize[K, V](capacity, size, hash) + if err != nil { + return nil, err + } + return &SyncedLRU[K, V]{lru: lru}, nil +} + +// Len returns the number of elements stored in the cache. +func (lru *SyncedLRU[K, V]) Len() (length int) { + lru.mu.RLock() + length = lru.lru.Len() + lru.mu.RUnlock() + + return +} + +// AddWithLifetime adds a key:value to the cache with a lifetime. +// Returns true, true if key was updated and eviction occurred. +func (lru *SyncedLRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + evicted = lru.lru.addWithLifetime(hash, key, value, lifetime) + lru.mu.Unlock() + + return +} + +// Add adds a key:value to the cache. +// Returns true, true if key was updated and eviction occurred. +func (lru *SyncedLRU[K, V]) Add(key K, value V) (evicted bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + evicted = lru.lru.add(hash, key, value) + lru.mu.Unlock() + + return +} + +// Get returns the value associated with the key, setting it as the most +// recently used item. +// If the found cache item is already expired, the evict function is called +// and the return value indicates that the key was not found. +func (lru *SyncedLRU[K, V]) Get(key K) (value V, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, ok = lru.lru.get(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.getWithLifetime(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key) + lru.mu.Unlock() + + return +} + +// GetAndRefresh returns the value associated with the key, setting it as the most +// recently used item. +// The lifetime of the found cache item is refreshed, even if it was already expired. +func (lru *SyncedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, ok = lru.lru.getAndRefresh(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, updated, ok = lru.lru.getAndRefreshOrAdd(hash, key, constructor) + if !updated && ok { + lru.lru.PurgeExpired() + } + lru.mu.Unlock() + + return +} + +// Peek looks up a key's value from the cache, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *SyncedLRU[K, V]) Peek(key K) (value V, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, ok = lru.lru.peek(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.peekWithLifetime(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + ok = lru.lru.updateLifetime(hash, key, value, lifetime) + lru.mu.Unlock() + + return +} + +// Contains checks for the existence of a key, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *SyncedLRU[K, V]) Contains(key K) (ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + ok = lru.lru.contains(hash, key) + lru.mu.Unlock() + + return +} + +// Remove removes the key from the cache. +// The return value indicates whether the key existed or not. +// The evict function is being called if the key existed. +func (lru *SyncedLRU[K, V]) Remove(key K) (removed bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + removed = lru.lru.remove(hash, key) + lru.mu.Unlock() + + return +} + +// RemoveOldest removes the oldest entry from the cache. +// Key, value and an indicator of whether the entry has been removed is returned. +// The evict function is called for the removed entry. +func (lru *SyncedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) { + lru.mu.Lock() + key, value, removed = lru.lru.RemoveOldest() + lru.mu.Unlock() + + return +} + +// Keys returns a slice of the keys in the cache, from oldest to newest. +// Expired entries are not included. +// The evict function is called for each expired item. +func (lru *SyncedLRU[K, V]) Keys() (keys []K) { + lru.mu.Lock() + keys = lru.lru.Keys() + lru.mu.Unlock() + + return +} + +// Purge purges all data (key and value) from the LRU. +// The evict function is called for each expired item. +// The LRU metrics are reset. +func (lru *SyncedLRU[K, V]) Purge() { + lru.mu.Lock() + lru.lru.Purge() + lru.mu.Unlock() +} + +// PurgeExpired purges all expired items from the LRU. +// The evict function is called for each expired item. +func (lru *SyncedLRU[K, V]) PurgeExpired() { + lru.mu.Lock() + lru.lru.PurgeExpired() + lru.mu.Unlock() +} + +// Metrics returns the metrics of the cache. +func (lru *SyncedLRU[K, V]) Metrics() Metrics { + lru.mu.Lock() + metrics := lru.lru.Metrics() + lru.mu.Unlock() + return metrics +} + +// ResetMetrics resets the metrics of the cache and returns the previous state. +func (lru *SyncedLRU[K, V]) ResetMetrics() Metrics { + lru.mu.Lock() + metrics := lru.lru.ResetMetrics() + lru.mu.Unlock() + return metrics +} + +// just used for debugging +func (lru *SyncedLRU[K, V]) dump() { + lru.mu.RLock() + lru.lru.dump() + lru.mu.RUnlock() +} + +func (lru *SyncedLRU[K, V]) PrintStats() { + lru.mu.RLock() + lru.lru.PrintStats() + lru.mu.RUnlock() +} diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 9528573..fd5817b 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -21,17 +21,14 @@ import ( "github.com/sagernet/sing/common/pipe" ) -// Deprecated: Use HandleConnectionEx instead. -func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, - //nolint:staticcheck - handler N.TCPConnectionHandler, metadata M.Metadata, -) error { - return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, nil) -} - -func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, - //nolint:staticcheck - handler N.TCPConnectionHandler, handlerEx N.TCPConnectionHandlerEx, source M.Socksaddr, onClose N.CloseHandlerFunc, +func HandleConnectionEx( + ctx context.Context, + conn net.Conn, + reader *std_bufio.Reader, + authenticator *auth.Authenticator, + handler N.TCPConnectionHandlerEx, + source M.Socksaddr, + onClose N.CloseHandlerFunc, ) error { for { request, err := ReadRequest(reader) @@ -105,13 +102,8 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } else { requestConn = conn } - if handler != nil { - //nolint:staticcheck - return handler.NewConnection(ctx, requestConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) - } else { - handlerEx.NewConnectionEx(ctx, requestConn, source, destination, onClose) - return nil - } + handler.NewConnectionEx(ctx, requestConn, source, destination, onClose) + return nil } else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" { destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) if destination.Port == 0 { @@ -124,19 +116,11 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } serverConn, clientConn := pipe.Pipe() go func() { - if handler != nil { - //nolint:staticcheck - err := handler.NewConnection(ctx, clientConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) - if err != nil { + handler.NewConnectionEx(ctx, clientConn, source, destination, func(it error) { + if it != nil { common.Close(serverConn, clientConn) } - } else { - handlerEx.NewConnectionEx(ctx, clientConn, source, destination, func(it error) { - if it != nil { - common.Close(serverConn, clientConn) - } - }) - } + }) }() err = request.Write(serverConn) if err != nil { @@ -150,7 +134,7 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } return bufio.CopyConn(ctx, conn, serverConn) } else { - err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) + err = handleHTTPConnection(ctx, handler, conn, request, source) if err != nil { return err } @@ -160,9 +144,7 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re func handleHTTPConnection( ctx context.Context, - //nolint:staticcheck - handler N.TCPConnectionHandler, - handlerEx N.TCPConnectionHandlerEx, + handler N.TCPConnectionHandlerEx, conn net.Conn, request *http.Request, source M.Socksaddr, ) error { @@ -188,21 +170,10 @@ func handleHTTPConnection( DisableCompression: true, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { input, output := pipe.Pipe() - if handler != nil { - go func() { - //nolint:staticcheck - hErr := handler.NewConnection(ctx, output, M.Metadata{Protocol: "http", Source: source, Destination: M.ParseSocksaddr(address)}) - if hErr != nil { - innerErr.Store(hErr) - common.Close(input, output) - } - }() - } else { - go handlerEx.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) { - innerErr.Store(it) - common.Close(input, output) - }) - } + go handler.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) { + innerErr.Store(it) + common.Close(input, output) + }) return input, nil }, }, diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 7232eea..dc9c057 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -19,19 +20,15 @@ import ( "github.com/sagernet/sing/protocol/socks/socks5" ) -// Deprecated: Use HandlerEx instead. -// -//nolint:staticcheck -type Handler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler -} - type HandlerEx interface { N.TCPConnectionHandlerEx N.UDPConnectionHandlerEx } +type PacketListener interface { + ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error) +} + func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { err := socks4.WriteRequest(conn, socks4.Request{ Command: command, @@ -87,6 +84,26 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, } else if authResponse.Method != socks5.AuthTypeNotRequired { return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method) } + + if command == socks5.CommandUDPAssociate { + if destination.Addr.IsPrivate() { + if destination.Addr.Is6() { + destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + } else { + destination.Addr = netip.IPv6Loopback() + } + } else if destination.Addr.IsGlobalUnicast() { + if destination.Addr.Is6() { + destination.Addr = netip.IPv6Unspecified() + } else { + destination.Addr = netip.IPv4Unspecified() + } + } else { + destination.Addr = netip.IPv6Unspecified() + } + destination.Port = 0 + } + err = socks5.WriteRequest(conn, socks5.Request{ Command: command, Destination: destination, @@ -104,23 +121,13 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, return response, err } -// Deprecated: use HandleConnectionEx instead. -func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { - return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata) -} - -// Deprecated: Use HandleConnectionEx instead. -func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { - return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, metadata.Destination, nil) -} - func HandleConnectionEx( ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, - //nolint:staticcheck - handler Handler, - handlerEx HandlerEx, - source M.Socksaddr, destination M.Socksaddr, + handler HandlerEx, + packetListener PacketListener, + // resolver TorResolver, + source M.Socksaddr, onClose N.CloseHandlerFunc, ) error { version, err := reader.ReadByte() @@ -145,21 +152,13 @@ func HandleConnectionEx( } return E.New("socks4: authentication failed, username=", request.Username) } - destination = request.Destination - if handlerEx != nil { - handlerEx.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, destination, onClose) - } else { - err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeGranted, - Destination: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err - } - //nolint:staticcheck - return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, M.Metadata{Protocol: "socks4", Source: source, Destination: destination}) - } + handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose) return nil + /*case CommandTorResolve, CommandTorResolvePTR: + if resolver == nil { + return E.New("socks4: torsocks: commands not implemented") + } + return handleTorSocks4(ctx, conn, request, resolver)*/ default: err = socks4.WriteResponse(conn, socks4.Response{ ReplyCode: socks4.ReplyCodeRejectedOrFailed, @@ -223,53 +222,43 @@ func HandleConnectionEx( } switch request.Command { case socks5.CommandConnect: - destination = request.Destination - if handlerEx != nil { - handlerEx.NewConnectionEx(ctx, NewLazyConn(conn, version), source, destination, onClose) - return nil - } else { - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err - } - //nolint:staticcheck - return handler.NewConnection(ctx, conn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) - } + handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose) + return nil case socks5.CommandUDPAssociate: - var udpConn *net.UDPConn - udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) - if err != nil { - return err - } - if handlerEx == nil { - defer udpConn.Close() - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), - }) - if err != nil { - return err - } - destination = request.Destination - associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn) - var innerError error - done := make(chan struct{}) - go func() { - //nolint:staticcheck - innerError = handler.NewPacketConnection(ctx, associatePacketConn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) - close(done) - }() - err = common.Error(io.Copy(io.Discard, conn)) - associatePacketConn.Close() - <-done - return E.Errors(innerError, err) + var ( + listenConfig net.ListenConfig + udpConn net.PacketConn + ) + if packetListener != nil { + udpConn, err = packetListener.ListenPacket(listenConfig, ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String()) } else { - handlerEx.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn), source, destination, onClose) - return nil + udpConn, err = listenConfig.ListenPacket(ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String()) } + if err != nil { + return E.Cause(err, "socks5: listen udp") + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), + }) + if err != nil { + return E.Cause(err, "socks5: write response") + } + var socksPacketConn N.PacketConn = NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), M.Socksaddr{}, conn) + firstPacket := buf.NewPacket() + var destination M.Socksaddr + destination, err = socksPacketConn.ReadPacket(firstPacket) + if err != nil { + return E.Cause(err, "socks5: read first packet") + } + socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination) + handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose) + return nil + /*case CommandTorResolve, CommandTorResolvePTR: + if resolver == nil { + return E.New("socks4: torsocks: commands not implemented") + } + return handleTorSocks5(ctx, conn, request, resolver)*/ default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, diff --git a/protocol/socks/handshake_tor.go b/protocol/socks/handshake_tor.go new file mode 100644 index 0000000..5d66322 --- /dev/null +++ b/protocol/socks/handshake_tor.go @@ -0,0 +1,146 @@ +package socks + +import ( + "context" + "net" + "net/netip" + "os" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks/socks4" + "github.com/sagernet/sing/protocol/socks/socks5" +) + +const ( + CommandTorResolve byte = 0xF0 + CommandTorResolvePTR byte = 0xF1 +) + +type TorResolver interface { + LookupIP(ctx context.Context, host string) (netip.Addr, error) + LookupPTR(ctx context.Context, addr netip.Addr) (string, error) +} + +func handleTorSocks4(ctx context.Context, conn net.Conn, request socks4.Request, resolver TorResolver) error { + switch request.Command { + case CommandTorResolve: + if !request.Destination.IsFqdn() { + return E.New("socks4: torsocks: invalid destination") + } + ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn) + if err != nil { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + if err != nil { + return err + } + return E.Cause(err, "socks4: torsocks: lookup failed for domain: ", request.Destination.Fqdn) + } + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFrom(ipAddr, 0), + }) + if err != nil { + return E.Cause(err, "socks4: torsocks: write response") + } + return nil + case CommandTorResolvePTR: + var ipAddr netip.Addr + if request.Destination.IsIP() { + ipAddr = request.Destination.Addr + } else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") { + ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")]) + } else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") { + ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":")) + } + if !ipAddr.IsValid() { + return E.New("socks4: torsocks: invalid destination") + } + host, err := resolver.LookupPTR(ctx, ipAddr) + if err != nil { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + if err != nil { + return err + } + return E.Cause(err, "socks4: torsocks: lookup PTR failed for ip: ", ipAddr) + } + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.Socksaddr{ + Fqdn: host, + }, + }) + if err != nil { + return E.Cause(err, "socks4: torsocks: write response") + } + return nil + default: + return os.ErrInvalid + } +} + +func handleTorSocks5(ctx context.Context, conn net.Conn, request socks5.Request, resolver TorResolver) error { + switch request.Command { + case CommandTorResolve: + if !request.Destination.IsFqdn() { + return E.New("socks5: torsocks: invalid destination") + } + ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn) + if err != nil { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeFailure, + }) + if err != nil { + return err + } + return E.Cause(err, "socks5: torsocks: lookup failed for domain: ", request.Destination.Fqdn) + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFrom(ipAddr, 0), + }) + if err != nil { + return E.Cause(err, "socks5: torsocks: write response") + } + return nil + case CommandTorResolvePTR: + var ipAddr netip.Addr + if request.Destination.IsIP() { + ipAddr = request.Destination.Addr + } else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") { + ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")]) + } else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") { + ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":")) + } + if !ipAddr.IsValid() { + return E.New("socks5: torsocks: invalid destination") + } + host, err := resolver.LookupPTR(ctx, ipAddr) + if err != nil { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeFailure, + }) + if err != nil { + return err + } + return E.Cause(err, "socks5: torsocks: lookup PTR failed for ip: ", ipAddr) + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.Socksaddr{ + Fqdn: host, + }, + }) + if err != nil { + return E.Cause(err, "socks5: torsocks: write response") + } + return nil + default: + return os.ErrInvalid + } +} diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go index f98ac3d..28782bc 100644 --- a/protocol/socks/lazy.go +++ b/protocol/socks/lazy.go @@ -105,12 +105,11 @@ type LazyAssociatePacketConn struct { responseWritten bool } -func NewLazyAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *LazyAssociatePacketConn { +func NewLazyAssociatePacketConn(conn net.Conn, underlying net.Conn) *LazyAssociatePacketConn { return &LazyAssociatePacketConn{ AssociatePacketConn: AssociatePacketConn{ AbstractConn: conn, conn: bufio.NewExtendedConn(conn), - remoteAddr: remoteAddr, underlying: underlying, }, }