From 0f2447a95b368f4a1cfca0c91e095e88bf7ad997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 20 Oct 2024 13:29:42 +0800 Subject: [PATCH 01/55] Crazy sekai overturns the small pond --- common/bufio/addr_conn.go | 11 +- common/bufio/vectorised_unix.go | 1 - common/exceptions/error.go | 1 + common/exceptions/timeout.go | 1 - common/metadata/metadata.go | 1 + common/network/conn.go | 51 ++++++- common/network/handshake.go | 65 ++++++++- common/network/thread.go | 2 +- common/random/seed_go119.go | 1 - common/rw/varint.go | 1 - common/udpnat/service.go | 69 ++++++++-- protocol/http/handshake.go | 160 +++++++++++++--------- protocol/socks/handshake.go | 131 +++++++++++------- protocol/socks/lazy.go | 215 ++++++++++++++++++++++++++++++ protocol/socks/socks5/protocol.go | 16 +++ 15 files changed, 588 insertions(+), 138 deletions(-) create mode 100644 protocol/socks/lazy.go diff --git a/common/bufio/addr_conn.go b/common/bufio/addr_conn.go index d74ce9f..4d095b5 100644 --- a/common/bufio/addr_conn.go +++ b/common/bufio/addr_conn.go @@ -9,19 +9,20 @@ import ( type AddrConn struct { net.Conn - M.Metadata + Source M.Socksaddr + Destination M.Socksaddr } func (c *AddrConn) LocalAddr() net.Addr { - if c.Metadata.Destination.IsValid() { - return c.Metadata.Destination.TCPAddr() + if c.Destination.IsValid() { + return c.Destination.TCPAddr() } return c.Conn.LocalAddr() } func (c *AddrConn) RemoteAddr() net.Addr { - if c.Metadata.Source.IsValid() { - return c.Metadata.Source.TCPAddr() + if c.Source.IsValid() { + return c.Source.TCPAddr() } return c.Conn.RemoteAddr() } diff --git a/common/bufio/vectorised_unix.go b/common/bufio/vectorised_unix.go index 6bb5d7d..b0697f4 100644 --- a/common/bufio/vectorised_unix.go +++ b/common/bufio/vectorised_unix.go @@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error { var innerErr unix.Errno err := w.rawConn.Write(func(fd uintptr) (done bool) { //nolint:staticcheck - //goland:noinspection GoDeprecation _, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList))) return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK }) diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 5d056e6..16b075a 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -12,6 +12,7 @@ import ( F "github.com/sagernet/sing/common/format" ) +// Deprecated: wtf is this? type Handler interface { NewError(ctx context.Context, err error) } diff --git a/common/exceptions/timeout.go b/common/exceptions/timeout.go index f2ae6c3..222123a 100644 --- a/common/exceptions/timeout.go +++ b/common/exceptions/timeout.go @@ -12,7 +12,6 @@ type TimeoutError interface { func IsTimeout(err error) bool { var netErr net.Error if errors.As(err, &netErr) { - //goland:noinspection GoDeprecation //nolint:staticcheck return netErr.Temporary() && netErr.Timeout() } diff --git a/common/metadata/metadata.go b/common/metadata/metadata.go index db2d7d0..a67cb6a 100644 --- a/common/metadata/metadata.go +++ b/common/metadata/metadata.go @@ -1,5 +1,6 @@ package metadata +// Deprecated: wtf is this? type Metadata struct { Protocol string Source Socksaddr diff --git a/common/network/conn.go b/common/network/conn.go index a920ab6..01fe135 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "sync" "time" "github.com/sagernet/sing/common" @@ -70,8 +71,38 @@ type ExtendedConn interface { net.Conn } +type CloseHandlerFunc = func(it error) + +func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc { + if parent == nil { + return parent + } else if onClose == nil { + return onClose + } + return func(it error) { + onClose(it) + parent(it) + } +} + +func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc { + var once sync.Once + return func(it error) { + once.Do(func() { + onClose(it) + }) + } +} + +// Deprecated: Use TCPConnectionHandlerEx instead. type TCPConnectionHandler interface { - NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error + NewConnection(ctx context.Context, conn net.Conn, + //nolint:staticcheck + metadata M.Metadata) error +} + +type TCPConnectionHandlerEx interface { + NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc) } type NetPacketConn interface { @@ -85,12 +116,26 @@ type BindPacketConn interface { net.Conn } +// Deprecated: Use UDPHandlerEx instead. type UDPHandler interface { - NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error + NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, + //nolint:staticcheck + metadata M.Metadata) error } +type UDPHandlerEx interface { + NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error +} + +// Deprecated: Use UDPConnectionHandlerEx instead. type UDPConnectionHandler interface { - NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error + NewPacketConnection(ctx context.Context, conn PacketConn, + //nolint:staticcheck + metadata M.Metadata) error +} + +type UDPConnectionHandlerEx interface { + NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc) } type CachedReader interface { diff --git a/common/network/handshake.go b/common/network/handshake.go index 674211d..5f13492 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -1,6 +1,9 @@ package network import ( + "io" + "net" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" ) @@ -13,17 +16,71 @@ type HandshakeSuccess interface { HandshakeSuccess() error } -func ReportHandshakeFailure(conn any, err error) error { - if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn { +type ConnHandshakeSuccess interface { + ConnHandshakeSuccess(conn net.Conn) error +} + +type PacketConnHandshakeSuccess interface { + PacketConnHandshakeSuccess(conn net.PacketConn) error +} + +func ReportHandshakeFailure(reporter any, err error) error { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { return E.Cause(err, "write handshake failure") }) } + return nil +} + +func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error { + if err != nil { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { + err = E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { + return E.Cause(err, "write handshake failure") + }) + } else { + if tcpConn, isTCPConn := common.Cast[interface { + SetLinger(sec int) error + }](reporter); isTCPConn { + tcpConn.SetLinger(0) + } + if closer, isCloser := reporter.(io.Closer); isCloser { + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close") + }) + } + } + } + if onClose != nil { + onClose(err) + } return err } -func ReportHandshakeSuccess(conn any) error { - if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn { +// Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead +func ReportHandshakeSuccess(reporter any) error { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.HandshakeSuccess() + } + return nil +} + +func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error { + if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.ConnHandshakeSuccess(conn) + } + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.HandshakeSuccess() + } + return nil +} + +func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error { + if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.PacketConnHandshakeSuccess(conn) + } + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn { return handshakeConn.HandshakeSuccess() } return nil diff --git a/common/network/thread.go b/common/network/thread.go index 22063af..4e47da7 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -11,6 +11,7 @@ type ThreadUnsafeWriter interface { } // Deprecated: Use ReadWaiter interface instead. + type ThreadSafeReader interface { // Deprecated: Use ReadWaiter interface instead. ReadBufferThreadSafe() (buffer *buf.Buffer, err error) @@ -18,7 +19,6 @@ type ThreadSafeReader interface { // Deprecated: Use ReadWaiter interface instead. type ThreadSafePacketReader interface { - // Deprecated: Use ReadWaiter interface instead. ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) } diff --git a/common/random/seed_go119.go b/common/random/seed_go119.go index c0da2ef..d339fca 100644 --- a/common/random/seed_go119.go +++ b/common/random/seed_go119.go @@ -20,6 +20,5 @@ func InitializeSeed() { func initializeSeed() { var seed int64 common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed)) - //goland:noinspection GoDeprecation mRand.Seed(seed) } diff --git a/common/rw/varint.go b/common/rw/varint.go index f9f5ca9..d19f162 100644 --- a/common/rw/varint.go +++ b/common/rw/varint.go @@ -27,7 +27,6 @@ func ToByteReader(reader io.Reader) io.ByteReader { // Deprecated: Use binary.ReadUvarint instead. func ReadUVariant(reader io.Reader) (uint64, error) { - //goland:noinspection GoDeprecation return binary.ReadUvarint(ToByteReader(reader)) } diff --git a/common/udpnat/service.go b/common/udpnat/service.go index bdd917d..a5b37db 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -16,18 +16,23 @@ import ( "github.com/sagernet/sing/common/pipe" ) +// Deprecated: Use N.UDPConnectionHandler instead. +// +//nolint:staticcheck type Handler interface { N.UDPConnectionHandler E.Handler } type Service[K comparable] struct { - nat *cache.LruCache[K, *conn] - handler Handler + nat *cache.LruCache[K, *conn] + handler Handler + handlerEx N.UDPConnectionHandlerEx } +// Deprecated: Use NewEx instead. func New[K comparable](maxAge int64, handler Handler) *Service[K] { - return &Service[K]{ + service := &Service[K]{ nat: cache.New( cache.WithAge[K, *conn](maxAge), cache.WithUpdateAgeOnGet[K, *conn](), @@ -37,11 +42,27 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] { ), handler: handler, } + return service +} + +func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] { + service := &Service[K]{ + nat: cache.New( + cache.WithAge[K, *conn](maxAge), + cache.WithUpdateAgeOnGet[K, *conn](), + cache.WithEvict[K, *conn](func(key K, conn *conn) { + conn.Close() + }), + ), + handlerEx: handler, + } + return service } func (s *Service[T]) WriteIsThreadUnsafe() { } +// Deprecated: don't use func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) { s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { return ctx, &DirectBackWriter{conn, natConn} @@ -61,18 +82,30 @@ func (w *DirectBackWriter) Upstream() any { return w.Source } +// Deprecated: use NewPacketEx instead. func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) { s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { return ctx, init(natConn) }) } +func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) { + s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { + return ctx, init(natConn) + }) +} + +// Deprecated: Use NewPacketConnectionEx instead. func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { + s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init) +} + +func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { c, loaded := s.nat.LoadOrStore(key, func() *conn { c := &conn{ data: make(chan packet, 64), - localAddr: metadata.Source, - remoteAddr: metadata.Destination, + localAddr: source, + remoteAddr: destination, readDeadline: pipe.MakeDeadline(), } c.ctx, c.cancel = common.ContextWithCancelCause(ctx) @@ -81,26 +114,36 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu if !loaded { ctx, c.source = init(c) go func() { - err := s.handler.NewPacketConnection(ctx, c, metadata) - if err != nil { - s.handler.NewError(ctx, err) + if s.handlerEx != nil { + s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) { + s.nat.Delete(key) + }) + } else { + //nolint:staticcheck + err := s.handler.NewPacketConnection(ctx, c, M.Metadata{ + Source: source, + Destination: destination, + }) + if err != nil { + s.handler.NewError(ctx, err) + } + c.Close() + s.nat.Delete(key) } - c.Close() - s.nat.Delete(key) }() } else { - c.localAddr = metadata.Source + c.localAddr = source } if common.Done(c.ctx) { s.nat.Delete(key) if !common.Done(ctx) { - s.NewContextPacket(ctx, key, buffer, metadata, init) + s.NewContextPacketEx(ctx, key, buffer, source, destination, init) } return } c.data <- packet{ data: buffer, - destination: metadata.Destination, + destination: destination, } } diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 8a156ad..955a722 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -20,9 +20,18 @@ import ( "github.com/sagernet/sing/common/pipe" ) -type Handler = N.TCPConnectionHandler +// Deprecated: Use HandleConnectionEx instead. +func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, + //nolint:staticcheck + handler N.TCPConnectionHandler, metadata M.Metadata, +) error { + return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, nil) +} -func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { +func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, + //nolint:staticcheck + handler N.TCPConnectionHandler, handlerEx N.TCPConnectionHandlerEx, source M.Socksaddr, onClose N.CloseHandlerFunc, +) error { for { request, err := ReadRequest(reader) if err != nil { @@ -68,7 +77,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } if sourceAddress := SourceAddress(request); sourceAddress.IsValid() { - metadata.Source = sourceAddress + source = sourceAddress } if request.Method == "CONNECT" { @@ -81,9 +90,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read if err != nil { return E.Cause(err, "write http response") } - metadata.Protocol = "http" - metadata.Destination = destination - var requestConn net.Conn if reader.Buffered() > 0 { buffer := buf.NewSize(reader.Buffered()) @@ -95,75 +101,105 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } else { requestConn = conn } - return handler.NewConnection(ctx, requestConn, metadata) - } - - keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" - request.RequestURI = "" - - removeHopByHopHeaders(request.Header) - removeExtraHTTPHostPort(request) - - if hostStr := request.Header.Get("Host"); hostStr != "" { - if hostStr != request.URL.Host { - request.Host = hostStr + if handler != nil { + //nolint:staticcheck + return handler.NewConnection(ctx, requestConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) + } else { + handlerEx.NewConnectionEx(ctx, requestConn, source, destination, onClose) + return nil } } - if request.URL.Scheme == "" || request.URL.Host == "" { - return responseWith(request, http.StatusBadRequest).Write(conn) + err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) + if err != nil { + return err } + } +} - var innerErr atomic.TypedValue[error] - httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - metadata.Destination = M.ParseSocksaddr(address) - metadata.Protocol = "http" - input, output := pipe.Pipe() +func handleHTTPConnection( + ctx context.Context, + //nolint:staticcheck + handler N.TCPConnectionHandler, + handlerEx N.TCPConnectionHandlerEx, + conn net.Conn, + request *http.Request, source M.Socksaddr, +) error { + keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" + request.RequestURI = "" + + removeHopByHopHeaders(request.Header) + removeExtraHTTPHostPort(request) + + if hostStr := request.Header.Get("Host"); hostStr != "" { + if hostStr != request.URL.Host { + request.Host = hostStr + } + } + + if request.URL.Scheme == "" || request.URL.Host == "" { + return responseWith(request, http.StatusBadRequest).Write(conn) + } + + var innerErr atomic.TypedValue[error] + httpClient := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + input, output := pipe.Pipe() + if handler != nil { go func() { - hErr := handler.NewConnection(ctx, output, metadata) + //nolint:staticcheck + hErr := handler.NewConnection(ctx, output, M.Metadata{Protocol: "http", Source: source, Destination: M.ParseSocksaddr(address)}) if hErr != nil { innerErr.Store(hErr) common.Close(input, output) } }() - return input, nil - }, + } else { + go handlerEx.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) { + innerErr.Store(it) + common.Close(input, output) + }) + } + return input, nil }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - requestCtx, cancel := context.WithCancel(ctx) - response, err := httpClient.Do(request.WithContext(requestCtx)) - if err != nil { - cancel() - return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) - } - - removeHopByHopHeaders(response.Header) - - if keepAlive { - response.Header.Set("Proxy-Connection", "keep-alive") - response.Header.Set("Connection", "keep-alive") - response.Header.Set("Keep-Alive", "timeout=4") - } - - response.Close = !keepAlive - - err = response.Write(conn) - if err != nil { - cancel() - return E.Errors(innerErr.Load(), err) - } - - cancel() - if !keepAlive { - return conn.Close() - } + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } + defer httpClient.CloseIdleConnections() + + requestCtx, cancel := context.WithCancel(ctx) + response, err := httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) + } + + removeHopByHopHeaders(response.Header) + + if keepAlive { + response.Header.Set("Proxy-Connection", "keep-alive") + response.Header.Set("Connection", "keep-alive") + response.Header.Set("Keep-Alive", "timeout=4") + } + + response.Close = !keepAlive + + err = response.Write(conn) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err) + } + + cancel() + if !keepAlive { + return conn.Close() + } + + return nil } func removeHopByHopHeaders(header http.Header) { diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 8ee2542..7232eea 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -19,11 +19,19 @@ import ( "github.com/sagernet/sing/protocol/socks/socks5" ) +// Deprecated: Use HandlerEx instead. +// +//nolint:staticcheck type Handler interface { N.TCPConnectionHandler N.UDPConnectionHandler } +type HandlerEx interface { + N.TCPConnectionHandlerEx + N.UDPConnectionHandlerEx +} + func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { err := socks4.WriteRequest(conn, socks4.Request{ Command: command, @@ -96,18 +104,33 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, return response, err } +// Deprecated: use HandleConnectionEx instead. func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata) } +// Deprecated: Use HandleConnectionEx instead. func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { + return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, metadata.Destination, nil) +} + +func HandleConnectionEx( + ctx context.Context, conn net.Conn, reader *std_bufio.Reader, + authenticator *auth.Authenticator, + //nolint:staticcheck + handler Handler, + handlerEx HandlerEx, + source M.Socksaddr, destination M.Socksaddr, + onClose N.CloseHandlerFunc, +) error { version, err := reader.ReadByte() if err != nil { return err } switch version { case socks4.Version: - request, err := socks4.ReadRequest0(reader) + var request socks4.Request + request, err = socks4.ReadRequest0(reader) if err != nil { return err } @@ -115,28 +138,31 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea case socks4.CommandConnect: if authenticator != nil && !authenticator.Verify(request.Username, "") { err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeRejectedOrFailed, - Destination: request.Destination, + ReplyCode: socks4.ReplyCodeRejectedOrFailed, }) if err != nil { return err } return E.New("socks4: authentication failed, username=", request.Username) } - err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeGranted, - Destination: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err + destination = request.Destination + if handlerEx != nil { + handlerEx.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, destination, onClose) + } else { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFromNet(conn.LocalAddr()), + }) + if err != nil { + return err + } + //nolint:staticcheck + return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, M.Metadata{Protocol: "socks4", Source: source, Destination: destination}) } - metadata.Protocol = "socks4" - metadata.Destination = request.Destination - return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, metadata) + return nil default: err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeRejectedOrFailed, - Destination: request.Destination, + ReplyCode: socks4.ReplyCodeRejectedOrFailed, }) if err != nil { return err @@ -144,7 +170,8 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea return E.New("socks4: unsupported command ", request.Command) } case socks5.Version: - authRequest, err := socks5.ReadAuthRequest0(reader) + var authRequest socks5.AuthRequest + authRequest, err = socks5.ReadAuthRequest0(reader) if err != nil { return err } @@ -169,7 +196,8 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea return err } if authMethod == socks5.AuthTypeUsernamePassword { - usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader) + var usernamePasswordAuthRequest socks5.UsernamePasswordAuthRequest + usernamePasswordAuthRequest, err = socks5.ReadUsernamePasswordAuthRequest(reader) if err != nil { return err } @@ -188,49 +216,60 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password) } } - request, err := socks5.ReadRequest(reader) + var request socks5.Request + request, err = socks5.ReadRequest(reader) if err != nil { return err } switch request.Command { case socks5.CommandConnect: - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err + destination = request.Destination + if handlerEx != nil { + handlerEx.NewConnectionEx(ctx, NewLazyConn(conn, version), source, destination, onClose) + return nil + } else { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(conn.LocalAddr()), + }) + if err != nil { + return err + } + //nolint:staticcheck + return handler.NewConnection(ctx, conn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) } - metadata.Protocol = "socks5" - metadata.Destination = request.Destination - return handler.NewConnection(ctx, conn, metadata) case socks5.CommandUDPAssociate: var udpConn *net.UDPConn udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) if err != nil { return err } - defer udpConn.Close() - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), - }) - if err != nil { - return err + if handlerEx == nil { + defer udpConn.Close() + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), + }) + if err != nil { + return err + } + destination = request.Destination + associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn) + var innerError error + done := make(chan struct{}) + go func() { + //nolint:staticcheck + innerError = handler.NewPacketConnection(ctx, associatePacketConn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) + close(done) + }() + err = common.Error(io.Copy(io.Discard, conn)) + associatePacketConn.Close() + <-done + return E.Errors(innerError, err) + } else { + handlerEx.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn), source, destination, onClose) + return nil } - metadata.Protocol = "socks5" - metadata.Destination = request.Destination - var innerError error - done := make(chan struct{}) - associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn) - go func() { - innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata) - close(done) - }() - err = common.Error(io.Copy(io.Discard, conn)) - associatePacketConn.Close() - <-done - return E.Errors(innerError, err) default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go new file mode 100644 index 0000000..e687475 --- /dev/null +++ b/protocol/socks/lazy.go @@ -0,0 +1,215 @@ +package socks + +import ( + "net" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks/socks4" + "github.com/sagernet/sing/protocol/socks/socks5" +) + +type LazyConn struct { + net.Conn + socksVersion byte + responseWritten bool +} + +func NewLazyConn(conn net.Conn, socksVersion byte) *LazyConn { + return &LazyConn{ + Conn: conn, + socksVersion: socksVersion, + } +} + +func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + }() + switch c.socksVersion { + case socks4.Version: + return socks4.WriteResponse(c.Conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFromNet(conn.LocalAddr()), + }) + case socks5.Version: + return socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(conn.LocalAddr()), + }) + default: + panic("unknown socks version") + } +} + +func (c *LazyConn) HandshakeFailure(err error) error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + }() + switch c.socksVersion { + case socks4.Version: + return socks4.WriteResponse(c.Conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + case socks5.Version: + return socks5.WriteResponse(c.Conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeForError(err), + }) + default: + panic("unknown socks version") + } +} + +func (c *LazyConn) Read(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.ConnHandshakeSuccess(c.Conn) + if err != nil { + return + } + } + return c.Conn.Read(p) +} + +func (c *LazyConn) Write(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.ConnHandshakeSuccess(c.Conn) + if err != nil { + return + } + } + return c.Conn.Write(p) +} + +func (c *LazyConn) ReaderReplaceable() bool { + return c.responseWritten +} + +func (c *LazyConn) WriterReplaceable() bool { + return c.responseWritten +} + +func (c *LazyConn) Upstream() any { + return c.Conn +} + +type LazyAssociatePacketConn struct { + AssociatePacketConn + responseWritten bool +} + +func NewLazyAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *LazyAssociatePacketConn { + return &LazyAssociatePacketConn{ + AssociatePacketConn: AssociatePacketConn{ + AbstractConn: conn, + conn: bufio.NewExtendedConn(conn), + remoteAddr: remoteAddr, + underlying: underlying, + }, + } +} + +func (c *LazyAssociatePacketConn) HandshakeSuccess() error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + }() + return socks5.WriteResponse(c.underlying, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(c.conn.LocalAddr()), + }) +} + +func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + c.conn.Close() + c.underlying.Close() + }() + return socks5.WriteResponse(c.underlying, socks5.Response{ + ReplyCode: socks5.ReplyCodeForError(err), + }) +} + +func (c *LazyAssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.ReadFrom(p) +} + +func (c *LazyAssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.WriteTo(p, addr) +} + +func (c *LazyAssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.ReadPacket(buffer) +} + +func (c *LazyAssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if !c.responseWritten { + err := c.HandshakeSuccess() + if err != nil { + return err + } + } + return c.AssociatePacketConn.WritePacket(buffer, destination) +} + +func (c *LazyAssociatePacketConn) Read(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.Read(p) +} + +func (c *LazyAssociatePacketConn) Write(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.Write(p) +} + +func (c *LazyAssociatePacketConn) ReaderReplaceable() bool { + return c.responseWritten +} + +func (c *LazyAssociatePacketConn) WriterReplaceable() bool { + return c.responseWritten +} + +func (c *LazyAssociatePacketConn) Upstream() any { + return c.underlying +} diff --git a/protocol/socks/socks5/protocol.go b/protocol/socks/socks5/protocol.go index 29ff3db..cb04227 100644 --- a/protocol/socks/socks5/protocol.go +++ b/protocol/socks/socks5/protocol.go @@ -1,8 +1,10 @@ package socks5 import ( + "errors" "io" "net/netip" + "syscall" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -37,6 +39,20 @@ const ( ReplyCodeAddressTypeUnsupported byte = 8 ) +func ReplyCodeForError(err error) byte { + if errors.Is(err, syscall.ENETUNREACH) { + return ReplyCodeNetworkUnreachable + } else if errors.Is(err, syscall.EHOSTUNREACH) { + return ReplyCodeHostUnreachable + } else if errors.Is(err, syscall.ECONNREFUSED) { + return ReplyCodeConnectionRefused + } else if errors.Is(err, syscall.EPERM) { + return ReplyCodeNotAllowed + } else { + return ReplyCodeFailure + } +} + // +----+----------+----------+ // |VER | NMETHODS | METHODS | // +----+----------+----------+ From e7ec021b81b99577254e6c60c0d2ef1b68bb6cbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 21 Oct 2024 18:00:58 +0800 Subject: [PATCH 02/55] freelru: copy source from v0.14.0 --- contrab/freelru/README.md | 3 + contrab/freelru/cache.go | 85 ++++++ contrab/freelru/lru.go | 591 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 679 insertions(+) create mode 100644 contrab/freelru/README.md create mode 100644 contrab/freelru/cache.go create mode 100644 contrab/freelru/lru.go diff --git a/contrab/freelru/README.md b/contrab/freelru/README.md new file mode 100644 index 0000000..206fbba --- /dev/null +++ b/contrab/freelru/README.md @@ -0,0 +1,3 @@ +# freelru + +kanged from github.com/elastic/go-freelru@v0.14.0 \ No newline at end of file diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go new file mode 100644 index 0000000..59435ee --- /dev/null +++ b/contrab/freelru/cache.go @@ -0,0 +1,85 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package freelru + +import "time" + +type Cache[K comparable, V any] interface { + // SetLifetime sets the default lifetime of LRU elements. + // Lifetime 0 means "forever". + SetLifetime(lifetime time.Duration) + + // SetOnEvict sets the OnEvict callback function. + // The onEvict function is called for each evicted lru entry. + SetOnEvict(onEvict OnEvictCallback[K, V]) + + // Len returns the number of elements stored in the cache. + Len() int + + // AddWithLifetime adds a key:value to the cache with a lifetime. + // Returns true, true if key was updated and eviction occurred. + AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) + + // Add adds a key:value to the cache. + // Returns true, true if key was updated and eviction occurred. + Add(key K, value V) (evicted bool) + + // Get returns the value associated with the key, setting it as the most + // recently used item. + // If the found cache item is already expired, the evict function is called + // and the return value indicates that the key was not found. + Get(key K) (V, bool) + + // Peek looks up a key's value from the cache, without changing its recent-ness. + // If the found entry is already expired, the evict function is called. + Peek(key K) (V, bool) + + // Contains checks for the existence of a key, without changing its recent-ness. + // If the found entry is already expired, the evict function is called. + Contains(key K) bool + + // Remove removes the key from the cache. + // The return value indicates whether the key existed or not. + // The evict function is called for the removed entry. + Remove(key K) bool + + // RemoveOldest removes the oldest entry from the cache. + // Key, value and an indicator of whether the entry has been removed is returned. + // The evict function is called for the removed entry. + RemoveOldest() (key K, value V, removed bool) + + // Keys returns a slice of the keys in the cache, from oldest to newest. + // Expired entries are not included. + // The evict function is called for each expired item. + Keys() []K + + // Purge purges all data (key and value) from the LRU. + // The evict function is called for each expired item. + // The LRU metrics are reset. + Purge() + + // PurgeExpired purges all expired items from the LRU. + // The evict function is called for each expired item. + PurgeExpired() + + // Metrics returns the metrics of the cache. + Metrics() Metrics + + // ResetMetrics resets the metrics of the cache and returns the previous state. + ResetMetrics() Metrics +} diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go new file mode 100644 index 0000000..af8b8e9 --- /dev/null +++ b/contrab/freelru/lru.go @@ -0,0 +1,591 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package freelru + +import ( + "errors" + "fmt" + "math" + "math/bits" + "time" +) + +// OnEvictCallback is the type for the eviction function. +type OnEvictCallback[K comparable, V any] func(K, V) + +// HashKeyCallback is the function that creates a hash from the passed key. +type HashKeyCallback[K comparable] func(K) uint32 + +type element[K comparable, V any] struct { + key K + value V + + // bucketNext and bucketPrev are indexes in the space-dimension doubly-linked list of elements. + // That is to add/remove items to the collision bucket without re-allocations and with O(1) + // complexity. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.latest.prev is last element and + // &l.last.next is the latest element. + nextBucket, prevBucket uint32 + + // bucketPos is the bucket that an element belongs to. + bucketPos uint32 + + // next and prev are indexes in the time-dimension doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.latest.prev is last element and + // &l.last.next is the latest element. + next, prev uint32 + + // expire is the point in time when the element expires. + // Its value is Unix milliseconds since epoch. + expire int64 +} + +const emptyBucket = math.MaxUint32 + +// LRU implements a non-thread safe fixed size LRU cache. +type LRU[K comparable, V any] struct { + buckets []uint32 // contains positions of bucket lists or 'emptyBucket' + elements []element[K, V] + onEvict OnEvictCallback[K, V] + hash HashKeyCallback[K] + lifetime time.Duration + metrics Metrics + + // used for element clearing after removal or expiration + emptyKey K + emptyValue V + + head uint32 // index of the newest element in the cache + len uint32 // current number of elements in the cache + cap uint32 // max number of elements in the cache + size uint32 // size of the element array (X% larger than cap) + mask uint32 // bitmask to avoid the costly idiv in hashToPos() if size is a 2^n value +} + +// Metrics contains metrics about the cache. +type Metrics struct { + Inserts uint64 + Collisions uint64 + Evictions uint64 + Removals uint64 + Hits uint64 + Misses uint64 +} + +var _ Cache[int, int] = (*LRU[int, int])(nil) + +// SetLifetime sets the default lifetime of LRU elements. +// Lifetime 0 means "forever". +func (lru *LRU[K, V]) SetLifetime(lifetime time.Duration) { + lru.lifetime = lifetime +} + +// SetOnEvict sets the OnEvict callback function. +// The onEvict function is called for each evicted lru entry. +// Eviction happens +// - when the cache is full and a new entry is added (oldest entry is evicted) +// - when an entry is removed by Remove() or RemoveOldest() +// - when an entry is recognized as expired +// - when Purge() is called +func (lru *LRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { + lru.onEvict = onEvict +} + +// New constructs an LRU with the given capacity of elements. +// The hash function calculates a hash value from the keys. +func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { + return NewWithSize[K, V](capacity, capacity, hash) +} + +// NewWithSize constructs an LRU with the given capacity and size. +// The hash function calculates a hash value from the keys. +// A size greater than the capacity increases memory consumption and decreases the CPU consumption +// by reducing the chance of collisions. +// Size must not be lower than the capacity. +func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) ( + *LRU[K, V], error) { + if capacity == 0 { + return nil, errors.New("capacity must be positive") + } + if size == emptyBucket { + return nil, fmt.Errorf("size must not be %#X", size) + } + if size < capacity { + return nil, fmt.Errorf("size (%d) is smaller than capacity (%d)", size, capacity) + } + if hash == nil { + return nil, errors.New("hash function must be set") + } + + buckets := make([]uint32, size) + elements := make([]element[K, V], size) + + var lru LRU[K, V] + initLRU(&lru, capacity, size, hash, buckets, elements) + + return &lru, nil +} + +func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], + buckets []uint32, elements []element[K, V]) { + lru.cap = capacity + lru.size = size + lru.hash = hash + lru.buckets = buckets + lru.elements = elements + + // If the size is 2^N, we can avoid costly divisions. + if bits.OnesCount32(lru.size) == 1 { + lru.mask = lru.size - 1 + } + + // Mark all slots as free. + for i := range lru.buckets { + lru.buckets[i] = emptyBucket + } +} + +// hashToBucketPos converts a hash value into a position in the elements array. +func (lru *LRU[K, V]) hashToBucketPos(hash uint32) uint32 { + if lru.mask != 0 { + return hash & lru.mask + } + return hash % lru.size +} + +// hashToPos converts a key into a position in the elements array. +func (lru *LRU[K, V]) hashToPos(hash uint32) (bucketPos, elemPos uint32) { + bucketPos = lru.hashToBucketPos(hash) + elemPos = lru.buckets[bucketPos] + return +} + +// setHead links the element as the head into the list. +func (lru *LRU[K, V]) setHead(pos uint32) { + // Both calls to setHead() check beforehand that pos != lru.head. + // So if you run into this situation, you likely use FreeLRU in a concurrent situation + // without proper locking. It requires a write lock, even around Get(). + // But better use SyncedLRU or SharedLRU in such a case. + if pos == lru.head { + panic(pos) + } + + lru.elements[pos].prev = lru.head + lru.elements[pos].next = lru.elements[lru.head].next + lru.elements[lru.elements[lru.head].next].prev = pos + lru.elements[lru.head].next = pos + lru.head = pos +} + +// unlinkElement removes the element from the elements list. +func (lru *LRU[K, V]) unlinkElement(pos uint32) { + lru.elements[lru.elements[pos].prev].next = lru.elements[pos].next + lru.elements[lru.elements[pos].next].prev = lru.elements[pos].prev +} + +// unlinkBucket removes the element from the buckets list. +func (lru *LRU[K, V]) unlinkBucket(pos uint32) { + prevBucket := lru.elements[pos].prevBucket + nextBucket := lru.elements[pos].nextBucket + if prevBucket == nextBucket && prevBucket == pos { //nolint:gocritic + // The element references itself, so it's the only bucket entry + lru.buckets[lru.elements[pos].bucketPos] = emptyBucket + return + } + lru.elements[prevBucket].nextBucket = nextBucket + lru.elements[nextBucket].prevBucket = prevBucket + lru.buckets[lru.elements[pos].bucketPos] = nextBucket +} + +// evict evicts the element at the given position. +func (lru *LRU[K, V]) evict(pos uint32) { + if pos == lru.head { + lru.head = lru.elements[pos].prev + } + + lru.unlinkElement(pos) + lru.unlinkBucket(pos) + lru.len-- + + if lru.onEvict != nil { + // Save k/v for the eviction function. + key := lru.elements[pos].key + value := lru.elements[pos].value + lru.onEvict(key, value) + } +} + +// Move element from position 'from' to position 'to'. +// That avoids 'gaps' and new elements can always be simply appended. +func (lru *LRU[K, V]) move(to, from uint32) { + if to == from { + return + } + if from == lru.head { + lru.head = to + } + + prev := lru.elements[from].prev + next := lru.elements[from].next + lru.elements[prev].next = to + lru.elements[next].prev = to + + prev = lru.elements[from].prevBucket + next = lru.elements[from].nextBucket + lru.elements[prev].nextBucket = to + lru.elements[next].prevBucket = to + + lru.elements[to] = lru.elements[from] + + if lru.buckets[lru.elements[to].bucketPos] == from { + lru.buckets[lru.elements[to].bucketPos] = to + } +} + +// insert stores the k/v at pos. +// It updates the head to point to this position. +func (lru *LRU[K, V]) insert(pos uint32, key K, value V, lifetime time.Duration) { + lru.elements[pos].key = key + lru.elements[pos].value = value + lru.elements[pos].expire = expire(lifetime) + + if lru.len == 0 { + lru.elements[pos].prev = pos + lru.elements[pos].next = pos + lru.head = pos + } else if pos != lru.head { + lru.setHead(pos) + } + lru.len++ + lru.metrics.Inserts++ +} + +func now() int64 { + return time.Now().UnixMilli() +} + +func expire(lifetime time.Duration) int64 { + if lifetime == 0 { + return 0 + } + return now() + lifetime.Milliseconds() +} + +// clearKeyAndValue clears stale data to avoid memory leaks +func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) { + lru.elements[pos].key = lru.emptyKey + lru.elements[pos].value = lru.emptyValue +} + +func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { + _, startPos := lru.hashToPos(hash) + if startPos == emptyBucket { + return emptyBucket, false + } + + pos := startPos + for { + if key == lru.elements[pos].key { + if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() { + lru.removeAt(pos) + return emptyBucket, false + } + return pos, true + } + + pos = lru.elements[pos].nextBucket + if pos == startPos { + // Key not found + return emptyBucket, false + } + } +} + +// Len returns the number of elements stored in the cache. +func (lru *LRU[K, V]) Len() int { + return int(lru.len) +} + +// AddWithLifetime adds a key:value to the cache with a lifetime. +// Returns true, true if key was updated and eviction occurred. +func (lru *LRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) { + return lru.addWithLifetime(lru.hash(key), key, value, lifetime) +} + +func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V, + lifetime time.Duration) (evicted bool) { + bucketPos, startPos := lru.hashToPos(hash) + if startPos == emptyBucket { + pos := lru.len + + if pos == lru.cap { + // Capacity reached, evict the oldest entry and + // store the new entry at evicted position. + pos = lru.elements[lru.head].next + lru.evict(pos) + lru.metrics.Evictions++ + evicted = true + } + + // insert new (first) entry into the bucket + lru.buckets[bucketPos] = pos + lru.elements[pos].bucketPos = bucketPos + + lru.elements[pos].nextBucket = pos + lru.elements[pos].prevBucket = pos + lru.insert(pos, key, value, lifetime) + return evicted + } + + // Walk through the bucket list to see whether key already exists. + pos := startPos + for { + if lru.elements[pos].key == key { + // Key exists, replace the value and update element to be the head element. + lru.elements[pos].value = value + lru.elements[pos].expire = expire(lifetime) + + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + // count as insert, even if it's just an update + lru.metrics.Inserts++ + return false + } + + pos = lru.elements[pos].nextBucket + if pos == startPos { + // Key not found + break + } + } + + pos = lru.len + if pos == lru.cap { + // Capacity reached, evict the oldest entry and + // store the new entry at evicted position. + pos = lru.elements[lru.head].next + lru.evict(pos) + lru.metrics.Evictions++ + evicted = true + startPos = lru.buckets[bucketPos] + if startPos == emptyBucket { + startPos = pos + } + } + + // insert new entry into the existing bucket before startPos + lru.buckets[bucketPos] = pos + lru.elements[pos].bucketPos = bucketPos + + lru.elements[pos].nextBucket = startPos + lru.elements[pos].prevBucket = lru.elements[startPos].prevBucket + lru.elements[lru.elements[startPos].prevBucket].nextBucket = pos + lru.elements[startPos].prevBucket = pos + lru.insert(pos, key, value, lifetime) + + if lru.elements[pos].prevBucket != pos { + // The bucket now contains more than 1 element. + // That means we have a collision. + lru.metrics.Collisions++ + } + return evicted +} + +// Add adds a key:value to the cache. +// Returns true, true if key was updated and eviction occurred. +func (lru *LRU[K, V]) Add(key K, value V) (evicted bool) { + return lru.addWithLifetime(lru.hash(key), key, value, lru.lifetime) +} + +func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { + return lru.addWithLifetime(hash, key, value, lru.lifetime) +} + +// Get returns the value associated with the key, setting it as the most +// recently used item. +// If the found cache item is already expired, the evict function is called +// and the return value indicates that the key was not found. +func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { + return lru.get(lru.hash(key), key) +} + +func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, ok + } + + lru.metrics.Misses++ + return +} + +// Peek looks up a key's value from the cache, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { + return lru.peek(lru.hash(key), key) +} + +func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + return lru.elements[pos].value, ok + } + + return +} + +// Contains checks for the existence of a key, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *LRU[K, V]) Contains(key K) (ok bool) { + _, ok = lru.peek(lru.hash(key), key) + return +} + +func (lru *LRU[K, V]) contains(hash uint32, key K) (ok bool) { + _, ok = lru.peek(hash, key) + return +} + +// Remove removes the key from the cache. +// The return value indicates whether the key existed or not. +// The evict function is called for the removed entry. +func (lru *LRU[K, V]) Remove(key K) (removed bool) { + return lru.remove(lru.hash(key), key) +} + +func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { + if pos, ok := lru.findKey(hash, key); ok { + lru.removeAt(pos) + return ok + } + + return +} + +func (lru *LRU[K, V]) removeAt(pos uint32) { + lru.evict(pos) + lru.move(pos, lru.len) + lru.metrics.Removals++ + + // remove stale data to avoid memory leaks + lru.clearKeyAndValue(lru.len) +} + +// RemoveOldest removes the oldest entry from the cache. +// Key, value and an indicator of whether the entry has been removed is returned. +// The evict function is called for the removed entry. +func (lru *LRU[K, V]) RemoveOldest() (key K, value V, removed bool) { + if lru.len == 0 { + return lru.emptyKey, lru.emptyValue, false + } + pos := lru.elements[lru.head].next + key = lru.elements[pos].key + value = lru.elements[pos].value + lru.removeAt(pos) + return key, value, true +} + +// Keys returns a slice of the keys in the cache, from oldest to newest. +// Expired entries are not included. +// The evict function is called for each expired item. +func (lru *LRU[K, V]) Keys() []K { + lru.PurgeExpired() + + keys := make([]K, 0, lru.len) + pos := lru.elements[lru.head].next + for i := uint32(0); i < lru.len; i++ { + keys = append(keys, lru.elements[pos].key) + pos = lru.elements[pos].next + } + return keys +} + +// Purge purges all data (key and value) from the LRU. +// The evict function is called for each expired item. +// The LRU metrics are reset. +func (lru *LRU[K, V]) Purge() { + for i := uint32(0); i < lru.len; i++ { + _, _, _ = lru.RemoveOldest() + } + + lru.metrics = Metrics{} +} + +// PurgeExpired purges all expired items from the LRU. +// The evict function is called for each expired item. +func (lru *LRU[K, V]) PurgeExpired() { + for i := uint32(0); i < lru.len; i++ { + pos := lru.elements[lru.head].next + if lru.elements[pos].expire != 0 { + if lru.elements[pos].expire > now() { + return // no more expired items + } + lru.removeAt(pos) + } + } +} + +// Metrics returns the metrics of the cache. +func (lru *LRU[K, V]) Metrics() Metrics { + return lru.metrics +} + +// ResetMetrics resets the metrics of the cache and returns the previous state. +func (lru *LRU[K, V]) ResetMetrics() Metrics { + metrics := lru.metrics + lru.metrics = Metrics{} + return metrics +} + +// just used for debugging +func (lru *LRU[K, V]) dump() { + fmt.Printf("head %d len %d cap %d size %d mask 0x%X\n", + lru.head, lru.len, lru.cap, lru.size, lru.mask) + + for i := range lru.buckets { + if lru.buckets[i] == emptyBucket { + continue + } + fmt.Printf(" bucket[%d] -> %d\n", i, lru.buckets[i]) + pos := lru.buckets[i] + for { + e := &lru.elements[pos] + fmt.Printf(" pos %d bucketPos %d prevBucket %d nextBucket %d prev %d next %d k %v v %v\n", + pos, e.bucketPos, e.prevBucket, e.nextBucket, e.prev, e.next, e.key, e.value) + pos = e.nextBucket + if pos == lru.buckets[i] { + break + } + } + } +} + +func (lru *LRU[K, V]) PrintStats() { + m := &lru.metrics + fmt.Printf("Inserts: %d Collisions: %d (%.2f%%) Evictions: %d Removals: %d Hits: %d (%.2f%%) Misses: %d\n", + m.Inserts, m.Collisions, float64(m.Collisions)/float64(m.Inserts)*100, + m.Evictions, m.Removals, + m.Hits, float64(m.Hits)/float64(m.Hits+m.Misses)*100, m.Misses) +} From 0641c71805e3d52639144566d44d3292c4b0d3ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 21 Oct 2024 18:11:45 +0800 Subject: [PATCH 03/55] maphash: copy source from v0.1.0 --- contrab/maphash/README.md | 3 + contrab/maphash/hasher.go | 48 ++++++++++++++++ contrab/maphash/runtime.go | 111 +++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 contrab/maphash/README.md create mode 100644 contrab/maphash/hasher.go create mode 100644 contrab/maphash/runtime.go diff --git a/contrab/maphash/README.md b/contrab/maphash/README.md new file mode 100644 index 0000000..a4cb792 --- /dev/null +++ b/contrab/maphash/README.md @@ -0,0 +1,3 @@ +# maphash + +kanged from github.com/dolthub/maphash@v0.1.0 \ No newline at end of file diff --git a/contrab/maphash/hasher.go b/contrab/maphash/hasher.go new file mode 100644 index 0000000..ef53596 --- /dev/null +++ b/contrab/maphash/hasher.go @@ -0,0 +1,48 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package maphash + +import "unsafe" + +// Hasher hashes values of type K. +// Uses runtime AES-based hashing. +type Hasher[K comparable] struct { + hash hashfn + seed uintptr +} + +// NewHasher creates a new Hasher[K] with a random seed. +func NewHasher[K comparable]() Hasher[K] { + return Hasher[K]{ + hash: getRuntimeHasher[K](), + seed: newHashSeed(), + } +} + +// NewSeed returns a copy of |h| with a new hash seed. +func NewSeed[K comparable](h Hasher[K]) Hasher[K] { + return Hasher[K]{ + hash: h.hash, + seed: newHashSeed(), + } +} + +// Hash hashes |key|. +func (h Hasher[K]) Hash(key K) uint64 { + // promise to the compiler that pointer + // |p| does not escape the stack. + p := noescape(unsafe.Pointer(&key)) + return uint64(h.hash(p, h.seed)) +} diff --git a/contrab/maphash/runtime.go b/contrab/maphash/runtime.go new file mode 100644 index 0000000..29cd6a8 --- /dev/null +++ b/contrab/maphash/runtime.go @@ -0,0 +1,111 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file incorporates work covered by the following copyright and +// permission notice: +// +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.18 || go1.19 +// +build go1.18 go1.19 + +package maphash + +import ( + "math/rand" + "unsafe" +) + +type hashfn func(unsafe.Pointer, uintptr) uintptr + +func getRuntimeHasher[K comparable]() (h hashfn) { + a := any(make(map[K]struct{})) + i := (*mapiface)(unsafe.Pointer(&a)) + h = i.typ.hasher + return +} + +func newHashSeed() uintptr { + return uintptr(rand.Int()) +} + +// noescape hides a pointer from escape analysis. It is the identity function +// but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// USE CAREFULLY! +// This was copied from the runtime (via pkg "strings"); see issues 23382 and 7921. +// +//go:nosplit +//go:nocheckptr +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +type mapiface struct { + typ *maptype + val *hmap +} + +// go/src/runtime/type.go +type maptype struct { + typ _type + key *_type + elem *_type + bucket *_type + // function for hashing keys (ptr to key, seed) -> hash + hasher func(unsafe.Pointer, uintptr) uintptr + keysize uint8 + elemsize uint8 + bucketsize uint16 + flags uint32 +} + +// go/src/runtime/map.go +type hmap struct { + count int + flags uint8 + B uint8 + noverflow uint16 + // hash seed + hash0 uint32 + buckets unsafe.Pointer + oldbuckets unsafe.Pointer + nevacuate uintptr + // true type is *mapextra + // but we don't need this data + extra unsafe.Pointer +} + +// go/src/runtime/type.go +type tflag uint8 +type nameOff int32 +type typeOff int32 + +// go/src/runtime/type.go +type _type struct { + size uintptr + ptrdata uintptr + hash uint32 + tflag tflag + align uint8 + fieldAlign uint8 + kind uint8 + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte + str nameOff + ptrToThis typeOff +} From 7ec09d604527ddec6cdbe4212a8da98a649c0a10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 21 Oct 2024 19:24:36 +0800 Subject: [PATCH 04/55] udpnat2: New synced udp nat service --- common/bufio/conn.go | 9 +--- common/network/conn.go | 2 +- common/network/direct.go | 20 ++++++-- common/udpnat/service.go | 6 --- common/udpnat2/conn.go | 90 +++++++++++++++++++++++++++++++++++ common/udpnat2/packet.go | 28 +++++++++++ common/udpnat2/service.go | 93 +++++++++++++++++++++++++++++++++++++ contrab/freelru/lru.go | 46 +++++++++++------- contrab/freelru/lru_test.go | 35 ++++++++++++++ contrab/maphash/hasher.go | 5 ++ contrab/maphash/runtime.go | 9 ++-- protocol/socks/lazy.go | 4 +- 12 files changed, 307 insertions(+), 40 deletions(-) create mode 100644 common/udpnat2/conn.go create mode 100644 common/udpnat2/packet.go create mode 100644 common/udpnat2/service.go create mode 100644 contrab/freelru/lru_test.go diff --git a/common/bufio/conn.go b/common/bufio/conn.go index 1b92589..5736eef 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -35,14 +35,7 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() - if destination.IsFqdn() { - udpAddr, err := net.ResolveUDPAddr("udp", destination.String()) - if err != nil { - return err - } - return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr)) - } - return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr())) + return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort())) } func (w *ExtendedUDPConn) Upstream() any { diff --git a/common/network/conn.go b/common/network/conn.go index 01fe135..c795a19 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -124,7 +124,7 @@ type UDPHandler interface { } type UDPHandlerEx interface { - NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error + NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) } // Deprecated: Use UDPConnectionHandlerEx instead. diff --git a/common/network/direct.go b/common/network/direct.go index 24f38d7..1122d70 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -19,15 +19,27 @@ func (o ReadWaitOptions) NeedHeadroom() bool { return o.FrontHeadroom > 0 || o.RearHeadroom > 0 } +func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer { + if o.FrontHeadroom > buffer.Start() || + o.RearHeadroom > buffer.FreeLen() { + newBuffer := o.newBuffer(buf.UDPBufferSize, false) + newBuffer.Write(buffer.Bytes()) + buffer.Release() + return newBuffer + } else { + return buffer + } +} + func (o ReadWaitOptions) NewBuffer() *buf.Buffer { - return o.newBuffer(buf.BufferSize) + return o.newBuffer(buf.BufferSize, true) } func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer { - return o.newBuffer(buf.UDPBufferSize) + return o.newBuffer(buf.UDPBufferSize, true) } -func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { +func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer { var bufferSize int if o.MTU > 0 { bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom @@ -38,7 +50,7 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { if o.FrontHeadroom > 0 { buffer.Resize(o.FrontHeadroom, 0) } - if o.RearHeadroom > 0 { + if o.RearHeadroom > 0 && reserve { buffer.Reserve(o.RearHeadroom) } return buffer diff --git a/common/udpnat/service.go b/common/udpnat/service.go index a5b37db..6f95dbb 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -131,8 +131,6 @@ func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf. s.nat.Delete(key) } }() - } else { - c.localAddr = source } if common.Done(c.ctx) { s.nat.Delete(key) @@ -215,10 +213,6 @@ func (c *conn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } -func (c *conn) NeedAdditionalReadDeadline() bool { - return true -} - func (c *conn) Upstream() any { return c.source } diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go new file mode 100644 index 0000000..a5ca8ac --- /dev/null +++ b/common/udpnat2/conn.go @@ -0,0 +1,90 @@ +package udpnat + +import ( + "io" + "net" + "os" + "time" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" +) + +type natConn struct { + writer N.PacketWriter + localAddr M.Socksaddr + packetChan chan *Packet + doneChan chan struct{} + readDeadline pipe.Deadline + readWaitOptions N.ReadWaitOptions +} + +func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { + select { + case p := <-c.packetChan: + _, err = buffer.ReadOnceFrom(p.Buffer) + destination := p.Destination + p.Buffer.Release() + PutPacket(p) + return destination, err + case <-c.doneChan: + return M.Socksaddr{}, io.ErrClosedPipe + case <-c.readDeadline.Wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return c.writer.WritePacket(buffer, destination) +} + +func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case packet := <-c.packetChan: + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + PutPacket(packet) + return + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + case <-c.readDeadline.Wait(): + return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *natConn) Close() error { + select { + case <-c.doneChan: + default: + close(c.doneChan) + } + return nil +} + +func (c *natConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *natConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *natConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *natConn) SetReadDeadline(t time.Time) error { + c.readDeadline.Set(t) + return nil +} + +func (c *natConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} diff --git a/common/udpnat2/packet.go b/common/udpnat2/packet.go new file mode 100644 index 0000000..1d56ff4 --- /dev/null +++ b/common/udpnat2/packet.go @@ -0,0 +1,28 @@ +package udpnat + +import ( + "sync" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +var packetPool = sync.Pool{ + New: func() any { + return new(Packet) + }, +} + +type Packet struct { + Buffer *buf.Buffer + Destination M.Socksaddr +} + +func NewPacket() *Packet { + return packetPool.Get().(*Packet) +} + +func PutPacket(packet *Packet) { + *packet = Packet{} + packetPool.Put(packet) +} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go new file mode 100644 index 0000000..85b3641 --- /dev/null +++ b/common/udpnat2/service.go @@ -0,0 +1,93 @@ +package udpnat + +import ( + "context" + "net/netip" + "time" + + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" +) + +type Service struct { + nat *freelru.LRU[netip.AddrPort, *natConn] + handler N.UDPConnectionHandlerEx + prepare PrepareFunc + metrics Metrics +} + +type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) + +type Metrics struct { + Creates uint64 + Rejects uint64 + Inputs uint64 + Drops uint64 +} + +func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service { + nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + nat.SetLifetime(timeout) + nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { + select { + case <-conn.doneChan: + return false + default: + return true + } + }) + nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { + conn.Close() + }) + return &Service{ + nat: nat, + handler: handler, + prepare: prepare, + } +} + +func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { + conn, loaded := s.nat.Get(source.AddrPort()) + if !loaded { + ok, ctx, writer, onClose := s.prepare(source, destination, userData) + if !ok { + s.metrics.Rejects++ + return + } + conn = &natConn{ + writer: writer, + localAddr: source, + packetChan: make(chan *Packet, 64), + doneChan: make(chan struct{}), + readDeadline: pipe.MakeDeadline(), + } + s.nat.Add(source.AddrPort(), conn) + s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) + s.metrics.Creates++ + } + packet := NewPacket() + buffer := conn.readWaitOptions.NewPacketBuffer() + for _, bufferSlice := range bufferSlices { + buffer.Write(bufferSlice) + } + *packet = Packet{ + Buffer: buffer, + Destination: destination, + } + select { + case conn.packetChan <- packet: + s.metrics.Inputs++ + default: + packet.Buffer.Release() + PutPacket(packet) + s.metrics.Drops++ + } +} + +func (s *Service) Metrics() Metrics { + return s.metrics +} diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index af8b8e9..045cc3e 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -31,6 +31,8 @@ type OnEvictCallback[K comparable, V any] func(K, V) // HashKeyCallback is the function that creates a hash from the passed key. type HashKeyCallback[K comparable] func(K) uint32 +type HealthCheckCallback[K comparable, V any] func(K, V) bool + type element[K comparable, V any] struct { key K value V @@ -61,12 +63,13 @@ const emptyBucket = math.MaxUint32 // LRU implements a non-thread safe fixed size LRU cache. type LRU[K comparable, V any] struct { - buckets []uint32 // contains positions of bucket lists or 'emptyBucket' - elements []element[K, V] - onEvict OnEvictCallback[K, V] - hash HashKeyCallback[K] - lifetime time.Duration - metrics Metrics + buckets []uint32 // contains positions of bucket lists or 'emptyBucket' + elements []element[K, V] + onEvict OnEvictCallback[K, V] + hash HashKeyCallback[K] + healthCheck HealthCheckCallback[K, V] + lifetime time.Duration + metrics Metrics // used for element clearing after removal or expiration emptyKey K @@ -108,6 +111,10 @@ func (lru *LRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { lru.onEvict = onEvict } +func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { + lru.healthCheck = healthCheck +} + // New constructs an LRU with the given capacity of elements. // The hash function calculates a hash value from the keys. func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { @@ -120,7 +127,8 @@ func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, // by reducing the chance of collisions. // Size must not be lower than the capacity. func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) ( - *LRU[K, V], error) { + *LRU[K, V], error, +) { if capacity == 0 { return nil, errors.New("capacity must be positive") } @@ -144,7 +152,8 @@ func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallbac } func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], - buckets []uint32, elements []element[K, V]) { + buckets []uint32, elements []element[K, V], +) { lru.cap = capacity lru.size = size lru.hash = hash @@ -294,7 +303,7 @@ func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) { lru.elements[pos].value = lru.emptyValue } -func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { +func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, bool) { _, startPos := lru.hashToPos(hash) if startPos == emptyBucket { return emptyBucket, false @@ -303,10 +312,14 @@ func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { pos := startPos for { if key == lru.elements[pos].key { - if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() { + elem := lru.elements[pos] + if (elem.expire != 0 && elem.expire <= now()) || (lru.healthCheck != nil && !lru.healthCheck(key, elem.value)) { lru.removeAt(pos) return emptyBucket, false } + if updateLifetimeOnGet { + lru.elements[pos].expire = expire(lru.lifetime) + } return pos, true } @@ -330,7 +343,8 @@ func (lru *LRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (e } func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V, - lifetime time.Duration) (evicted bool) { + lifetime time.Duration, +) (evicted bool) { bucketPos, startPos := lru.hashToPos(hash) if startPos == emptyBucket { pos := lru.len @@ -425,11 +439,11 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { - return lru.get(lru.hash(key), key) + return lru.get(lru.hash(key), key, true) } -func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key); ok { +func (lru *LRU[K, V]) get(hash uint32, key K, updateLifetime bool) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key, updateLifetime); ok { if pos != lru.head { lru.unlinkElement(pos) lru.setHead(pos) @@ -449,7 +463,7 @@ func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { } func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key); ok { + if pos, ok := lru.findKey(hash, key, false); ok { return lru.elements[pos].value, ok } @@ -476,7 +490,7 @@ func (lru *LRU[K, V]) Remove(key K) (removed bool) { } func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { - if pos, ok := lru.findKey(hash, key); ok { + if pos, ok := lru.findKey(hash, key, false); ok { lru.removeAt(pos) return ok } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go new file mode 100644 index 0000000..4c4ba59 --- /dev/null +++ b/contrab/freelru/lru_test.go @@ -0,0 +1,35 @@ +package freelru_test + +import ( + "testing" + "time" + + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" + + "github.com/stretchr/testify/require" +) + +func TestMyChange0(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.AddWithLifetime("hello", "world", 2*time.Second) + time.Sleep(time.Second) + lru.Get("hello") + time.Sleep(time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.True(t, ok) +} + +func TestMyChange1(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.AddWithLifetime("hello", "world", 2*time.Second) + time.Sleep(time.Second) + lru.Peek("hello") + time.Sleep(time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.False(t, ok) +} diff --git a/contrab/maphash/hasher.go b/contrab/maphash/hasher.go index ef53596..cc60b2e 100644 --- a/contrab/maphash/hasher.go +++ b/contrab/maphash/hasher.go @@ -46,3 +46,8 @@ func (h Hasher[K]) Hash(key K) uint64 { p := noescape(unsafe.Pointer(&key)) return uint64(h.hash(p, h.seed)) } + +func (h Hasher[K]) Hash32(key K) uint32 { + p := noescape(unsafe.Pointer(&key)) + return uint32(h.hash(p, h.seed)) +} diff --git a/contrab/maphash/runtime.go b/contrab/maphash/runtime.go index 29cd6a8..f2aa2e0 100644 --- a/contrab/maphash/runtime.go +++ b/contrab/maphash/runtime.go @@ -52,6 +52,7 @@ func newHashSeed() uintptr { //go:nocheckptr func noescape(p unsafe.Pointer) unsafe.Pointer { x := uintptr(p) + //nolint:staticcheck return unsafe.Pointer(x ^ 0) } @@ -91,9 +92,11 @@ type hmap struct { } // go/src/runtime/type.go -type tflag uint8 -type nameOff int32 -type typeOff int32 +type ( + tflag uint8 + nameOff int32 + typeOff int32 +) // go/src/runtime/type.go type _type struct { diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go index e687475..3468981 100644 --- a/protocol/socks/lazy.go +++ b/protocol/socks/lazy.go @@ -37,7 +37,7 @@ func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error { Destination: M.SocksaddrFromNet(conn.LocalAddr()), }) case socks5.Version: - return socks5.WriteResponse(conn, socks5.Response{ + return socks5.WriteResponse(c.Conn, socks5.Response{ ReplyCode: socks5.ReplyCodeSuccess, Bind: M.SocksaddrFromNet(conn.LocalAddr()), }) @@ -211,5 +211,5 @@ func (c *LazyAssociatePacketConn) WriterReplaceable() bool { } func (c *LazyAssociatePacketConn) Upstream() any { - return c.underlying + return &c.AssociatePacketConn } From a4eb7fa900c79a3b28005016efc2a77bccc7a8bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 23 Oct 2024 13:30:48 +0800 Subject: [PATCH 05/55] udpnat2: Add SetHandler --- common/bufio/cache.go | 4 +++- common/network/conn.go | 7 +------ common/network/packet.go | 35 +++++++++++++++++++++++++++++++ common/udpnat2/conn.go | 44 ++++++++++++++++++++++++++------------- common/udpnat2/packet.go | 28 ------------------------- common/udpnat2/service.go | 25 +++++++++++++--------- 6 files changed, 84 insertions(+), 59 deletions(-) create mode 100644 common/network/packet.go delete mode 100644 common/udpnat2/packet.go diff --git a/common/bufio/cache.go b/common/bufio/cache.go index ace7259..ce62d4d 100644 --- a/common/bufio/cache.go +++ b/common/bufio/cache.go @@ -184,10 +184,12 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer { if buffer != nil { buffer.DecRef() } - return &N.PacketBuffer{ + packet := N.NewPacketBuffer() + *packet = N.PacketBuffer{ Buffer: buffer, Destination: c.destination, } + return packet } func (c *CachedPacketConn) Upstream() any { diff --git a/common/network/conn.go b/common/network/conn.go index c795a19..c289bf6 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -124,7 +124,7 @@ type UDPHandler interface { } type UDPHandlerEx interface { - NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) + NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) } // Deprecated: Use UDPConnectionHandlerEx instead. @@ -146,11 +146,6 @@ type CachedPacketReader interface { ReadCachedPacket() *PacketBuffer } -type PacketBuffer struct { - Buffer *buf.Buffer - Destination M.Socksaddr -} - type WithUpstreamReader interface { UpstreamReader() any } diff --git a/common/network/packet.go b/common/network/packet.go new file mode 100644 index 0000000..5b85214 --- /dev/null +++ b/common/network/packet.go @@ -0,0 +1,35 @@ +package network + +import ( + "sync" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type PacketBuffer struct { + Buffer *buf.Buffer + Destination M.Socksaddr +} + +var packetPool = sync.Pool{ + New: func() any { + return new(PacketBuffer) + }, +} + +func NewPacketBuffer() *PacketBuffer { + return packetPool.Get().(*PacketBuffer) +} + +func PutPacketBuffer(packet *PacketBuffer) { + *packet = PacketBuffer{} + packetPool.Put(packet) +} + +func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) { + for _, packet := range packetBuffers { + packet.Buffer.Release() + PutPacketBuffer(packet) + } +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index a5ca8ac..a96f4c8 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -12,22 +12,23 @@ import ( "github.com/sagernet/sing/common/pipe" ) -type natConn struct { +type Conn struct { writer N.PacketWriter localAddr M.Socksaddr - packetChan chan *Packet + handler N.UDPHandlerEx + packetChan chan *N.PacketBuffer doneChan chan struct{} readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions } -func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { +func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { select { case p := <-c.packetChan: _, err = buffer.ReadOnceFrom(p.Buffer) destination := p.Destination p.Buffer.Release() - PutPacket(p) + N.PutPacketBuffer(p) return destination, err case <-c.doneChan: return M.Socksaddr{}, io.ErrClosedPipe @@ -36,21 +37,36 @@ func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } } -func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { +func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WritePacket(buffer, destination) } -func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { +func (c *Conn) SetHandler(handler N.UDPHandlerEx) { + c.handler = handler +fetch: + for { + select { + case packet := <-c.packetChan: + c.handler.NewPacketEx(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + continue fetch + default: + break fetch + } + } +} + +func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { c.readWaitOptions = options return false } -func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { +func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { select { case packet := <-c.packetChan: buffer = c.readWaitOptions.Copy(packet.Buffer) destination = packet.Destination - PutPacket(packet) + N.PutPacketBuffer(packet) return case <-c.doneChan: return nil, M.Socksaddr{}, io.ErrClosedPipe @@ -59,7 +75,7 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, } } -func (c *natConn) Close() error { +func (c *Conn) Close() error { select { case <-c.doneChan: default: @@ -68,23 +84,23 @@ func (c *natConn) Close() error { return nil } -func (c *natConn) LocalAddr() net.Addr { +func (c *Conn) LocalAddr() net.Addr { return c.localAddr } -func (c *natConn) RemoteAddr() net.Addr { +func (c *Conn) RemoteAddr() net.Addr { return M.Socksaddr{} } -func (c *natConn) SetDeadline(t time.Time) error { +func (c *Conn) SetDeadline(t time.Time) error { return os.ErrInvalid } -func (c *natConn) SetReadDeadline(t time.Time) error { +func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) return nil } -func (c *natConn) SetWriteDeadline(t time.Time) error { +func (c *Conn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } diff --git a/common/udpnat2/packet.go b/common/udpnat2/packet.go deleted file mode 100644 index 1d56ff4..0000000 --- a/common/udpnat2/packet.go +++ /dev/null @@ -1,28 +0,0 @@ -package udpnat - -import ( - "sync" - - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" -) - -var packetPool = sync.Pool{ - New: func() any { - return new(Packet) - }, -} - -type Packet struct { - Buffer *buf.Buffer - Destination M.Socksaddr -} - -func NewPacket() *Packet { - return packetPool.Get().(*Packet) -} - -func PutPacket(packet *Packet) { - *packet = Packet{} - packetPool.Put(packet) -} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 85b3641..8c8afc9 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -14,7 +14,7 @@ import ( ) type Service struct { - nat *freelru.LRU[netip.AddrPort, *natConn] + nat *freelru.LRU[netip.AddrPort, *Conn] handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics @@ -30,9 +30,9 @@ type Metrics struct { } func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service { - nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + nat := common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) nat.SetLifetime(timeout) - nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { + nat.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { select { case <-conn.doneChan: return false @@ -40,7 +40,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur return true } }) - nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { + nat.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { conn.Close() }) return &Service{ @@ -55,26 +55,31 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati if !loaded { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { + println(2) s.metrics.Rejects++ return } - conn = &natConn{ + conn = &Conn{ writer: writer, localAddr: source, - packetChan: make(chan *Packet, 64), + packetChan: make(chan *N.PacketBuffer, 64), doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } s.nat.Add(source.AddrPort(), conn) - s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) + go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) s.metrics.Creates++ } - packet := NewPacket() buffer := conn.readWaitOptions.NewPacketBuffer() for _, bufferSlice := range bufferSlices { buffer.Write(bufferSlice) } - *packet = Packet{ + if conn.handler != nil { + conn.handler.NewPacketEx(buffer, destination) + return + } + packet := N.NewPacketBuffer() + *packet = N.PacketBuffer{ Buffer: buffer, Destination: destination, } @@ -83,7 +88,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati s.metrics.Inputs++ default: packet.Buffer.Release() - PutPacket(packet) + N.PutPacketBuffer(packet) s.metrics.Drops++ } } From c80c8f907c56b8bed317df8364a842120ba1a353 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 31 Oct 2024 20:21:00 +0800 Subject: [PATCH 06/55] badjson: Add context marshaler/unmarshaler --- common/json/badjson/json.go | 5 +- common/json/badjson/merge.go | 45 ++++++------- common/json/badjson/merge_objects.go | 28 +++++--- common/json/badjson/object.go | 15 ++++- common/json/badjson/typed.go | 23 +++++-- common/json/context_ext.go | 23 +++++++ common/json/internal/contextjson/context.go | 11 ++++ .../json/internal/contextjson/context_test.go | 43 ++++++++++++ common/json/internal/contextjson/decode.go | 49 ++++++++++++-- common/json/internal/contextjson/encode.go | 66 +++++++++++++++++-- common/json/internal/contextjson/stream.go | 16 ++++- common/json/internal/contextjson/unmarshal.go | 14 ++++ common/json/unmarshal.go | 7 +- 13 files changed, 285 insertions(+), 60 deletions(-) create mode 100644 common/json/context_ext.go create mode 100644 common/json/internal/contextjson/context.go create mode 100644 common/json/internal/contextjson/context_test.go diff --git a/common/json/badjson/json.go b/common/json/badjson/json.go index 04dba1e..35f33a8 100644 --- a/common/json/badjson/json.go +++ b/common/json/badjson/json.go @@ -2,13 +2,14 @@ package badjson import ( "bytes" + "context" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" ) -func Decode(content []byte) (any, error) { - decoder := json.NewDecoder(bytes.NewReader(content)) +func Decode(ctx context.Context, content []byte) (any, error) { + decoder := json.NewDecoderContext(ctx, bytes.NewReader(content)) return decodeJSON(decoder) } diff --git a/common/json/badjson/merge.go b/common/json/badjson/merge.go index ee7193e..ac1f12f 100644 --- a/common/json/badjson/merge.go +++ b/common/json/badjson/merge.go @@ -1,6 +1,7 @@ package badjson import ( + "context" "os" "reflect" @@ -9,75 +10,75 @@ import ( "github.com/sagernet/sing/common/json" ) -func Omitempty[T any](value T) (T, error) { +func Omitempty[T any](ctx context.Context, value T) (T, error) { objectContent, err := json.Marshal(value) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal object") } - rawNewObject, err := Decode(objectContent) + rawNewObject, err := Decode(ctx, objectContent) if err != nil { return common.DefaultValue[T](), err } - newObjectContent, err := json.Marshal(rawNewObject) + newObjectContent, err := json.MarshalContext(ctx, rawNewObject) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal new object") } var newObject T - err = json.Unmarshal(newObjectContent, &newObject) + err = json.UnmarshalContext(ctx, newObjectContent, &newObject) if err != nil { return common.DefaultValue[T](), E.Cause(err, "unmarshal new object") } return newObject, nil } -func Merge[T any](source T, destination T, disableAppend bool) (T, error) { - rawSource, err := json.Marshal(source) +func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) { + rawSource, err := json.MarshalContext(ctx, source) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal source") } - rawDestination, err := json.Marshal(destination) + rawDestination, err := json.MarshalContext(ctx, destination) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal destination") } - return MergeFrom[T](rawSource, rawDestination, disableAppend) + return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend) } -func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) { +func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) { if rawSource == nil { return destination, nil } - rawDestination, err := json.Marshal(destination) + rawDestination, err := json.MarshalContext(ctx, destination) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal destination") } - return MergeFrom[T](rawSource, rawDestination, disableAppend) + return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend) } -func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) { +func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) { if rawDestination == nil { return source, nil } - rawSource, err := json.Marshal(source) + rawSource, err := json.MarshalContext(ctx, source) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal source") } - return MergeFrom[T](rawSource, rawDestination, disableAppend) + return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend) } -func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) { - rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend) +func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) { + rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend) if err != nil { return common.DefaultValue[T](), E.Cause(err, "merge options") } var merged T - err = json.Unmarshal(rawMerged, &merged) + err = json.UnmarshalContext(ctx, rawMerged, &merged) if err != nil { return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options") } return merged, nil } -func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) { +func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) { if rawSource == nil && rawDestination == nil { return nil, os.ErrInvalid } else if rawSource == nil { @@ -85,16 +86,16 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl } else if rawDestination == nil { return rawSource, nil } - source, err := Decode(rawSource) + source, err := Decode(ctx, rawSource) if err != nil { return nil, E.Cause(err, "decode source") } - destination, err := Decode(rawDestination) + destination, err := Decode(ctx, rawDestination) if err != nil { return nil, E.Cause(err, "decode destination") } if source == nil { - return json.Marshal(destination) + return json.MarshalContext(ctx, destination) } else if destination == nil { return json.Marshal(source) } @@ -102,7 +103,7 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl if err != nil { return nil, err } - return json.Marshal(merged) + return json.MarshalContext(ctx, merged) } func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) { diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index 37a5daf..fa6c2d4 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -1,32 +1,42 @@ package badjson import ( + "context" + E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" ) func MarshallObjects(objects ...any) ([]byte, error) { + return MarshallObjectsContext(context.Background(), objects...) +} + +func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) { if len(objects) == 1 { return json.Marshal(objects[0]) } var content JSONObject for _, object := range objects { - objectMap, err := newJSONObject(object) + objectMap, err := newJSONObject(ctx, object) if err != nil { return nil, err } content.PutAll(objectMap) } - return content.MarshalJSON() + return content.MarshalJSONContext(ctx) } func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(parentObject) + return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object) +} + +func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { + parentContent, err := newJSONObject(ctx, parentObject) if err != nil { return err } var content JSONObject - err = content.UnmarshalJSON(inputContent) + err = content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } @@ -39,20 +49,20 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } return E.New("unexpected key: ", content.Keys()[0]) } - inputContent, err = content.MarshalJSON() + inputContent, err = content.MarshalJSONContext(ctx) if err != nil { return err } - return json.UnmarshalDisallowUnknownFields(inputContent, object) + return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object) } -func newJSONObject(object any) (*JSONObject, error) { - inputContent, err := json.Marshal(object) +func newJSONObject(ctx context.Context, object any) (*JSONObject, error) { + inputContent, err := json.MarshalContext(ctx, object) if err != nil { return nil, err } var content JSONObject - err = content.UnmarshalJSON(inputContent) + err = content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return nil, err } diff --git a/common/json/badjson/object.go b/common/json/badjson/object.go index 61d5862..3f5dab4 100644 --- a/common/json/badjson/object.go +++ b/common/json/badjson/object.go @@ -2,6 +2,7 @@ package badjson import ( "bytes" + "context" "strings" "github.com/sagernet/sing/common" @@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool { } func (m *JSONObject) MarshalJSON() ([]byte, error) { + return m.MarshalJSONContext(context.Background()) +} + +func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) { buffer := new(bytes.Buffer) buffer.WriteString("{") items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool { @@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) { }) iLen := len(items) for i, entry := range items { - keyContent, err := json.Marshal(entry.Key) + keyContent, err := json.MarshalContext(ctx, entry.Key) if err != nil { return nil, err } buffer.WriteString(strings.TrimSpace(string(keyContent))) buffer.WriteString(": ") - valueContent, err := json.Marshal(entry.Value) + valueContent, err := json.MarshalContext(ctx, entry.Value) if err != nil { return nil, err } @@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) { } func (m *JSONObject) UnmarshalJSON(content []byte) error { - decoder := json.NewDecoder(bytes.NewReader(content)) + return m.UnmarshalJSONContext(context.Background(), content) +} + +func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error { + decoder := json.NewDecoderContext(ctx, bytes.NewReader(content)) m.Clear() objectStart, err := decoder.Token() if err != nil { diff --git a/common/json/badjson/typed.go b/common/json/badjson/typed.go index 66f41a6..aef85c9 100644 --- a/common/json/badjson/typed.go +++ b/common/json/badjson/typed.go @@ -2,6 +2,7 @@ package badjson import ( "bytes" + "context" "strings" E "github.com/sagernet/sing/common/exceptions" @@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct { } func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { + return m.MarshalJSONContext(context.Background()) +} + +func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) { buffer := new(bytes.Buffer) buffer.WriteString("{") items := m.Entries() iLen := len(items) for i, entry := range items { - keyContent, err := json.Marshal(entry.Key) + keyContent, err := json.MarshalContext(ctx, entry.Key) if err != nil { return nil, err } buffer.WriteString(strings.TrimSpace(string(keyContent))) buffer.WriteString(": ") - valueContent, err := json.Marshal(entry.Value) + valueContent, err := json.MarshalContext(ctx, entry.Value) if err != nil { return nil, err } @@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { } func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { - decoder := json.NewDecoder(bytes.NewReader(content)) + return m.UnmarshalJSONContext(context.Background(), content) +} + +func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error { + decoder := json.NewDecoderContext(ctx, bytes.NewReader(content)) m.Clear() objectStart, err := decoder.Token() if err != nil { @@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { } else if objectStart != json.Delim('{') { return E.New("expected json object start, but starts with ", objectStart) } - err = m.decodeJSON(decoder) + err = m.decodeJSON(ctx, decoder) if err != nil { return E.Cause(err, "decode json object content") } @@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { return nil } -func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error { +func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error { for decoder.More() { keyToken, err := decoder.Token() if err != nil { return err } - keyContent, err := json.Marshal(keyToken) + keyContent, err := json.MarshalContext(ctx, keyToken) if err != nil { return err } var entryKey K - err = json.Unmarshal(keyContent, &entryKey) + err = json.UnmarshalContext(ctx, keyContent, &entryKey) if err != nil { return err } diff --git a/common/json/context_ext.go b/common/json/context_ext.go new file mode 100644 index 0000000..aec149a --- /dev/null +++ b/common/json/context_ext.go @@ -0,0 +1,23 @@ +package json + +import ( + "context" + + "github.com/sagernet/sing/common/json/internal/contextjson" +) + +var ( + MarshalContext = json.MarshalContext + UnmarshalContext = json.UnmarshalContext + NewEncoderContext = json.NewEncoderContext + NewDecoderContext = json.NewDecoderContext + UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields +) + +type ContextMarshaler interface { + MarshalJSONContext(ctx context.Context) ([]byte, error) +} + +type ContextUnmarshaler interface { + UnmarshalJSONContext(ctx context.Context, content []byte) error +} diff --git a/common/json/internal/contextjson/context.go b/common/json/internal/contextjson/context.go new file mode 100644 index 0000000..ded69d7 --- /dev/null +++ b/common/json/internal/contextjson/context.go @@ -0,0 +1,11 @@ +package json + +import "context" + +type ContextMarshaler interface { + MarshalJSONContext(ctx context.Context) ([]byte, error) +} + +type ContextUnmarshaler interface { + UnmarshalJSONContext(ctx context.Context, content []byte) error +} diff --git a/common/json/internal/contextjson/context_test.go b/common/json/internal/contextjson/context_test.go new file mode 100644 index 0000000..cffecbb --- /dev/null +++ b/common/json/internal/contextjson/context_test.go @@ -0,0 +1,43 @@ +package json_test + +import ( + "context" + "testing" + + "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type myStruct struct { + value string +} + +func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return json.Marshal(ctx.Value("key").(string)) +} + +func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error { + m.value = ctx.Value("key").(string) + return nil +} + +//nolint:staticcheck +func TestMarshalContext(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), "key", "value") + var s myStruct + b, err := json.MarshalContext(ctx, &s) + require.NoError(t, err) + require.Equal(t, []byte(`"value"`), b) +} + +//nolint:staticcheck +func TestUnmarshalContext(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), "key", "value") + var s myStruct + err := json.UnmarshalContext(ctx, []byte(`{}`), &s) + require.NoError(t, err) + require.Equal(t, "value", s.value) +} diff --git a/common/json/internal/contextjson/decode.go b/common/json/internal/contextjson/decode.go index 8457171..20c7ac6 100644 --- a/common/json/internal/contextjson/decode.go +++ b/common/json/internal/contextjson/decode.go @@ -8,6 +8,7 @@ package json import ( + "context" "encoding" "encoding/base64" "fmt" @@ -95,10 +96,15 @@ import ( // Instead, they are replaced by the Unicode replacement // character U+FFFD. func Unmarshal(data []byte, v any) error { + return UnmarshalContext(context.Background(), data, v) +} + +func UnmarshalContext(ctx context.Context, data []byte, v any) error { // Check for well-formedness. // Avoids filling out half a data structure // before discovering a JSON syntax error. var d decodeState + d.ctx = ctx err := checkValid(data, &d.scan) if err != nil { return err @@ -209,6 +215,7 @@ type errorContext struct { // decodeState represents the state while decoding a JSON value. type decodeState struct { + ctx context.Context data []byte off int // next read offset in data opcode int // last read result @@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any { // If it encounters an Unmarshaler, indirect stops and returns that. // If decodingNull is true, indirect stops at the first settable pointer so it // can be set to nil. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm } if v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { - return u, nil, reflect.Value{} + return u, nil, nil, reflect.Value{} + } + if cu, ok := v.Interface().(ContextUnmarshaler); ok { + return nil, cu, nil, reflect.Value{} } if !decodingNull { if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { - return nil, u, reflect.Value{} + return nil, nil, u, reflect.Value{} } } } @@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm v = v.Elem() } } - return nil, nil, v + return nil, nil, nil, v } // array consumes an array from d.data[d.off-1:], decoding into v. // The first byte of the array ('[') has been read already. func (d *decodeState) array(v reflect.Value) error { // Check for unmarshaler. - u, ut, pv := indirect(v, false) + u, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error { } return nil } + if cu != nil { + start := d.readIndex() + d.skip() + err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) d.skip() @@ -612,7 +631,7 @@ var ( // The first byte ('{') of the object has been read already. func (d *decodeState) object(v reflect.Value) error { // Check for unmarshaler. - u, ut, pv := indirect(v, false) + u, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error { } return nil } + if cu != nil { + start := d.readIndex() + d.skip() + err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) d.skip() @@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool return nil } isNull := item[0] == 'n' // null - u, ut, pv := indirect(v, isNull) + u, cu, ut, pv := indirect(v, isNull) if u != nil { err := u.UnmarshalJSON(item) if err != nil { @@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool } return nil } + if cu != nil { + err := cu.UnmarshalJSONContext(d.ctx, item) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { if item[0] != '"' { if fromQuoted { diff --git a/common/json/internal/contextjson/encode.go b/common/json/internal/contextjson/encode.go index 296177a..27f901b 100644 --- a/common/json/internal/contextjson/encode.go +++ b/common/json/internal/contextjson/encode.go @@ -12,6 +12,7 @@ package json import ( "bytes" + "context" "encoding" "encoding/base64" "fmt" @@ -156,7 +157,11 @@ import ( // handle them. Passing cyclic structures to Marshal will result in // an error. func Marshal(v any) ([]byte, error) { - e := newEncodeState() + return MarshalContext(context.Background(), v) +} + +func MarshalContext(ctx context.Context, v any) ([]byte, error) { + e := newEncodeState(ctx) defer encodeStatePool.Put(e) err := e.marshal(v, encOpts{escapeHTML: true}) @@ -251,6 +256,7 @@ var hex = "0123456789abcdef" type encodeState struct { bytes.Buffer // accumulated output + ctx context.Context // Keep track of what pointers we've seen in the current recursive call // path, to avoid cycles that could lead to a stack overflow. Only do // the relatively expensive map operations if ptrLevel is larger than @@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000 var encodeStatePool sync.Pool -func newEncodeState() *encodeState { +func newEncodeState(ctx context.Context) *encodeState { if v := encodeStatePool.Get(); v != nil { e := v.(*encodeState) e.Reset() @@ -274,7 +280,7 @@ func newEncodeState() *encodeState { e.ptrLevel = 0 return e } - return &encodeState{ptrSeen: make(map[any]struct{})} + return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})} } // jsonError is an error wrapper type for internal use only. @@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc { } var ( - marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() - textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) // newTypeEncoder constructs an encoderFunc for a type. @@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } + if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) { + return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false)) + } if t.Implements(marshalerType) { return marshalerEncoder } + if t.Implements(contextMarshalerType) { + return contextMarshalerEncoder + } if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) { return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) } @@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { } } +func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if v.Kind() == reflect.Pointer && v.IsNil() { + e.WriteString("null") + return + } + m, ok := v.Interface().(ContextMarshaler) + if !ok { + e.WriteString("null") + return + } + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + e.Grow(len(b)) + out := availableBuffer(&e.Buffer) + out, err = appendCompact(out, b, opts.escapeHTML) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) + } +} + +func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(ContextMarshaler) + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + e.Grow(len(b)) + out := availableBuffer(&e.Buffer) + out, err = appendCompact(out, b, opts.escapeHTML) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) + } +} + func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.Kind() == reflect.Pointer && v.IsNil() { e.WriteString("null") @@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. if t.Elem().Kind() == reflect.Uint8 { p := reflect.PointerTo(t.Elem()) - if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) { + if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) { return encodeByteSlice } } diff --git a/common/json/internal/contextjson/stream.go b/common/json/internal/contextjson/stream.go index a670ab1..2849dbf 100644 --- a/common/json/internal/contextjson/stream.go +++ b/common/json/internal/contextjson/stream.go @@ -6,6 +6,7 @@ package json import ( "bytes" + "context" "errors" "io" ) @@ -29,7 +30,11 @@ type Decoder struct { // The decoder introduces its own buffering and may // read data from r beyond the JSON values requested. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{r: r} + return NewDecoderContext(context.Background(), r) +} + +func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder { + return &Decoder{r: r, d: decodeState{ctx: ctx}} } // UseNumber causes the Decoder to unmarshal a number into an interface{} as a @@ -183,6 +188,7 @@ func nonSpace(b []byte) bool { // An Encoder writes JSON values to an output stream. type Encoder struct { + ctx context.Context w io.Writer err error escapeHTML bool @@ -194,7 +200,11 @@ type Encoder struct { // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { - return &Encoder{w: w, escapeHTML: true} + return NewEncoderContext(context.Background(), w) +} + +func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder { + return &Encoder{ctx: ctx, w: w, escapeHTML: true} } // Encode writes the JSON encoding of v to the stream, @@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error { return enc.err } - e := newEncodeState() + e := newEncodeState(enc.ctx) defer encodeStatePool.Put(e) err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML}) diff --git a/common/json/internal/contextjson/unmarshal.go b/common/json/internal/contextjson/unmarshal.go index 2940539..04c13cb 100644 --- a/common/json/internal/contextjson/unmarshal.go +++ b/common/json/internal/contextjson/unmarshal.go @@ -1,5 +1,7 @@ package json +import "context" + func UnmarshalDisallowUnknownFields(data []byte, v any) error { var d decodeState d.disallowUnknownFields = true @@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error { d.init(data) return d.unmarshal(v) } + +func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error { + var d decodeState + d.ctx = ctx + d.disallowUnknownFields = true + err := checkValid(data, &d.scan) + if err != nil { + return err + } + d.init(data) + return d.unmarshal(v) +} diff --git a/common/json/unmarshal.go b/common/json/unmarshal.go index 7505ebc..94a2d76 100644 --- a/common/json/unmarshal.go +++ b/common/json/unmarshal.go @@ -2,6 +2,7 @@ package json import ( "bytes" + "context" "errors" "strings" @@ -10,7 +11,11 @@ import ( ) func UnmarshalExtended[T any](content []byte) (T, error) { - decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content))) + return UnmarshalExtendedContext[T](context.Background(), content) +} + +func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) { + decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content))) var value T err := decoder.Decode(&value) if err == nil { From b5f9e70ffd60ca1b5d770708875a13a0f2c62a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Nov 2024 21:16:56 +0800 Subject: [PATCH 07/55] badjson: Fix Listable --- common/json/badoption/listable.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/common/json/badoption/listable.go b/common/json/badoption/listable.go index a4e88e3..60a044c 100644 --- a/common/json/badoption/listable.go +++ b/common/json/badoption/listable.go @@ -1,6 +1,8 @@ package badoption import ( + "context" + E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" ) @@ -15,13 +17,13 @@ func (l Listable[T]) MarshalJSON() ([]byte, error) { return json.Marshal(arrayList) } -func (l *Listable[T]) UnmarshalJSON(content []byte) error { - err := json.UnmarshalDisallowUnknownFields(content, (*[]T)(l)) +func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l)) if err == nil { return nil } var singleItem T - newError := json.UnmarshalDisallowUnknownFields(content, &singleItem) + newError := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem) if newError != nil { return E.Errors(err, newError) } From 524a6bd0d1c63737ab4353681fa3db200b863188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Nov 2024 10:22:04 +0800 Subject: [PATCH 08/55] udpnat2: Set upstream to writer --- common/udpnat2/conn.go | 4 ++++ common/udpnat2/service.go | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index a96f4c8..5d474e6 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -104,3 +104,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } + +func (c *Conn) Upstream() any { + return c.writer +} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 8c8afc9..e2a1482 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -55,7 +55,6 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati if !loaded { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { - println(2) s.metrics.Rejects++ return } From fcb19641e6af7753840fa21ff35bdeeb5f72de4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 9 Nov 2024 11:36:15 +0800 Subject: [PATCH 09/55] freelru: Copy shared source --- contrab/freelru/sharedlru.go | 315 +++++++++++++++++++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 contrab/freelru/sharedlru.go diff --git a/contrab/freelru/sharedlru.go b/contrab/freelru/sharedlru.go new file mode 100644 index 0000000..d3ba09c --- /dev/null +++ b/contrab/freelru/sharedlru.go @@ -0,0 +1,315 @@ +package freelru + +import ( + "errors" + "fmt" + "math/bits" + "runtime" + "sync" + "time" +) + +// ShardedLRU is a thread-safe, sharded, fixed size LRU cache. +// Sharding is used to reduce lock contention on high concurrency. +// The downside is that exact LRU behavior is not given (as for the LRU and SynchedLRU types). +type ShardedLRU[K comparable, V any] struct { + lrus []LRU[K, V] + mus []sync.RWMutex + hash HashKeyCallback[K] + shards uint32 + mask uint32 +} + +var _ Cache[int, int] = (*ShardedLRU[int, int])(nil) + +// SetLifetime sets the default lifetime of LRU elements. +// Lifetime 0 means "forever". +func (lru *ShardedLRU[K, V]) SetLifetime(lifetime time.Duration) { + for shard := range lru.lrus { + lru.mus[shard].Lock() + lru.lrus[shard].SetLifetime(lifetime) + lru.mus[shard].Unlock() + } +} + +// SetOnEvict sets the OnEvict callback function. +// The onEvict function is called for each evicted lru entry. +func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { + for shard := range lru.lrus { + lru.mus[shard].Lock() + lru.lrus[shard].SetOnEvict(onEvict) + lru.mus[shard].Unlock() + } +} + +func nextPowerOfTwo(val uint32) uint32 { + if bits.OnesCount32(val) != 1 { + return 1 << bits.Len32(val) + } + return val +} + +// NewSharded creates a new thread-safe sharded LRU hashmap with the given capacity. +func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V], + error) { + size := uint32(float64(capacity) * 1.25) // 25% extra space for fewer collisions + + return NewShardedWithSize[K, V](uint32(runtime.GOMAXPROCS(0)*16), capacity, size, hash) +} + +func NewShardedWithSize[K comparable, V any](shards, capacity, size uint32, + hash HashKeyCallback[K]) ( + *ShardedLRU[K, V], error) { + if capacity == 0 { + return nil, errors.New("capacity must be positive") + } + if size < capacity { + return nil, fmt.Errorf("size (%d) is smaller than capacity (%d)", size, capacity) + } + + if size < 1<<31 { + size = nextPowerOfTwo(size) // next power of 2 so the LRUs can avoid costly divisions + } else { + size = 1 << 31 // the highest 2^N value that fits in a uint32 + } + + shards = nextPowerOfTwo(shards) // next power of 2 so we can avoid costly division for sharding + + for shards > size/16 { + shards /= 16 + } + if shards == 0 { + shards = 1 + } + + size /= shards // size per LRU + if size == 0 { + size = 1 + } + + capacity = (capacity + shards - 1) / shards // size per LRU + if capacity == 0 { + capacity = 1 + } + + lrus := make([]LRU[K, V], shards) + buckets := make([]uint32, size*shards) + elements := make([]element[K, V], size*shards) + + from := 0 + to := int(size) + for i := range lrus { + initLRU(&lrus[i], capacity, size, hash, buckets[from:to], elements[from:to]) + from = to + to += int(size) + } + + return &ShardedLRU[K, V]{ + lrus: lrus, + mus: make([]sync.RWMutex, shards), + hash: hash, + shards: shards, + mask: shards - 1, + }, nil +} + +// Len returns the number of elements stored in the cache. +func (lru *ShardedLRU[K, V]) Len() (length int) { + for shard := range lru.lrus { + lru.mus[shard].RLock() + length += lru.lrus[shard].Len() + lru.mus[shard].RUnlock() + } + return +} + +// AddWithLifetime adds a key:value to the cache with a lifetime. +// Returns true, true if key was updated and eviction occurred. +func (lru *ShardedLRU[K, V]) AddWithLifetime(key K, value V, + lifetime time.Duration) (evicted bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + evicted = lru.lrus[shard].addWithLifetime(hash, key, value, lifetime) + lru.mus[shard].Unlock() + + return +} + +// Add adds a key:value to the cache. +// Returns true, true if key was updated and eviction occurred. +func (lru *ShardedLRU[K, V]) Add(key K, value V) (evicted bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + evicted = lru.lrus[shard].add(hash, key, value) + lru.mus[shard].Unlock() + + return +} + +// Get returns the value associated with the key, setting it as the most +// recently used item. +// If the found cache item is already expired, the evict function is called +// and the return value indicates that the key was not found. +func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, ok = lru.lrus[shard].get(hash, key) + lru.mus[shard].Unlock() + + return +} + +// Peek looks up a key's value from the cache, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, ok = lru.lrus[shard].peek(hash, key) + lru.mus[shard].Unlock() + + return +} + +// Contains checks for the existence of a key, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *ShardedLRU[K, V]) Contains(key K) (ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + ok = lru.lrus[shard].contains(hash, key) + lru.mus[shard].Unlock() + + return +} + +// Remove removes the key from the cache. +// The return value indicates whether the key existed or not. +// The evict function is called for the removed entry. +func (lru *ShardedLRU[K, V]) Remove(key K) (removed bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + removed = lru.lrus[shard].remove(hash, key) + lru.mus[shard].Unlock() + + return +} + +// RemoveOldest removes the oldest entry from the cache. +// Key, value and an indicator of whether the entry has been removed is returned. +// The evict function is called for the removed entry. +func (lru *ShardedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + key, value, removed = lru.lrus[shard].RemoveOldest() + lru.mus[shard].Unlock() + + return +} + +// Keys returns a slice of the keys in the cache, from oldest to newest. +// Expired entries are not included. +// The evict function is called for each expired item. +func (lru *ShardedLRU[K, V]) Keys() []K { + keys := make([]K, 0, lru.shards*lru.lrus[0].cap) + for shard := range lru.lrus { + lru.mus[shard].Lock() + keys = append(keys, lru.lrus[shard].Keys()...) + lru.mus[shard].Unlock() + } + + return keys +} + +// Purge purges all data (key and value) from the LRU. +// The evict function is called for each expired item. +// The LRU metrics are reset. +func (lru *ShardedLRU[K, V]) Purge() { + for shard := range lru.lrus { + lru.mus[shard].Lock() + lru.lrus[shard].Purge() + lru.mus[shard].Unlock() + } +} + +// PurgeExpired purges all expired items from the LRU. +// The evict function is called for each expired item. +func (lru *ShardedLRU[K, V]) PurgeExpired() { + for shard := range lru.lrus { + lru.mus[shard].Lock() + lru.lrus[shard].PurgeExpired() + lru.mus[shard].Unlock() + } +} + +// Metrics returns the metrics of the cache. +func (lru *ShardedLRU[K, V]) Metrics() Metrics { + metrics := Metrics{} + + for shard := range lru.lrus { + lru.mus[shard].Lock() + m := lru.lrus[shard].Metrics() + lru.mus[shard].Unlock() + + addMetrics(&metrics, m) + } + + return metrics +} + +// ResetMetrics resets the metrics of the cache and returns the previous state. +func (lru *ShardedLRU[K, V]) ResetMetrics() Metrics { + metrics := Metrics{} + + for shard := range lru.lrus { + lru.mus[shard].Lock() + m := lru.lrus[shard].ResetMetrics() + lru.mus[shard].Unlock() + + addMetrics(&metrics, m) + } + + return metrics +} + +func addMetrics(dst *Metrics, src Metrics) { + dst.Inserts += src.Inserts + dst.Collisions += src.Collisions + dst.Evictions += src.Evictions + dst.Removals += src.Removals + dst.Hits += src.Hits + dst.Misses += src.Misses +} + +// just used for debugging +func (lru *ShardedLRU[K, V]) dump() { + for shard := range lru.lrus { + fmt.Printf("Shard %d:\n", shard) + lru.mus[shard].RLock() + lru.lrus[shard].dump() + lru.mus[shard].RUnlock() + fmt.Println("") + } +} + +func (lru *ShardedLRU[K, V]) PrintStats() { + for shard := range lru.lrus { + fmt.Printf("Shard %d:\n", shard) + lru.mus[shard].RLock() + lru.lrus[shard].PrintStats() + lru.mus[shard].RUnlock() + fmt.Println("") + } +} From 11ffb962aec1ea57c119c8d4614ef09c43efb414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 9 Nov 2024 11:36:22 +0800 Subject: [PATCH 10/55] freelru: Fix impl --- contrab/freelru/lru.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 045cc3e..5a1ceb6 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -439,11 +439,11 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { - return lru.get(lru.hash(key), key, true) + return lru.get(lru.hash(key), key) } -func (lru *LRU[K, V]) get(hash uint32, key K, updateLifetime bool) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key, updateLifetime); ok { +func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key, true); ok { if pos != lru.head { lru.unlinkElement(pos) lru.setHead(pos) From 72ff654ee06cd8e8b98e8eedc316e1238efeea22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 9 Nov 2024 11:38:32 +0800 Subject: [PATCH 11/55] shared: Add SetHealthCheck to interface --- contrab/freelru/cache.go | 2 ++ contrab/freelru/sharedlru.go | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 59435ee..7363383 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -24,6 +24,8 @@ type Cache[K comparable, V any] interface { // Lifetime 0 means "forever". SetLifetime(lifetime time.Duration) + SetHealthCheck(healthCheck HealthCheckCallback[K, V]) + // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. SetOnEvict(onEvict OnEvictCallback[K, V]) diff --git a/contrab/freelru/sharedlru.go b/contrab/freelru/sharedlru.go index d3ba09c..86cd54e 100644 --- a/contrab/freelru/sharedlru.go +++ b/contrab/freelru/sharedlru.go @@ -42,6 +42,14 @@ func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { } } +func (lru *ShardedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { + for shard := range lru.lrus { + lru.mus[shard].Lock() + lru.lrus[shard].SetHealthCheck(healthCheck) + lru.mus[shard].Unlock() + } +} + func nextPowerOfTwo(val uint32) uint32 { if bits.OnesCount32(val) != 1 { return 1 << bits.Len32(val) @@ -51,7 +59,8 @@ func nextPowerOfTwo(val uint32) uint32 { // NewSharded creates a new thread-safe sharded LRU hashmap with the given capacity. func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V], - error) { + error, +) { size := uint32(float64(capacity) * 1.25) // 25% extra space for fewer collisions return NewShardedWithSize[K, V](uint32(runtime.GOMAXPROCS(0)*16), capacity, size, hash) @@ -59,7 +68,8 @@ func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) ( func NewShardedWithSize[K comparable, V any](shards, capacity, size uint32, hash HashKeyCallback[K]) ( - *ShardedLRU[K, V], error) { + *ShardedLRU[K, V], error, +) { if capacity == 0 { return nil, errors.New("capacity must be positive") } @@ -126,7 +136,8 @@ func (lru *ShardedLRU[K, V]) Len() (length int) { // AddWithLifetime adds a key:value to the cache with a lifetime. // Returns true, true if key was updated and eviction occurred. func (lru *ShardedLRU[K, V]) AddWithLifetime(key K, value V, - lifetime time.Duration) (evicted bool) { + lifetime time.Duration, +) (evicted bool) { hash := lru.hash(key) shard := (hash >> 16) & lru.mask From 099899991126ef393fd8ba7e972fe9975181420b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 9 Nov 2024 11:39:59 +0800 Subject: [PATCH 12/55] udpnat2: Fix missing shared impl --- common/udpnat2/service.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index e2a1482..68df666 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -14,7 +14,7 @@ import ( ) type Service struct { - nat *freelru.LRU[netip.AddrPort, *Conn] + cache freelru.Cache[netip.AddrPort, *Conn] handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics @@ -29,10 +29,15 @@ type Metrics struct { Drops uint64 } -func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service { - nat := common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) - nat.SetLifetime(timeout) - nat.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { +func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { + var cache freelru.Cache[netip.AddrPort, *Conn] + if !shared { + cache = common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + } else { + cache = common.Must1(freelru.NewSharded[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + } + cache.SetLifetime(timeout) + cache.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { select { case <-conn.doneChan: return false @@ -40,18 +45,18 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur return true } }) - nat.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { + cache.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { conn.Close() }) return &Service{ - nat: nat, + cache: cache, handler: handler, prepare: prepare, } } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { - conn, loaded := s.nat.Get(source.AddrPort()) + conn, loaded := s.cache.Get(source.AddrPort()) if !loaded { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { @@ -65,7 +70,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } - s.nat.Add(source.AddrPort(), conn) + s.cache.Add(source.AddrPort(), conn) go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) s.metrics.Creates++ } From cc7e6309230fe6909280476318ebcf47616efa04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 11 Nov 2024 00:06:25 +0800 Subject: [PATCH 13/55] control: Refactor interface finder --- common/cond.go | 41 +++++++++--------- common/control/bind_darwin.go | 4 +- common/control/bind_finder.go | 44 +++++++++++++++++-- common/control/bind_finder_default.go | 62 +++++++++++---------------- common/control/bind_linux.go | 4 +- common/control/bind_windows.go | 6 +-- 6 files changed, 91 insertions(+), 70 deletions(-) diff --git a/common/cond.go b/common/cond.go index 6fe11bc..5558715 100644 --- a/common/cond.go +++ b/common/cond.go @@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int { return -1 } +func Equal[S ~[]E, E comparable](s1, s2 S) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} + //go:norace func Dup[T any](obj T) T { pointer := uintptr(unsafe.Pointer(&obj)) @@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T { return arr } +func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K { + ret := make(map[V]K, len(m)) + for k, v := range m { + ret[v] = k + } + return ret +} + func Done(ctx context.Context) bool { select { case <-ctx.Done(): @@ -362,24 +382,3 @@ func Close(closers ...any) error { } return retErr } - -// Deprecated: wtf is this? -type Starter interface { - Start() error -} - -// Deprecated: wtf is this? -func Start(starters ...any) error { - for _, rawStarter := range starters { - if rawStarter == nil { - continue - } - if starter, isStarter := rawStarter.(Starter); isStarter { - err := starter.Start() - if err != nil { - return err - } - } - } - return nil -} diff --git a/common/control/bind_darwin.go b/common/control/bind_darwin.go index bff6c29..2fb3db9 100644 --- a/common/control/bind_darwin.go +++ b/common/control/bind_darwin.go @@ -9,15 +9,15 @@ import ( func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { - var err error if interfaceIndex == -1 { if finder == nil { return os.ErrInvalid } - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + iif, err := finder.ByName(interfaceName) if err != nil { return err } + interfaceIndex = iif.Index } switch network { case "tcp6", "udp6": diff --git a/common/control/bind_finder.go b/common/control/bind_finder.go index 9b013d3..f956f19 100644 --- a/common/control/bind_finder.go +++ b/common/control/bind_finder.go @@ -3,21 +3,57 @@ package control import ( "net" "net/netip" + "unsafe" + + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" ) type InterfaceFinder interface { Update() error Interfaces() []Interface - InterfaceIndexByName(name string) (int, error) - InterfaceNameByIndex(index int) (string, error) - InterfaceByAddr(addr netip.Addr) (*Interface, error) + ByName(name string) (*Interface, error) + ByIndex(index int) (*Interface, error) + ByAddr(addr netip.Addr) (*Interface, error) } type Interface struct { Index int MTU int Name string - Addresses []netip.Prefix HardwareAddr net.HardwareAddr Flags net.Flags + Addresses []netip.Prefix +} + +func (i Interface) Equals(other Interface) bool { + return i.Index == other.Index && + i.MTU == other.MTU && + i.Name == other.Name && + common.Equal(i.HardwareAddr, other.HardwareAddr) && + i.Flags == other.Flags && + common.Equal(i.Addresses, other.Addresses) +} + +func (i Interface) NetInterface() net.Interface { + return *(*net.Interface)(unsafe.Pointer(&i)) +} + +func InterfaceFromNet(iif net.Interface) (Interface, error) { + ifAddrs, err := iif.Addrs() + if err != nil { + return Interface{}, err + } + return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil +} + +func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface { + return Interface{ + Index: iif.Index, + MTU: iif.MTU, + Name: iif.Name, + HardwareAddr: iif.HardwareAddr, + Flags: iif.Flags, + Addresses: addresses, + } } diff --git a/common/control/bind_finder_default.go b/common/control/bind_finder_default.go index 804497b..cfc481e 100644 --- a/common/control/bind_finder_default.go +++ b/common/control/bind_finder_default.go @@ -3,11 +3,8 @@ package control import ( "net" "net/netip" - _ "unsafe" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" ) var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil) @@ -27,18 +24,12 @@ func (f *DefaultInterfaceFinder) Update() error { } interfaces := make([]Interface, 0, len(netIfs)) for _, netIf := range netIfs { - ifAddrs, err := netIf.Addrs() + var iif Interface + iif, err = InterfaceFromNet(netIf) if err != nil { return err } - interfaces = append(interfaces, Interface{ - Index: netIf.Index, - MTU: netIf.MTU, - Name: netIf.Name, - Addresses: common.Map(ifAddrs, M.PrefixFromNet), - HardwareAddr: netIf.HardwareAddr, - Flags: netIf.Flags, - }) + interfaces = append(interfaces, iif) } f.interfaces = interfaces return nil @@ -52,46 +43,41 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface { return f.interfaces } -func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) { +func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) { for _, netInterface := range f.interfaces { if netInterface.Name == name { - return netInterface.Index, nil + return &netInterface, nil } } - netInterface, err := net.InterfaceByName(name) - if err != nil { - return 0, err + _, err := net.InterfaceByName(name) + if err == nil { + err = f.Update() + if err != nil { + return nil, err + } + return f.ByName(name) } - f.Update() - return netInterface.Index, nil + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")} } -func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) { +func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) { for _, netInterface := range f.interfaces { if netInterface.Index == index { - return netInterface.Name, nil + return &netInterface, nil } } - netInterface, err := net.InterfaceByIndex(index) - if err != nil { - return "", err + _, err := net.InterfaceByIndex(index) + if err == nil { + err = f.Update() + if err != nil { + return nil, err + } + return f.ByIndex(index) } - f.Update() - return netInterface.Name, nil + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")} } -func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) { - for _, netInterface := range f.interfaces { - for _, prefix := range netInterface.Addresses { - if prefix.Contains(addr) { - return &netInterface, nil - } - } - } - err := f.Update() - if err != nil { - return nil, err - } +func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) { for _, netInterface := range f.interfaces { for _, prefix := range netInterface.Addresses { if prefix.Contains(addr) { diff --git a/common/control/bind_linux.go b/common/control/bind_linux.go index c92bf6b..c5e668d 100644 --- a/common/control/bind_linux.go +++ b/common/control/bind_linux.go @@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde if interfaceName == "" { return os.ErrInvalid } - var err error - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + iif, err := finder.ByName(interfaceName) if err != nil { return err } + interfaceIndex = iif.Index } err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex) if err == nil { diff --git a/common/control/bind_windows.go b/common/control/bind_windows.go index a499556..cf83386 100644 --- a/common/control/bind_windows.go +++ b/common/control/bind_windows.go @@ -11,19 +11,19 @@ import ( func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { - var err error if interfaceIndex == -1 { if finder == nil { return os.ErrInvalid } - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + iif, err := finder.ByName(interfaceName) if err != nil { return err } + interfaceIndex = iif.Index } handle := syscall.Handle(fd) if M.ParseSocksaddr(address).AddrString() == "" { - err = bind4(handle, interfaceIndex) + err := bind4(handle, interfaceIndex) if err != nil { return err } From c432befd02cbabcb0b94e98d2d934c385c67c16b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 13 Nov 2024 13:49:56 +0800 Subject: [PATCH 14/55] http: Fix proxying websocket --- protocol/http/handshake.go | 62 +++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 11 deletions(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 955a722..9528573 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -4,6 +4,7 @@ import ( std_bufio "bufio" "context" "encoding/base64" + "io" "net" "net/http" "strings" @@ -37,7 +38,6 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re if err != nil { return E.Cause(err, "read http request") } - if authenticator != nil { var ( username string @@ -81,11 +81,15 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } if request.Method == "CONNECT" { - portStr := request.URL.Port() - if portStr == "" { - portStr = "80" + destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) + if destination.Port == 0 { + switch request.URL.Scheme { + case "https", "wss": + destination.Port = 443 + default: + destination.Port = 80 + } } - destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr) _, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) if err != nil { return E.Cause(err, "write http response") @@ -108,11 +112,48 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re handlerEx.NewConnectionEx(ctx, requestConn, source, destination, onClose) return nil } - } - - err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) - if err != nil { - return err + } else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" { + destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) + if destination.Port == 0 { + switch request.URL.Scheme { + case "https", "wss": + destination.Port = 443 + default: + destination.Port = 80 + } + } + serverConn, clientConn := pipe.Pipe() + go func() { + if handler != nil { + //nolint:staticcheck + err := handler.NewConnection(ctx, clientConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) + if err != nil { + common.Close(serverConn, clientConn) + } + } else { + handlerEx.NewConnectionEx(ctx, clientConn, source, destination, func(it error) { + if it != nil { + common.Close(serverConn, clientConn) + } + }) + } + }() + err = request.Write(serverConn) + if err != nil { + return E.Cause(err, "http: write upgrade request") + } + if reader.Buffered() > 0 { + _, err = io.CopyN(serverConn, reader, int64(reader.Buffered())) + if err != nil { + return err + } + } + return bufio.CopyConn(ctx, conn, serverConn) + } else { + err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) + if err != nil { + return err + } } } } @@ -198,7 +239,6 @@ func handleHTTPConnection( if !keepAlive { return conn.Close() } - return nil } From ae139d9ee1a0d6540f4c2a0108a55ea88c205e7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Nov 2024 17:43:18 +0800 Subject: [PATCH 15/55] Update N.PayloadDialer --- common/network/dialer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network/dialer.go b/common/network/dialer.go index d32cda9..7c3f3e2 100644 --- a/common/network/dialer.go +++ b/common/network/dialer.go @@ -14,7 +14,7 @@ type Dialer interface { } type PayloadDialer interface { - DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) + DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payload [][]byte) (net.Conn, error) } type ParallelDialer interface { From 7f621fdd781cb50645aca876e319e53bd9c95756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Nov 2024 12:29:09 +0800 Subject: [PATCH 16/55] Add freelru.SetUpdateLifetimeOnGet/GetWithLifetime --- common/udpnat2/service.go | 1 + contrab/freelru/cache.go | 4 +++ contrab/freelru/lru.go | 48 ++++++++++++++++++++++-------------- contrab/freelru/lru_test.go | 7 ++++-- contrab/freelru/sharedlru.go | 21 +++++++++++++++- 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 68df666..e06cebc 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -37,6 +37,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur cache = common.Must1(freelru.NewSharded[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } cache.SetLifetime(timeout) + cache.SetUpdateLifetimeOnGet(true) cache.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { select { case <-conn.doneChan: diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 7363383..707e5bc 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -24,6 +24,8 @@ type Cache[K comparable, V any] interface { // Lifetime 0 means "forever". SetLifetime(lifetime time.Duration) + SetUpdateLifetimeOnGet(update bool) + SetHealthCheck(healthCheck HealthCheckCallback[K, V]) // SetOnEvict sets the OnEvict callback function. @@ -47,6 +49,8 @@ type Cache[K comparable, V any] interface { // and the return value indicates that the key was not found. Get(key K) (V, bool) + GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) + // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. Peek(key K) (V, bool) diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 5a1ceb6..7672460 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -63,13 +63,14 @@ const emptyBucket = math.MaxUint32 // LRU implements a non-thread safe fixed size LRU cache. type LRU[K comparable, V any] struct { - buckets []uint32 // contains positions of bucket lists or 'emptyBucket' - elements []element[K, V] - onEvict OnEvictCallback[K, V] - hash HashKeyCallback[K] - healthCheck HealthCheckCallback[K, V] - lifetime time.Duration - metrics Metrics + buckets []uint32 // contains positions of bucket lists or 'emptyBucket' + elements []element[K, V] + onEvict OnEvictCallback[K, V] + hash HashKeyCallback[K] + healthCheck HealthCheckCallback[K, V] + lifetime time.Duration + updateLifetimeOnGet bool + metrics Metrics // used for element clearing after removal or expiration emptyKey K @@ -100,6 +101,10 @@ func (lru *LRU[K, V]) SetLifetime(lifetime time.Duration) { lru.lifetime = lifetime } +func (lru *LRU[K, V]) SetUpdateLifetimeOnGet(update bool) { + lru.updateLifetimeOnGet = update +} + // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. // Eviction happens @@ -303,10 +308,10 @@ func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) { lru.elements[pos].value = lru.emptyValue } -func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, bool) { +func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, int64, bool) { _, startPos := lru.hashToPos(hash) if startPos == emptyBucket { - return emptyBucket, false + return emptyBucket, 0, false } pos := startPos @@ -315,18 +320,18 @@ func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uin elem := lru.elements[pos] if (elem.expire != 0 && elem.expire <= now()) || (lru.healthCheck != nil && !lru.healthCheck(key, elem.value)) { lru.removeAt(pos) - return emptyBucket, false + return emptyBucket, elem.expire, false } if updateLifetimeOnGet { lru.elements[pos].expire = expire(lru.lifetime) } - return pos, true + return pos, elem.expire, true } pos = lru.elements[pos].nextBucket if pos == startPos { // Key not found - return emptyBucket, false + return emptyBucket, 0, false } } } @@ -439,17 +444,24 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { - return lru.get(lru.hash(key), key) + value, _, ok = lru.get(lru.hash(key), key) + return } -func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key, true); ok { +func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + value, expireMills, ok := lru.get(lru.hash(key), key) + lifetime = time.UnixMilli(expireMills) + return +} + +func (lru *LRU[K, V]) get(hash uint32, key K) (value V, expire int64, ok bool) { + if pos, expire, ok := lru.findKey(hash, key, lru.updateLifetimeOnGet); ok { if pos != lru.head { lru.unlinkElement(pos) lru.setHead(pos) } lru.metrics.Hits++ - return lru.elements[pos].value, ok + return lru.elements[pos].value, expire, ok } lru.metrics.Misses++ @@ -463,7 +475,7 @@ func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { } func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key, false); ok { + if pos, _, ok := lru.findKey(hash, key, false); ok { return lru.elements[pos].value, ok } @@ -490,7 +502,7 @@ func (lru *LRU[K, V]) Remove(key K) (removed bool) { } func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { - if pos, ok := lru.findKey(hash, key, false); ok { + if pos, _, ok := lru.findKey(hash, key, false); ok { lru.removeAt(pos) return ok } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index 4c4ba59..fa3c157 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -14,18 +14,21 @@ func TestMyChange0(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) + lru.SetUpdateLifetimeOnGet(true) lru.AddWithLifetime("hello", "world", 2*time.Second) time.Sleep(time.Second) - lru.Get("hello") - time.Sleep(time.Second + time.Millisecond*100) _, ok := lru.Get("hello") require.True(t, ok) + time.Sleep(time.Second + time.Millisecond*100) + _, ok = lru.Get("hello") + require.True(t, ok) } func TestMyChange1(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) + lru.SetUpdateLifetimeOnGet(true) lru.AddWithLifetime("hello", "world", 2*time.Second) time.Sleep(time.Second) lru.Peek("hello") diff --git a/contrab/freelru/sharedlru.go b/contrab/freelru/sharedlru.go index 86cd54e..db1d8cd 100644 --- a/contrab/freelru/sharedlru.go +++ b/contrab/freelru/sharedlru.go @@ -32,6 +32,14 @@ func (lru *ShardedLRU[K, V]) SetLifetime(lifetime time.Duration) { } } +func (lru *ShardedLRU[K, V]) SetUpdateLifetimeOnGet(update bool) { + for shard := range lru.lrus { + lru.mus[shard].Lock() + lru.lrus[shard].SetUpdateLifetimeOnGet(update) + lru.mus[shard].Unlock() + } +} + // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { @@ -170,12 +178,23 @@ func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) { shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, ok = lru.lrus[shard].get(hash, key) + value, _, ok = lru.lrus[shard].get(hash, key) lru.mus[shard].Unlock() return } +func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, expireMills, ok := lru.lrus[shard].get(hash, key) + lru.mus[shard].Unlock() + lifetime = time.UnixMilli(expireMills) + return +} + // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) { From e52e04f7217bfbc3b1792df0b50184c5ade03986 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 14 Nov 2024 19:45:41 +0800 Subject: [PATCH 17/55] Fix HandshakeFailure usages --- common/network/dialer.go | 4 ---- common/network/handshake.go | 8 +++++++- protocol/socks/lazy.go | 5 +++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/common/network/dialer.go b/common/network/dialer.go index 7c3f3e2..0dbebf4 100644 --- a/common/network/dialer.go +++ b/common/network/dialer.go @@ -13,10 +13,6 @@ type Dialer interface { ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) } -type PayloadDialer interface { - DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payload [][]byte) (net.Conn, error) -} - type ParallelDialer interface { Dialer DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) diff --git a/common/network/handshake.go b/common/network/handshake.go index 5f13492..d2203e0 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -36,7 +36,13 @@ func ReportHandshakeFailure(reporter any, err error) error { func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error { if err != nil { if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { - err = E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { + hErr := handshakeConn.HandshakeFailure(err) + err = E.Append(err, hErr, func(err error) error { + if closer, isCloser := reporter.(io.Closer); isCloser { + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close") + }) + } return E.Cause(err, "write handshake failure") }) } else { diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go index 3468981..f98ac3d 100644 --- a/protocol/socks/lazy.go +++ b/protocol/socks/lazy.go @@ -2,6 +2,7 @@ package socks import ( "net" + "os" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -48,7 +49,7 @@ func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error { func (c *LazyConn) HandshakeFailure(err error) error { if c.responseWritten { - return nil + return os.ErrInvalid } defer func() { c.responseWritten = true @@ -130,7 +131,7 @@ func (c *LazyAssociatePacketConn) HandshakeSuccess() error { func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error { if c.responseWritten { - return nil + return os.ErrInvalid } defer func() { c.responseWritten = true From fdca9b3f8e355dcf0caa27a6207109406e7be4bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 16 Nov 2024 12:16:33 +0800 Subject: [PATCH 18/55] badjson: Fix Listable --- common/json/badjson/merge.go | 2 +- common/json/badoption/listable.go | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/common/json/badjson/merge.go b/common/json/badjson/merge.go index ac1f12f..30bc74d 100644 --- a/common/json/badjson/merge.go +++ b/common/json/badjson/merge.go @@ -11,7 +11,7 @@ import ( ) func Omitempty[T any](ctx context.Context, value T) (T, error) { - objectContent, err := json.Marshal(value) + objectContent, err := json.MarshalContext(ctx, value) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal object") } diff --git a/common/json/badoption/listable.go b/common/json/badoption/listable.go index 60a044c..df02217 100644 --- a/common/json/badoption/listable.go +++ b/common/json/badoption/listable.go @@ -9,24 +9,24 @@ import ( type Listable[T any] []T -func (l Listable[T]) MarshalJSON() ([]byte, error) { +func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) { arrayList := []T(l) if len(arrayList) == 1 { return json.Marshal(arrayList[0]) } - return json.Marshal(arrayList) + return json.MarshalContext(ctx, arrayList) } func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error { - err := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l)) + var singleItem T + err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem) if err == nil { + *l = []T{singleItem} return nil } - var singleItem T - newError := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem) - if newError != nil { - return E.Errors(err, newError) + newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l)) + if newErr == nil { + return nil } - *l = []T{singleItem} - return nil + return E.Errors(err, newErr) } From 30fbafd9546cd816eed4c39f399afe86a6b2d2c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 18 Nov 2024 12:14:35 +0800 Subject: [PATCH 19/55] udpnat2: Add cache funcs --- common/udpnat2/service.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index e06cebc..4bbef75 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -98,6 +98,20 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati } } +func (s *Service) Purge() { + s.cache.Purge() +} + +func (s *Service) PurgeExpired() { + s.cache.PurgeExpired() +} + func (s *Service) Metrics() Metrics { return s.metrics } + +func (s *Service) ResetMetrics() Metrics { + metrics := s.metrics + s.metrics = Metrics{} + return metrics +} From fa5355e99ec3fc5de6974acba13fbf92eff3d38f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 20 Nov 2024 11:27:20 +0800 Subject: [PATCH 20/55] bufio: more copy funcs --- common/bufio/copy.go | 146 +++++++++++++++----------------------- common/network/direct.go | 8 +++ common/udpnat2/conn.go | 38 ++++++---- common/udpnat2/service.go | 14 ++-- 4 files changed, 97 insertions(+), 109 deletions(-) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index ebb03fe..506eab1 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -30,27 +30,38 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { if !cachedBuffer.IsEmpty() { + dataLen := cachedBuffer.Len() + for _, counter := range readCounters { + counter(int64(dataLen)) + } _, err = destination.Write(cachedBuffer.Bytes()) if err != nil { cachedBuffer.Release() return } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } } cachedBuffer.Release() continue } } - srcSyscallConn, srcIsSyscall := source.(syscall.Conn) - dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) - if srcIsSyscall && dstIsSyscall { - var handled bool - handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) - if handled { - return - } - } break } + return CopyWithCounters(destination, source, originSource, readCounters, writeCounters) +} + +func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + srcSyscallConn, srcIsSyscall := source.(syscall.Conn) + dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) + if srcIsSyscall && dstIsSyscall { + var handled bool + handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) + if handled { + return + } + } return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) } @@ -75,6 +86,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) } +// Deprecated: not used func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() @@ -113,19 +125,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so } func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(destination) - rearHeadroom := N.CalculateRearHeadroom(destination) - bufferSize := N.CalculateMTU(source, destination) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } + options := N.NewReadWaitOptions(source, destination) var notFirstTime bool for { - buffer := buf.NewSize(bufferSize) - buffer.Resize(frontHeadroom, 0) - buffer.Reserve(rearHeadroom) + buffer := options.NewBuffer() err = source.ReadBuffer(buffer) if err != nil { buffer.Release() @@ -136,7 +139,10 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } dataLen := buffer.Len() - buffer.OverCap(rearHeadroom) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + options.PostReturn(buffer) err = destination.WriteBuffer(buffer) if err != nil { buffer.Leak() @@ -146,9 +152,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } for _, counter := range writeCounters { counter(int64(dataLen)) } @@ -196,18 +199,6 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error return group.Run(ctx) } -// Deprecated: not used -func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { - switch len(contextList) { - case 0: - return CopyConn(context.Background(), source, destination) - case 1: - return CopyConn(contextList[0], source, destination) - default: - panic("invalid context list") - } -} - func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { var readCounters, writeCounters []N.CountFunc var cachedPackets []*N.PacketBuffer @@ -225,24 +216,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, break } if cachedPackets != nil { - n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) + n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters) if err != nil { return } } - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) + copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters) + n += copeN + return +} + +func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var ( handled bool copeN int64 ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { - needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ - FrontHeadroom: frontHeadroom, - RearHeadroom: rearHeadroom, - MTU: N.CalculateMTU(source, destinationConn), - }) + needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn)) if !needCopy || common.LowMemory { handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) if handled { @@ -256,28 +247,22 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } -func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) - bufferSize := N.CalculateMTU(source, destinationConn) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.UDPBufferSize - } - var destination M.Socksaddr +func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { + options := N.NewReadWaitOptions(source, destination) + var destinationAddress M.Socksaddr for { - buffer := buf.NewSize(bufferSize) - buffer.Resize(frontHeadroom, 0) - buffer.Reserve(rearHeadroom) - destination, err = source.ReadPacket(buffer) + buffer := options.NewPacketBuffer() + destinationAddress, err = source.ReadPacket(buffer) if err != nil { buffer.Release() return } dataLen := buffer.Len() - buffer.OverCap(rearHeadroom) - err = destinationConn.WritePacket(buffer, destination) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + options.PostReturn(buffer) + err = destination.WritePacket(buffer, destinationAddress) if err != nil { buffer.Leak() if !notFirstTime { @@ -285,34 +270,25 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri } return } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } for _, counter := range writeCounters { counter(int64(dataLen)) } + n += int64(dataLen) notFirstTime = true } } -func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(destinationConn) - rearHeadroom := N.CalculateRearHeadroom(destinationConn) +func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + options := N.NewReadWaitOptions(nil, destination) var notFirstTime bool for _, packetBuffer := range packetBuffers { - buffer := buf.NewPacket() - buffer.Resize(frontHeadroom, 0) - buffer.Reserve(rearHeadroom) - _, err = buffer.Write(packetBuffer.Buffer.Bytes()) - packetBuffer.Buffer.Release() - if err != nil { - buffer.Release() - continue + for _, counter := range readCounters { + counter(int64(packetBuffer.Buffer.Len())) } + buffer := options.Copy(packetBuffer.Buffer) dataLen := buffer.Len() - buffer.OverCap(rearHeadroom) - err = destinationConn.WritePacket(buffer, packetBuffer.Destination) + err = destination.WritePacket(buffer, packetBuffer.Destination) + N.PutPacketBuffer(packetBuffer) if err != nil { buffer.Leak() if !notFirstTime { @@ -320,7 +296,11 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr } return } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } n += int64(dataLen) + notFirstTime = true } return } @@ -339,15 +319,3 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack group.FastFail() return group.Run(ctx) } - -// Deprecated: not used -func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { - switch len(contextList) { - case 0: - return CopyPacketConn(context.Background(), source, destination) - case 1: - return CopyPacketConn(contextList[0], source, destination) - default: - panic("invalid context list") - } -} diff --git a/common/network/direct.go b/common/network/direct.go index 1122d70..f587cd6 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -15,6 +15,14 @@ type ReadWaitOptions struct { MTU int } +func NewReadWaitOptions(source any, destination any) ReadWaitOptions { + return ReadWaitOptions{ + FrontHeadroom: CalculateFrontHeadroom(destination), + RearHeadroom: CalculateRearHeadroom(destination), + MTU: CalculateMTU(source, destination), + } +} + func (o ReadWaitOptions) NeedHeadroom() bool { return o.FrontHeadroom > 0 || o.RearHeadroom > 0 } diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 5d474e6..9d5bfa9 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -12,7 +12,14 @@ import ( "github.com/sagernet/sing/common/pipe" ) -type Conn struct { +type Conn interface { + N.PacketConn + SetHandler(handler N.UDPHandlerEx) +} + +var _ Conn = (*natConn)(nil) + +type natConn struct { writer N.PacketWriter localAddr M.Socksaddr handler N.UDPHandlerEx @@ -22,7 +29,7 @@ type Conn struct { readWaitOptions N.ReadWaitOptions } -func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { +func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { select { case p := <-c.packetChan: _, err = buffer.ReadOnceFrom(p.Buffer) @@ -37,12 +44,17 @@ func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } } -func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { +func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.writer.WritePacket(buffer, destination) } -func (c *Conn) SetHandler(handler N.UDPHandlerEx) { +func (c *natConn) SetHandler(handler N.UDPHandlerEx) { + select { + case <-c.doneChan: + default: + } c.handler = handler + c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) fetch: for { select { @@ -56,12 +68,12 @@ fetch: } } -func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { +func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { c.readWaitOptions = options return false } -func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { +func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { select { case packet := <-c.packetChan: buffer = c.readWaitOptions.Copy(packet.Buffer) @@ -75,7 +87,7 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er } } -func (c *Conn) Close() error { +func (c *natConn) Close() error { select { case <-c.doneChan: default: @@ -84,27 +96,27 @@ func (c *Conn) Close() error { return nil } -func (c *Conn) LocalAddr() net.Addr { +func (c *natConn) LocalAddr() net.Addr { return c.localAddr } -func (c *Conn) RemoteAddr() net.Addr { +func (c *natConn) RemoteAddr() net.Addr { return M.Socksaddr{} } -func (c *Conn) SetDeadline(t time.Time) error { +func (c *natConn) SetDeadline(t time.Time) error { return os.ErrInvalid } -func (c *Conn) SetReadDeadline(t time.Time) error { +func (c *natConn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) return nil } -func (c *Conn) SetWriteDeadline(t time.Time) error { +func (c *natConn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } -func (c *Conn) Upstream() any { +func (c *natConn) Upstream() any { return c.writer } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 4bbef75..ac5da1d 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -14,7 +14,7 @@ import ( ) type Service struct { - cache freelru.Cache[netip.AddrPort, *Conn] + cache freelru.Cache[netip.AddrPort, *natConn] handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics @@ -30,15 +30,15 @@ type Metrics struct { } func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { - var cache freelru.Cache[netip.AddrPort, *Conn] + var cache freelru.Cache[netip.AddrPort, *natConn] if !shared { - cache = common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } else { - cache = common.Must1(freelru.NewSharded[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } cache.SetLifetime(timeout) cache.SetUpdateLifetimeOnGet(true) - cache.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { + cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { select { case <-conn.doneChan: return false @@ -46,7 +46,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur return true } }) - cache.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { + cache.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { conn.Close() }) return &Service{ @@ -64,7 +64,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati s.metrics.Rejects++ return } - conn = &Conn{ + conn = &natConn{ writer: writer, localAddr: source, packetChan: make(chan *N.PacketBuffer, 64), From c8f251c66880057da79cf9e78005cc753ddfad2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 24 Nov 2024 13:57:11 +0800 Subject: [PATCH 21/55] Fix copy count --- common/bufio/copy.go | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 506eab1..a21fbc4 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -29,21 +29,18 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { if cachedSrc, isCached := source.(N.CachedReader); isCached { cachedBuffer := cachedSrc.ReadCached() if cachedBuffer != nil { - if !cachedBuffer.IsEmpty() { - dataLen := cachedBuffer.Len() - for _, counter := range readCounters { - counter(int64(dataLen)) - } - _, err = destination.Write(cachedBuffer.Bytes()) - if err != nil { - cachedBuffer.Release() - return - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - } + dataLen := cachedBuffer.Len() + _, err = destination.Write(cachedBuffer.Bytes()) cachedBuffer.Release() + if err != nil { + return + } + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } continue } } @@ -139,9 +136,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } dataLen := buffer.Len() - for _, counter := range readCounters { - counter(int64(dataLen)) - } options.PostReturn(buffer) err = destination.WriteBuffer(buffer) if err != nil { @@ -152,6 +146,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, return } n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } for _, counter := range writeCounters { counter(int64(dataLen)) } @@ -258,9 +255,6 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, return } dataLen := buffer.Len() - for _, counter := range readCounters { - counter(int64(dataLen)) - } options.PostReturn(buffer) err = destination.WritePacket(buffer, destinationAddress) if err != nil { @@ -270,6 +264,9 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, } return } + for _, counter := range readCounters { + counter(int64(dataLen)) + } for _, counter := range writeCounters { counter(int64(dataLen)) } @@ -282,9 +279,6 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter options := N.NewReadWaitOptions(nil, destination) var notFirstTime bool for _, packetBuffer := range packetBuffers { - for _, counter := range readCounters { - counter(int64(packetBuffer.Buffer.Len())) - } buffer := options.Copy(packetBuffer.Buffer) dataLen := buffer.Len() err = destination.WritePacket(buffer, packetBuffer.Destination) @@ -296,6 +290,9 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter } return } + for _, counter := range readCounters { + counter(int64(dataLen)) + } for _, counter := range writeCounters { counter(int64(dataLen)) } From 3613ead48056b744f174a8c1770f8dc6a52a0b76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 26 Nov 2024 11:29:14 +0800 Subject: [PATCH 22/55] freelru: Add `PeekWithLifetime` and `UpdateLifetime` --- contrab/freelru/cache.go | 4 +++ contrab/freelru/lru.go | 60 ++++++++++++++++++++++++++++++------ contrab/freelru/lru_test.go | 44 ++++++++++++++++++++++++-- contrab/freelru/sharedlru.go | 32 ++++++++++++------- 4 files changed, 117 insertions(+), 23 deletions(-) diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 707e5bc..338e6dc 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -55,6 +55,10 @@ type Cache[K comparable, V any] interface { // If the found entry is already expired, the evict function is called. Peek(key K) (V, bool) + PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) + + UpdateLifetime(key K, value V, lifetime time.Duration) bool + // Contains checks for the existence of a key, without changing its recent-ness. // If the found entry is already expired, the evict function is called. Contains(key K) bool diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 7672460..b3e6e29 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -62,7 +62,7 @@ type element[K comparable, V any] struct { const emptyBucket = math.MaxUint32 // LRU implements a non-thread safe fixed size LRU cache. -type LRU[K comparable, V any] struct { +type LRU[K comparable, V comparable] struct { buckets []uint32 // contains positions of bucket lists or 'emptyBucket' elements []element[K, V] onEvict OnEvictCallback[K, V] @@ -122,7 +122,7 @@ func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { // New constructs an LRU with the given capacity of elements. // The hash function calculates a hash value from the keys. -func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { +func New[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { return NewWithSize[K, V](capacity, capacity, hash) } @@ -131,7 +131,7 @@ func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, // A size greater than the capacity increases memory consumption and decreases the CPU consumption // by reducing the chance of collisions. // Size must not be lower than the capacity. -func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) ( +func NewWithSize[K comparable, V comparable](capacity, size uint32, hash HashKeyCallback[K]) ( *LRU[K, V], error, ) { if capacity == 0 { @@ -156,7 +156,7 @@ func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallbac return &lru, nil } -func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], +func initLRU[K comparable, V comparable](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], buckets []uint32, elements []element[K, V], ) { lru.cap = capacity @@ -471,26 +471,66 @@ func (lru *LRU[K, V]) get(hash uint32, key K) (value V, expire int64, ok bool) { // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { - return lru.peek(lru.hash(key), key) + value, _, ok = lru.peek(lru.hash(key), key) + return } -func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { - if pos, _, ok := lru.findKey(hash, key, false); ok { - return lru.elements[pos].value, ok +func (lru *LRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + value, expireMills, ok := lru.peek(lru.hash(key), key) + lifetime = time.UnixMilli(expireMills) + return +} + +func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, expire int64, ok bool) { + if pos, expireMills, ok := lru.findKey(hash, key, false); ok { + return lru.elements[pos].value, expireMills, ok } return } +func (lru *LRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) bool { + return lru.updateLifetime(lru.hash(key), key, value, lifetime) +} + +func (lru *LRU[K, V]) updateLifetime(hash uint32, key K, value V, lifetime time.Duration) bool { + _, startPos := lru.hashToPos(hash) + if startPos == emptyBucket { + return false + } + pos := startPos + for { + if lru.elements[pos].key == key { + if lru.elements[pos].value != value { + return false + } + + lru.elements[pos].expire = expire(lifetime) + + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Inserts++ + return true + } + + pos = lru.elements[pos].nextBucket + if pos == startPos { + return false + } + } +} + // Contains checks for the existence of a key, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *LRU[K, V]) Contains(key K) (ok bool) { - _, ok = lru.peek(lru.hash(key), key) + _, _, ok = lru.peek(lru.hash(key), key) return } func (lru *LRU[K, V]) contains(hash uint32, key K) (ok bool) { - _, ok = lru.peek(hash, key) + _, _, ok = lru.peek(hash, key) return } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index fa3c157..36e9a05 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMyChange0(t *testing.T) { +func TestUpdateLifetimeOnGet(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) @@ -24,7 +24,7 @@ func TestMyChange0(t *testing.T) { require.True(t, ok) } -func TestMyChange1(t *testing.T) { +func TestUpdateLifetimeOnGet1(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) @@ -36,3 +36,43 @@ func TestMyChange1(t *testing.T) { _, ok := lru.Get("hello") require.False(t, ok) } + +func TestUpdateLifetime(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.Add("hello", "world") + require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second)) + time.Sleep(time.Second) + _, ok := lru.Get("hello") + require.True(t, ok) + time.Sleep(time.Second + time.Millisecond*100) + _, ok = lru.Get("hello") + require.False(t, ok) +} + +func TestUpdateLifetime1(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.Add("hello", "world") + require.False(t, lru.UpdateLifetime("hello", "not world", 2*time.Second)) + time.Sleep(2*time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.True(t, ok) +} + +func TestUpdateLifetime2(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.AddWithLifetime("hello", "world", 2*time.Second) + time.Sleep(time.Second) + require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second)) + time.Sleep(time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.True(t, ok) + time.Sleep(time.Second + time.Millisecond*100) + _, ok = lru.Get("hello") + require.False(t, ok) +} diff --git a/contrab/freelru/sharedlru.go b/contrab/freelru/sharedlru.go index db1d8cd..1b43dc2 100644 --- a/contrab/freelru/sharedlru.go +++ b/contrab/freelru/sharedlru.go @@ -12,7 +12,7 @@ import ( // ShardedLRU is a thread-safe, sharded, fixed size LRU cache. // Sharding is used to reduce lock contention on high concurrency. // The downside is that exact LRU behavior is not given (as for the LRU and SynchedLRU types). -type ShardedLRU[K comparable, V any] struct { +type ShardedLRU[K comparable, V comparable] struct { lrus []LRU[K, V] mus []sync.RWMutex hash HashKeyCallback[K] @@ -66,7 +66,7 @@ func nextPowerOfTwo(val uint32) uint32 { } // NewSharded creates a new thread-safe sharded LRU hashmap with the given capacity. -func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V], +func NewSharded[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V], error, ) { size := uint32(float64(capacity) * 1.25) // 25% extra space for fewer collisions @@ -74,7 +74,7 @@ func NewSharded[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) ( return NewShardedWithSize[K, V](uint32(runtime.GOMAXPROCS(0)*16), capacity, size, hash) } -func NewShardedWithSize[K comparable, V any](shards, capacity, size uint32, +func NewShardedWithSize[K comparable, V comparable](shards, capacity, size uint32, hash HashKeyCallback[K]) ( *ShardedLRU[K, V], error, ) { @@ -174,13 +174,7 @@ func (lru *ShardedLRU[K, V]) Add(key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) { - hash := lru.hash(key) - shard := (hash >> 16) & lru.mask - - lru.mus[shard].Lock() - value, _, ok = lru.lrus[shard].get(hash, key) - lru.mus[shard].Unlock() - + value, _, ok = lru.GetWithLifetime(key) return } @@ -198,11 +192,27 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) { + value, _, ok = lru.PeekWithLifetime(key) + return +} + +func (lru *ShardedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { hash := lru.hash(key) shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, ok = lru.lrus[shard].peek(hash, key) + value, expireMills, ok := lru.lrus[shard].peek(hash, key) + lru.mus[shard].Unlock() + lifetime = time.UnixMilli(expireMills) + return +} + +func (lru *ShardedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + ok = lru.lrus[shard].updateLifetime(hash, key, value, lifetime) lru.mus[shard].Unlock() return From a8285e06a59a8c9a70316000c2f57590efb7528e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 26 Nov 2024 11:30:38 +0800 Subject: [PATCH 23/55] udpnat2: Implement set timeout for nat conn --- common/canceler/instance.go | 4 +-- common/canceler/packet.go | 14 ++++---- common/canceler/packet_timeout.go | 4 +-- common/udpnat2/conn.go | 53 ++++++++++++++++++++----------- common/udpnat2/service.go | 1 + 5 files changed, 47 insertions(+), 29 deletions(-) diff --git a/common/canceler/instance.go b/common/canceler/instance.go index 05faa91..c47270d 100644 --- a/common/canceler/instance.go +++ b/common/canceler/instance.go @@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration { return i.timeout } -func (i *Instance) SetTimeout(timeout time.Duration) { +func (i *Instance) SetTimeout(timeout time.Duration) bool { i.timeout = timeout - i.Update() + return i.Update() } func (i *Instance) wait() { diff --git a/common/canceler/packet.go b/common/canceler/packet.go index fb4ad84..46cf9a0 100644 --- a/common/canceler/packet.go +++ b/common/canceler/packet.go @@ -13,7 +13,7 @@ import ( type PacketConn interface { N.PacketConn Timeout() time.Duration - SetTimeout(timeout time.Duration) + SetTimeout(timeout time.Duration) bool } type TimerPacketConn struct { @@ -24,10 +24,12 @@ type TimerPacketConn struct { func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) { if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { oldTimeout := timeoutConn.Timeout() - if timeout < oldTimeout { - timeoutConn.SetTimeout(timeout) + if timeout >= oldTimeout { + return ctx, conn + } + if timeoutConn.SetTimeout(timeout) { + return ctx, conn } - return ctx, conn } err := conn.SetReadDeadline(time.Time{}) if err == nil { @@ -58,8 +60,8 @@ func (c *TimerPacketConn) Timeout() time.Duration { return c.instance.Timeout() } -func (c *TimerPacketConn) SetTimeout(timeout time.Duration) { - c.instance.SetTimeout(timeout) +func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool { + return c.instance.SetTimeout(timeout) } func (c *TimerPacketConn) Close() error { diff --git a/common/canceler/packet_timeout.go b/common/canceler/packet_timeout.go index ab5c760..a679567 100644 --- a/common/canceler/packet_timeout.go +++ b/common/canceler/packet_timeout.go @@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration { return c.timeout } -func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) { +func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool { c.timeout = timeout - c.PacketConn.SetReadDeadline(time.Now()) + return c.PacketConn.SetReadDeadline(time.Now()) == nil } func (c *TimeoutPacketConn) Close() error { diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 9d5bfa9..8ae4557 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -4,9 +4,11 @@ import ( "io" "net" "os" + "sync" "time" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/canceler" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/pipe" @@ -15,15 +17,18 @@ import ( type Conn interface { N.PacketConn SetHandler(handler N.UDPHandlerEx) + canceler.PacketConn } var _ Conn = (*natConn)(nil) type natConn struct { + service *Service writer N.PacketWriter localAddr M.Socksaddr handler N.UDPHandlerEx packetChan chan *N.PacketBuffer + closeOnce sync.Once doneChan chan struct{} readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions @@ -48,6 +53,25 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error return c.writer.WritePacket(buffer, destination) } +func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case packet := <-c.packetChan: + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + N.PutPacketBuffer(packet) + return + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + case <-c.readDeadline.Wait(): + return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + func (c *natConn) SetHandler(handler N.UDPHandlerEx) { select { case <-c.doneChan: @@ -68,31 +92,22 @@ fetch: } } -func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { - c.readWaitOptions = options - return false +func (c *natConn) Timeout() time.Duration { + rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort()) + if !loaded || rawConn != c { + return 0 + } + return time.Until(lifetime) } -func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case packet := <-c.packetChan: - buffer = c.readWaitOptions.Copy(packet.Buffer) - destination = packet.Destination - N.PutPacketBuffer(packet) - return - case <-c.doneChan: - return nil, M.Socksaddr{}, io.ErrClosedPipe - case <-c.readDeadline.Wait(): - return nil, M.Socksaddr{}, os.ErrDeadlineExceeded - } +func (c *natConn) SetTimeout(timeout time.Duration) bool { + return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout) } func (c *natConn) Close() error { - select { - case <-c.doneChan: - default: + c.closeOnce.Do(func() { close(c.doneChan) - } + }) return nil } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index ac5da1d..f5485af 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -65,6 +65,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati return } conn = &natConn{ + service: s, writer: writer, localAddr: source, packetChan: make(chan *N.PacketBuffer, 64), From 7fd3517e4d4266c75201e0160255b4ba58f765c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 26 Nov 2024 11:30:51 +0800 Subject: [PATCH 24/55] udpnat2: Add purge expire ticker --- common/udpnat2/service.go | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index f5485af..492a04e 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -3,6 +3,7 @@ package udpnat import ( "context" "net/netip" + "sync" "time" "github.com/sagernet/sing/common" @@ -18,6 +19,10 @@ type Service struct { handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics + + timeout time.Duration + closeOnce sync.Once + doneChan chan struct{} } type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) @@ -50,12 +55,38 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur conn.Close() }) return &Service{ - cache: cache, - handler: handler, - prepare: prepare, + cache: cache, + handler: handler, + prepare: prepare, + timeout: timeout, + doneChan: make(chan struct{}), } } +func (s *Service) Start() error { + ticker := time.NewTicker(s.timeout) + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.PurgeExpired() + case <-s.doneChan: + s.Purge() + return + } + } + }() + return nil +} + +func (s *Service) Close() error { + s.closeOnce.Do(func() { + close(s.doneChan) + }) + return nil +} + func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { conn, loaded := s.cache.Get(source.AddrPort()) if !loaded { From 30e9d91b57d9b752fff581791bd28c8361ac072c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 26 Nov 2024 12:00:56 +0800 Subject: [PATCH 25/55] Fix AppendClose --- common/network/conn.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/network/conn.go b/common/network/conn.go index c289bf6..ab3961f 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -74,9 +74,10 @@ type ExtendedConn interface { type CloseHandlerFunc = func(it error) func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc { + if onClose == nil { + panic("nil onClose") + } if parent == nil { - return parent - } else if onClose == nil { return onClose } return func(it error) { From a8f5bf4eb026563845e2d0886b4dcef5af534dd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 26 Nov 2024 19:08:35 +0800 Subject: [PATCH 26/55] udpnat2: Add timeout check --- common/udpnat2/service.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 492a04e..b04d233 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -35,6 +35,9 @@ type Metrics struct { } func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { + if timeout == 0 { + panic("invalid timeout") + } var cache freelru.Cache[netip.AddrPort, *natConn] if !shared { cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) @@ -80,11 +83,10 @@ func (s *Service) Start() error { return nil } -func (s *Service) Close() error { +func (s *Service) Close() { s.closeOnce.Do(func() { close(s.doneChan) }) - return nil } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { From c44912a8610463cc8e4a0be8148e6e11a62456d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 27 Nov 2024 13:51:08 +0800 Subject: [PATCH 27/55] freelru: Fix purge --- contrab/freelru/lru.go | 6 ++++-- contrab/freelru/lru_test.go | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index b3e6e29..8096015 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -592,7 +592,8 @@ func (lru *LRU[K, V]) Keys() []K { // The evict function is called for each expired item. // The LRU metrics are reset. func (lru *LRU[K, V]) Purge() { - for i := uint32(0); i < lru.len; i++ { + lruLen := lru.len + for i := uint32(0); i < lruLen; i++ { _, _, _ = lru.RemoveOldest() } @@ -602,7 +603,8 @@ func (lru *LRU[K, V]) Purge() { // PurgeExpired purges all expired items from the LRU. // The evict function is called for each expired item. func (lru *LRU[K, V]) PurgeExpired() { - for i := uint32(0); i < lru.len; i++ { + lruLen := lru.len + for i := uint32(0); i < lruLen; i++ { pos := lru.elements[lru.head].next if lru.elements[pos].expire != 0 { if lru.elements[pos].expire > now() { diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index 36e9a05..2650802 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -76,3 +76,20 @@ func TestUpdateLifetime2(t *testing.T) { _, ok = lru.Get("hello") require.False(t, ok) } + +func TestPeekWithLifetime(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.SetLifetime(time.Second) + lru.Add("1", "") + time.Sleep(300 * time.Millisecond) + lru.Add("2", "") + time.Sleep(300 * time.Millisecond) + lru.Add("3", "") + time.Sleep(300 * time.Millisecond) + lru.Add("4", "") + time.Sleep(time.Second) + lru.PurgeExpired() + require.Equal(t, 0, lru.Len()) +} From 4ba1eb123cdb76433620ae2c3e4840a9899499a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 27 Nov 2024 17:28:18 +0800 Subject: [PATCH 28/55] Fix set timeout --- common/canceler/packet.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/canceler/packet.go b/common/canceler/packet.go index 46cf9a0..519f283 100644 --- a/common/canceler/packet.go +++ b/common/canceler/packet.go @@ -24,7 +24,7 @@ type TimerPacketConn struct { func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) { if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn { oldTimeout := timeoutConn.Timeout() - if timeout >= oldTimeout { + if oldTimeout > 0 && timeout >= oldTimeout { return ctx, conn } if timeoutConn.SetTimeout(timeout) { From 0a2e2a3eaf749cb202e80ac2b9c8505014f5ee50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 27 Nov 2024 18:02:22 +0800 Subject: [PATCH 29/55] udpnat2: Fix timeout --- common/udpnat2/service.go | 37 ++++--------------------------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index b04d233..5e52930 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -3,7 +3,6 @@ package udpnat import ( "context" "net/netip" - "sync" "time" "github.com/sagernet/sing/common" @@ -19,10 +18,6 @@ type Service struct { handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics - - timeout time.Duration - closeOnce sync.Once - doneChan chan struct{} } type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) @@ -58,37 +53,12 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur conn.Close() }) return &Service{ - cache: cache, - handler: handler, - prepare: prepare, - timeout: timeout, - doneChan: make(chan struct{}), + cache: cache, + handler: handler, + prepare: prepare, } } -func (s *Service) Start() error { - ticker := time.NewTicker(s.timeout) - go func() { - defer ticker.Stop() - for { - select { - case <-ticker.C: - s.PurgeExpired() - case <-s.doneChan: - s.Purge() - return - } - } - }() - return nil -} - -func (s *Service) Close() { - s.closeOnce.Do(func() { - close(s.doneChan) - }) -} - func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { conn, loaded := s.cache.Get(source.AddrPort()) if !loaded { @@ -105,6 +75,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } + s.PurgeExpired() s.cache.Add(source.AddrPort(), conn) go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) s.metrics.Creates++ From 6edd2ce0ea478cb757edb9a060427635d5ed6d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 28 Nov 2024 13:14:16 +0800 Subject: [PATCH 30/55] freelru: Update source and add GetAndRefreshOrAdd --- contrab/freelru/LICENSE | 201 ++++++++++++++ contrab/freelru/NOTICE | 2 + contrab/freelru/README.md | 3 +- contrab/freelru/cache.go | 19 +- contrab/freelru/lru.go | 186 +++++++++---- .../freelru/{sharedlru.go => shardedlru.go} | 57 +++- contrab/freelru/syncedlru.go | 257 ++++++++++++++++++ 7 files changed, 652 insertions(+), 73 deletions(-) create mode 100644 contrab/freelru/LICENSE create mode 100644 contrab/freelru/NOTICE rename contrab/freelru/{sharedlru.go => shardedlru.go} (88%) create mode 100644 contrab/freelru/syncedlru.go diff --git a/contrab/freelru/LICENSE b/contrab/freelru/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/contrab/freelru/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrab/freelru/NOTICE b/contrab/freelru/NOTICE new file mode 100644 index 0000000..9b1c877 --- /dev/null +++ b/contrab/freelru/NOTICE @@ -0,0 +1,2 @@ +Go LRU Hashmap +Copyright 2022 Elasticsearch B.V. diff --git a/contrab/freelru/README.md b/contrab/freelru/README.md index 206fbba..72bec81 100644 --- a/contrab/freelru/README.md +++ b/contrab/freelru/README.md @@ -1,3 +1,4 @@ # freelru -kanged from github.com/elastic/go-freelru@v0.14.0 \ No newline at end of file +upstream: github.com/elastic/go-freelru@v0.16.0 +source: github.com/sagernet/go-freelru@1b34934a560d528d1866415d440625ed2a2560fe diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 338e6dc..319e662 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -19,19 +19,17 @@ package freelru import "time" -type Cache[K comparable, V any] interface { +type Cache[K comparable, V comparable] interface { // SetLifetime sets the default lifetime of LRU elements. // Lifetime 0 means "forever". SetLifetime(lifetime time.Duration) - SetUpdateLifetimeOnGet(update bool) - - SetHealthCheck(healthCheck HealthCheckCallback[K, V]) - // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. SetOnEvict(onEvict OnEvictCallback[K, V]) + SetHealthCheck(healthCheck HealthCheckCallback[K, V]) + // Len returns the number of elements stored in the cache. Len() int @@ -49,13 +47,20 @@ type Cache[K comparable, V any] interface { // and the return value indicates that the key was not found. Get(key K) (V, bool) - GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) + GetWithLifetime(key K) (V, time.Time, bool) + + // GetAndRefresh returns the value associated with the key, setting it as the most + // recently used item. + // The lifetime of the found cache item is refreshed, even if it was already expired. + GetAndRefresh(key K) (V, bool) + + GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool) // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. Peek(key K) (V, bool) - PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) + PeekWithLifetime(key K) (V, time.Time, bool) UpdateLifetime(key K, value V, lifetime time.Duration) bool diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 8096015..7f732ea 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -26,14 +26,14 @@ import ( ) // OnEvictCallback is the type for the eviction function. -type OnEvictCallback[K comparable, V any] func(K, V) +type OnEvictCallback[K comparable, V comparable] func(K, V) // HashKeyCallback is the function that creates a hash from the passed key. type HashKeyCallback[K comparable] func(K) uint32 -type HealthCheckCallback[K comparable, V any] func(K, V) bool +type HealthCheckCallback[K comparable, V comparable] func(K, V) bool -type element[K comparable, V any] struct { +type element[K comparable, V comparable] struct { key K value V @@ -63,14 +63,13 @@ const emptyBucket = math.MaxUint32 // LRU implements a non-thread safe fixed size LRU cache. type LRU[K comparable, V comparable] struct { - buckets []uint32 // contains positions of bucket lists or 'emptyBucket' - elements []element[K, V] - onEvict OnEvictCallback[K, V] - hash HashKeyCallback[K] - healthCheck HealthCheckCallback[K, V] - lifetime time.Duration - updateLifetimeOnGet bool - metrics Metrics + buckets []uint32 // contains positions of bucket lists or 'emptyBucket' + elements []element[K, V] + onEvict OnEvictCallback[K, V] + hash HashKeyCallback[K] + healthCheck HealthCheckCallback[K, V] + lifetime time.Duration + metrics Metrics // used for element clearing after removal or expiration emptyKey K @@ -101,10 +100,6 @@ func (lru *LRU[K, V]) SetLifetime(lifetime time.Duration) { lru.lifetime = lifetime } -func (lru *LRU[K, V]) SetUpdateLifetimeOnGet(update bool) { - lru.updateLifetimeOnGet = update -} - // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. // Eviction happens @@ -181,7 +176,13 @@ func (lru *LRU[K, V]) hashToBucketPos(hash uint32) uint32 { if lru.mask != 0 { return hash & lru.mask } - return hash % lru.size + return fastModulo(hash, lru.size) +} + +// fastModulo calculates x % n without using the modulo operator (~4x faster). +// Reference: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +func fastModulo(x, n uint32) uint32 { + return uint32((uint64(x) * uint64(n)) >> 32) //nolint:gosec } // hashToPos converts a key into a position in the elements array. @@ -308,30 +309,50 @@ func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) { lru.elements[pos].value = lru.emptyValue } -func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, int64, bool) { +func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { _, startPos := lru.hashToPos(hash) if startPos == emptyBucket { - return emptyBucket, 0, false + return emptyBucket, false } pos := startPos for { if key == lru.elements[pos].key { - elem := lru.elements[pos] - if (elem.expire != 0 && elem.expire <= now()) || (lru.healthCheck != nil && !lru.healthCheck(key, elem.value)) { + if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() || (lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value)) { lru.removeAt(pos) - return emptyBucket, elem.expire, false + return emptyBucket, false } - if updateLifetimeOnGet { - lru.elements[pos].expire = expire(lru.lifetime) - } - return pos, elem.expire, true + return pos, true } pos = lru.elements[pos].nextBucket if pos == startPos { // Key not found - return emptyBucket, 0, false + return emptyBucket, false + } + } +} + +func (lru *LRU[K, V]) findKeyNoExpire(hash uint32, key K) (uint32, bool) { + _, startPos := lru.hashToPos(hash) + if startPos == emptyBucket { + return emptyBucket, false + } + + pos := startPos + for { + if key == lru.elements[pos].key { + if lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value) { + lru.removeAt(pos) + return emptyBucket, false + } + return pos, true + } + + pos = lru.elements[pos].nextBucket + if pos == startPos { + // Key not found + return emptyBucket, false } } } @@ -444,46 +465,109 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { - value, _, ok = lru.get(lru.hash(key), key) - return + return lru.get(lru.hash(key), key) } -func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { - value, expireMills, ok := lru.get(lru.hash(key), key) - lifetime = time.UnixMilli(expireMills) - return -} - -func (lru *LRU[K, V]) get(hash uint32, key K) (value V, expire int64, ok bool) { - if pos, expire, ok := lru.findKey(hash, key, lru.updateLifetimeOnGet); ok { +func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { if pos != lru.head { lru.unlinkElement(pos) lru.setHead(pos) } lru.metrics.Hits++ - return lru.elements[pos].value, expire, ok + return lru.elements[pos].value, ok } lru.metrics.Misses++ return } +func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + return lru.getWithLifetime(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + lru.metrics.Misses++ + return +} + +// GetAndRefresh returns the value associated with the key, setting it as the most +// recently used item. +// The lifetime of the found cache item is refreshed, even if it was already expired. +func (lru *LRU[K, V]) GetAndRefresh(key K) (V, bool) { + return lru.getAndRefresh(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getAndRefresh(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + lru.elements[pos].expire = expire(lru.lifetime) + return lru.elements[pos].value, ok + } + + lru.metrics.Misses++ + return +} + +func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool) { + return lru.getAndRefreshOrAdd(lru.hash(key), key, constructor) +} + +func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + lru.elements[pos].expire = expire(lru.lifetime) + return lru.elements[pos].value, ok + } + + lru.metrics.Misses++ + value, ok = constructor() + if !ok { + return + } + lru.addWithLifetime(hash, key, value, lru.lifetime) + lru.PurgeExpired() + return value, false +} + // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { - value, _, ok = lru.peek(lru.hash(key), key) + return lru.peek(lru.hash(key), key) +} + +func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + return lru.elements[pos].value, ok + } + return } func (lru *LRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { - value, expireMills, ok := lru.peek(lru.hash(key), key) - lifetime = time.UnixMilli(expireMills) - return + return lru.peekWithLifetime(lru.hash(key), key) } -func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, expire int64, ok bool) { - if pos, expireMills, ok := lru.findKey(hash, key, false); ok { - return lru.elements[pos].value, expireMills, ok +func (lru *LRU[K, V]) peekWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKey(hash, key); ok { + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok } return @@ -525,12 +609,12 @@ func (lru *LRU[K, V]) updateLifetime(hash uint32, key K, value V, lifetime time. // Contains checks for the existence of a key, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *LRU[K, V]) Contains(key K) (ok bool) { - _, _, ok = lru.peek(lru.hash(key), key) + _, ok = lru.peek(lru.hash(key), key) return } func (lru *LRU[K, V]) contains(hash uint32, key K) (ok bool) { - _, _, ok = lru.peek(hash, key) + _, ok = lru.peek(hash, key) return } @@ -542,7 +626,7 @@ func (lru *LRU[K, V]) Remove(key K) (removed bool) { } func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { - if pos, _, ok := lru.findKey(hash, key, false); ok { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { lru.removeAt(pos) return ok } @@ -592,8 +676,8 @@ func (lru *LRU[K, V]) Keys() []K { // The evict function is called for each expired item. // The LRU metrics are reset. func (lru *LRU[K, V]) Purge() { - lruLen := lru.len - for i := uint32(0); i < lruLen; i++ { + l := lru.len + for i := uint32(0); i < l; i++ { _, _, _ = lru.RemoveOldest() } @@ -603,8 +687,8 @@ func (lru *LRU[K, V]) Purge() { // PurgeExpired purges all expired items from the LRU. // The evict function is called for each expired item. func (lru *LRU[K, V]) PurgeExpired() { - lruLen := lru.len - for i := uint32(0); i < lruLen; i++ { + l := lru.len + for i := uint32(0); i < l; i++ { pos := lru.elements[lru.head].next if lru.elements[pos].expire != 0 { if lru.elements[pos].expire > now() { diff --git a/contrab/freelru/sharedlru.go b/contrab/freelru/shardedlru.go similarity index 88% rename from contrab/freelru/sharedlru.go rename to contrab/freelru/shardedlru.go index 1b43dc2..e6aca65 100644 --- a/contrab/freelru/sharedlru.go +++ b/contrab/freelru/shardedlru.go @@ -32,14 +32,6 @@ func (lru *ShardedLRU[K, V]) SetLifetime(lifetime time.Duration) { } } -func (lru *ShardedLRU[K, V]) SetUpdateLifetimeOnGet(update bool) { - for shard := range lru.lrus { - lru.mus[shard].Lock() - lru.lrus[shard].SetUpdateLifetimeOnGet(update) - lru.mus[shard].Unlock() - } -} - // SetOnEvict sets the OnEvict callback function. // The onEvict function is called for each evicted lru entry. func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { @@ -174,7 +166,13 @@ func (lru *ShardedLRU[K, V]) Add(key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) { - value, _, ok = lru.GetWithLifetime(key) + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, ok = lru.lrus[shard].get(hash, key) + lru.mus[shard].Unlock() + return } @@ -183,16 +181,47 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, expireMills, ok := lru.lrus[shard].get(hash, key) + value, lifetime, ok = lru.lrus[shard].getWithLifetime(hash, key) lru.mus[shard].Unlock() - lifetime = time.UnixMilli(expireMills) + + return +} + +// GetAndRefresh returns the value associated with the key, setting it as the most +// recently used item. +// The lifetime of the found cache item is refreshed, even if it was already expired. +func (lru *ShardedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, ok = lru.lrus[shard].getAndRefresh(hash, key) + lru.mus[shard].Unlock() + + return +} + +func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, updated = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor) + lru.mus[shard].Unlock() + return } // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) { - value, _, ok = lru.PeekWithLifetime(key) + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].Lock() + value, ok = lru.lrus[shard].peek(hash, key) + lru.mus[shard].Unlock() + return } @@ -201,9 +230,9 @@ func (lru *ShardedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Tim shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, expireMills, ok := lru.lrus[shard].peek(hash, key) + value, lifetime, ok = lru.lrus[shard].peekWithLifetime(hash, key) lru.mus[shard].Unlock() - lifetime = time.UnixMilli(expireMills) + return } diff --git a/contrab/freelru/syncedlru.go b/contrab/freelru/syncedlru.go new file mode 100644 index 0000000..d42da70 --- /dev/null +++ b/contrab/freelru/syncedlru.go @@ -0,0 +1,257 @@ +package freelru + +import ( + "sync" + "time" +) + +type SyncedLRU[K comparable, V comparable] struct { + mu sync.RWMutex + lru *LRU[K, V] +} + +var _ Cache[int, int] = (*SyncedLRU[int, int])(nil) + +// SetLifetime sets the default lifetime of LRU elements. +// Lifetime 0 means "forever". +func (lru *SyncedLRU[K, V]) SetLifetime(lifetime time.Duration) { + lru.mu.Lock() + lru.lru.SetLifetime(lifetime) + lru.mu.Unlock() +} + +// SetOnEvict sets the OnEvict callback function. +// The onEvict function is called for each evicted lru entry. +func (lru *SyncedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { + lru.mu.Lock() + lru.lru.SetOnEvict(onEvict) + lru.mu.Unlock() +} + +func (lru *SyncedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { + lru.mu.Lock() + lru.lru.SetHealthCheck(healthCheck) + lru.mu.Unlock() +} + +// NewSynced creates a new thread-safe LRU hashmap with the given capacity. +func NewSynced[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*SyncedLRU[K, V], + error, +) { + return NewSyncedWithSize[K, V](capacity, capacity, hash) +} + +func NewSyncedWithSize[K comparable, V comparable](capacity, size uint32, + hash HashKeyCallback[K], +) (*SyncedLRU[K, V], error) { + lru, err := NewWithSize[K, V](capacity, size, hash) + if err != nil { + return nil, err + } + return &SyncedLRU[K, V]{lru: lru}, nil +} + +// Len returns the number of elements stored in the cache. +func (lru *SyncedLRU[K, V]) Len() (length int) { + lru.mu.RLock() + length = lru.lru.Len() + lru.mu.RUnlock() + + return +} + +// AddWithLifetime adds a key:value to the cache with a lifetime. +// Returns true, true if key was updated and eviction occurred. +func (lru *SyncedLRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + evicted = lru.lru.addWithLifetime(hash, key, value, lifetime) + lru.mu.Unlock() + + return +} + +// Add adds a key:value to the cache. +// Returns true, true if key was updated and eviction occurred. +func (lru *SyncedLRU[K, V]) Add(key K, value V) (evicted bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + evicted = lru.lru.add(hash, key, value) + lru.mu.Unlock() + + return +} + +// Get returns the value associated with the key, setting it as the most +// recently used item. +// If the found cache item is already expired, the evict function is called +// and the return value indicates that the key was not found. +func (lru *SyncedLRU[K, V]) Get(key K) (value V, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, ok = lru.lru.get(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.getWithLifetime(hash, key) + lru.mu.Unlock() + + return +} + +// GetAndRefresh returns the value associated with the key, setting it as the most +// recently used item. +// The lifetime of the found cache item is refreshed, even if it was already expired. +func (lru *SyncedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, ok = lru.lru.getAndRefresh(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, updated = lru.lru.getAndRefreshOrAdd(hash, key, constructor) + lru.mu.Unlock() + + return +} + +// Peek looks up a key's value from the cache, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *SyncedLRU[K, V]) Peek(key K) (value V, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, ok = lru.lru.peek(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.peekWithLifetime(hash, key) + lru.mu.Unlock() + + return +} + +func (lru *SyncedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + ok = lru.lru.updateLifetime(hash, key, value, lifetime) + lru.mu.Unlock() + + return +} + +// Contains checks for the existence of a key, without changing its recent-ness. +// If the found entry is already expired, the evict function is called. +func (lru *SyncedLRU[K, V]) Contains(key K) (ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + ok = lru.lru.contains(hash, key) + lru.mu.Unlock() + + return +} + +// Remove removes the key from the cache. +// The return value indicates whether the key existed or not. +// The evict function is being called if the key existed. +func (lru *SyncedLRU[K, V]) Remove(key K) (removed bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + removed = lru.lru.remove(hash, key) + lru.mu.Unlock() + + return +} + +// RemoveOldest removes the oldest entry from the cache. +// Key, value and an indicator of whether the entry has been removed is returned. +// The evict function is called for the removed entry. +func (lru *SyncedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) { + lru.mu.Lock() + key, value, removed = lru.lru.RemoveOldest() + lru.mu.Unlock() + + return +} + +// Keys returns a slice of the keys in the cache, from oldest to newest. +// Expired entries are not included. +// The evict function is called for each expired item. +func (lru *SyncedLRU[K, V]) Keys() (keys []K) { + lru.mu.Lock() + keys = lru.lru.Keys() + lru.mu.Unlock() + + return +} + +// Purge purges all data (key and value) from the LRU. +// The evict function is called for each expired item. +// The LRU metrics are reset. +func (lru *SyncedLRU[K, V]) Purge() { + lru.mu.Lock() + lru.lru.Purge() + lru.mu.Unlock() +} + +// PurgeExpired purges all expired items from the LRU. +// The evict function is called for each expired item. +func (lru *SyncedLRU[K, V]) PurgeExpired() { + lru.mu.Lock() + lru.lru.PurgeExpired() + lru.mu.Unlock() +} + +// Metrics returns the metrics of the cache. +func (lru *SyncedLRU[K, V]) Metrics() Metrics { + lru.mu.Lock() + metrics := lru.lru.Metrics() + lru.mu.Unlock() + return metrics +} + +// ResetMetrics resets the metrics of the cache and returns the previous state. +func (lru *SyncedLRU[K, V]) ResetMetrics() Metrics { + lru.mu.Lock() + metrics := lru.lru.ResetMetrics() + lru.mu.Unlock() + return metrics +} + +// just used for debugging +func (lru *SyncedLRU[K, V]) dump() { + lru.mu.RLock() + lru.lru.dump() + lru.mu.RUnlock() +} + +func (lru *SyncedLRU[K, V]) PrintStats() { + lru.mu.RLock() + lru.lru.PrintStats() + lru.mu.RUnlock() +} From 39040e06dcd5dd887ca4eafd8be55ed1411160ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 28 Nov 2024 13:16:33 +0800 Subject: [PATCH 31/55] udpnat2: Fix concurrency --- common/udpnat2/conn.go | 8 +++++--- common/udpnat2/service.go | 42 ++++++++++----------------------------- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 8ae4557..4ed7b74 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -3,6 +3,7 @@ package udpnat import ( "io" "net" + "net/netip" "os" "sync" "time" @@ -12,6 +13,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/pipe" + "github.com/sagernet/sing/contrab/freelru" ) type Conn interface { @@ -23,7 +25,7 @@ type Conn interface { var _ Conn = (*natConn)(nil) type natConn struct { - service *Service + cache freelru.Cache[netip.AddrPort, *natConn] writer N.PacketWriter localAddr M.Socksaddr handler N.UDPHandlerEx @@ -93,7 +95,7 @@ fetch: } func (c *natConn) Timeout() time.Duration { - rawConn, lifetime, loaded := c.service.cache.PeekWithLifetime(c.localAddr.AddrPort()) + rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort()) if !loaded || rawConn != c { return 0 } @@ -101,7 +103,7 @@ func (c *natConn) Timeout() time.Duration { } func (c *natConn) SetTimeout(timeout time.Duration) bool { - return c.service.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout) + return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout) } func (c *natConn) Close() error { diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 5e52930..07e786d 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -17,18 +17,10 @@ type Service struct { cache freelru.Cache[netip.AddrPort, *natConn] handler N.UDPConnectionHandlerEx prepare PrepareFunc - metrics Metrics } type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) -type Metrics struct { - Creates uint64 - Rejects uint64 - Inputs uint64 - Drops uint64 -} - func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { if timeout == 0 { panic("invalid timeout") @@ -40,7 +32,6 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } cache.SetLifetime(timeout) - cache.SetUpdateLifetimeOnGet(true) cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { select { case <-conn.doneChan: @@ -60,25 +51,26 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { - conn, loaded := s.cache.Get(source.AddrPort()) - if !loaded { + conn, loaded := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { - s.metrics.Rejects++ - return + return nil, false } - conn = &natConn{ - service: s, + newConn := &natConn{ + cache: s.cache, writer: writer, localAddr: source, packetChan: make(chan *N.PacketBuffer, 64), doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } - s.PurgeExpired() - s.cache.Add(source.AddrPort(), conn) - go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) - s.metrics.Creates++ + go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) + return newConn, true + }) + if !loaded { + if conn == nil { + return + } } buffer := conn.readWaitOptions.NewPacketBuffer() for _, bufferSlice := range bufferSlices { @@ -95,11 +87,9 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati } select { case conn.packetChan <- packet: - s.metrics.Inputs++ default: packet.Buffer.Release() N.PutPacketBuffer(packet) - s.metrics.Drops++ } } @@ -110,13 +100,3 @@ func (s *Service) Purge() { func (s *Service) PurgeExpired() { s.cache.PurgeExpired() } - -func (s *Service) Metrics() Metrics { - return s.metrics -} - -func (s *Service) ResetMetrics() Metrics { - metrics := s.metrics - s.metrics = Metrics{} - return metrics -} From 3f30aaf25ec5f55da9dffcb2c553bbd20a7a80ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 30 Nov 2024 16:06:50 +0800 Subject: [PATCH 32/55] freelru: purge all expired items --- contrab/freelru/lru.go | 13 +++++++++---- contrab/freelru/lru_test.go | 19 +++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 7f732ea..5642225 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -688,14 +688,19 @@ func (lru *LRU[K, V]) Purge() { // The evict function is called for each expired item. func (lru *LRU[K, V]) PurgeExpired() { l := lru.len + if l == 0 { + return + } + n := now() + pos := lru.head for i := uint32(0); i < l; i++ { - pos := lru.elements[lru.head].next + next := lru.elements[pos].next if lru.elements[pos].expire != 0 { - if lru.elements[pos].expire > now() { - return // no more expired items + if lru.elements[pos].expire <= n { + lru.removeAt(pos) } - lru.removeAt(pos) } + pos = next } } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index 2650802..0b548b5 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -14,10 +14,9 @@ func TestUpdateLifetimeOnGet(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) - lru.SetUpdateLifetimeOnGet(true) lru.AddWithLifetime("hello", "world", 2*time.Second) time.Sleep(time.Second) - _, ok := lru.Get("hello") + _, ok := lru.GetAndRefresh("hello") require.True(t, ok) time.Sleep(time.Second + time.Millisecond*100) _, ok = lru.Get("hello") @@ -28,7 +27,6 @@ func TestUpdateLifetimeOnGet1(t *testing.T) { t.Parallel() lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) - lru.SetUpdateLifetimeOnGet(true) lru.AddWithLifetime("hello", "world", 2*time.Second) time.Sleep(time.Second) lru.Peek("hello") @@ -82,14 +80,11 @@ func TestPeekWithLifetime(t *testing.T) { lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) lru.SetLifetime(time.Second) - lru.Add("1", "") - time.Sleep(300 * time.Millisecond) - lru.Add("2", "") - time.Sleep(300 * time.Millisecond) - lru.Add("3", "") - time.Sleep(300 * time.Millisecond) - lru.Add("4", "") - time.Sleep(time.Second) + lru.AddWithLifetime("hello", "world", 10*time.Second) + lru.Add("hello1", "") + lru.Add("hello2", "") + lru.Add("hello3", "") + time.Sleep(2 * time.Second) lru.PurgeExpired() - require.Equal(t, 0, lru.Len()) + require.Equal(t, 1, lru.Len()) } From 478265cd459d57c64186a4214b918843601c56b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 1 Dec 2024 14:33:23 +0800 Subject: [PATCH 33/55] badoption: Finish netip options --- common/json/badoption/netip.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/common/json/badoption/netip.go b/common/json/badoption/netip.go index 61aef82..f22df43 100644 --- a/common/json/badoption/netip.go +++ b/common/json/badoption/netip.go @@ -35,6 +35,13 @@ func (a *Addr) UnmarshalJSON(content []byte) error { type Prefix netip.Prefix +func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix { + if p == nil { + return defaultPrefix + } + return netip.Prefix(*p) +} + func (p *Prefix) MarshalJSON() ([]byte, error) { return json.Marshal(netip.Prefix(*p).String()) } @@ -55,6 +62,13 @@ func (p *Prefix) UnmarshalJSON(content []byte) error { type Prefixable netip.Prefix +func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix { + if p == nil { + return defaultPrefix + } + return netip.Prefix(*p) +} + func (p *Prefixable) MarshalJSON() ([]byte, error) { prefix := netip.Prefix(*p) if prefix.Bits() == prefix.Addr().BitLen() { From 9f69e7f9f7e2c4fc26828d98ec9f81ab72fdec98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 1 Dec 2024 20:19:37 +0800 Subject: [PATCH 34/55] E: IsClosedOrCanceled check IsTimeout --- common/exceptions/error.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 16b075a..3f08ac4 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -40,7 +40,7 @@ func Extend(cause error, message ...any) error { } func IsClosedOrCanceled(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded) + return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded) || IsTimeout(err) } func IsClosed(err error) bool { From 809d8eca139712f6c833cea813674a1cb1154ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 3 Dec 2024 11:50:46 +0800 Subject: [PATCH 35/55] freelru: fix PurgeExpired --- contrab/freelru/lru.go | 15 +++++++-------- contrab/freelru/lru_test.go | 28 +++++++++++++++++++--------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 5642225..55e8937 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -687,20 +687,19 @@ func (lru *LRU[K, V]) Purge() { // PurgeExpired purges all expired items from the LRU. // The evict function is called for each expired item. func (lru *LRU[K, V]) PurgeExpired() { + n := now() +loop: l := lru.len if l == 0 { return } - n := now() - pos := lru.head + pos := lru.elements[lru.head].next for i := uint32(0); i < l; i++ { - next := lru.elements[pos].next - if lru.elements[pos].expire != 0 { - if lru.elements[pos].expire <= n { - lru.removeAt(pos) - } + if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= n { + lru.removeAt(pos) + goto loop } - pos = next + pos = lru.elements[pos].next } } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index 0b548b5..d3ea64a 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -1,6 +1,9 @@ package freelru_test import ( + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" + "math/rand/v2" "testing" "time" @@ -75,16 +78,23 @@ func TestUpdateLifetime2(t *testing.T) { require.False(t, ok) } -func TestPeekWithLifetime(t *testing.T) { +func TestPurgeExpired(t *testing.T) { t.Parallel() - lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + lru, err := freelru.New[string, *string](1024, maphash.NewHasher[string]().Hash32) require.NoError(t, err) lru.SetLifetime(time.Second) - lru.AddWithLifetime("hello", "world", 10*time.Second) - lru.Add("hello1", "") - lru.Add("hello2", "") - lru.Add("hello3", "") - time.Sleep(2 * time.Second) - lru.PurgeExpired() - require.Equal(t, 1, lru.Len()) + lru.SetOnEvict(func(s string, s2 *string) { + if s2 == nil { + t.Fail() + } + }) + for i := 0; i < 100; i++ { + lru.AddWithLifetime("hello_"+F.ToString(i), common.Ptr("world_"+F.ToString(i)), time.Duration(rand.Int32N(3000))*time.Millisecond) + } + for i := 0; i < 5; i++ { + time.Sleep(time.Second) + lru.GetAndRefreshOrAdd("hellox"+F.ToString(i), func() (*string, bool) { + return common.Ptr("worldx"), true + }) + } } From 957166799ec39663c00da4f974d6627e5fd9d0ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 4 Dec 2024 17:14:58 +0800 Subject: [PATCH 36/55] Fix CloseOnHandshakeFailure --- common/network/handshake.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/common/network/handshake.go b/common/network/handshake.go index d2203e0..273b9e3 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -33,7 +33,7 @@ func ReportHandshakeFailure(reporter any, err error) error { return nil } -func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error { +func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error { if err != nil { if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { hErr := handshakeConn.HandshakeFailure(err) @@ -51,12 +51,10 @@ func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) }](reporter); isTCPConn { tcpConn.SetLinger(0) } - if closer, isCloser := reporter.(io.Closer); isCloser { - err = E.Append(err, closer.Close(), func(err error) error { - return E.Cause(err, "close") - }) - } } + err = E.Append(err, reporter.Close(), func(err error) error { + return E.Cause(err, "close") + }) } if onClose != nil { onClose(err) From 73776cf797eec17d913b974dad5a55f9429a6a4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 10 Dec 2024 19:42:33 +0800 Subject: [PATCH 37/55] Fix lru test --- contrab/freelru/lru_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go index d3ea64a..3d4a328 100644 --- a/contrab/freelru/lru_test.go +++ b/contrab/freelru/lru_test.go @@ -1,12 +1,12 @@ package freelru_test import ( - "github.com/sagernet/sing/common" - F "github.com/sagernet/sing/common/format" - "math/rand/v2" + "math/rand" "testing" "time" + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" @@ -89,7 +89,7 @@ func TestPurgeExpired(t *testing.T) { } }) for i := 0; i < 100; i++ { - lru.AddWithLifetime("hello_"+F.ToString(i), common.Ptr("world_"+F.ToString(i)), time.Duration(rand.Int32N(3000))*time.Millisecond) + lru.AddWithLifetime("hello_"+F.ToString(i), common.Ptr("world_"+F.ToString(i)), time.Duration(rand.Intn(3000))*time.Millisecond) } for i := 0; i < 5; i++ { time.Sleep(time.Second) From 3374a45475c4a5926dd24930260d21f1aed3be71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 10 Dec 2024 19:42:37 +0800 Subject: [PATCH 38/55] Fix socks5 UDP implementation --- protocol/http/handshake.go | 67 +++++++---------------- protocol/socks/handshake.go | 103 +++++++++--------------------------- protocol/socks/lazy.go | 3 +- 3 files changed, 46 insertions(+), 127 deletions(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 9528573..fd5817b 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -21,17 +21,14 @@ import ( "github.com/sagernet/sing/common/pipe" ) -// Deprecated: Use HandleConnectionEx instead. -func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, - //nolint:staticcheck - handler N.TCPConnectionHandler, metadata M.Metadata, -) error { - return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, nil) -} - -func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, - //nolint:staticcheck - handler N.TCPConnectionHandler, handlerEx N.TCPConnectionHandlerEx, source M.Socksaddr, onClose N.CloseHandlerFunc, +func HandleConnectionEx( + ctx context.Context, + conn net.Conn, + reader *std_bufio.Reader, + authenticator *auth.Authenticator, + handler N.TCPConnectionHandlerEx, + source M.Socksaddr, + onClose N.CloseHandlerFunc, ) error { for { request, err := ReadRequest(reader) @@ -105,13 +102,8 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } else { requestConn = conn } - if handler != nil { - //nolint:staticcheck - return handler.NewConnection(ctx, requestConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) - } else { - handlerEx.NewConnectionEx(ctx, requestConn, source, destination, onClose) - return nil - } + handler.NewConnectionEx(ctx, requestConn, source, destination, onClose) + return nil } else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" { destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port()) if destination.Port == 0 { @@ -124,19 +116,11 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } serverConn, clientConn := pipe.Pipe() go func() { - if handler != nil { - //nolint:staticcheck - err := handler.NewConnection(ctx, clientConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) - if err != nil { + handler.NewConnectionEx(ctx, clientConn, source, destination, func(it error) { + if it != nil { common.Close(serverConn, clientConn) } - } else { - handlerEx.NewConnectionEx(ctx, clientConn, source, destination, func(it error) { - if it != nil { - common.Close(serverConn, clientConn) - } - }) - } + }) }() err = request.Write(serverConn) if err != nil { @@ -150,7 +134,7 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re } return bufio.CopyConn(ctx, conn, serverConn) } else { - err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) + err = handleHTTPConnection(ctx, handler, conn, request, source) if err != nil { return err } @@ -160,9 +144,7 @@ func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Re func handleHTTPConnection( ctx context.Context, - //nolint:staticcheck - handler N.TCPConnectionHandler, - handlerEx N.TCPConnectionHandlerEx, + handler N.TCPConnectionHandlerEx, conn net.Conn, request *http.Request, source M.Socksaddr, ) error { @@ -188,21 +170,10 @@ func handleHTTPConnection( DisableCompression: true, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { input, output := pipe.Pipe() - if handler != nil { - go func() { - //nolint:staticcheck - hErr := handler.NewConnection(ctx, output, M.Metadata{Protocol: "http", Source: source, Destination: M.ParseSocksaddr(address)}) - if hErr != nil { - innerErr.Store(hErr) - common.Close(input, output) - } - }() - } else { - go handlerEx.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) { - innerErr.Store(it) - common.Close(input, output) - }) - } + go handler.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) { + innerErr.Store(it) + common.Close(input, output) + }) return input, nil }, }, diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 7232eea..8a0ff86 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -19,14 +19,6 @@ import ( "github.com/sagernet/sing/protocol/socks/socks5" ) -// Deprecated: Use HandlerEx instead. -// -//nolint:staticcheck -type Handler interface { - N.TCPConnectionHandler - N.UDPConnectionHandler -} - type HandlerEx interface { N.TCPConnectionHandlerEx N.UDPConnectionHandlerEx @@ -87,6 +79,26 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, } else if authResponse.Method != socks5.AuthTypeNotRequired { return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method) } + + if command == socks5.CommandUDPAssociate { + if destination.Addr.IsPrivate() { + if destination.Addr.Is6() { + destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + } else { + destination.Addr = netip.IPv6Loopback() + } + } else if destination.Addr.IsGlobalUnicast() { + if destination.Addr.Is6() { + destination.Addr = netip.IPv6Unspecified() + } else { + destination.Addr = netip.IPv4Unspecified() + } + } else { + destination.Addr = netip.IPv6Unspecified() + } + destination.Port = 0 + } + err = socks5.WriteRequest(conn, socks5.Request{ Command: command, Destination: destination, @@ -104,23 +116,11 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, return response, err } -// Deprecated: use HandleConnectionEx instead. -func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { - return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata) -} - -// Deprecated: Use HandleConnectionEx instead. -func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { - return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, metadata.Destination, nil) -} - func HandleConnectionEx( ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, - //nolint:staticcheck - handler Handler, - handlerEx HandlerEx, - source M.Socksaddr, destination M.Socksaddr, + handler HandlerEx, + source M.Socksaddr, onClose N.CloseHandlerFunc, ) error { version, err := reader.ReadByte() @@ -145,20 +145,7 @@ func HandleConnectionEx( } return E.New("socks4: authentication failed, username=", request.Username) } - destination = request.Destination - if handlerEx != nil { - handlerEx.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, destination, onClose) - } else { - err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeGranted, - Destination: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err - } - //nolint:staticcheck - return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, M.Metadata{Protocol: "socks4", Source: source, Destination: destination}) - } + handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose) return nil default: err = socks4.WriteResponse(conn, socks4.Response{ @@ -223,53 +210,15 @@ func HandleConnectionEx( } switch request.Command { case socks5.CommandConnect: - destination = request.Destination - if handlerEx != nil { - handlerEx.NewConnectionEx(ctx, NewLazyConn(conn, version), source, destination, onClose) - return nil - } else { - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err - } - //nolint:staticcheck - return handler.NewConnection(ctx, conn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) - } + handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose) + return nil case socks5.CommandUDPAssociate: var udpConn *net.UDPConn udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) if err != nil { return err } - if handlerEx == nil { - defer udpConn.Close() - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), - }) - if err != nil { - return err - } - destination = request.Destination - associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn) - var innerError error - done := make(chan struct{}) - go func() { - //nolint:staticcheck - innerError = handler.NewPacketConnection(ctx, associatePacketConn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) - close(done) - }() - err = common.Error(io.Copy(io.Discard, conn)) - associatePacketConn.Close() - <-done - return E.Errors(innerError, err) - } else { - handlerEx.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn), source, destination, onClose) - return nil - } + handler.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), conn), source, M.Socksaddr{}, onClose) default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go index f98ac3d..28782bc 100644 --- a/protocol/socks/lazy.go +++ b/protocol/socks/lazy.go @@ -105,12 +105,11 @@ type LazyAssociatePacketConn struct { responseWritten bool } -func NewLazyAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *LazyAssociatePacketConn { +func NewLazyAssociatePacketConn(conn net.Conn, underlying net.Conn) *LazyAssociatePacketConn { return &LazyAssociatePacketConn{ AssociatePacketConn: AssociatePacketConn{ AbstractConn: conn, conn: bufio.NewExtendedConn(conn), - remoteAddr: remoteAddr, underlying: underlying, }, } From 442cceb9fa07e84980d7c7ec5a4c8ac1553932d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 12 Dec 2024 20:40:01 +0800 Subject: [PATCH 39/55] Fix disable UDP fragment --- common/control/frag_darwin.go | 17 ++++++++++++----- common/control/frag_linux.go | 16 +++++++++------- common/control/frag_windows.go | 16 +++++++++------- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/common/control/frag_darwin.go b/common/control/frag_darwin.go index 567a821..f76b241 100644 --- a/common/control/frag_darwin.go +++ b/common/control/frag_darwin.go @@ -4,19 +4,26 @@ import ( "os" "syscall" + N "github.com/sagernet/sing/common/network" + "golang.org/x/sys/unix" ) func DisableUDPFragment() Func { return func(network, address string, conn syscall.RawConn) error { + if N.NetworkName(network) != N.NetworkUDP { + return nil + } return Raw(conn, func(fd uintptr) error { - switch network { - case "udp4": - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil { + if network == "udp" || network == "udp4" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1) + if err != nil { return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err) } - case "udp6": - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil { + } + if network == "udp" || network == "udp6" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1) + if err != nil { return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err) } } diff --git a/common/control/frag_linux.go b/common/control/frag_linux.go index 5cb5fca..3bf9d57 100644 --- a/common/control/frag_linux.go +++ b/common/control/frag_linux.go @@ -11,17 +11,19 @@ import ( func DisableUDPFragment() Func { return func(network, address string, conn syscall.RawConn) error { - switch N.NetworkName(network) { - case N.NetworkUDP: - default: + if N.NetworkName(network) != N.NetworkUDP { return nil } return Raw(conn, func(fd uintptr) error { - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil { - return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + if network == "udp" || network == "udp4" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + if err != nil { + return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + } } - if network == "udp6" { - if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil { + if network == "udp" || network == "udp6" { + err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + if err != nil { return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err) } } diff --git a/common/control/frag_windows.go b/common/control/frag_windows.go index bf02948..8e9bb83 100644 --- a/common/control/frag_windows.go +++ b/common/control/frag_windows.go @@ -25,17 +25,19 @@ const ( func DisableUDPFragment() Func { return func(network, address string, conn syscall.RawConn) error { - switch N.NetworkName(network) { - case N.NetworkUDP: - default: + if N.NetworkName(network) != N.NetworkUDP { return nil } return Raw(conn, func(fd uintptr) error { - if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil { - return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + if network == "udp" || network == "udp4" { + err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO) + if err != nil { + return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err) + } } - if network == "udp6" { - if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil { + if network == "udp" || network == "udp6" { + err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO) + if err != nil { return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err) } } From 33beacc0538c8de18efd2e93892984c0027da9a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Dec 2024 18:16:15 +0800 Subject: [PATCH 40/55] Fix socks5 UDP handshake --- protocol/socks/handshake.go | 1 + 1 file changed, 1 insertion(+) diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 8a0ff86..080752a 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -219,6 +219,7 @@ func HandleConnectionEx( return err } handler.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), conn), source, M.Socksaddr{}, onClose) + return nil default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, From aa7d2543a3ce2b3a639c6bf8cb8f241159f4c487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 14 Dec 2024 18:16:15 +0800 Subject: [PATCH 41/55] Fix errors usage --- common/baderror/baderror.go | 6 ++---- common/exceptions/inner.go | 22 ++++++---------------- common/exceptions/multi.go | 9 +-------- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go index c5ab530..d90ddeb 100644 --- a/common/baderror/baderror.go +++ b/common/baderror/baderror.go @@ -2,11 +2,10 @@ package baderror import ( "context" + "errors" "io" "net" "strings" - - E "github.com/sagernet/sing/common/exceptions" ) func Contains(err error, msgList ...string) bool { @@ -22,8 +21,7 @@ func WrapH2(err error) error { if err == nil { return nil } - err = E.Unwrap(err) - if err == io.ErrUnexpectedEOF { + if errors.Is(err, io.ErrUnexpectedEOF) { return io.EOF } if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") { diff --git a/common/exceptions/inner.go b/common/exceptions/inner.go index 3799af9..58bcc5a 100644 --- a/common/exceptions/inner.go +++ b/common/exceptions/inner.go @@ -1,24 +1,14 @@ package exceptions -import "github.com/sagernet/sing/common" +import ( + "errors" -type HasInnerError interface { - Unwrap() error -} + "github.com/sagernet/sing/common" +) +// Deprecated: Use errors.Unwrap instead. func Unwrap(err error) error { - for { - inner, ok := err.(HasInnerError) - if !ok { - break - } - innerErr := inner.Unwrap() - if innerErr == nil { - break - } - err = innerErr - } - return err + return errors.Unwrap(err) } func Cast[T any](err error) (T, bool) { diff --git a/common/exceptions/multi.go b/common/exceptions/multi.go index 2cdec05..78fb3b6 100644 --- a/common/exceptions/multi.go +++ b/common/exceptions/multi.go @@ -63,12 +63,5 @@ func IsMulti(err error, targetList ...error) bool { return true } } - err = Unwrap(err) - multiErr, isMulti := err.(MultiError) - if !isMulti { - return false - } - return common.All(multiErr.Unwrap(), func(it error) bool { - return IsMulti(it, targetList...) - }) + return false } From be9840c70ff38c03b730b347c9d0f4d1f1e195ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 8 Jan 2025 11:04:24 +0800 Subject: [PATCH 42/55] listable: Fix incorrect unmarshaling of null to []T{null} --- common/json/badoption/listable.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/common/json/badoption/listable.go b/common/json/badoption/listable.go index df02217..c48df12 100644 --- a/common/json/badoption/listable.go +++ b/common/json/badoption/listable.go @@ -18,6 +18,9 @@ func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) { } func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error { + if string(content) == "null" { + return nil + } var singleItem T err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem) if err == nil { From 4dabb9be9740dbcd5f99f2f16b46d55eb9c29193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 9 Jan 2025 15:57:55 +0800 Subject: [PATCH 43/55] freelru: Fix GetAndRefreshOrAdd --- common/udpnat2/service.go | 8 +++----- contrab/freelru/cache.go | 2 +- contrab/freelru/lru.go | 16 +++++++++------- contrab/freelru/shardedlru.go | 7 +++++-- contrab/freelru/syncedlru.go | 7 +++++-- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 07e786d..1402211 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -51,7 +51,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { - conn, loaded := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) { + conn, _, ok := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { return nil, false @@ -67,10 +67,8 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) return newConn, true }) - if !loaded { - if conn == nil { - return - } + if !ok { + return } buffer := conn.readWaitOptions.NewPacketBuffer() for _, bufferSlice := range bufferSlices { diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index 319e662..e1877fb 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -54,7 +54,7 @@ type Cache[K comparable, V comparable] interface { // The lifetime of the found cache item is refreshed, even if it was already expired. GetAndRefresh(key K) (V, bool) - GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool) + GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) // Peek looks up a key's value from the cache, without changing its recent-ness. // If the found entry is already expired, the evict function is called. diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 55e8937..6c31857 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -522,11 +522,15 @@ func (lru *LRU[K, V]) getAndRefresh(hash uint32, key K) (value V, ok bool) { return } -func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool) { - return lru.getAndRefreshOrAdd(lru.hash(key), key, constructor) +func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) { + value, updated, ok := lru.getAndRefreshOrAdd(lru.hash(key), key, constructor) + if !updated && ok { + lru.PurgeExpired() + } + return value, updated, ok } -func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, ok bool) { +func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { if pos, ok := lru.findKeyNoExpire(hash, key); ok { if pos != lru.head { lru.unlinkElement(pos) @@ -534,17 +538,15 @@ func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() } lru.metrics.Hits++ lru.elements[pos].expire = expire(lru.lifetime) - return lru.elements[pos].value, ok + return lru.elements[pos].value, true, true } - lru.metrics.Misses++ value, ok = constructor() if !ok { return } lru.addWithLifetime(hash, key, value, lru.lifetime) - lru.PurgeExpired() - return value, false + return value, false, true } // Peek looks up a key's value from the cache, without changing its recent-ness. diff --git a/contrab/freelru/shardedlru.go b/contrab/freelru/shardedlru.go index e6aca65..db97efa 100644 --- a/contrab/freelru/shardedlru.go +++ b/contrab/freelru/shardedlru.go @@ -201,14 +201,17 @@ func (lru *ShardedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { return } -func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool) { +func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { hash := lru.hash(key) shard := (hash >> 16) & lru.mask lru.mus[shard].Lock() - value, updated = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor) + value, updated, ok = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor) lru.mus[shard].Unlock() + if !updated && ok { + lru.PurgeExpired() + } return } diff --git a/contrab/freelru/syncedlru.go b/contrab/freelru/syncedlru.go index d42da70..38a854a 100644 --- a/contrab/freelru/syncedlru.go +++ b/contrab/freelru/syncedlru.go @@ -121,11 +121,14 @@ func (lru *SyncedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) { return } -func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool) { +func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) { hash := lru.lru.hash(key) lru.mu.Lock() - value, updated = lru.lru.getAndRefreshOrAdd(hash, key, constructor) + value, updated, ok = lru.lru.getAndRefreshOrAdd(hash, key, constructor) + if !updated && ok { + lru.lru.PurgeExpired() + } lru.mu.Unlock() return From d9f6eb136d177971b9cfc2ae528eb1fa95486e49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 9 Jan 2025 23:30:25 +0800 Subject: [PATCH 44/55] Fix set windows system time --- common/ntp/time_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/common/ntp/time_windows.go b/common/ntp/time_windows.go index 5aab23e..82d8798 100644 --- a/common/ntp/time_windows.go +++ b/common/ntp/time_windows.go @@ -7,6 +7,7 @@ import ( ) func SetSystemTime(nowTime time.Time) error { + nowTime = nowTime.UTC() var systemTime windows.Systemtime systemTime.Year = uint16(nowTime.Year()) systemTime.Month = uint16(nowTime.Month()) From d8153df67f67e059a4def62c3904298b870e95db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 5 Feb 2025 09:47:22 +0800 Subject: [PATCH 45/55] Add ENOTCONN to IsClosed --- common/exceptions/error.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 3f08ac4..0f33792 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -40,11 +40,11 @@ func Extend(cause error, message ...any) error { } func IsClosedOrCanceled(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded) || IsTimeout(err) + return IsClosed(err) || IsCanceled(err) || IsTimeout(err) } func IsClosed(err error) bool { - return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET) + return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN) } func IsCanceled(err error) bool { From 9eafc7fc62b10528df821cdfbf4e4e8f122f4b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 10 Feb 2025 15:08:18 +0800 Subject: [PATCH 46/55] udpnat2: Fix crash --- common/udpnat2/service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 1402211..f658fa2 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -27,7 +27,7 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur } var cache freelru.Cache[netip.AddrPort, *natConn] if !shared { - cache = common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + cache = common.Must1(freelru.NewSynced[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } else { cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) } From d54716612ceee75a28d728e9c682bf6e926b2056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 28 Feb 2025 12:06:56 +0800 Subject: [PATCH 47/55] Fix syscall packet read waiter for Windows --- common/bufio/copy_direct_windows.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 482b649..7da9213 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions var readN int var from windows.Sockaddr readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0) + //goland:noinspection GoDirectComparisonOfErrors + if w.readErr != nil { + buffer.Release() + return w.readErr != windows.WSAEWOULDBLOCK + } if readN > 0 { buffer.Truncate(readN) - w.options.PostReturn(buffer) - w.buffer = buffer - } else { - buffer.Release() - } - if w.readErr == windows.WSAEWOULDBLOCK { - return false } + w.options.PostReturn(buffer) + w.buffer = buffer if from != nil { switch fromAddr := from.(type) { case *windows.SockaddrInet4: From b55d1c78b381f4ae51321d2b535d418da5b09086 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 9 Mar 2025 15:18:37 +0800 Subject: [PATCH 48/55] bufio: Add destination NAT packet conn --- common/bufio/nat.go | 62 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/common/bufio/nat.go b/common/bufio/nat.go index cafeb06..6e5ab64 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -30,6 +30,14 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So } } +func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &destinationNATPacketConn{ + NetPacketConn: conn, + origin: origin, + destination: destination, + } +} + type unidirectionalNATPacketConn struct { N.NetPacketConn origin M.Socksaddr @@ -144,6 +152,60 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr { return c.destination.UDPAddr() } +type destinationNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + if M.SocksaddrFromNet(addr) == c.origin { + addr = c.destination.UDPAddr() + } + return +} + +func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if M.SocksaddrFromNet(addr) == c.destination { + addr = c.origin.UDPAddr() + } + return c.NetPacketConn.WriteTo(p, addr) +} + +func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return + } + if destination == c.origin { + destination = c.destination + } + return +} + +func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination == c.destination { + destination = c.origin + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *destinationNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *destinationNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { destination.Port = 0 return destination From 2238a05966ae4ad930e79d88ea0c5a9c8e26f329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 10 Feb 2025 18:59:17 +0800 Subject: [PATCH 49/55] Fix merge objects --- common/json/badjson/merge_objects.go | 10 +++---- common/json/internal/contextjson/keys.go | 20 ++++++++++++++ common/json/internal/contextjson/keys_test.go | 26 +++++++++++++++++++ 3 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 common/json/internal/contextjson/keys.go create mode 100644 common/json/internal/contextjson/keys_test.go diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index fa6c2d4..5b23209 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -2,9 +2,11 @@ package badjson import ( "context" + "reflect" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" + cJSON "github.com/sagernet/sing/common/json/internal/contextjson" ) func MarshallObjects(objects ...any) ([]byte, error) { @@ -31,16 +33,12 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(ctx, parentObject) - if err != nil { - return err - } var content JSONObject - err = content.UnmarshalJSONContext(ctx, inputContent) + err := content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } - for _, key := range parentContent.Keys() { + for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) { content.Remove(key) } if object == nil { diff --git a/common/json/internal/contextjson/keys.go b/common/json/internal/contextjson/keys.go new file mode 100644 index 0000000..589007f --- /dev/null +++ b/common/json/internal/contextjson/keys.go @@ -0,0 +1,20 @@ +package json + +import ( + "reflect" + + "github.com/sagernet/sing/common" +) + +func ObjectKeys(object reflect.Type) []string { + switch object.Kind() { + case reflect.Pointer: + return ObjectKeys(object.Elem()) + case reflect.Struct: + default: + panic("invalid non-struct input") + } + return common.Map(cachedTypeFields(object).list, func(field field) string { + return field.name + }) +} diff --git a/common/json/internal/contextjson/keys_test.go b/common/json/internal/contextjson/keys_test.go new file mode 100644 index 0000000..11c72cb --- /dev/null +++ b/common/json/internal/contextjson/keys_test.go @@ -0,0 +1,26 @@ +package json_test + +import ( + "reflect" + "testing" + + json "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type MyObject struct { + Hello string `json:"hello,omitempty"` + MyWorld + MyWorld2 string `json:"-"` +} + +type MyWorld struct { + World string `json:"world,omitempty"` +} + +func TestObjectKeys(t *testing.T) { + t.Parallel() + keys := json.ObjectKeys(reflect.TypeOf(&MyObject{})) + require.Equal(t, []string{"hello", "world"}, keys) +} From ce1b4851a451781a8075b29dc7fb0bc17d061652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 16 Mar 2025 10:20:59 +0800 Subject: [PATCH 50/55] Fix socks5 UDP --- protocol/socks/handshake.go | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 080752a..73a16b8 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -216,9 +217,24 @@ func HandleConnectionEx( var udpConn *net.UDPConn udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) if err != nil { - return err + return E.Cause(err, "socks5: listen udp") } - handler.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), conn), source, M.Socksaddr{}, onClose) + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), + }) + if err != nil { + return E.Cause(err, "socks5: write response") + } + var socksPacketConn N.PacketConn = NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), M.Socksaddr{}, conn) + firstPacket := buf.NewPacket() + var destination M.Socksaddr + destination, err = socksPacketConn.ReadPacket(firstPacket) + if err != nil { + return E.Cause(err, "socks5: read first packet") + } + socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination) + handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose) return nil default: err = socks5.WriteResponse(conn, socks5.Response{ From 23b0180a1b7bbf0b702f73610a9b6193d225a2bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 24 Mar 2025 18:06:52 +0800 Subject: [PATCH 51/55] Fix crash on udpnat2 handler --- common/udpnat2/conn.go | 18 +++--------------- common/udpnat2/service.go | 7 +++++-- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 4ed7b74..114b446 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -28,6 +28,7 @@ type natConn struct { cache freelru.Cache[netip.AddrPort, *natConn] writer N.PacketWriter localAddr M.Socksaddr + handlerAccess sync.RWMutex handler N.UDPHandlerEx packetChan chan *N.PacketBuffer closeOnce sync.Once @@ -75,23 +76,10 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, } func (c *natConn) SetHandler(handler N.UDPHandlerEx) { - select { - case <-c.doneChan: - default: - } + c.handlerAccess.Lock() + defer c.handlerAccess.Unlock() c.handler = handler c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) -fetch: - for { - select { - case packet := <-c.packetChan: - c.handler.NewPacketEx(packet.Buffer, packet.Destination) - N.PutPacketBuffer(packet) - continue fetch - default: - break fetch - } - } } func (c *natConn) Timeout() time.Duration { diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index f658fa2..3367bd7 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -74,8 +74,11 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati for _, bufferSlice := range bufferSlices { buffer.Write(bufferSlice) } - if conn.handler != nil { - conn.handler.NewPacketEx(buffer, destination) + conn.handlerAccess.RLock() + handler := conn.handler + conn.handlerAccess.RUnlock() + if handler != nil { + handler.NewPacketEx(buffer, destination) return } packet := N.NewPacketBuffer() From 2b41455f5ab42c3290cfeff27c66c5e8477d9208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 26 Mar 2025 12:46:15 +0800 Subject: [PATCH 52/55] Fix udpnat2 handler again --- common/udpnat2/conn.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 114b446..e8987f4 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -77,9 +77,20 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, func (c *natConn) SetHandler(handler N.UDPHandlerEx) { c.handlerAccess.Lock() - defer c.handlerAccess.Unlock() c.handler = handler c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) + c.handlerAccess.Unlock() +fetch: + for { + select { + case packet := <-c.packetChan: + c.handler.NewPacketEx(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + continue fetch + default: + break fetch + } + } } func (c *natConn) Timeout() time.Duration { From ea0ac932aeeb0469641723f688cfb58868af2ed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 29 Jan 2025 19:59:59 +0800 Subject: [PATCH 53/55] Add winiphlpapi --- common/windnsapi/dnsapi_test.go | 6 +- common/winiphlpapi/helper.go | 217 +++++++++++++++++ common/winiphlpapi/iphlpapi.go | 313 +++++++++++++++++++++++++ common/winiphlpapi/iphlpapi_test.go | 90 +++++++ common/winiphlpapi/syscall_windows.go | 27 +++ common/winiphlpapi/zsyscall_windows.go | 131 +++++++++++ 6 files changed, 780 insertions(+), 4 deletions(-) create mode 100644 common/winiphlpapi/helper.go create mode 100644 common/winiphlpapi/iphlpapi.go create mode 100644 common/winiphlpapi/iphlpapi_test.go create mode 100644 common/winiphlpapi/syscall_windows.go create mode 100644 common/winiphlpapi/zsyscall_windows.go diff --git a/common/windnsapi/dnsapi_test.go b/common/windnsapi/dnsapi_test.go index adf582d..c5ea831 100644 --- a/common/windnsapi/dnsapi_test.go +++ b/common/windnsapi/dnsapi_test.go @@ -1,16 +1,14 @@ +//go:build windows + package windnsapi import ( - "runtime" "testing" "github.com/stretchr/testify/require" ) func TestDNSAPI(t *testing.T) { - if runtime.GOOS != "windows" { - t.SkipNow() - } t.Parallel() require.NoError(t, FlushResolverCache()) } diff --git a/common/winiphlpapi/helper.go b/common/winiphlpapi/helper.go new file mode 100644 index 0000000..6bd4e8f --- /dev/null +++ b/common/winiphlpapi/helper.go @@ -0,0 +1,217 @@ +//go:build windows + +package winiphlpapi + +import ( + "context" + "encoding/binary" + "net" + "net/netip" + "os" + "time" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func LoadEStats() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetTcpTable.Find() + if err != nil { + return err + } + err = procGetTcp6Table.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procGetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcpConnectionEStats.Find() + if err != nil { + return err + } + err = procSetPerTcp6ConnectionEStats.Find() + if err != nil { + return err + } + return nil +} + +func LoadExtendedTable() error { + err := modiphlpapi.Load() + if err != nil { + return err + } + err = procGetExtendedTcpTable.Find() + if err != nil { + return err + } + err = procGetExtendedUdpTable.Find() + if err != nil { + return err + } + return nil +} + +func FindPid(network string, source netip.AddrPort) (uint32, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + if source.Addr().Is4() { + tcpTable, err := GetExtendedTcpTable() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + tcpTable, err := GetExtendedTcp6Table() + if err != nil { + return 0, err + } + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + case N.NetworkUDP: + if source.Addr().Is4() { + udpTable, err := GetExtendedUdpTable() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } else { + udpTable, err := GetExtendedUdp6Table() + if err != nil { + return 0, err + } + for _, row := range udpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) { + return row.DwOwningPid, nil + } + } + } + } + return 0, E.New("process not found for ", source) +} + +func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error { + source := M.AddrPortFromNet(conn.LocalAddr()) + destination := M.AddrPortFromNet(conn.RemoteAddr()) + if source.Addr().Is4() { + tcpTable, err := GetTcpTable() + if err != nil { + return err + } + var tcpRow *MibTcpRow + for _, row := range tcpTable { + if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) || + destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } else { + tcpTable, err := GetTcp6Table() + if err != nil { + return err + } + var tcpRow *MibTcp6Row + for _, row := range tcpTable { + if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) || + destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) { + tcpRow = &row + break + } + } + if tcpRow == nil { + return E.New("row not found for: ", source) + } + err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: true, + }) + if err != nil { + return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err) + } + defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{ + EnableCollection: false, + }) + _, err = conn.Write(payload) + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow) + if err != nil { + return err + } + if eStstsSendBuffer.CurRetxQueue == 0 { + return nil + } + time.Sleep(10 * time.Millisecond) + } + } +} + +func DwordToAddr(addr uint32) netip.Addr { + return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr))) +} + +func DwordToPort(dword uint32) uint16 { + return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:]) +} diff --git a/common/winiphlpapi/iphlpapi.go b/common/winiphlpapi/iphlpapi.go new file mode 100644 index 0000000..74e5b90 --- /dev/null +++ b/common/winiphlpapi/iphlpapi.go @@ -0,0 +1,313 @@ +//go:build windows + +package winiphlpapi + +import ( + "errors" + "os" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + TcpTableBasicListener uint32 = iota + TcpTableBasicConnections + TcpTableBasicAll + TcpTableOwnerPidListener + TcpTableOwnerPidConnections + TcpTableOwnerPidAll + TcpTableOwnerModuleListener + TcpTableOwnerModuleConnections + TcpTableOwnerModuleAll +) + +const ( + UdpTableBasic uint32 = iota + UdpTableOwnerPid + UdpTableOwnerModule +) + +const ( + TcpConnectionEstatsSynOpts uint32 = iota + TcpConnectionEstatsData + TcpConnectionEstatsSndCong + TcpConnectionEstatsPath + TcpConnectionEstatsSendBuff + TcpConnectionEstatsRec + TcpConnectionEstatsObsRec + TcpConnectionEstatsBandwidth + TcpConnectionEstatsFineRtt + TcpConnectionEstatsMaximum +) + +type MibTcpTable struct { + DwNumEntries uint32 + Table [1]MibTcpRow +} + +type MibTcpRow struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 +} + +type MibTcp6Table struct { + DwNumEntries uint32 + Table [1]MibTcp6Row +} + +type MibTcp6Row struct { + State uint32 + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeId uint32 + RemotePort uint32 +} + +type MibTcpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcpRowOwnerPid +} + +type MibTcpRowOwnerPid struct { + DwState uint32 + DwLocalAddr uint32 + DwLocalPort uint32 + DwRemoteAddr uint32 + DwRemotePort uint32 + DwOwningPid uint32 +} + +type MibTcp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibTcp6RowOwnerPid +} + +type MibTcp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + UcRemoteAddr [16]byte + DwRemoteScopeId uint32 + DwRemotePort uint32 + DwState uint32 + DwOwningPid uint32 +} + +type MibUdpTableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdpRowOwnerPid +} + +type MibUdpRowOwnerPid struct { + DwLocalAddr uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type MibUdp6TableOwnerPid struct { + DwNumEntries uint32 + Table [1]MibUdp6RowOwnerPid +} + +type MibUdp6RowOwnerPid struct { + UcLocalAddr [16]byte + DwLocalScopeId uint32 + DwLocalPort uint32 + DwOwningPid uint32 +} + +type TcpEstatsSendBufferRodV0 struct { + CurRetxQueue uint64 + MaxRetxQueue uint64 + CurAppWQueue uint64 + MaxAppWQueue uint64 +} + +type TcpEstatsSendBuffRwV0 struct { + EnableCollection bool +} + +const ( + offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table) + offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table) + offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table) + offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table) + offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table) + sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{}) + sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{}) +) + +func GetTcpTable() ([]MibTcpRow, error) { + var size uint32 + err := getTcpTable(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcpTable(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil + } +} + +func GetTcp6Table() ([]MibTcp6Row, error) { + var size uint32 + err := getTcp6Table(nil, &size, false) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, err + } + for { + table := make([]byte, size) + err = getTcp6Table(&table[0], &size, false) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, err + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil + } +} + +func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) { + var size uint32 + err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedTcpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil + } +} + +func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) { + var size uint32 + err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + for { + table := make([]byte, size) + err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0) + if err != nil { + if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + continue + } + return nil, os.NewSyscallError("GetExtendedUdpTable", err) + } + dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0]))) + return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil + } +} + +func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcpConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) { + var rod TcpEstatsSendBufferRodV0 + err := getPerTcp6ConnectionEStats(row, + TcpConnectionEstatsSendBuff, + 0, + 0, + 0, + 0, + 0, + 0, + uintptr(unsafe.Pointer(&rod)), + 0, + uint64(sizeOfTcpEstatsSendBufferRodV0), + ) + if err != nil { + return nil, err + } + return &rod, nil +} + +func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} + +func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error { + return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0) +} diff --git a/common/winiphlpapi/iphlpapi_test.go b/common/winiphlpapi/iphlpapi_test.go new file mode 100644 index 0000000..5fc3b74 --- /dev/null +++ b/common/winiphlpapi/iphlpapi_test.go @@ -0,0 +1,90 @@ +//go:build windows + +package winiphlpapi_test + +import ( + "context" + "net" + "syscall" + "testing" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/winiphlpapi" + + "github.com/stretchr/testify/require" +) + +func TestFindPidTcp4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidTcp6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp4(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestFindPidUdp6(t *testing.T) { + t.Parallel() + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) + defer conn.Close() + pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr())) + require.NoError(t, err) + require.Equal(t, uint32(syscall.Getpid()), pid) +} + +func TestWaitAck4(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} + +func TestWaitAck6(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + defer listener.Close() + go listener.Accept() + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello")) + require.NoError(t, err) +} diff --git a/common/winiphlpapi/syscall_windows.go b/common/winiphlpapi/syscall_windows.go new file mode 100644 index 0000000..f6aab14 --- /dev/null +++ b/common/winiphlpapi/syscall_windows.go @@ -0,0 +1,27 @@ +package winiphlpapi + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable +//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table +//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats +//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats +//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats +//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats +//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable +//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable + +// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable +//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable diff --git a/common/winiphlpapi/zsyscall_windows.go b/common/winiphlpapi/zsyscall_windows.go new file mode 100644 index 0000000..e5e9308 --- /dev/null +++ b/common/winiphlpapi/zsyscall_windows.go @@ -0,0 +1,131 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winiphlpapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + + procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable") + procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats") + procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats") + procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table") + procGetTcpTable = modiphlpapi.NewProc("GetTcpTable") + procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats") + procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats") +) + +func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) { + var _p0 uint32 + if bOrder { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) { + r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) { + var _p0 uint32 + if order { + _p0 = 1 + } + r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) { + r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset)) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} From ea82ac275feb97afaf073cfd9118e5528a6afbf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 24 Feb 2025 18:25:10 +0800 Subject: [PATCH 54/55] Add freelru.GetWithLifetimeNoExpire --- contrab/freelru/cache.go | 2 ++ contrab/freelru/lru.go | 18 ++++++++++++++++++ contrab/freelru/shardedlru.go | 11 +++++++++++ contrab/freelru/syncedlru.go | 10 ++++++++++ 4 files changed, 41 insertions(+) diff --git a/contrab/freelru/cache.go b/contrab/freelru/cache.go index e1877fb..22488e0 100644 --- a/contrab/freelru/cache.go +++ b/contrab/freelru/cache.go @@ -49,6 +49,8 @@ type Cache[K comparable, V comparable] interface { GetWithLifetime(key K) (V, time.Time, bool) + GetWithLifetimeNoExpire(key K) (V, time.Time, bool) + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index 6c31857..055057c 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -500,6 +500,24 @@ func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime tim return } +func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + return lru.getWithLifetimeNoExpire(lru.hash(key), key) +} + +func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) { + if pos, ok := lru.findKeyNoExpire(hash, key); ok { + if pos != lru.head { + lru.unlinkElement(pos) + lru.setHead(pos) + } + lru.metrics.Hits++ + return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok + } + + lru.metrics.Misses++ + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. diff --git a/contrab/freelru/shardedlru.go b/contrab/freelru/shardedlru.go index db97efa..fa325d7 100644 --- a/contrab/freelru/shardedlru.go +++ b/contrab/freelru/shardedlru.go @@ -187,6 +187,17 @@ func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time return } +func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.hash(key) + shard := (hash >> 16) & lru.mask + + lru.mus[shard].RLock() + value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key) + lru.mus[shard].RUnlock() + + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. diff --git a/contrab/freelru/syncedlru.go b/contrab/freelru/syncedlru.go index 38a854a..4364907 100644 --- a/contrab/freelru/syncedlru.go +++ b/contrab/freelru/syncedlru.go @@ -108,6 +108,16 @@ func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, return } +func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) { + hash := lru.lru.hash(key) + + lru.mu.Lock() + value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key) + lru.mu.Unlock() + + return +} + // GetAndRefresh returns the value associated with the key, setting it as the most // recently used item. // The lifetime of the found cache item is refreshed, even if it was already expired. From d39c2c2fddfa195cb59f5b0569382eaaad1892e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 19 Mar 2025 20:00:11 +0800 Subject: [PATCH 55/55] socks: Add custom udp listener --- protocol/socks/handshake.go | 27 +++++- protocol/socks/handshake_tor.go | 146 ++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 protocol/socks/handshake_tor.go diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 73a16b8..dc9c057 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -25,6 +25,10 @@ type HandlerEx interface { N.UDPConnectionHandlerEx } +type PacketListener interface { + ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error) +} + func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { err := socks4.WriteRequest(conn, socks4.Request{ Command: command, @@ -121,6 +125,8 @@ func HandleConnectionEx( ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler HandlerEx, + packetListener PacketListener, + // resolver TorResolver, source M.Socksaddr, onClose N.CloseHandlerFunc, ) error { @@ -148,6 +154,11 @@ func HandleConnectionEx( } handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose) return nil + /*case CommandTorResolve, CommandTorResolvePTR: + if resolver == nil { + return E.New("socks4: torsocks: commands not implemented") + } + return handleTorSocks4(ctx, conn, request, resolver)*/ default: err = socks4.WriteResponse(conn, socks4.Response{ ReplyCode: socks4.ReplyCodeRejectedOrFailed, @@ -214,8 +225,15 @@ func HandleConnectionEx( handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose) return nil case socks5.CommandUDPAssociate: - var udpConn *net.UDPConn - udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) + var ( + listenConfig net.ListenConfig + udpConn net.PacketConn + ) + if packetListener != nil { + udpConn, err = packetListener.ListenPacket(listenConfig, ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String()) + } else { + udpConn, err = listenConfig.ListenPacket(ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String()) + } if err != nil { return E.Cause(err, "socks5: listen udp") } @@ -236,6 +254,11 @@ func HandleConnectionEx( socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination) handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose) return nil + /*case CommandTorResolve, CommandTorResolvePTR: + if resolver == nil { + return E.New("socks4: torsocks: commands not implemented") + } + return handleTorSocks5(ctx, conn, request, resolver)*/ default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, diff --git a/protocol/socks/handshake_tor.go b/protocol/socks/handshake_tor.go new file mode 100644 index 0000000..5d66322 --- /dev/null +++ b/protocol/socks/handshake_tor.go @@ -0,0 +1,146 @@ +package socks + +import ( + "context" + "net" + "net/netip" + "os" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks/socks4" + "github.com/sagernet/sing/protocol/socks/socks5" +) + +const ( + CommandTorResolve byte = 0xF0 + CommandTorResolvePTR byte = 0xF1 +) + +type TorResolver interface { + LookupIP(ctx context.Context, host string) (netip.Addr, error) + LookupPTR(ctx context.Context, addr netip.Addr) (string, error) +} + +func handleTorSocks4(ctx context.Context, conn net.Conn, request socks4.Request, resolver TorResolver) error { + switch request.Command { + case CommandTorResolve: + if !request.Destination.IsFqdn() { + return E.New("socks4: torsocks: invalid destination") + } + ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn) + if err != nil { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + if err != nil { + return err + } + return E.Cause(err, "socks4: torsocks: lookup failed for domain: ", request.Destination.Fqdn) + } + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFrom(ipAddr, 0), + }) + if err != nil { + return E.Cause(err, "socks4: torsocks: write response") + } + return nil + case CommandTorResolvePTR: + var ipAddr netip.Addr + if request.Destination.IsIP() { + ipAddr = request.Destination.Addr + } else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") { + ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")]) + } else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") { + ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":")) + } + if !ipAddr.IsValid() { + return E.New("socks4: torsocks: invalid destination") + } + host, err := resolver.LookupPTR(ctx, ipAddr) + if err != nil { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + if err != nil { + return err + } + return E.Cause(err, "socks4: torsocks: lookup PTR failed for ip: ", ipAddr) + } + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.Socksaddr{ + Fqdn: host, + }, + }) + if err != nil { + return E.Cause(err, "socks4: torsocks: write response") + } + return nil + default: + return os.ErrInvalid + } +} + +func handleTorSocks5(ctx context.Context, conn net.Conn, request socks5.Request, resolver TorResolver) error { + switch request.Command { + case CommandTorResolve: + if !request.Destination.IsFqdn() { + return E.New("socks5: torsocks: invalid destination") + } + ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn) + if err != nil { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeFailure, + }) + if err != nil { + return err + } + return E.Cause(err, "socks5: torsocks: lookup failed for domain: ", request.Destination.Fqdn) + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFrom(ipAddr, 0), + }) + if err != nil { + return E.Cause(err, "socks5: torsocks: write response") + } + return nil + case CommandTorResolvePTR: + var ipAddr netip.Addr + if request.Destination.IsIP() { + ipAddr = request.Destination.Addr + } else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") { + ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")]) + } else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") { + ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":")) + } + if !ipAddr.IsValid() { + return E.New("socks5: torsocks: invalid destination") + } + host, err := resolver.LookupPTR(ctx, ipAddr) + if err != nil { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeFailure, + }) + if err != nil { + return err + } + return E.Cause(err, "socks5: torsocks: lookup PTR failed for ip: ", ipAddr) + } + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.Socksaddr{ + Fqdn: host, + }, + }) + if err != nil { + return E.Cause(err, "socks5: torsocks: write response") + } + return nil + default: + return os.ErrInvalid + } +}