mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-05 21:07:41 +03:00
Compare commits
No commits in common. "dev" and "v0.5.0-rc.4" have entirely different histories.
dev
...
v0.5.0-rc.
77 changed files with 490 additions and 4510 deletions
|
@ -2,10 +2,11 @@ package baderror
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Contains(err error, msgList ...string) bool {
|
func Contains(err error, msgList ...string) bool {
|
||||||
|
@ -21,7 +22,8 @@ func WrapH2(err error) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
err = E.Unwrap(err)
|
||||||
|
if err == io.ErrUnexpectedEOF {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
if Contains(err, "client disconnected", "body closed by handler", "response body closed", "; CANCEL") {
|
||||||
|
|
|
@ -9,20 +9,19 @@ import (
|
||||||
|
|
||||||
type AddrConn struct {
|
type AddrConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
Source M.Socksaddr
|
M.Metadata
|
||||||
Destination M.Socksaddr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AddrConn) LocalAddr() net.Addr {
|
func (c *AddrConn) LocalAddr() net.Addr {
|
||||||
if c.Destination.IsValid() {
|
if c.Metadata.Destination.IsValid() {
|
||||||
return c.Destination.TCPAddr()
|
return c.Metadata.Destination.TCPAddr()
|
||||||
}
|
}
|
||||||
return c.Conn.LocalAddr()
|
return c.Conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AddrConn) RemoteAddr() net.Addr {
|
func (c *AddrConn) RemoteAddr() net.Addr {
|
||||||
if c.Source.IsValid() {
|
if c.Metadata.Source.IsValid() {
|
||||||
return c.Source.TCPAddr()
|
return c.Metadata.Source.TCPAddr()
|
||||||
}
|
}
|
||||||
return c.Conn.RemoteAddr()
|
return c.Conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
|
@ -184,12 +184,10 @@ func (c *CachedPacketConn) ReadCachedPacket() *N.PacketBuffer {
|
||||||
if buffer != nil {
|
if buffer != nil {
|
||||||
buffer.DecRef()
|
buffer.DecRef()
|
||||||
}
|
}
|
||||||
packet := N.NewPacketBuffer()
|
return &N.PacketBuffer{
|
||||||
*packet = N.PacketBuffer{
|
|
||||||
Buffer: buffer,
|
Buffer: buffer,
|
||||||
Destination: c.destination,
|
Destination: c.destination,
|
||||||
}
|
}
|
||||||
return packet
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CachedPacketConn) Upstream() any {
|
func (c *CachedPacketConn) Upstream() any {
|
||||||
|
|
|
@ -35,7 +35,14 @@ func (w *ExtendedUDPConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
|
||||||
|
|
||||||
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
func (w *ExtendedUDPConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
return common.Error(w.UDPConn.WriteToUDPAddrPort(buffer.Bytes(), destination.AddrPort()))
|
if destination.IsFqdn() {
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", destination.String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return common.Error(w.UDPConn.WriteTo(buffer.Bytes(), udpAddr))
|
||||||
|
}
|
||||||
|
return common.Error(w.UDPConn.WriteToUDP(buffer.Bytes(), destination.UDPAddr()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ExtendedUDPConn) Upstream() any {
|
func (w *ExtendedUDPConn) Upstream() any {
|
||||||
|
|
|
@ -29,35 +29,27 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
|
||||||
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
||||||
cachedBuffer := cachedSrc.ReadCached()
|
cachedBuffer := cachedSrc.ReadCached()
|
||||||
if cachedBuffer != nil {
|
if cachedBuffer != nil {
|
||||||
dataLen := cachedBuffer.Len()
|
if !cachedBuffer.IsEmpty() {
|
||||||
_, err = destination.Write(cachedBuffer.Bytes())
|
_, err = destination.Write(cachedBuffer.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
cachedBuffer.Release()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
cachedBuffer.Release()
|
cachedBuffer.Release()
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, counter := range readCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
for _, counter := range writeCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break
|
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
||||||
}
|
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
||||||
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
if srcIsSyscall && dstIsSyscall {
|
||||||
}
|
var handled bool
|
||||||
|
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
||||||
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
if handled {
|
||||||
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
|
return
|
||||||
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
|
}
|
||||||
if srcIsSyscall && dstIsSyscall {
|
|
||||||
var handled bool
|
|
||||||
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
|
|
||||||
if handled {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
|
@ -83,7 +75,6 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
|
||||||
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: not used
|
|
||||||
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
buffer.IncRef()
|
buffer.IncRef()
|
||||||
defer buffer.DecRef()
|
defer buffer.DecRef()
|
||||||
|
@ -122,10 +113,19 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
||||||
options := N.NewReadWaitOptions(source, destination)
|
frontHeadroom := N.CalculateFrontHeadroom(destination)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destination)
|
||||||
|
bufferSize := N.CalculateMTU(source, destination)
|
||||||
|
if bufferSize > 0 {
|
||||||
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
|
} else {
|
||||||
|
bufferSize = buf.BufferSize
|
||||||
|
}
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for {
|
for {
|
||||||
buffer := options.NewBuffer()
|
buffer := buf.NewSize(bufferSize)
|
||||||
|
buffer.Resize(frontHeadroom, 0)
|
||||||
|
buffer.Reserve(rearHeadroom)
|
||||||
err = source.ReadBuffer(buffer)
|
err = source.ReadBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
|
@ -136,7 +136,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
options.PostReturn(buffer)
|
buffer.OverCap(rearHeadroom)
|
||||||
err = destination.WriteBuffer(buffer)
|
err = destination.WriteBuffer(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Leak()
|
buffer.Leak()
|
||||||
|
@ -196,6 +196,18 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
|
||||||
return group.Run(ctx)
|
return group.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: not used
|
||||||
|
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
|
||||||
|
switch len(contextList) {
|
||||||
|
case 0:
|
||||||
|
return CopyConn(context.Background(), source, destination)
|
||||||
|
case 1:
|
||||||
|
return CopyConn(contextList[0], source, destination)
|
||||||
|
default:
|
||||||
|
panic("invalid context list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
|
||||||
var readCounters, writeCounters []N.CountFunc
|
var readCounters, writeCounters []N.CountFunc
|
||||||
var cachedPackets []*N.PacketBuffer
|
var cachedPackets []*N.PacketBuffer
|
||||||
|
@ -213,24 +225,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if cachedPackets != nil {
|
if cachedPackets != nil {
|
||||||
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
|
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
n += copeN
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
|
||||||
var (
|
var (
|
||||||
handled bool
|
handled bool
|
||||||
copeN int64
|
copeN int64
|
||||||
)
|
)
|
||||||
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
|
||||||
if isReadWaiter {
|
if isReadWaiter {
|
||||||
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
|
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
|
||||||
|
FrontHeadroom: frontHeadroom,
|
||||||
|
RearHeadroom: rearHeadroom,
|
||||||
|
MTU: N.CalculateMTU(source, destinationConn),
|
||||||
|
})
|
||||||
if !needCopy || common.LowMemory {
|
if !needCopy || common.LowMemory {
|
||||||
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
|
||||||
if handled {
|
if handled {
|
||||||
|
@ -244,19 +256,28 @@ func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReade
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
|
||||||
options := N.NewReadWaitOptions(source, destination)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
var destinationAddress M.Socksaddr
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
|
bufferSize := N.CalculateMTU(source, destinationConn)
|
||||||
|
if bufferSize > 0 {
|
||||||
|
bufferSize += frontHeadroom + rearHeadroom
|
||||||
|
} else {
|
||||||
|
bufferSize = buf.UDPBufferSize
|
||||||
|
}
|
||||||
|
var destination M.Socksaddr
|
||||||
for {
|
for {
|
||||||
buffer := options.NewPacketBuffer()
|
buffer := buf.NewSize(bufferSize)
|
||||||
destinationAddress, err = source.ReadPacket(buffer)
|
buffer.Resize(frontHeadroom, 0)
|
||||||
|
buffer.Reserve(rearHeadroom)
|
||||||
|
destination, err = source.ReadPacket(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
options.PostReturn(buffer)
|
buffer.OverCap(rearHeadroom)
|
||||||
err = destination.WritePacket(buffer, destinationAddress)
|
err = destinationConn.WritePacket(buffer, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Leak()
|
buffer.Leak()
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
|
@ -264,25 +285,34 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
n += int64(dataLen)
|
||||||
for _, counter := range readCounters {
|
for _, counter := range readCounters {
|
||||||
counter(int64(dataLen))
|
counter(int64(dataLen))
|
||||||
}
|
}
|
||||||
for _, counter := range writeCounters {
|
for _, counter := range writeCounters {
|
||||||
counter(int64(dataLen))
|
counter(int64(dataLen))
|
||||||
}
|
}
|
||||||
n += int64(dataLen)
|
|
||||||
notFirstTime = true
|
notFirstTime = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
|
func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
|
||||||
options := N.NewReadWaitOptions(nil, destination)
|
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
|
||||||
|
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
for _, packetBuffer := range packetBuffers {
|
for _, packetBuffer := range packetBuffers {
|
||||||
buffer := options.Copy(packetBuffer.Buffer)
|
buffer := buf.NewPacket()
|
||||||
|
buffer.Resize(frontHeadroom, 0)
|
||||||
|
buffer.Reserve(rearHeadroom)
|
||||||
|
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
|
||||||
|
packetBuffer.Buffer.Release()
|
||||||
|
if err != nil {
|
||||||
|
buffer.Release()
|
||||||
|
continue
|
||||||
|
}
|
||||||
dataLen := buffer.Len()
|
dataLen := buffer.Len()
|
||||||
err = destination.WritePacket(buffer, packetBuffer.Destination)
|
buffer.OverCap(rearHeadroom)
|
||||||
N.PutPacketBuffer(packetBuffer)
|
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Leak()
|
buffer.Leak()
|
||||||
if !notFirstTime {
|
if !notFirstTime {
|
||||||
|
@ -290,14 +320,7 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, counter := range readCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
for _, counter := range writeCounters {
|
|
||||||
counter(int64(dataLen))
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
n += int64(dataLen)
|
||||||
notFirstTime = true
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -316,3 +339,15 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack
|
||||||
group.FastFail()
|
group.FastFail()
|
||||||
return group.Run(ctx)
|
return group.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: not used
|
||||||
|
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
|
||||||
|
switch len(contextList) {
|
||||||
|
case 0:
|
||||||
|
return CopyPacketConn(context.Background(), source, destination)
|
||||||
|
case 1:
|
||||||
|
return CopyPacketConn(contextList[0], source, destination)
|
||||||
|
default:
|
||||||
|
panic("invalid context list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -120,16 +120,16 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
|
||||||
var readN int
|
var readN int
|
||||||
var from windows.Sockaddr
|
var from windows.Sockaddr
|
||||||
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
|
||||||
//goland:noinspection GoDirectComparisonOfErrors
|
|
||||||
if w.readErr != nil {
|
|
||||||
buffer.Release()
|
|
||||||
return w.readErr != windows.WSAEWOULDBLOCK
|
|
||||||
}
|
|
||||||
if readN > 0 {
|
if readN > 0 {
|
||||||
buffer.Truncate(readN)
|
buffer.Truncate(readN)
|
||||||
|
w.options.PostReturn(buffer)
|
||||||
|
w.buffer = buffer
|
||||||
|
} else {
|
||||||
|
buffer.Release()
|
||||||
|
}
|
||||||
|
if w.readErr == windows.WSAEWOULDBLOCK {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
w.options.PostReturn(buffer)
|
|
||||||
w.buffer = buffer
|
|
||||||
if from != nil {
|
if from != nil {
|
||||||
switch fromAddr := from.(type) {
|
switch fromAddr := from.(type) {
|
||||||
case *windows.SockaddrInet4:
|
case *windows.SockaddrInet4:
|
||||||
|
|
|
@ -30,14 +30,6 @@ func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.So
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDestinationNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
|
|
||||||
return &destinationNATPacketConn{
|
|
||||||
NetPacketConn: conn,
|
|
||||||
origin: origin,
|
|
||||||
destination: destination,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type unidirectionalNATPacketConn struct {
|
type unidirectionalNATPacketConn struct {
|
||||||
N.NetPacketConn
|
N.NetPacketConn
|
||||||
origin M.Socksaddr
|
origin M.Socksaddr
|
||||||
|
@ -152,60 +144,6 @@ func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
|
||||||
return c.destination.UDPAddr()
|
return c.destination.UDPAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
type destinationNATPacketConn struct {
|
|
||||||
N.NetPacketConn
|
|
||||||
origin M.Socksaddr
|
|
||||||
destination M.Socksaddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
n, addr, err = c.NetPacketConn.ReadFrom(p)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if M.SocksaddrFromNet(addr) == c.origin {
|
|
||||||
addr = c.destination.UDPAddr()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
if M.SocksaddrFromNet(addr) == c.destination {
|
|
||||||
addr = c.origin.UDPAddr()
|
|
||||||
}
|
|
||||||
return c.NetPacketConn.WriteTo(p, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
destination, err = c.NetPacketConn.ReadPacket(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if destination == c.origin {
|
|
||||||
destination = c.destination
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
if destination == c.destination {
|
|
||||||
destination = c.origin
|
|
||||||
}
|
|
||||||
return c.NetPacketConn.WritePacket(buffer, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
|
|
||||||
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) Upstream() any {
|
|
||||||
return c.NetPacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *destinationNATPacketConn) RemoteAddr() net.Addr {
|
|
||||||
return c.destination.UDPAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
|
||||||
destination.Port = 0
|
destination.Port = 0
|
||||||
return destination
|
return destination
|
||||||
|
|
|
@ -38,6 +38,7 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
|
||||||
var innerErr unix.Errno
|
var innerErr unix.Errno
|
||||||
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
err := w.rawConn.Write(func(fd uintptr) (done bool) {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
|
//goland:noinspection GoDeprecation
|
||||||
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
|
||||||
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
|
||||||
})
|
})
|
||||||
|
|
|
@ -41,9 +41,9 @@ func (i *Instance) Timeout() time.Duration {
|
||||||
return i.timeout
|
return i.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) SetTimeout(timeout time.Duration) bool {
|
func (i *Instance) SetTimeout(timeout time.Duration) {
|
||||||
i.timeout = timeout
|
i.timeout = timeout
|
||||||
return i.Update()
|
i.Update()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) wait() {
|
func (i *Instance) wait() {
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
type PacketConn interface {
|
type PacketConn interface {
|
||||||
N.PacketConn
|
N.PacketConn
|
||||||
Timeout() time.Duration
|
Timeout() time.Duration
|
||||||
SetTimeout(timeout time.Duration) bool
|
SetTimeout(timeout time.Duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TimerPacketConn struct {
|
type TimerPacketConn struct {
|
||||||
|
@ -24,12 +24,10 @@ type TimerPacketConn struct {
|
||||||
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
|
||||||
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
|
||||||
oldTimeout := timeoutConn.Timeout()
|
oldTimeout := timeoutConn.Timeout()
|
||||||
if oldTimeout > 0 && timeout >= oldTimeout {
|
if timeout < oldTimeout {
|
||||||
return ctx, conn
|
timeoutConn.SetTimeout(timeout)
|
||||||
}
|
|
||||||
if timeoutConn.SetTimeout(timeout) {
|
|
||||||
return ctx, conn
|
|
||||||
}
|
}
|
||||||
|
return ctx, conn
|
||||||
}
|
}
|
||||||
err := conn.SetReadDeadline(time.Time{})
|
err := conn.SetReadDeadline(time.Time{})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -60,8 +58,8 @@ func (c *TimerPacketConn) Timeout() time.Duration {
|
||||||
return c.instance.Timeout()
|
return c.instance.Timeout()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) bool {
|
func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
|
||||||
return c.instance.SetTimeout(timeout)
|
c.instance.SetTimeout(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimerPacketConn) Close() error {
|
func (c *TimerPacketConn) Close() error {
|
||||||
|
|
|
@ -61,9 +61,9 @@ func (c *TimeoutPacketConn) Timeout() time.Duration {
|
||||||
return c.timeout
|
return c.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) bool {
|
func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
|
||||||
c.timeout = timeout
|
c.timeout = timeout
|
||||||
return c.PacketConn.SetReadDeadline(time.Now()) == nil
|
c.PacketConn.SetReadDeadline(time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimeoutPacketConn) Close() error {
|
func (c *TimeoutPacketConn) Close() error {
|
||||||
|
|
|
@ -157,18 +157,6 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func Equal[S ~[]E, E comparable](s1, s2 S) bool {
|
|
||||||
if len(s1) != len(s2) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range s1 {
|
|
||||||
if s1[i] != s2[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:norace
|
//go:norace
|
||||||
func Dup[T any](obj T) T {
|
func Dup[T any](obj T) T {
|
||||||
pointer := uintptr(unsafe.Pointer(&obj))
|
pointer := uintptr(unsafe.Pointer(&obj))
|
||||||
|
@ -280,14 +268,6 @@ func Reverse[T any](arr []T) []T {
|
||||||
return arr
|
return arr
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
|
|
||||||
ret := make(map[V]K, len(m))
|
|
||||||
for k, v := range m {
|
|
||||||
ret[v] = k
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func Done(ctx context.Context) bool {
|
func Done(ctx context.Context) bool {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -382,3 +362,24 @@ func Close(closers ...any) error {
|
||||||
}
|
}
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: wtf is this?
|
||||||
|
type Starter interface {
|
||||||
|
Start() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: wtf is this?
|
||||||
|
func Start(starters ...any) error {
|
||||||
|
for _, rawStarter := range starters {
|
||||||
|
if rawStarter == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if starter, isStarter := rawStarter.(Starter); isStarter {
|
||||||
|
err := starter.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -9,15 +9,15 @@ import (
|
||||||
|
|
||||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
|
var err error
|
||||||
if interfaceIndex == -1 {
|
if interfaceIndex == -1 {
|
||||||
if finder == nil {
|
if finder == nil {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
iif, err := finder.ByName(interfaceName)
|
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaceIndex = iif.Index
|
|
||||||
}
|
}
|
||||||
switch network {
|
switch network {
|
||||||
case "tcp6", "udp6":
|
case "tcp6", "udp6":
|
||||||
|
|
|
@ -3,57 +3,20 @@ package control
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceFinder interface {
|
type InterfaceFinder interface {
|
||||||
Update() error
|
Update() error
|
||||||
Interfaces() []Interface
|
Interfaces() []Interface
|
||||||
ByName(name string) (*Interface, error)
|
InterfaceIndexByName(name string) (int, error)
|
||||||
ByIndex(index int) (*Interface, error)
|
InterfaceNameByIndex(index int) (string, error)
|
||||||
ByAddr(addr netip.Addr) (*Interface, error)
|
InterfaceByAddr(addr netip.Addr) (*Interface, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
Index int
|
Index int
|
||||||
MTU int
|
MTU int
|
||||||
Name string
|
Name string
|
||||||
HardwareAddr net.HardwareAddr
|
|
||||||
Flags net.Flags
|
|
||||||
Addresses []netip.Prefix
|
Addresses []netip.Prefix
|
||||||
}
|
HardwareAddr net.HardwareAddr
|
||||||
|
|
||||||
func (i Interface) Equals(other Interface) bool {
|
|
||||||
return i.Index == other.Index &&
|
|
||||||
i.MTU == other.MTU &&
|
|
||||||
i.Name == other.Name &&
|
|
||||||
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
|
|
||||||
i.Flags == other.Flags &&
|
|
||||||
common.Equal(i.Addresses, other.Addresses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i Interface) NetInterface() net.Interface {
|
|
||||||
return *(*net.Interface)(unsafe.Pointer(&i))
|
|
||||||
}
|
|
||||||
|
|
||||||
func InterfaceFromNet(iif net.Interface) (Interface, error) {
|
|
||||||
ifAddrs, err := iif.Addrs()
|
|
||||||
if err != nil {
|
|
||||||
return Interface{}, err
|
|
||||||
}
|
|
||||||
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
|
|
||||||
return Interface{
|
|
||||||
Index: iif.Index,
|
|
||||||
MTU: iif.MTU,
|
|
||||||
Name: iif.Name,
|
|
||||||
HardwareAddr: iif.HardwareAddr,
|
|
||||||
Flags: iif.Flags,
|
|
||||||
Addresses: addresses,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,11 @@ package control
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
_ "unsafe"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
|
||||||
|
@ -24,12 +27,17 @@ func (f *DefaultInterfaceFinder) Update() error {
|
||||||
}
|
}
|
||||||
interfaces := make([]Interface, 0, len(netIfs))
|
interfaces := make([]Interface, 0, len(netIfs))
|
||||||
for _, netIf := range netIfs {
|
for _, netIf := range netIfs {
|
||||||
var iif Interface
|
ifAddrs, err := netIf.Addrs()
|
||||||
iif, err = InterfaceFromNet(netIf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaces = append(interfaces, iif)
|
interfaces = append(interfaces, Interface{
|
||||||
|
Index: netIf.Index,
|
||||||
|
MTU: netIf.MTU,
|
||||||
|
Name: netIf.Name,
|
||||||
|
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
|
||||||
|
HardwareAddr: netIf.HardwareAddr,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
f.interfaces = interfaces
|
f.interfaces = interfaces
|
||||||
return nil
|
return nil
|
||||||
|
@ -43,41 +51,46 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface {
|
||||||
return f.interfaces
|
return f.interfaces
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
|
func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
|
||||||
for _, netInterface := range f.interfaces {
|
for _, netInterface := range f.interfaces {
|
||||||
if netInterface.Name == name {
|
if netInterface.Name == name {
|
||||||
return &netInterface, nil
|
return netInterface.Index, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := net.InterfaceByName(name)
|
netInterface, err := net.InterfaceByName(name)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
err = f.Update()
|
return 0, err
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return f.ByName(name)
|
|
||||||
}
|
}
|
||||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
f.Update()
|
||||||
|
return netInterface.Index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
|
func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
|
||||||
for _, netInterface := range f.interfaces {
|
for _, netInterface := range f.interfaces {
|
||||||
if netInterface.Index == index {
|
if netInterface.Index == index {
|
||||||
return &netInterface, nil
|
return netInterface.Name, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := net.InterfaceByIndex(index)
|
netInterface, err := net.InterfaceByIndex(index)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
err = f.Update()
|
return "", err
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return f.ByIndex(index)
|
|
||||||
}
|
}
|
||||||
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
|
f.Update()
|
||||||
|
return netInterface.Name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
|
func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
|
||||||
|
for _, netInterface := range f.interfaces {
|
||||||
|
for _, prefix := range netInterface.Addresses {
|
||||||
|
if prefix.Contains(addr) {
|
||||||
|
return &netInterface, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := f.Update()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
for _, netInterface := range f.interfaces {
|
for _, netInterface := range f.interfaces {
|
||||||
for _, prefix := range netInterface.Addresses {
|
for _, prefix := range netInterface.Addresses {
|
||||||
if prefix.Contains(addr) {
|
if prefix.Contains(addr) {
|
||||||
|
|
|
@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
iif, err := finder.ByName(interfaceName)
|
var err error
|
||||||
|
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaceIndex = iif.Index
|
|
||||||
}
|
}
|
||||||
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -11,19 +11,19 @@ import (
|
||||||
|
|
||||||
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
|
var err error
|
||||||
if interfaceIndex == -1 {
|
if interfaceIndex == -1 {
|
||||||
if finder == nil {
|
if finder == nil {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
iif, err := finder.ByName(interfaceName)
|
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
interfaceIndex = iif.Index
|
|
||||||
}
|
}
|
||||||
handle := syscall.Handle(fd)
|
handle := syscall.Handle(fd)
|
||||||
if M.ParseSocksaddr(address).AddrString() == "" {
|
if M.ParseSocksaddr(address).AddrString() == "" {
|
||||||
err := bind4(handle, interfaceIndex)
|
err = bind4(handle, interfaceIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,26 +4,19 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DisableUDPFragment() Func {
|
func DisableUDPFragment() Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
if N.NetworkName(network) != N.NetworkUDP {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
if network == "udp" || network == "udp4" {
|
switch network {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1)
|
case "udp4":
|
||||||
if err != nil {
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_DONTFRAG, 1); err != nil {
|
||||||
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
return os.NewSyscallError("SETSOCKOPT IP_DONTFRAG", err)
|
||||||
}
|
}
|
||||||
}
|
case "udp6":
|
||||||
if network == "udp" || network == "udp6" {
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1); err != nil {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_DONTFRAG, 1)
|
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
return os.NewSyscallError("SETSOCKOPT IPV6_DONTFRAG", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,19 +11,17 @@ import (
|
||||||
|
|
||||||
func DisableUDPFragment() Func {
|
func DisableUDPFragment() Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
if N.NetworkName(network) != N.NetworkUDP {
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkUDP:
|
||||||
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
if network == "udp" || network == "udp4" {
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if network == "udp" || network == "udp6" {
|
if network == "udp6" {
|
||||||
err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO); err != nil {
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,19 +25,17 @@ const (
|
||||||
|
|
||||||
func DisableUDPFragment() Func {
|
func DisableUDPFragment() Func {
|
||||||
return func(network, address string, conn syscall.RawConn) error {
|
return func(network, address string, conn syscall.RawConn) error {
|
||||||
if N.NetworkName(network) != N.NetworkUDP {
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkUDP:
|
||||||
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return Raw(conn, func(fd uintptr) error {
|
return Raw(conn, func(fd uintptr) error {
|
||||||
if network == "udp" || network == "udp4" {
|
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
|
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IP_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if network == "udp" || network == "udp6" {
|
if network == "udp6" {
|
||||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO)
|
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
return os.NewSyscallError("SETSOCKOPT IPV6_MTU_DISCOVER IP_PMTUDISC_DO", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,16 +12,3 @@ func (e *causeError) Error() string {
|
||||||
func (e *causeError) Unwrap() error {
|
func (e *causeError) Unwrap() error {
|
||||||
return e.cause
|
return e.cause
|
||||||
}
|
}
|
||||||
|
|
||||||
type causeError1 struct {
|
|
||||||
error
|
|
||||||
cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *causeError1) Error() string {
|
|
||||||
return e.error.Error() + ": " + e.cause.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *causeError1) Unwrap() []error {
|
|
||||||
return []error{e.error, e.cause}
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
F "github.com/sagernet/sing/common/format"
|
F "github.com/sagernet/sing/common/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
NewError(ctx context.Context, err error)
|
NewError(ctx context.Context, err error)
|
||||||
}
|
}
|
||||||
|
@ -32,13 +31,6 @@ func Cause(cause error, message ...any) error {
|
||||||
return &causeError{F.ToString(message...), cause}
|
return &causeError{F.ToString(message...), cause}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Cause1(err error, cause error) error {
|
|
||||||
if cause == nil {
|
|
||||||
panic("cause on an nil error")
|
|
||||||
}
|
|
||||||
return &causeError1{err, cause}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Extend(cause error, message ...any) error {
|
func Extend(cause error, message ...any) error {
|
||||||
if cause == nil {
|
if cause == nil {
|
||||||
panic("extend on an nil error")
|
panic("extend on an nil error")
|
||||||
|
@ -47,11 +39,11 @@ func Extend(cause error, message ...any) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsClosedOrCanceled(err error) bool {
|
func IsClosedOrCanceled(err error) bool {
|
||||||
return IsClosed(err) || IsCanceled(err) || IsTimeout(err)
|
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, context.Canceled, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsClosed(err error) bool {
|
func IsClosed(err error) bool {
|
||||||
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET, syscall.ENOTCONN)
|
return IsMulti(err, io.EOF, net.ErrClosed, io.ErrClosedPipe, os.ErrClosed, syscall.EPIPE, syscall.ECONNRESET)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsCanceled(err error) bool {
|
func IsCanceled(err error) bool {
|
||||||
|
|
|
@ -1,14 +1,24 @@
|
||||||
package exceptions
|
package exceptions
|
||||||
|
|
||||||
import (
|
import "github.com/sagernet/sing/common"
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
type HasInnerError interface {
|
||||||
)
|
Unwrap() error
|
||||||
|
}
|
||||||
|
|
||||||
// Deprecated: Use errors.Unwrap instead.
|
|
||||||
func Unwrap(err error) error {
|
func Unwrap(err error) error {
|
||||||
return errors.Unwrap(err)
|
for {
|
||||||
|
inner, ok := err.(HasInnerError)
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
innerErr := inner.Unwrap()
|
||||||
|
if innerErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = innerErr
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Cast[T any](err error) (T, bool) {
|
func Cast[T any](err error) (T, bool) {
|
||||||
|
|
|
@ -63,5 +63,12 @@ func IsMulti(err error, targetList ...error) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
err = Unwrap(err)
|
||||||
|
multiErr, isMulti := err.(MultiError)
|
||||||
|
if !isMulti {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return common.All(multiErr.Unwrap(), func(it error) bool {
|
||||||
|
return IsMulti(it, targetList...)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ type TimeoutError interface {
|
||||||
func IsTimeout(err error) bool {
|
func IsTimeout(err error) bool {
|
||||||
var netErr net.Error
|
var netErr net.Error
|
||||||
if errors.As(err, &netErr) {
|
if errors.As(err, &netErr) {
|
||||||
|
//goland:noinspection GoDeprecation
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
return netErr.Temporary() && netErr.Timeout()
|
return netErr.Temporary() && netErr.Timeout()
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,14 +2,13 @@ package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/json"
|
"github.com/sagernet/sing/common/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Decode(ctx context.Context, content []byte) (any, error) {
|
func Decode(content []byte) (any, error) {
|
||||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||||
return decodeJSON(decoder)
|
return decodeJSON(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package badjson
|
package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
@ -10,75 +9,75 @@ import (
|
||||||
"github.com/sagernet/sing/common/json"
|
"github.com/sagernet/sing/common/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Omitempty[T any](ctx context.Context, value T) (T, error) {
|
func Omitempty[T any](value T) (T, error) {
|
||||||
objectContent, err := json.MarshalContext(ctx, value)
|
objectContent, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
return common.DefaultValue[T](), E.Cause(err, "marshal object")
|
||||||
}
|
}
|
||||||
rawNewObject, err := Decode(ctx, objectContent)
|
rawNewObject, err := Decode(objectContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), err
|
return common.DefaultValue[T](), err
|
||||||
}
|
}
|
||||||
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
|
newObjectContent, err := json.Marshal(rawNewObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
|
||||||
}
|
}
|
||||||
var newObject T
|
var newObject T
|
||||||
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
|
err = json.Unmarshal(newObjectContent, &newObject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
|
||||||
}
|
}
|
||||||
return newObject, nil
|
return newObject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
|
func Merge[T any](source T, destination T, disableAppend bool) (T, error) {
|
||||||
rawSource, err := json.MarshalContext(ctx, source)
|
rawSource, err := json.Marshal(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||||
}
|
}
|
||||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
rawDestination, err := json.Marshal(destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||||
}
|
}
|
||||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
return MergeFrom[T](rawSource, rawDestination, disableAppend)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
|
||||||
if rawSource == nil {
|
if rawSource == nil {
|
||||||
return destination, nil
|
return destination, nil
|
||||||
}
|
}
|
||||||
rawDestination, err := json.MarshalContext(ctx, destination)
|
rawDestination, err := json.Marshal(destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
|
||||||
}
|
}
|
||||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
return MergeFrom[T](rawSource, rawDestination, disableAppend)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||||
if rawDestination == nil {
|
if rawDestination == nil {
|
||||||
return source, nil
|
return source, nil
|
||||||
}
|
}
|
||||||
rawSource, err := json.MarshalContext(ctx, source)
|
rawSource, err := json.Marshal(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
return common.DefaultValue[T](), E.Cause(err, "marshal source")
|
||||||
}
|
}
|
||||||
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
|
return MergeFrom[T](rawSource, rawDestination, disableAppend)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
|
||||||
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
|
rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
return common.DefaultValue[T](), E.Cause(err, "merge options")
|
||||||
}
|
}
|
||||||
var merged T
|
var merged T
|
||||||
err = json.UnmarshalContext(ctx, rawMerged, &merged)
|
err = json.Unmarshal(rawMerged, &merged)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
|
||||||
}
|
}
|
||||||
return merged, nil
|
return merged, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
|
||||||
if rawSource == nil && rawDestination == nil {
|
if rawSource == nil && rawDestination == nil {
|
||||||
return nil, os.ErrInvalid
|
return nil, os.ErrInvalid
|
||||||
} else if rawSource == nil {
|
} else if rawSource == nil {
|
||||||
|
@ -86,16 +85,16 @@ func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination js
|
||||||
} else if rawDestination == nil {
|
} else if rawDestination == nil {
|
||||||
return rawSource, nil
|
return rawSource, nil
|
||||||
}
|
}
|
||||||
source, err := Decode(ctx, rawSource)
|
source, err := Decode(rawSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "decode source")
|
return nil, E.Cause(err, "decode source")
|
||||||
}
|
}
|
||||||
destination, err := Decode(ctx, rawDestination)
|
destination, err := Decode(rawDestination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Cause(err, "decode destination")
|
return nil, E.Cause(err, "decode destination")
|
||||||
}
|
}
|
||||||
if source == nil {
|
if source == nil {
|
||||||
return json.MarshalContext(ctx, destination)
|
return json.Marshal(destination)
|
||||||
} else if destination == nil {
|
} else if destination == nil {
|
||||||
return json.Marshal(source)
|
return json.Marshal(source)
|
||||||
}
|
}
|
||||||
|
@ -103,7 +102,7 @@ func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination js
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return json.MarshalContext(ctx, merged)
|
return json.Marshal(merged)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
|
func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
|
||||||
|
|
|
@ -1,44 +1,36 @@
|
||||||
package badjson
|
package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/json"
|
"github.com/sagernet/sing/common/json"
|
||||||
cJSON "github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func MarshallObjects(objects ...any) ([]byte, error) {
|
func MarshallObjects(objects ...any) ([]byte, error) {
|
||||||
return MarshallObjectsContext(context.Background(), objects...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
|
|
||||||
if len(objects) == 1 {
|
if len(objects) == 1 {
|
||||||
return json.Marshal(objects[0])
|
return json.Marshal(objects[0])
|
||||||
}
|
}
|
||||||
var content JSONObject
|
var content JSONObject
|
||||||
for _, object := range objects {
|
for _, object := range objects {
|
||||||
objectMap, err := newJSONObject(ctx, object)
|
objectMap, err := newJSONObject(object)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
content.PutAll(objectMap)
|
content.PutAll(objectMap)
|
||||||
}
|
}
|
||||||
return content.MarshalJSONContext(ctx)
|
return content.MarshalJSON()
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
|
func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
|
||||||
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
|
parentContent, err := newJSONObject(parentObject)
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
|
|
||||||
var content JSONObject
|
|
||||||
err := content.UnmarshalJSONContext(ctx, inputContent)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, key := range cJSON.ObjectKeys(reflect.TypeOf(parentObject)) {
|
var content JSONObject
|
||||||
|
err = content.UnmarshalJSON(inputContent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, key := range parentContent.Keys() {
|
||||||
content.Remove(key)
|
content.Remove(key)
|
||||||
}
|
}
|
||||||
if object == nil {
|
if object == nil {
|
||||||
|
@ -47,20 +39,20 @@ func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentO
|
||||||
}
|
}
|
||||||
return E.New("unexpected key: ", content.Keys()[0])
|
return E.New("unexpected key: ", content.Keys()[0])
|
||||||
}
|
}
|
||||||
inputContent, err = content.MarshalJSONContext(ctx)
|
inputContent, err = content.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
|
return json.UnmarshalDisallowUnknownFields(inputContent, object)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
|
func newJSONObject(object any) (*JSONObject, error) {
|
||||||
inputContent, err := json.MarshalContext(ctx, object)
|
inputContent, err := json.Marshal(object)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var content JSONObject
|
var content JSONObject
|
||||||
err = content.UnmarshalJSONContext(ctx, inputContent)
|
err = content.UnmarshalJSON(inputContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
@ -29,10 +28,6 @@ func (m *JSONObject) IsEmpty() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
func (m *JSONObject) MarshalJSON() ([]byte, error) {
|
||||||
return m.MarshalJSONContext(context.Background())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
buffer.WriteString("{")
|
buffer.WriteString("{")
|
||||||
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
|
||||||
|
@ -43,13 +38,13 @@ func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
})
|
})
|
||||||
iLen := len(items)
|
iLen := len(items)
|
||||||
for i, entry := range items {
|
for i, entry := range items {
|
||||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
keyContent, err := json.Marshal(entry.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||||
buffer.WriteString(": ")
|
buffer.WriteString(": ")
|
||||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
valueContent, err := json.Marshal(entry.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -63,11 +58,7 @@ func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
func (m *JSONObject) UnmarshalJSON(content []byte) error {
|
||||||
return m.UnmarshalJSONContext(context.Background(), content)
|
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
|
||||||
m.Clear()
|
m.Clear()
|
||||||
objectStart, err := decoder.Token()
|
objectStart, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package badjson
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
@ -15,22 +14,18 @@ type TypedMap[K comparable, V any] struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||||
return m.MarshalJSONContext(context.Background())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
buffer.WriteString("{")
|
buffer.WriteString("{")
|
||||||
items := m.Entries()
|
items := m.Entries()
|
||||||
iLen := len(items)
|
iLen := len(items)
|
||||||
for i, entry := range items {
|
for i, entry := range items {
|
||||||
keyContent, err := json.MarshalContext(ctx, entry.Key)
|
keyContent, err := json.Marshal(entry.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
buffer.WriteString(strings.TrimSpace(string(keyContent)))
|
||||||
buffer.WriteString(": ")
|
buffer.WriteString(": ")
|
||||||
valueContent, err := json.MarshalContext(ctx, entry.Value)
|
valueContent, err := json.Marshal(entry.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -44,11 +39,7 @@ func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
|
||||||
return m.UnmarshalJSONContext(context.Background(), content)
|
decoder := json.NewDecoder(bytes.NewReader(content))
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
|
|
||||||
m.Clear()
|
m.Clear()
|
||||||
objectStart, err := decoder.Token()
|
objectStart, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +47,7 @@ func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byt
|
||||||
} else if objectStart != json.Delim('{') {
|
} else if objectStart != json.Delim('{') {
|
||||||
return E.New("expected json object start, but starts with ", objectStart)
|
return E.New("expected json object start, but starts with ", objectStart)
|
||||||
}
|
}
|
||||||
err = m.decodeJSON(ctx, decoder)
|
err = m.decodeJSON(decoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "decode json object content")
|
return E.Cause(err, "decode json object content")
|
||||||
}
|
}
|
||||||
|
@ -69,18 +60,18 @@ func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byt
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
|
func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
|
||||||
for decoder.More() {
|
for decoder.More() {
|
||||||
keyToken, err := decoder.Token()
|
keyToken, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
keyContent, err := json.MarshalContext(ctx, keyToken)
|
keyContent, err := json.Marshal(keyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var entryKey K
|
var entryKey K
|
||||||
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
|
err = json.Unmarshal(keyContent, &entryKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,35 +1,30 @@
|
||||||
package badoption
|
package badoption
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/json"
|
"github.com/sagernet/sing/common/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listable[T any] []T
|
type Listable[T any] []T
|
||||||
|
|
||||||
func (l Listable[T]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
func (l Listable[T]) MarshalJSON() ([]byte, error) {
|
||||||
arrayList := []T(l)
|
arrayList := []T(l)
|
||||||
if len(arrayList) == 1 {
|
if len(arrayList) == 1 {
|
||||||
return json.Marshal(arrayList[0])
|
return json.Marshal(arrayList[0])
|
||||||
}
|
}
|
||||||
return json.MarshalContext(ctx, arrayList)
|
return json.Marshal(arrayList)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Listable[T]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
func (l *Listable[T]) UnmarshalJSON(content []byte) error {
|
||||||
if string(content) == "null" {
|
err := json.UnmarshalDisallowUnknownFields(content, (*[]T)(l))
|
||||||
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var singleItem T
|
var singleItem T
|
||||||
err := json.UnmarshalContextDisallowUnknownFields(ctx, content, &singleItem)
|
newError := json.UnmarshalDisallowUnknownFields(content, &singleItem)
|
||||||
if err == nil {
|
if newError != nil {
|
||||||
*l = []T{singleItem}
|
return E.Errors(err, newError)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
newErr := json.UnmarshalContextDisallowUnknownFields(ctx, content, (*[]T)(l))
|
*l = []T{singleItem}
|
||||||
if newErr == nil {
|
return nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return E.Errors(err, newErr)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,13 +35,6 @@ func (a *Addr) UnmarshalJSON(content []byte) error {
|
||||||
|
|
||||||
type Prefix netip.Prefix
|
type Prefix netip.Prefix
|
||||||
|
|
||||||
func (p *Prefix) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
|
||||||
if p == nil {
|
|
||||||
return defaultPrefix
|
|
||||||
}
|
|
||||||
return netip.Prefix(*p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Prefix) MarshalJSON() ([]byte, error) {
|
func (p *Prefix) MarshalJSON() ([]byte, error) {
|
||||||
return json.Marshal(netip.Prefix(*p).String())
|
return json.Marshal(netip.Prefix(*p).String())
|
||||||
}
|
}
|
||||||
|
@ -62,13 +55,6 @@ func (p *Prefix) UnmarshalJSON(content []byte) error {
|
||||||
|
|
||||||
type Prefixable netip.Prefix
|
type Prefixable netip.Prefix
|
||||||
|
|
||||||
func (p *Prefixable) Build(defaultPrefix netip.Prefix) netip.Prefix {
|
|
||||||
if p == nil {
|
|
||||||
return defaultPrefix
|
|
||||||
}
|
|
||||||
return netip.Prefix(*p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Prefixable) MarshalJSON() ([]byte, error) {
|
func (p *Prefixable) MarshalJSON() ([]byte, error) {
|
||||||
prefix := netip.Prefix(*p)
|
prefix := netip.Prefix(*p)
|
||||||
if prefix.Bits() == prefix.Addr().BitLen() {
|
if prefix.Bits() == prefix.Addr().BitLen() {
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
MarshalContext = json.MarshalContext
|
|
||||||
UnmarshalContext = json.UnmarshalContext
|
|
||||||
NewEncoderContext = json.NewEncoderContext
|
|
||||||
NewDecoderContext = json.NewDecoderContext
|
|
||||||
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
|
|
||||||
)
|
|
||||||
|
|
||||||
type ContextMarshaler interface {
|
|
||||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextUnmarshaler interface {
|
|
||||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type ContextMarshaler interface {
|
|
||||||
MarshalJSONContext(ctx context.Context) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextUnmarshaler interface {
|
|
||||||
UnmarshalJSONContext(ctx context.Context, content []byte) error
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
package json_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type myStruct struct {
|
|
||||||
value string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
|
|
||||||
return json.Marshal(ctx.Value("key").(string))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
|
|
||||||
m.value = ctx.Value("key").(string)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
func TestMarshalContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ctx := context.WithValue(context.Background(), "key", "value")
|
|
||||||
var s myStruct
|
|
||||||
b, err := json.MarshalContext(ctx, &s)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, []byte(`"value"`), b)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:staticcheck
|
|
||||||
func TestUnmarshalContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ctx := context.WithValue(context.Background(), "key", "value")
|
|
||||||
var s myStruct
|
|
||||||
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "value", s.value)
|
|
||||||
}
|
|
|
@ -8,7 +8,6 @@
|
||||||
package json
|
package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding"
|
"encoding"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -96,15 +95,10 @@ import (
|
||||||
// Instead, they are replaced by the Unicode replacement
|
// Instead, they are replaced by the Unicode replacement
|
||||||
// character U+FFFD.
|
// character U+FFFD.
|
||||||
func Unmarshal(data []byte, v any) error {
|
func Unmarshal(data []byte, v any) error {
|
||||||
return UnmarshalContext(context.Background(), data, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
|
|
||||||
// Check for well-formedness.
|
// Check for well-formedness.
|
||||||
// Avoids filling out half a data structure
|
// Avoids filling out half a data structure
|
||||||
// before discovering a JSON syntax error.
|
// before discovering a JSON syntax error.
|
||||||
var d decodeState
|
var d decodeState
|
||||||
d.ctx = ctx
|
|
||||||
err := checkValid(data, &d.scan)
|
err := checkValid(data, &d.scan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -215,7 +209,6 @@ type errorContext struct {
|
||||||
|
|
||||||
// decodeState represents the state while decoding a JSON value.
|
// decodeState represents the state while decoding a JSON value.
|
||||||
type decodeState struct {
|
type decodeState struct {
|
||||||
ctx context.Context
|
|
||||||
data []byte
|
data []byte
|
||||||
off int // next read offset in data
|
off int // next read offset in data
|
||||||
opcode int // last read result
|
opcode int // last read result
|
||||||
|
@ -435,7 +428,7 @@ func (d *decodeState) valueQuoted() any {
|
||||||
// If it encounters an Unmarshaler, indirect stops and returns that.
|
// If it encounters an Unmarshaler, indirect stops and returns that.
|
||||||
// If decodingNull is true, indirect stops at the first settable pointer so it
|
// If decodingNull is true, indirect stops at the first settable pointer so it
|
||||||
// can be set to nil.
|
// can be set to nil.
|
||||||
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||||
// Issue #24153 indicates that it is generally not a guaranteed property
|
// Issue #24153 indicates that it is generally not a guaranteed property
|
||||||
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
|
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
|
||||||
// and expect the value to still be settable for values derived from
|
// and expect the value to still be settable for values derived from
|
||||||
|
@ -489,14 +482,11 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshal
|
||||||
}
|
}
|
||||||
if v.Type().NumMethod() > 0 && v.CanInterface() {
|
if v.Type().NumMethod() > 0 && v.CanInterface() {
|
||||||
if u, ok := v.Interface().(Unmarshaler); ok {
|
if u, ok := v.Interface().(Unmarshaler); ok {
|
||||||
return u, nil, nil, reflect.Value{}
|
return u, nil, reflect.Value{}
|
||||||
}
|
|
||||||
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
|
|
||||||
return nil, cu, nil, reflect.Value{}
|
|
||||||
}
|
}
|
||||||
if !decodingNull {
|
if !decodingNull {
|
||||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||||
return nil, nil, u, reflect.Value{}
|
return nil, u, reflect.Value{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -508,14 +498,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshal
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, nil, nil, v
|
return nil, nil, v
|
||||||
}
|
}
|
||||||
|
|
||||||
// array consumes an array from d.data[d.off-1:], decoding into v.
|
// array consumes an array from d.data[d.off-1:], decoding into v.
|
||||||
// The first byte of the array ('[') has been read already.
|
// The first byte of the array ('[') has been read already.
|
||||||
func (d *decodeState) array(v reflect.Value) error {
|
func (d *decodeState) array(v reflect.Value) error {
|
||||||
// Check for unmarshaler.
|
// Check for unmarshaler.
|
||||||
u, cu, ut, pv := indirect(v, false)
|
u, ut, pv := indirect(v, false)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
start := d.readIndex()
|
start := d.readIndex()
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -525,15 +515,6 @@ func (d *decodeState) array(v reflect.Value) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cu != nil {
|
|
||||||
start := d.readIndex()
|
|
||||||
d.skip()
|
|
||||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
|
||||||
if err != nil {
|
|
||||||
d.saveError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ut != nil {
|
if ut != nil {
|
||||||
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
|
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -631,7 +612,7 @@ var (
|
||||||
// The first byte ('{') of the object has been read already.
|
// The first byte ('{') of the object has been read already.
|
||||||
func (d *decodeState) object(v reflect.Value) error {
|
func (d *decodeState) object(v reflect.Value) error {
|
||||||
// Check for unmarshaler.
|
// Check for unmarshaler.
|
||||||
u, cu, ut, pv := indirect(v, false)
|
u, ut, pv := indirect(v, false)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
start := d.readIndex()
|
start := d.readIndex()
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -641,15 +622,6 @@ func (d *decodeState) object(v reflect.Value) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cu != nil {
|
|
||||||
start := d.readIndex()
|
|
||||||
d.skip()
|
|
||||||
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
|
|
||||||
if err != nil {
|
|
||||||
d.saveError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ut != nil {
|
if ut != nil {
|
||||||
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
|
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
|
||||||
d.skip()
|
d.skip()
|
||||||
|
@ -898,7 +870,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
isNull := item[0] == 'n' // null
|
isNull := item[0] == 'n' // null
|
||||||
u, cu, ut, pv := indirect(v, isNull)
|
u, ut, pv := indirect(v, isNull)
|
||||||
if u != nil {
|
if u != nil {
|
||||||
err := u.UnmarshalJSON(item)
|
err := u.UnmarshalJSON(item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -906,13 +878,6 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if cu != nil {
|
|
||||||
err := cu.UnmarshalJSONContext(d.ctx, item)
|
|
||||||
if err != nil {
|
|
||||||
d.saveError(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if ut != nil {
|
if ut != nil {
|
||||||
if item[0] != '"' {
|
if item[0] != '"' {
|
||||||
if fromQuoted {
|
if fromQuoted {
|
||||||
|
|
|
@ -12,7 +12,6 @@ package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding"
|
"encoding"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -157,11 +156,7 @@ import (
|
||||||
// handle them. Passing cyclic structures to Marshal will result in
|
// handle them. Passing cyclic structures to Marshal will result in
|
||||||
// an error.
|
// an error.
|
||||||
func Marshal(v any) ([]byte, error) {
|
func Marshal(v any) ([]byte, error) {
|
||||||
return MarshalContext(context.Background(), v)
|
e := newEncodeState()
|
||||||
}
|
|
||||||
|
|
||||||
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
|
|
||||||
e := newEncodeState(ctx)
|
|
||||||
defer encodeStatePool.Put(e)
|
defer encodeStatePool.Put(e)
|
||||||
|
|
||||||
err := e.marshal(v, encOpts{escapeHTML: true})
|
err := e.marshal(v, encOpts{escapeHTML: true})
|
||||||
|
@ -256,7 +251,6 @@ var hex = "0123456789abcdef"
|
||||||
type encodeState struct {
|
type encodeState struct {
|
||||||
bytes.Buffer // accumulated output
|
bytes.Buffer // accumulated output
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
// Keep track of what pointers we've seen in the current recursive call
|
// Keep track of what pointers we've seen in the current recursive call
|
||||||
// path, to avoid cycles that could lead to a stack overflow. Only do
|
// path, to avoid cycles that could lead to a stack overflow. Only do
|
||||||
// the relatively expensive map operations if ptrLevel is larger than
|
// the relatively expensive map operations if ptrLevel is larger than
|
||||||
|
@ -270,7 +264,7 @@ const startDetectingCyclesAfter = 1000
|
||||||
|
|
||||||
var encodeStatePool sync.Pool
|
var encodeStatePool sync.Pool
|
||||||
|
|
||||||
func newEncodeState(ctx context.Context) *encodeState {
|
func newEncodeState() *encodeState {
|
||||||
if v := encodeStatePool.Get(); v != nil {
|
if v := encodeStatePool.Get(); v != nil {
|
||||||
e := v.(*encodeState)
|
e := v.(*encodeState)
|
||||||
e.Reset()
|
e.Reset()
|
||||||
|
@ -280,7 +274,7 @@ func newEncodeState(ctx context.Context) *encodeState {
|
||||||
e.ptrLevel = 0
|
e.ptrLevel = 0
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
|
return &encodeState{ptrSeen: make(map[any]struct{})}
|
||||||
}
|
}
|
||||||
|
|
||||||
// jsonError is an error wrapper type for internal use only.
|
// jsonError is an error wrapper type for internal use only.
|
||||||
|
@ -377,9 +371,8 @@ func typeEncoder(t reflect.Type) encoderFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
|
||||||
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
|
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||||
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// newTypeEncoder constructs an encoderFunc for a type.
|
// newTypeEncoder constructs an encoderFunc for a type.
|
||||||
|
@ -392,15 +385,9 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
|
||||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
|
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
|
||||||
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
|
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
|
||||||
}
|
}
|
||||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
|
|
||||||
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
|
|
||||||
}
|
|
||||||
if t.Implements(marshalerType) {
|
if t.Implements(marshalerType) {
|
||||||
return marshalerEncoder
|
return marshalerEncoder
|
||||||
}
|
}
|
||||||
if t.Implements(contextMarshalerType) {
|
|
||||||
return contextMarshalerEncoder
|
|
||||||
}
|
|
||||||
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
|
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
|
||||||
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
|
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
|
||||||
}
|
}
|
||||||
|
@ -483,47 +470,6 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|
||||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
|
||||||
e.WriteString("null")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m, ok := v.Interface().(ContextMarshaler)
|
|
||||||
if !ok {
|
|
||||||
e.WriteString("null")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b, err := m.MarshalJSONContext(e.ctx)
|
|
||||||
if err == nil {
|
|
||||||
e.Grow(len(b))
|
|
||||||
out := availableBuffer(&e.Buffer)
|
|
||||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
|
||||||
e.Buffer.Write(out)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
|
||||||
va := v.Addr()
|
|
||||||
if va.IsNil() {
|
|
||||||
e.WriteString("null")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m := va.Interface().(ContextMarshaler)
|
|
||||||
b, err := m.MarshalJSONContext(e.ctx)
|
|
||||||
if err == nil {
|
|
||||||
e.Grow(len(b))
|
|
||||||
out := availableBuffer(&e.Buffer)
|
|
||||||
out, err = appendCompact(out, b, opts.escapeHTML)
|
|
||||||
e.Buffer.Write(out)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||||
if v.Kind() == reflect.Pointer && v.IsNil() {
|
if v.Kind() == reflect.Pointer && v.IsNil() {
|
||||||
e.WriteString("null")
|
e.WriteString("null")
|
||||||
|
@ -881,7 +827,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||||
// Byte slices get special treatment; arrays don't.
|
// Byte slices get special treatment; arrays don't.
|
||||||
if t.Elem().Kind() == reflect.Uint8 {
|
if t.Elem().Kind() == reflect.Uint8 {
|
||||||
p := reflect.PointerTo(t.Elem())
|
p := reflect.PointerTo(t.Elem())
|
||||||
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
|
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
|
||||||
return encodeByteSlice
|
return encodeByteSlice
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
package json
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ObjectKeys(object reflect.Type) []string {
|
|
||||||
switch object.Kind() {
|
|
||||||
case reflect.Pointer:
|
|
||||||
return ObjectKeys(object.Elem())
|
|
||||||
case reflect.Struct:
|
|
||||||
default:
|
|
||||||
panic("invalid non-struct input")
|
|
||||||
}
|
|
||||||
return common.Map(cachedTypeFields(object).list, func(field field) string {
|
|
||||||
return field.name
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
package json_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
json "github.com/sagernet/sing/common/json/internal/contextjson"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MyObject struct {
|
|
||||||
Hello string `json:"hello,omitempty"`
|
|
||||||
MyWorld
|
|
||||||
MyWorld2 string `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MyWorld struct {
|
|
||||||
World string `json:"world,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestObjectKeys(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
keys := json.ObjectKeys(reflect.TypeOf(&MyObject{}))
|
|
||||||
require.Equal(t, []string{"hello", "world"}, keys)
|
|
||||||
}
|
|
|
@ -6,7 +6,6 @@ package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
@ -30,11 +29,7 @@ type Decoder struct {
|
||||||
// The decoder introduces its own buffering and may
|
// The decoder introduces its own buffering and may
|
||||||
// read data from r beyond the JSON values requested.
|
// read data from r beyond the JSON values requested.
|
||||||
func NewDecoder(r io.Reader) *Decoder {
|
func NewDecoder(r io.Reader) *Decoder {
|
||||||
return NewDecoderContext(context.Background(), r)
|
return &Decoder{r: r}
|
||||||
}
|
|
||||||
|
|
||||||
func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder {
|
|
||||||
return &Decoder{r: r, d: decodeState{ctx: ctx}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
|
||||||
|
@ -188,7 +183,6 @@ func nonSpace(b []byte) bool {
|
||||||
|
|
||||||
// An Encoder writes JSON values to an output stream.
|
// An Encoder writes JSON values to an output stream.
|
||||||
type Encoder struct {
|
type Encoder struct {
|
||||||
ctx context.Context
|
|
||||||
w io.Writer
|
w io.Writer
|
||||||
err error
|
err error
|
||||||
escapeHTML bool
|
escapeHTML bool
|
||||||
|
@ -200,11 +194,7 @@ type Encoder struct {
|
||||||
|
|
||||||
// NewEncoder returns a new encoder that writes to w.
|
// NewEncoder returns a new encoder that writes to w.
|
||||||
func NewEncoder(w io.Writer) *Encoder {
|
func NewEncoder(w io.Writer) *Encoder {
|
||||||
return NewEncoderContext(context.Background(), w)
|
return &Encoder{w: w, escapeHTML: true}
|
||||||
}
|
|
||||||
|
|
||||||
func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder {
|
|
||||||
return &Encoder{ctx: ctx, w: w, escapeHTML: true}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode writes the JSON encoding of v to the stream,
|
// Encode writes the JSON encoding of v to the stream,
|
||||||
|
@ -217,7 +207,7 @@ func (enc *Encoder) Encode(v any) error {
|
||||||
return enc.err
|
return enc.err
|
||||||
}
|
}
|
||||||
|
|
||||||
e := newEncodeState(enc.ctx)
|
e := newEncodeState()
|
||||||
defer encodeStatePool.Put(e)
|
defer encodeStatePool.Put(e)
|
||||||
|
|
||||||
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
package json
|
package json
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
||||||
var d decodeState
|
var d decodeState
|
||||||
d.disallowUnknownFields = true
|
d.disallowUnknownFields = true
|
||||||
|
@ -12,15 +10,3 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error {
|
||||||
d.init(data)
|
d.init(data)
|
||||||
return d.unmarshal(v)
|
return d.unmarshal(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error {
|
|
||||||
var d decodeState
|
|
||||||
d.ctx = ctx
|
|
||||||
d.disallowUnknownFields = true
|
|
||||||
err := checkValid(data, &d.scan)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
d.init(data)
|
|
||||||
return d.unmarshal(v)
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package json
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -11,11 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnmarshalExtended[T any](content []byte) (T, error) {
|
func UnmarshalExtended[T any](content []byte) (T, error) {
|
||||||
return UnmarshalExtendedContext[T](context.Background(), content)
|
decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content)))
|
||||||
}
|
|
||||||
|
|
||||||
func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) {
|
|
||||||
decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content)))
|
|
||||||
var value T
|
var value T
|
||||||
err := decoder.Decode(&value)
|
err := decoder.Decode(&value)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package metadata
|
package metadata
|
||||||
|
|
||||||
// Deprecated: wtf is this?
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
Protocol string
|
Protocol string
|
||||||
Source Socksaddr
|
Source Socksaddr
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
|
@ -71,39 +70,8 @@ type ExtendedConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloseHandlerFunc = func(it error)
|
|
||||||
|
|
||||||
func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
|
|
||||||
if onClose == nil {
|
|
||||||
panic("nil onClose")
|
|
||||||
}
|
|
||||||
if parent == nil {
|
|
||||||
return onClose
|
|
||||||
}
|
|
||||||
return func(it error) {
|
|
||||||
onClose(it)
|
|
||||||
parent(it)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc {
|
|
||||||
var once sync.Once
|
|
||||||
return func(it error) {
|
|
||||||
once.Do(func() {
|
|
||||||
onClose(it)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: Use TCPConnectionHandlerEx instead.
|
|
||||||
type TCPConnectionHandler interface {
|
type TCPConnectionHandler interface {
|
||||||
NewConnection(ctx context.Context, conn net.Conn,
|
NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error
|
||||||
//nolint:staticcheck
|
|
||||||
metadata M.Metadata) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPConnectionHandlerEx interface {
|
|
||||||
NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetPacketConn interface {
|
type NetPacketConn interface {
|
||||||
|
@ -117,26 +85,12 @@ type BindPacketConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use UDPHandlerEx instead.
|
|
||||||
type UDPHandler interface {
|
type UDPHandler interface {
|
||||||
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer,
|
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
|
||||||
//nolint:staticcheck
|
|
||||||
metadata M.Metadata) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UDPHandlerEx interface {
|
|
||||||
NewPacketEx(buffer *buf.Buffer, source M.Socksaddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: Use UDPConnectionHandlerEx instead.
|
|
||||||
type UDPConnectionHandler interface {
|
type UDPConnectionHandler interface {
|
||||||
NewPacketConnection(ctx context.Context, conn PacketConn,
|
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error
|
||||||
//nolint:staticcheck
|
|
||||||
metadata M.Metadata) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type UDPConnectionHandlerEx interface {
|
|
||||||
NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CachedReader interface {
|
type CachedReader interface {
|
||||||
|
@ -147,6 +101,11 @@ type CachedPacketReader interface {
|
||||||
ReadCachedPacket() *PacketBuffer
|
ReadCachedPacket() *PacketBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PacketBuffer struct {
|
||||||
|
Buffer *buf.Buffer
|
||||||
|
Destination M.Socksaddr
|
||||||
|
}
|
||||||
|
|
||||||
type WithUpstreamReader interface {
|
type WithUpstreamReader interface {
|
||||||
UpstreamReader() any
|
UpstreamReader() any
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,10 @@ type Dialer interface {
|
||||||
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
|
ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PayloadDialer interface {
|
||||||
|
DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ParallelDialer interface {
|
type ParallelDialer interface {
|
||||||
Dialer
|
Dialer
|
||||||
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)
|
DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error)
|
||||||
|
|
|
@ -15,39 +15,19 @@ type ReadWaitOptions struct {
|
||||||
MTU int
|
MTU int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReadWaitOptions(source any, destination any) ReadWaitOptions {
|
|
||||||
return ReadWaitOptions{
|
|
||||||
FrontHeadroom: CalculateFrontHeadroom(destination),
|
|
||||||
RearHeadroom: CalculateRearHeadroom(destination),
|
|
||||||
MTU: CalculateMTU(source, destination),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o ReadWaitOptions) NeedHeadroom() bool {
|
func (o ReadWaitOptions) NeedHeadroom() bool {
|
||||||
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
|
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer {
|
|
||||||
if o.FrontHeadroom > buffer.Start() ||
|
|
||||||
o.RearHeadroom > buffer.FreeLen() {
|
|
||||||
newBuffer := o.newBuffer(buf.UDPBufferSize, false)
|
|
||||||
newBuffer.Write(buffer.Bytes())
|
|
||||||
buffer.Release()
|
|
||||||
return newBuffer
|
|
||||||
} else {
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
|
func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
|
||||||
return o.newBuffer(buf.BufferSize, true)
|
return o.newBuffer(buf.BufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
|
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
|
||||||
return o.newBuffer(buf.UDPBufferSize, true)
|
return o.newBuffer(buf.UDPBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer {
|
func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer {
|
||||||
var bufferSize int
|
var bufferSize int
|
||||||
if o.MTU > 0 {
|
if o.MTU > 0 {
|
||||||
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
|
bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom
|
||||||
|
@ -58,7 +38,7 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buf
|
||||||
if o.FrontHeadroom > 0 {
|
if o.FrontHeadroom > 0 {
|
||||||
buffer.Resize(o.FrontHeadroom, 0)
|
buffer.Resize(o.FrontHeadroom, 0)
|
||||||
}
|
}
|
||||||
if o.RearHeadroom > 0 && reserve {
|
if o.RearHeadroom > 0 {
|
||||||
buffer.Reserve(o.RearHeadroom)
|
buffer.Reserve(o.RearHeadroom)
|
||||||
}
|
}
|
||||||
return buffer
|
return buffer
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
package network
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
@ -16,75 +13,17 @@ type HandshakeSuccess interface {
|
||||||
HandshakeSuccess() error
|
HandshakeSuccess() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnHandshakeSuccess interface {
|
func ReportHandshakeFailure(conn any, err error) error {
|
||||||
ConnHandshakeSuccess(conn net.Conn) error
|
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn {
|
||||||
}
|
|
||||||
|
|
||||||
type PacketConnHandshakeSuccess interface {
|
|
||||||
PacketConnHandshakeSuccess(conn net.PacketConn) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportHandshakeFailure(reporter any, err error) error {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
|
|
||||||
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
|
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
|
||||||
return E.Cause(err, "write handshake failure")
|
return E.Cause(err, "write handshake failure")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
|
|
||||||
hErr := handshakeConn.HandshakeFailure(err)
|
|
||||||
err = E.Append(err, hErr, func(err error) error {
|
|
||||||
if closer, isCloser := reporter.(io.Closer); isCloser {
|
|
||||||
err = E.Append(err, closer.Close(), func(err error) error {
|
|
||||||
return E.Cause(err, "close")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return E.Cause(err, "write handshake failure")
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
if tcpConn, isTCPConn := common.Cast[interface {
|
|
||||||
SetLinger(sec int) error
|
|
||||||
}](reporter); isTCPConn {
|
|
||||||
tcpConn.SetLinger(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = E.Append(err, reporter.Close(), func(err error) error {
|
|
||||||
return E.Cause(err, "close")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if onClose != nil {
|
|
||||||
onClose(err)
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead
|
func ReportHandshakeSuccess(conn any) error {
|
||||||
func ReportHandshakeSuccess(reporter any) error {
|
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn {
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.HandshakeSuccess()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.ConnHandshakeSuccess(conn)
|
|
||||||
}
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.HandshakeSuccess()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error {
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.PacketConnHandshakeSuccess(conn)
|
|
||||||
}
|
|
||||||
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
|
|
||||||
return handshakeConn.HandshakeSuccess()
|
return handshakeConn.HandshakeSuccess()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
package network
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PacketBuffer struct {
|
|
||||||
Buffer *buf.Buffer
|
|
||||||
Destination M.Socksaddr
|
|
||||||
}
|
|
||||||
|
|
||||||
var packetPool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return new(PacketBuffer)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPacketBuffer() *PacketBuffer {
|
|
||||||
return packetPool.Get().(*PacketBuffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func PutPacketBuffer(packet *PacketBuffer) {
|
|
||||||
*packet = PacketBuffer{}
|
|
||||||
packetPool.Put(packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReleaseMultiPacketBuffer(packetBuffers []*PacketBuffer) {
|
|
||||||
for _, packet := range packetBuffers {
|
|
||||||
packet.Buffer.Release()
|
|
||||||
PutPacketBuffer(packet)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -11,7 +11,6 @@ type ThreadUnsafeWriter interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use ReadWaiter interface instead.
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
|
|
||||||
type ThreadSafeReader interface {
|
type ThreadSafeReader interface {
|
||||||
// Deprecated: Use ReadWaiter interface instead.
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
|
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
|
||||||
|
@ -19,6 +18,7 @@ type ThreadSafeReader interface {
|
||||||
|
|
||||||
// Deprecated: Use ReadWaiter interface instead.
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
type ThreadSafePacketReader interface {
|
type ThreadSafePacketReader interface {
|
||||||
|
// Deprecated: Use ReadWaiter interface instead.
|
||||||
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
|
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetSystemTime(nowTime time.Time) error {
|
func SetSystemTime(nowTime time.Time) error {
|
||||||
nowTime = nowTime.UTC()
|
|
||||||
var systemTime windows.Systemtime
|
var systemTime windows.Systemtime
|
||||||
systemTime.Year = uint16(nowTime.Year())
|
systemTime.Year = uint16(nowTime.Year())
|
||||||
systemTime.Month = uint16(nowTime.Month())
|
systemTime.Month = uint16(nowTime.Month())
|
||||||
|
|
|
@ -20,5 +20,6 @@ func InitializeSeed() {
|
||||||
func initializeSeed() {
|
func initializeSeed() {
|
||||||
var seed int64
|
var seed int64
|
||||||
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
|
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
|
||||||
|
//goland:noinspection GoDeprecation
|
||||||
mRand.Seed(seed)
|
mRand.Seed(seed)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ func ToByteReader(reader io.Reader) io.ByteReader {
|
||||||
|
|
||||||
// Deprecated: Use binary.ReadUvarint instead.
|
// Deprecated: Use binary.ReadUvarint instead.
|
||||||
func ReadUVariant(reader io.Reader) (uint64, error) {
|
func ReadUVariant(reader io.Reader) (uint64, error) {
|
||||||
|
//goland:noinspection GoDeprecation
|
||||||
return binary.ReadUvarint(ToByteReader(reader))
|
return binary.ReadUvarint(ToByteReader(reader))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,23 +16,18 @@ import (
|
||||||
"github.com/sagernet/sing/common/pipe"
|
"github.com/sagernet/sing/common/pipe"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Deprecated: Use N.UDPConnectionHandler instead.
|
|
||||||
//
|
|
||||||
//nolint:staticcheck
|
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
N.UDPConnectionHandler
|
N.UDPConnectionHandler
|
||||||
E.Handler
|
E.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
type Service[K comparable] struct {
|
type Service[K comparable] struct {
|
||||||
nat *cache.LruCache[K, *conn]
|
nat *cache.LruCache[K, *conn]
|
||||||
handler Handler
|
handler Handler
|
||||||
handlerEx N.UDPConnectionHandlerEx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use NewEx instead.
|
|
||||||
func New[K comparable](maxAge int64, handler Handler) *Service[K] {
|
func New[K comparable](maxAge int64, handler Handler) *Service[K] {
|
||||||
service := &Service[K]{
|
return &Service[K]{
|
||||||
nat: cache.New(
|
nat: cache.New(
|
||||||
cache.WithAge[K, *conn](maxAge),
|
cache.WithAge[K, *conn](maxAge),
|
||||||
cache.WithUpdateAgeOnGet[K, *conn](),
|
cache.WithUpdateAgeOnGet[K, *conn](),
|
||||||
|
@ -42,27 +37,11 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] {
|
||||||
),
|
),
|
||||||
handler: handler,
|
handler: handler,
|
||||||
}
|
}
|
||||||
return service
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] {
|
|
||||||
service := &Service[K]{
|
|
||||||
nat: cache.New(
|
|
||||||
cache.WithAge[K, *conn](maxAge),
|
|
||||||
cache.WithUpdateAgeOnGet[K, *conn](),
|
|
||||||
cache.WithEvict[K, *conn](func(key K, conn *conn) {
|
|
||||||
conn.Close()
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
handlerEx: handler,
|
|
||||||
}
|
|
||||||
return service
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service[T]) WriteIsThreadUnsafe() {
|
func (s *Service[T]) WriteIsThreadUnsafe() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: don't use
|
|
||||||
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
|
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
|
||||||
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
||||||
return ctx, &DirectBackWriter{conn, natConn}
|
return ctx, &DirectBackWriter{conn, natConn}
|
||||||
|
@ -82,30 +61,18 @@ func (w *DirectBackWriter) Upstream() any {
|
||||||
return w.Source
|
return w.Source
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use NewPacketEx instead.
|
|
||||||
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
|
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
|
||||||
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
||||||
return ctx, init(natConn)
|
return ctx, init(natConn)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) {
|
|
||||||
s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
|
|
||||||
return ctx, init(natConn)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deprecated: Use NewPacketConnectionEx instead.
|
|
||||||
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
|
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
|
||||||
s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
|
|
||||||
c, loaded := s.nat.LoadOrStore(key, func() *conn {
|
c, loaded := s.nat.LoadOrStore(key, func() *conn {
|
||||||
c := &conn{
|
c := &conn{
|
||||||
data: make(chan packet, 64),
|
data: make(chan packet, 64),
|
||||||
localAddr: source,
|
localAddr: metadata.Source,
|
||||||
remoteAddr: destination,
|
remoteAddr: metadata.Destination,
|
||||||
readDeadline: pipe.MakeDeadline(),
|
readDeadline: pipe.MakeDeadline(),
|
||||||
}
|
}
|
||||||
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
|
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
|
||||||
|
@ -114,34 +81,26 @@ func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.
|
||||||
if !loaded {
|
if !loaded {
|
||||||
ctx, c.source = init(c)
|
ctx, c.source = init(c)
|
||||||
go func() {
|
go func() {
|
||||||
if s.handlerEx != nil {
|
err := s.handler.NewPacketConnection(ctx, c, metadata)
|
||||||
s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) {
|
if err != nil {
|
||||||
s.nat.Delete(key)
|
s.handler.NewError(ctx, err)
|
||||||
})
|
|
||||||
} else {
|
|
||||||
//nolint:staticcheck
|
|
||||||
err := s.handler.NewPacketConnection(ctx, c, M.Metadata{
|
|
||||||
Source: source,
|
|
||||||
Destination: destination,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
s.handler.NewError(ctx, err)
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
s.nat.Delete(key)
|
|
||||||
}
|
}
|
||||||
|
c.Close()
|
||||||
|
s.nat.Delete(key)
|
||||||
}()
|
}()
|
||||||
|
} else {
|
||||||
|
c.localAddr = metadata.Source
|
||||||
}
|
}
|
||||||
if common.Done(c.ctx) {
|
if common.Done(c.ctx) {
|
||||||
s.nat.Delete(key)
|
s.nat.Delete(key)
|
||||||
if !common.Done(ctx) {
|
if !common.Done(ctx) {
|
||||||
s.NewContextPacketEx(ctx, key, buffer, source, destination, init)
|
s.NewContextPacket(ctx, key, buffer, metadata, init)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.data <- packet{
|
c.data <- packet{
|
||||||
data: buffer,
|
data: buffer,
|
||||||
destination: destination,
|
destination: metadata.Destination,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,6 +172,10 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||||
return os.ErrInvalid
|
return os.ErrInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) NeedAdditionalReadDeadline() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) Upstream() any {
|
func (c *conn) Upstream() any {
|
||||||
return c.source
|
return c.source
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,138 +0,0 @@
|
||||||
package udpnat
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/canceler"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
"github.com/sagernet/sing/common/pipe"
|
|
||||||
"github.com/sagernet/sing/contrab/freelru"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Conn interface {
|
|
||||||
N.PacketConn
|
|
||||||
SetHandler(handler N.UDPHandlerEx)
|
|
||||||
canceler.PacketConn
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Conn = (*natConn)(nil)
|
|
||||||
|
|
||||||
type natConn struct {
|
|
||||||
cache freelru.Cache[netip.AddrPort, *natConn]
|
|
||||||
writer N.PacketWriter
|
|
||||||
localAddr M.Socksaddr
|
|
||||||
handlerAccess sync.RWMutex
|
|
||||||
handler N.UDPHandlerEx
|
|
||||||
packetChan chan *N.PacketBuffer
|
|
||||||
closeOnce sync.Once
|
|
||||||
doneChan chan struct{}
|
|
||||||
readDeadline pipe.Deadline
|
|
||||||
readWaitOptions N.ReadWaitOptions
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case p := <-c.packetChan:
|
|
||||||
_, err = buffer.ReadOnceFrom(p.Buffer)
|
|
||||||
destination := p.Destination
|
|
||||||
p.Buffer.Release()
|
|
||||||
N.PutPacketBuffer(p)
|
|
||||||
return destination, err
|
|
||||||
case <-c.doneChan:
|
|
||||||
return M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
return c.writer.WritePacket(buffer, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
|
|
||||||
c.readWaitOptions = options
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
|
||||||
select {
|
|
||||||
case packet := <-c.packetChan:
|
|
||||||
buffer = c.readWaitOptions.Copy(packet.Buffer)
|
|
||||||
destination = packet.Destination
|
|
||||||
N.PutPacketBuffer(packet)
|
|
||||||
return
|
|
||||||
case <-c.doneChan:
|
|
||||||
return nil, M.Socksaddr{}, io.ErrClosedPipe
|
|
||||||
case <-c.readDeadline.Wait():
|
|
||||||
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetHandler(handler N.UDPHandlerEx) {
|
|
||||||
c.handlerAccess.Lock()
|
|
||||||
c.handler = handler
|
|
||||||
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
|
|
||||||
c.handlerAccess.Unlock()
|
|
||||||
fetch:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case packet := <-c.packetChan:
|
|
||||||
c.handler.NewPacketEx(packet.Buffer, packet.Destination)
|
|
||||||
N.PutPacketBuffer(packet)
|
|
||||||
continue fetch
|
|
||||||
default:
|
|
||||||
break fetch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) Timeout() time.Duration {
|
|
||||||
rawConn, lifetime, loaded := c.cache.PeekWithLifetime(c.localAddr.AddrPort())
|
|
||||||
if !loaded || rawConn != c {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return time.Until(lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetTimeout(timeout time.Duration) bool {
|
|
||||||
return c.cache.UpdateLifetime(c.localAddr.AddrPort(), c, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) Close() error {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
close(c.doneChan)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) LocalAddr() net.Addr {
|
|
||||||
return c.localAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) RemoteAddr() net.Addr {
|
|
||||||
return M.Socksaddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetReadDeadline(t time.Time) error {
|
|
||||||
c.readDeadline.Set(t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *natConn) Upstream() any {
|
|
||||||
return c.writer
|
|
||||||
}
|
|
|
@ -1,103 +0,0 @@
|
||||||
package udpnat
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
"github.com/sagernet/sing/common/pipe"
|
|
||||||
"github.com/sagernet/sing/contrab/freelru"
|
|
||||||
"github.com/sagernet/sing/contrab/maphash"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Service struct {
|
|
||||||
cache freelru.Cache[netip.AddrPort, *natConn]
|
|
||||||
handler N.UDPConnectionHandlerEx
|
|
||||||
prepare PrepareFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
type PrepareFunc func(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc)
|
|
||||||
|
|
||||||
func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service {
|
|
||||||
if timeout == 0 {
|
|
||||||
panic("invalid timeout")
|
|
||||||
}
|
|
||||||
var cache freelru.Cache[netip.AddrPort, *natConn]
|
|
||||||
if !shared {
|
|
||||||
cache = common.Must1(freelru.NewSynced[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
|
|
||||||
} else {
|
|
||||||
cache = common.Must1(freelru.NewSharded[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32))
|
|
||||||
}
|
|
||||||
cache.SetLifetime(timeout)
|
|
||||||
cache.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool {
|
|
||||||
select {
|
|
||||||
case <-conn.doneChan:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
cache.SetOnEvict(func(_ netip.AddrPort, conn *natConn) {
|
|
||||||
conn.Close()
|
|
||||||
})
|
|
||||||
return &Service{
|
|
||||||
cache: cache,
|
|
||||||
handler: handler,
|
|
||||||
prepare: prepare,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) {
|
|
||||||
conn, _, ok := s.cache.GetAndRefreshOrAdd(source.AddrPort(), func() (*natConn, bool) {
|
|
||||||
ok, ctx, writer, onClose := s.prepare(source, destination, userData)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
newConn := &natConn{
|
|
||||||
cache: s.cache,
|
|
||||||
writer: writer,
|
|
||||||
localAddr: source,
|
|
||||||
packetChan: make(chan *N.PacketBuffer, 64),
|
|
||||||
doneChan: make(chan struct{}),
|
|
||||||
readDeadline: pipe.MakeDeadline(),
|
|
||||||
}
|
|
||||||
go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose)
|
|
||||||
return newConn, true
|
|
||||||
})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer := conn.readWaitOptions.NewPacketBuffer()
|
|
||||||
for _, bufferSlice := range bufferSlices {
|
|
||||||
buffer.Write(bufferSlice)
|
|
||||||
}
|
|
||||||
conn.handlerAccess.RLock()
|
|
||||||
handler := conn.handler
|
|
||||||
conn.handlerAccess.RUnlock()
|
|
||||||
if handler != nil {
|
|
||||||
handler.NewPacketEx(buffer, destination)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
packet := N.NewPacketBuffer()
|
|
||||||
*packet = N.PacketBuffer{
|
|
||||||
Buffer: buffer,
|
|
||||||
Destination: destination,
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case conn.packetChan <- packet:
|
|
||||||
default:
|
|
||||||
packet.Buffer.Release()
|
|
||||||
N.PutPacketBuffer(packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) Purge() {
|
|
||||||
s.cache.Purge()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) PurgeExpired() {
|
|
||||||
s.cache.PurgeExpired()
|
|
||||||
}
|
|
|
@ -1,14 +1,16 @@
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package windnsapi
|
package windnsapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSAPI(t *testing.T) {
|
func TestDNSAPI(t *testing.T) {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
require.NoError(t, FlushResolverCache())
|
require.NoError(t, FlushResolverCache())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,217 +0,0 @@
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package winiphlpapi
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
)
|
|
||||||
|
|
||||||
func LoadEStats() error {
|
|
||||||
err := modiphlpapi.Load()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procGetTcpTable.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procGetTcp6Table.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procGetPerTcp6ConnectionEStats.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procGetPerTcp6ConnectionEStats.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procSetPerTcpConnectionEStats.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procSetPerTcp6ConnectionEStats.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadExtendedTable() error {
|
|
||||||
err := modiphlpapi.Load()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procGetExtendedTcpTable.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = procGetExtendedUdpTable.Find()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindPid(network string, source netip.AddrPort) (uint32, error) {
|
|
||||||
switch N.NetworkName(network) {
|
|
||||||
case N.NetworkTCP:
|
|
||||||
if source.Addr().Is4() {
|
|
||||||
tcpTable, err := GetExtendedTcpTable()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, row := range tcpTable {
|
|
||||||
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
|
|
||||||
return row.DwOwningPid, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tcpTable, err := GetExtendedTcp6Table()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, row := range tcpTable {
|
|
||||||
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
|
|
||||||
return row.DwOwningPid, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case N.NetworkUDP:
|
|
||||||
if source.Addr().Is4() {
|
|
||||||
udpTable, err := GetExtendedUdpTable()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, row := range udpTable {
|
|
||||||
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) {
|
|
||||||
return row.DwOwningPid, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
udpTable, err := GetExtendedUdp6Table()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, row := range udpTable {
|
|
||||||
if source == netip.AddrPortFrom(netip.AddrFrom16(row.UcLocalAddr), DwordToPort(row.DwLocalPort)) {
|
|
||||||
return row.DwOwningPid, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, E.New("process not found for ", source)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteAndWaitAck(ctx context.Context, conn net.Conn, payload []byte) error {
|
|
||||||
source := M.AddrPortFromNet(conn.LocalAddr())
|
|
||||||
destination := M.AddrPortFromNet(conn.RemoteAddr())
|
|
||||||
if source.Addr().Is4() {
|
|
||||||
tcpTable, err := GetTcpTable()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var tcpRow *MibTcpRow
|
|
||||||
for _, row := range tcpTable {
|
|
||||||
if source == netip.AddrPortFrom(DwordToAddr(row.DwLocalAddr), DwordToPort(row.DwLocalPort)) ||
|
|
||||||
destination == netip.AddrPortFrom(DwordToAddr(row.DwRemoteAddr), DwordToPort(row.DwRemotePort)) {
|
|
||||||
tcpRow = &row
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tcpRow == nil {
|
|
||||||
return E.New("row not found for: ", source)
|
|
||||||
}
|
|
||||||
err = SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
|
|
||||||
EnableCollection: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
|
|
||||||
}
|
|
||||||
defer SetPerTcpConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
|
|
||||||
EnableCollection: false,
|
|
||||||
})
|
|
||||||
_, err = conn.Write(payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
eStstsSendBuffer, err := GetPerTcpConnectionEStatsSendBuffer(tcpRow)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if eStstsSendBuffer.CurRetxQueue == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tcpTable, err := GetTcp6Table()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var tcpRow *MibTcp6Row
|
|
||||||
for _, row := range tcpTable {
|
|
||||||
if source == netip.AddrPortFrom(netip.AddrFrom16(row.LocalAddr), DwordToPort(row.LocalPort)) ||
|
|
||||||
destination == netip.AddrPortFrom(netip.AddrFrom16(row.RemoteAddr), DwordToPort(row.RemotePort)) {
|
|
||||||
tcpRow = &row
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tcpRow == nil {
|
|
||||||
return E.New("row not found for: ", source)
|
|
||||||
}
|
|
||||||
err = SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
|
|
||||||
EnableCollection: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return os.NewSyscallError("SetPerTcpConnectionEStatsSendBufferV0", err)
|
|
||||||
}
|
|
||||||
defer SetPerTcp6ConnectionEStatsSendBuffer(tcpRow, &TcpEstatsSendBuffRwV0{
|
|
||||||
EnableCollection: false,
|
|
||||||
})
|
|
||||||
_, err = conn.Write(payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
eStstsSendBuffer, err := GetPerTcp6ConnectionEStatsSendBuffer(tcpRow)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if eStstsSendBuffer.CurRetxQueue == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func DwordToAddr(addr uint32) netip.Addr {
|
|
||||||
return netip.AddrFrom4(*(*[4]byte)(unsafe.Pointer(&addr)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func DwordToPort(dword uint32) uint16 {
|
|
||||||
return binary.BigEndian.Uint16((*[4]byte)(unsafe.Pointer(&dword))[:])
|
|
||||||
}
|
|
|
@ -1,313 +0,0 @@
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package winiphlpapi
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
TcpTableBasicListener uint32 = iota
|
|
||||||
TcpTableBasicConnections
|
|
||||||
TcpTableBasicAll
|
|
||||||
TcpTableOwnerPidListener
|
|
||||||
TcpTableOwnerPidConnections
|
|
||||||
TcpTableOwnerPidAll
|
|
||||||
TcpTableOwnerModuleListener
|
|
||||||
TcpTableOwnerModuleConnections
|
|
||||||
TcpTableOwnerModuleAll
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
UdpTableBasic uint32 = iota
|
|
||||||
UdpTableOwnerPid
|
|
||||||
UdpTableOwnerModule
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
TcpConnectionEstatsSynOpts uint32 = iota
|
|
||||||
TcpConnectionEstatsData
|
|
||||||
TcpConnectionEstatsSndCong
|
|
||||||
TcpConnectionEstatsPath
|
|
||||||
TcpConnectionEstatsSendBuff
|
|
||||||
TcpConnectionEstatsRec
|
|
||||||
TcpConnectionEstatsObsRec
|
|
||||||
TcpConnectionEstatsBandwidth
|
|
||||||
TcpConnectionEstatsFineRtt
|
|
||||||
TcpConnectionEstatsMaximum
|
|
||||||
)
|
|
||||||
|
|
||||||
type MibTcpTable struct {
|
|
||||||
DwNumEntries uint32
|
|
||||||
Table [1]MibTcpRow
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcpRow struct {
|
|
||||||
DwState uint32
|
|
||||||
DwLocalAddr uint32
|
|
||||||
DwLocalPort uint32
|
|
||||||
DwRemoteAddr uint32
|
|
||||||
DwRemotePort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcp6Table struct {
|
|
||||||
DwNumEntries uint32
|
|
||||||
Table [1]MibTcp6Row
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcp6Row struct {
|
|
||||||
State uint32
|
|
||||||
LocalAddr [16]byte
|
|
||||||
LocalScopeId uint32
|
|
||||||
LocalPort uint32
|
|
||||||
RemoteAddr [16]byte
|
|
||||||
RemoteScopeId uint32
|
|
||||||
RemotePort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcpTableOwnerPid struct {
|
|
||||||
DwNumEntries uint32
|
|
||||||
Table [1]MibTcpRowOwnerPid
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcpRowOwnerPid struct {
|
|
||||||
DwState uint32
|
|
||||||
DwLocalAddr uint32
|
|
||||||
DwLocalPort uint32
|
|
||||||
DwRemoteAddr uint32
|
|
||||||
DwRemotePort uint32
|
|
||||||
DwOwningPid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcp6TableOwnerPid struct {
|
|
||||||
DwNumEntries uint32
|
|
||||||
Table [1]MibTcp6RowOwnerPid
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibTcp6RowOwnerPid struct {
|
|
||||||
UcLocalAddr [16]byte
|
|
||||||
DwLocalScopeId uint32
|
|
||||||
DwLocalPort uint32
|
|
||||||
UcRemoteAddr [16]byte
|
|
||||||
DwRemoteScopeId uint32
|
|
||||||
DwRemotePort uint32
|
|
||||||
DwState uint32
|
|
||||||
DwOwningPid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibUdpTableOwnerPid struct {
|
|
||||||
DwNumEntries uint32
|
|
||||||
Table [1]MibUdpRowOwnerPid
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibUdpRowOwnerPid struct {
|
|
||||||
DwLocalAddr uint32
|
|
||||||
DwLocalPort uint32
|
|
||||||
DwOwningPid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibUdp6TableOwnerPid struct {
|
|
||||||
DwNumEntries uint32
|
|
||||||
Table [1]MibUdp6RowOwnerPid
|
|
||||||
}
|
|
||||||
|
|
||||||
type MibUdp6RowOwnerPid struct {
|
|
||||||
UcLocalAddr [16]byte
|
|
||||||
DwLocalScopeId uint32
|
|
||||||
DwLocalPort uint32
|
|
||||||
DwOwningPid uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type TcpEstatsSendBufferRodV0 struct {
|
|
||||||
CurRetxQueue uint64
|
|
||||||
MaxRetxQueue uint64
|
|
||||||
CurAppWQueue uint64
|
|
||||||
MaxAppWQueue uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
type TcpEstatsSendBuffRwV0 struct {
|
|
||||||
EnableCollection bool
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
offsetOfMibTcpTable = unsafe.Offsetof(MibTcpTable{}.Table)
|
|
||||||
offsetOfMibTcp6Table = unsafe.Offsetof(MibTcp6Table{}.Table)
|
|
||||||
offsetOfMibTcpTableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
|
|
||||||
offsetOfMibTcp6TableOwnerPid = unsafe.Offsetof(MibTcpTableOwnerPid{}.Table)
|
|
||||||
offsetOfMibUdpTableOwnerPid = unsafe.Offsetof(MibUdpTableOwnerPid{}.Table)
|
|
||||||
offsetOfMibUdp6TableOwnerPid = unsafe.Offsetof(MibUdp6TableOwnerPid{}.Table)
|
|
||||||
sizeOfTcpEstatsSendBuffRwV0 = unsafe.Sizeof(TcpEstatsSendBuffRwV0{})
|
|
||||||
sizeOfTcpEstatsSendBufferRodV0 = unsafe.Sizeof(TcpEstatsSendBufferRodV0{})
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetTcpTable() ([]MibTcpRow, error) {
|
|
||||||
var size uint32
|
|
||||||
err := getTcpTable(nil, &size, false)
|
|
||||||
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
table := make([]byte, size)
|
|
||||||
err = getTcpTable(&table[0], &size, false)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
|
|
||||||
return unsafe.Slice((*MibTcpRow)(unsafe.Pointer(&table[offsetOfMibTcpTable])), dwNumEntries), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetTcp6Table() ([]MibTcp6Row, error) {
|
|
||||||
var size uint32
|
|
||||||
err := getTcp6Table(nil, &size, false)
|
|
||||||
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
table := make([]byte, size)
|
|
||||||
err = getTcp6Table(&table[0], &size, false)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
|
|
||||||
return unsafe.Slice((*MibTcp6Row)(unsafe.Pointer(&table[offsetOfMibTcp6Table])), dwNumEntries), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetExtendedTcpTable() ([]MibTcpRowOwnerPid, error) {
|
|
||||||
var size uint32
|
|
||||||
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
|
|
||||||
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
table := make([]byte, size)
|
|
||||||
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET, TcpTableOwnerPidConnections, 0)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
|
|
||||||
}
|
|
||||||
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
|
|
||||||
return unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcpTableOwnerPid])), dwNumEntries), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetExtendedTcp6Table() ([]MibTcp6RowOwnerPid, error) {
|
|
||||||
var size uint32
|
|
||||||
err := getExtendedTcpTable(nil, &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
|
|
||||||
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
table := make([]byte, size)
|
|
||||||
err = getExtendedTcpTable(&table[0], &size, false, windows.AF_INET6, TcpTableOwnerPidConnections, 0)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, os.NewSyscallError("GetExtendedTcpTable", err)
|
|
||||||
}
|
|
||||||
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
|
|
||||||
return unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibTcp6TableOwnerPid])), dwNumEntries), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetExtendedUdpTable() ([]MibUdpRowOwnerPid, error) {
|
|
||||||
var size uint32
|
|
||||||
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
|
|
||||||
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
table := make([]byte, size)
|
|
||||||
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET, UdpTableOwnerPid, 0)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
|
|
||||||
}
|
|
||||||
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
|
|
||||||
return unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdpTableOwnerPid])), dwNumEntries), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetExtendedUdp6Table() ([]MibUdp6RowOwnerPid, error) {
|
|
||||||
var size uint32
|
|
||||||
err := getExtendedUdpTable(nil, &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
|
|
||||||
if !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
table := make([]byte, size)
|
|
||||||
err = getExtendedUdpTable(&table[0], &size, false, windows.AF_INET6, UdpTableOwnerPid, 0)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, os.NewSyscallError("GetExtendedUdpTable", err)
|
|
||||||
}
|
|
||||||
dwNumEntries := int(*(*uint32)(unsafe.Pointer(&table[0])))
|
|
||||||
return unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&table[offsetOfMibUdp6TableOwnerPid])), dwNumEntries), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow) (*TcpEstatsSendBufferRodV0, error) {
|
|
||||||
var rod TcpEstatsSendBufferRodV0
|
|
||||||
err := getPerTcpConnectionEStats(row,
|
|
||||||
TcpConnectionEstatsSendBuff,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
uintptr(unsafe.Pointer(&rod)),
|
|
||||||
0,
|
|
||||||
uint64(sizeOfTcpEstatsSendBufferRodV0),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &rod, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row) (*TcpEstatsSendBufferRodV0, error) {
|
|
||||||
var rod TcpEstatsSendBufferRodV0
|
|
||||||
err := getPerTcp6ConnectionEStats(row,
|
|
||||||
TcpConnectionEstatsSendBuff,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
uintptr(unsafe.Pointer(&rod)),
|
|
||||||
0,
|
|
||||||
uint64(sizeOfTcpEstatsSendBufferRodV0),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &rod, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetPerTcpConnectionEStatsSendBuffer(row *MibTcpRow, rw *TcpEstatsSendBuffRwV0) error {
|
|
||||||
return setPerTcpConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetPerTcp6ConnectionEStatsSendBuffer(row *MibTcp6Row, rw *TcpEstatsSendBuffRwV0) error {
|
|
||||||
return setPerTcp6ConnectionEStats(row, TcpConnectionEstatsSendBuff, uintptr(unsafe.Pointer(&rw)), 0, uint64(sizeOfTcpEstatsSendBuffRwV0), 0)
|
|
||||||
}
|
|
|
@ -1,90 +0,0 @@
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package winiphlpapi_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
"github.com/sagernet/sing/common/winiphlpapi"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestFindPidTcp4(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer listener.Close()
|
|
||||||
go listener.Accept()
|
|
||||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, uint32(syscall.Getpid()), pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindPidTcp6(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
listener, err := net.Listen("tcp", "[::1]:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer listener.Close()
|
|
||||||
go listener.Accept()
|
|
||||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
pid, err := winiphlpapi.FindPid(N.NetworkTCP, M.AddrPortFromNet(conn.LocalAddr()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, uint32(syscall.Getpid()), pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindPidUdp4(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, uint32(syscall.Getpid()), pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindPidUdp6(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
conn, err := net.ListenPacket("udp", "[::1]:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
pid, err := winiphlpapi.FindPid(N.NetworkUDP, M.AddrPortFromNet(conn.LocalAddr()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, uint32(syscall.Getpid()), pid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitAck4(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer listener.Close()
|
|
||||||
go listener.Accept()
|
|
||||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitAck6(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
listener, err := net.Listen("tcp", "[::1]:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer listener.Close()
|
|
||||||
go listener.Accept()
|
|
||||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
err = winiphlpapi.WriteAndWaitAck(context.Background(), conn, []byte("hello"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
|
@ -1,27 +0,0 @@
|
||||||
package winiphlpapi
|
|
||||||
|
|
||||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcptable
|
|
||||||
//sys getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcpTable
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-gettcp6table
|
|
||||||
//sys getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) = iphlpapi.GetTcp6Table
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcpconnectionestats
|
|
||||||
//sys getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcpConnectionEStats
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getpertcp6connectionestats
|
|
||||||
//sys getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) = iphlpapi.GetPerTcp6ConnectionEStats
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcpconnectionestats
|
|
||||||
//sys setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcpConnectionEStats
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-setpertcp6connectionestats
|
|
||||||
//sys setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) = iphlpapi.SetPerTcp6ConnectionEStats
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
|
|
||||||
//sys getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedTcpTable
|
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedudptable
|
|
||||||
//sys getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) = (errcode error) = iphlpapi.GetExtendedUdpTable
|
|
|
@ -1,131 +0,0 @@
|
||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package winiphlpapi
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
errERROR_EINVAL error = syscall.EINVAL
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return errERROR_EINVAL
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
|
|
||||||
|
|
||||||
procGetExtendedTcpTable = modiphlpapi.NewProc("GetExtendedTcpTable")
|
|
||||||
procGetExtendedUdpTable = modiphlpapi.NewProc("GetExtendedUdpTable")
|
|
||||||
procGetPerTcp6ConnectionEStats = modiphlpapi.NewProc("GetPerTcp6ConnectionEStats")
|
|
||||||
procGetPerTcpConnectionEStats = modiphlpapi.NewProc("GetPerTcpConnectionEStats")
|
|
||||||
procGetTcp6Table = modiphlpapi.NewProc("GetTcp6Table")
|
|
||||||
procGetTcpTable = modiphlpapi.NewProc("GetTcpTable")
|
|
||||||
procSetPerTcp6ConnectionEStats = modiphlpapi.NewProc("SetPerTcp6ConnectionEStats")
|
|
||||||
procSetPerTcpConnectionEStats = modiphlpapi.NewProc("SetPerTcpConnectionEStats")
|
|
||||||
)
|
|
||||||
|
|
||||||
func getExtendedTcpTable(pTcpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
|
|
||||||
var _p0 uint32
|
|
||||||
if bOrder {
|
|
||||||
_p0 = 1
|
|
||||||
}
|
|
||||||
r0, _, _ := syscall.Syscall6(procGetExtendedTcpTable.Addr(), 6, uintptr(unsafe.Pointer(pTcpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getExtendedUdpTable(pUdpTable *byte, pdwSize *uint32, bOrder bool, ulAf uint64, tableClass uint32, reserved uint64) (errcode error) {
|
|
||||||
var _p0 uint32
|
|
||||||
if bOrder {
|
|
||||||
_p0 = 1
|
|
||||||
}
|
|
||||||
r0, _, _ := syscall.Syscall6(procGetExtendedUdpTable.Addr(), 6, uintptr(unsafe.Pointer(pUdpTable)), uintptr(unsafe.Pointer(pdwSize)), uintptr(_p0), uintptr(ulAf), uintptr(tableClass), uintptr(reserved))
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
|
|
||||||
r0, _, _ := syscall.Syscall12(procGetPerTcp6ConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, ros uintptr, rosVersion uint64, rosSize uint64, rod uintptr, rodVersion uint64, rodSize uint64) (errcode error) {
|
|
||||||
r0, _, _ := syscall.Syscall12(procGetPerTcpConnectionEStats.Addr(), 11, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(ros), uintptr(rosVersion), uintptr(rosSize), uintptr(rod), uintptr(rodVersion), uintptr(rodSize), 0)
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTcp6Table(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
|
|
||||||
var _p0 uint32
|
|
||||||
if order {
|
|
||||||
_p0 = 1
|
|
||||||
}
|
|
||||||
r0, _, _ := syscall.Syscall(procGetTcp6Table.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTcpTable(tcpTable *byte, sizePointer *uint32, order bool) (errcode error) {
|
|
||||||
var _p0 uint32
|
|
||||||
if order {
|
|
||||||
_p0 = 1
|
|
||||||
}
|
|
||||||
r0, _, _ := syscall.Syscall(procGetTcpTable.Addr(), 3, uintptr(unsafe.Pointer(tcpTable)), uintptr(unsafe.Pointer(sizePointer)), uintptr(_p0))
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setPerTcp6ConnectionEStats(row *MibTcp6Row, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
|
|
||||||
r0, _, _ := syscall.Syscall6(procSetPerTcp6ConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setPerTcpConnectionEStats(row *MibTcpRow, estatsType uint32, rw uintptr, rwVersion uint64, rwSize uint64, offset uint64) (errcode error) {
|
|
||||||
r0, _, _ := syscall.Syscall6(procSetPerTcpConnectionEStats.Addr(), 6, uintptr(unsafe.Pointer(row)), uintptr(estatsType), uintptr(rw), uintptr(rwVersion), uintptr(rwSize), uintptr(offset))
|
|
||||||
if r0 != 0 {
|
|
||||||
errcode = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,201 +0,0 @@
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
|
@ -1,2 +0,0 @@
|
||||||
Go LRU Hashmap
|
|
||||||
Copyright 2022 Elasticsearch B.V.
|
|
|
@ -1,4 +0,0 @@
|
||||||
# freelru
|
|
||||||
|
|
||||||
upstream: github.com/elastic/go-freelru@v0.16.0
|
|
||||||
source: github.com/sagernet/go-freelru@1b34934a560d528d1866415d440625ed2a2560fe
|
|
|
@ -1,102 +0,0 @@
|
||||||
// Licensed to Elasticsearch B.V. under one or more contributor
|
|
||||||
// license agreements. See the NOTICE file distributed with
|
|
||||||
// this work for additional information regarding copyright
|
|
||||||
// ownership. Elasticsearch B.V. licenses this file to you under
|
|
||||||
// the Apache License, Version 2.0 (the "License"); you may
|
|
||||||
// not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing,
|
|
||||||
// software distributed under the License is distributed on an
|
|
||||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
// KIND, either express or implied. See the License for the
|
|
||||||
// specific language governing permissions and limitations
|
|
||||||
// under the License.
|
|
||||||
|
|
||||||
package freelru
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
type Cache[K comparable, V comparable] interface {
|
|
||||||
// SetLifetime sets the default lifetime of LRU elements.
|
|
||||||
// Lifetime 0 means "forever".
|
|
||||||
SetLifetime(lifetime time.Duration)
|
|
||||||
|
|
||||||
// SetOnEvict sets the OnEvict callback function.
|
|
||||||
// The onEvict function is called for each evicted lru entry.
|
|
||||||
SetOnEvict(onEvict OnEvictCallback[K, V])
|
|
||||||
|
|
||||||
SetHealthCheck(healthCheck HealthCheckCallback[K, V])
|
|
||||||
|
|
||||||
// Len returns the number of elements stored in the cache.
|
|
||||||
Len() int
|
|
||||||
|
|
||||||
// AddWithLifetime adds a key:value to the cache with a lifetime.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool)
|
|
||||||
|
|
||||||
// Add adds a key:value to the cache.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
Add(key K, value V) (evicted bool)
|
|
||||||
|
|
||||||
// Get returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// If the found cache item is already expired, the evict function is called
|
|
||||||
// and the return value indicates that the key was not found.
|
|
||||||
Get(key K) (V, bool)
|
|
||||||
|
|
||||||
GetWithLifetime(key K) (V, time.Time, bool)
|
|
||||||
|
|
||||||
GetWithLifetimeNoExpire(key K) (V, time.Time, bool)
|
|
||||||
|
|
||||||
// GetAndRefresh returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// The lifetime of the found cache item is refreshed, even if it was already expired.
|
|
||||||
GetAndRefresh(key K) (V, bool)
|
|
||||||
|
|
||||||
GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool)
|
|
||||||
|
|
||||||
// Peek looks up a key's value from the cache, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
Peek(key K) (V, bool)
|
|
||||||
|
|
||||||
PeekWithLifetime(key K) (V, time.Time, bool)
|
|
||||||
|
|
||||||
UpdateLifetime(key K, value V, lifetime time.Duration) bool
|
|
||||||
|
|
||||||
// Contains checks for the existence of a key, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
Contains(key K) bool
|
|
||||||
|
|
||||||
// Remove removes the key from the cache.
|
|
||||||
// The return value indicates whether the key existed or not.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
Remove(key K) bool
|
|
||||||
|
|
||||||
// RemoveOldest removes the oldest entry from the cache.
|
|
||||||
// Key, value and an indicator of whether the entry has been removed is returned.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
RemoveOldest() (key K, value V, removed bool)
|
|
||||||
|
|
||||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
|
||||||
// Expired entries are not included.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
Keys() []K
|
|
||||||
|
|
||||||
// Purge purges all data (key and value) from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
// The LRU metrics are reset.
|
|
||||||
Purge()
|
|
||||||
|
|
||||||
// PurgeExpired purges all expired items from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
PurgeExpired()
|
|
||||||
|
|
||||||
// Metrics returns the metrics of the cache.
|
|
||||||
Metrics() Metrics
|
|
||||||
|
|
||||||
// ResetMetrics resets the metrics of the cache and returns the previous state.
|
|
||||||
ResetMetrics() Metrics
|
|
||||||
}
|
|
|
@ -1,767 +0,0 @@
|
||||||
// Licensed to Elasticsearch B.V. under one or more contributor
|
|
||||||
// license agreements. See the NOTICE file distributed with
|
|
||||||
// this work for additional information regarding copyright
|
|
||||||
// ownership. Elasticsearch B.V. licenses this file to you under
|
|
||||||
// the Apache License, Version 2.0 (the "License"); you may
|
|
||||||
// not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing,
|
|
||||||
// software distributed under the License is distributed on an
|
|
||||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
// KIND, either express or implied. See the License for the
|
|
||||||
// specific language governing permissions and limitations
|
|
||||||
// under the License.
|
|
||||||
|
|
||||||
package freelru
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"math/bits"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OnEvictCallback is the type for the eviction function.
|
|
||||||
type OnEvictCallback[K comparable, V comparable] func(K, V)
|
|
||||||
|
|
||||||
// HashKeyCallback is the function that creates a hash from the passed key.
|
|
||||||
type HashKeyCallback[K comparable] func(K) uint32
|
|
||||||
|
|
||||||
type HealthCheckCallback[K comparable, V comparable] func(K, V) bool
|
|
||||||
|
|
||||||
type element[K comparable, V comparable] struct {
|
|
||||||
key K
|
|
||||||
value V
|
|
||||||
|
|
||||||
// bucketNext and bucketPrev are indexes in the space-dimension doubly-linked list of elements.
|
|
||||||
// That is to add/remove items to the collision bucket without re-allocations and with O(1)
|
|
||||||
// complexity.
|
|
||||||
// To simplify the implementation, internally a list l is implemented
|
|
||||||
// as a ring, such that &l.latest.prev is last element and
|
|
||||||
// &l.last.next is the latest element.
|
|
||||||
nextBucket, prevBucket uint32
|
|
||||||
|
|
||||||
// bucketPos is the bucket that an element belongs to.
|
|
||||||
bucketPos uint32
|
|
||||||
|
|
||||||
// next and prev are indexes in the time-dimension doubly-linked list of elements.
|
|
||||||
// To simplify the implementation, internally a list l is implemented
|
|
||||||
// as a ring, such that &l.latest.prev is last element and
|
|
||||||
// &l.last.next is the latest element.
|
|
||||||
next, prev uint32
|
|
||||||
|
|
||||||
// expire is the point in time when the element expires.
|
|
||||||
// Its value is Unix milliseconds since epoch.
|
|
||||||
expire int64
|
|
||||||
}
|
|
||||||
|
|
||||||
const emptyBucket = math.MaxUint32
|
|
||||||
|
|
||||||
// LRU implements a non-thread safe fixed size LRU cache.
|
|
||||||
type LRU[K comparable, V comparable] struct {
|
|
||||||
buckets []uint32 // contains positions of bucket lists or 'emptyBucket'
|
|
||||||
elements []element[K, V]
|
|
||||||
onEvict OnEvictCallback[K, V]
|
|
||||||
hash HashKeyCallback[K]
|
|
||||||
healthCheck HealthCheckCallback[K, V]
|
|
||||||
lifetime time.Duration
|
|
||||||
metrics Metrics
|
|
||||||
|
|
||||||
// used for element clearing after removal or expiration
|
|
||||||
emptyKey K
|
|
||||||
emptyValue V
|
|
||||||
|
|
||||||
head uint32 // index of the newest element in the cache
|
|
||||||
len uint32 // current number of elements in the cache
|
|
||||||
cap uint32 // max number of elements in the cache
|
|
||||||
size uint32 // size of the element array (X% larger than cap)
|
|
||||||
mask uint32 // bitmask to avoid the costly idiv in hashToPos() if size is a 2^n value
|
|
||||||
}
|
|
||||||
|
|
||||||
// Metrics contains metrics about the cache.
|
|
||||||
type Metrics struct {
|
|
||||||
Inserts uint64
|
|
||||||
Collisions uint64
|
|
||||||
Evictions uint64
|
|
||||||
Removals uint64
|
|
||||||
Hits uint64
|
|
||||||
Misses uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Cache[int, int] = (*LRU[int, int])(nil)
|
|
||||||
|
|
||||||
// SetLifetime sets the default lifetime of LRU elements.
|
|
||||||
// Lifetime 0 means "forever".
|
|
||||||
func (lru *LRU[K, V]) SetLifetime(lifetime time.Duration) {
|
|
||||||
lru.lifetime = lifetime
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOnEvict sets the OnEvict callback function.
|
|
||||||
// The onEvict function is called for each evicted lru entry.
|
|
||||||
// Eviction happens
|
|
||||||
// - when the cache is full and a new entry is added (oldest entry is evicted)
|
|
||||||
// - when an entry is removed by Remove() or RemoveOldest()
|
|
||||||
// - when an entry is recognized as expired
|
|
||||||
// - when Purge() is called
|
|
||||||
func (lru *LRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
|
|
||||||
lru.onEvict = onEvict
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
|
|
||||||
lru.healthCheck = healthCheck
|
|
||||||
}
|
|
||||||
|
|
||||||
// New constructs an LRU with the given capacity of elements.
|
|
||||||
// The hash function calculates a hash value from the keys.
|
|
||||||
func New[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) {
|
|
||||||
return NewWithSize[K, V](capacity, capacity, hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWithSize constructs an LRU with the given capacity and size.
|
|
||||||
// The hash function calculates a hash value from the keys.
|
|
||||||
// A size greater than the capacity increases memory consumption and decreases the CPU consumption
|
|
||||||
// by reducing the chance of collisions.
|
|
||||||
// Size must not be lower than the capacity.
|
|
||||||
func NewWithSize[K comparable, V comparable](capacity, size uint32, hash HashKeyCallback[K]) (
|
|
||||||
*LRU[K, V], error,
|
|
||||||
) {
|
|
||||||
if capacity == 0 {
|
|
||||||
return nil, errors.New("capacity must be positive")
|
|
||||||
}
|
|
||||||
if size == emptyBucket {
|
|
||||||
return nil, fmt.Errorf("size must not be %#X", size)
|
|
||||||
}
|
|
||||||
if size < capacity {
|
|
||||||
return nil, fmt.Errorf("size (%d) is smaller than capacity (%d)", size, capacity)
|
|
||||||
}
|
|
||||||
if hash == nil {
|
|
||||||
return nil, errors.New("hash function must be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
buckets := make([]uint32, size)
|
|
||||||
elements := make([]element[K, V], size)
|
|
||||||
|
|
||||||
var lru LRU[K, V]
|
|
||||||
initLRU(&lru, capacity, size, hash, buckets, elements)
|
|
||||||
|
|
||||||
return &lru, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func initLRU[K comparable, V comparable](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K],
|
|
||||||
buckets []uint32, elements []element[K, V],
|
|
||||||
) {
|
|
||||||
lru.cap = capacity
|
|
||||||
lru.size = size
|
|
||||||
lru.hash = hash
|
|
||||||
lru.buckets = buckets
|
|
||||||
lru.elements = elements
|
|
||||||
|
|
||||||
// If the size is 2^N, we can avoid costly divisions.
|
|
||||||
if bits.OnesCount32(lru.size) == 1 {
|
|
||||||
lru.mask = lru.size - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark all slots as free.
|
|
||||||
for i := range lru.buckets {
|
|
||||||
lru.buckets[i] = emptyBucket
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashToBucketPos converts a hash value into a position in the elements array.
|
|
||||||
func (lru *LRU[K, V]) hashToBucketPos(hash uint32) uint32 {
|
|
||||||
if lru.mask != 0 {
|
|
||||||
return hash & lru.mask
|
|
||||||
}
|
|
||||||
return fastModulo(hash, lru.size)
|
|
||||||
}
|
|
||||||
|
|
||||||
// fastModulo calculates x % n without using the modulo operator (~4x faster).
|
|
||||||
// Reference: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
|
|
||||||
func fastModulo(x, n uint32) uint32 {
|
|
||||||
return uint32((uint64(x) * uint64(n)) >> 32) //nolint:gosec
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashToPos converts a key into a position in the elements array.
|
|
||||||
func (lru *LRU[K, V]) hashToPos(hash uint32) (bucketPos, elemPos uint32) {
|
|
||||||
bucketPos = lru.hashToBucketPos(hash)
|
|
||||||
elemPos = lru.buckets[bucketPos]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// setHead links the element as the head into the list.
|
|
||||||
func (lru *LRU[K, V]) setHead(pos uint32) {
|
|
||||||
// Both calls to setHead() check beforehand that pos != lru.head.
|
|
||||||
// So if you run into this situation, you likely use FreeLRU in a concurrent situation
|
|
||||||
// without proper locking. It requires a write lock, even around Get().
|
|
||||||
// But better use SyncedLRU or SharedLRU in such a case.
|
|
||||||
if pos == lru.head {
|
|
||||||
panic(pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.elements[pos].prev = lru.head
|
|
||||||
lru.elements[pos].next = lru.elements[lru.head].next
|
|
||||||
lru.elements[lru.elements[lru.head].next].prev = pos
|
|
||||||
lru.elements[lru.head].next = pos
|
|
||||||
lru.head = pos
|
|
||||||
}
|
|
||||||
|
|
||||||
// unlinkElement removes the element from the elements list.
|
|
||||||
func (lru *LRU[K, V]) unlinkElement(pos uint32) {
|
|
||||||
lru.elements[lru.elements[pos].prev].next = lru.elements[pos].next
|
|
||||||
lru.elements[lru.elements[pos].next].prev = lru.elements[pos].prev
|
|
||||||
}
|
|
||||||
|
|
||||||
// unlinkBucket removes the element from the buckets list.
|
|
||||||
func (lru *LRU[K, V]) unlinkBucket(pos uint32) {
|
|
||||||
prevBucket := lru.elements[pos].prevBucket
|
|
||||||
nextBucket := lru.elements[pos].nextBucket
|
|
||||||
if prevBucket == nextBucket && prevBucket == pos { //nolint:gocritic
|
|
||||||
// The element references itself, so it's the only bucket entry
|
|
||||||
lru.buckets[lru.elements[pos].bucketPos] = emptyBucket
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lru.elements[prevBucket].nextBucket = nextBucket
|
|
||||||
lru.elements[nextBucket].prevBucket = prevBucket
|
|
||||||
lru.buckets[lru.elements[pos].bucketPos] = nextBucket
|
|
||||||
}
|
|
||||||
|
|
||||||
// evict evicts the element at the given position.
|
|
||||||
func (lru *LRU[K, V]) evict(pos uint32) {
|
|
||||||
if pos == lru.head {
|
|
||||||
lru.head = lru.elements[pos].prev
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.unlinkBucket(pos)
|
|
||||||
lru.len--
|
|
||||||
|
|
||||||
if lru.onEvict != nil {
|
|
||||||
// Save k/v for the eviction function.
|
|
||||||
key := lru.elements[pos].key
|
|
||||||
value := lru.elements[pos].value
|
|
||||||
lru.onEvict(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move element from position 'from' to position 'to'.
|
|
||||||
// That avoids 'gaps' and new elements can always be simply appended.
|
|
||||||
func (lru *LRU[K, V]) move(to, from uint32) {
|
|
||||||
if to == from {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if from == lru.head {
|
|
||||||
lru.head = to
|
|
||||||
}
|
|
||||||
|
|
||||||
prev := lru.elements[from].prev
|
|
||||||
next := lru.elements[from].next
|
|
||||||
lru.elements[prev].next = to
|
|
||||||
lru.elements[next].prev = to
|
|
||||||
|
|
||||||
prev = lru.elements[from].prevBucket
|
|
||||||
next = lru.elements[from].nextBucket
|
|
||||||
lru.elements[prev].nextBucket = to
|
|
||||||
lru.elements[next].prevBucket = to
|
|
||||||
|
|
||||||
lru.elements[to] = lru.elements[from]
|
|
||||||
|
|
||||||
if lru.buckets[lru.elements[to].bucketPos] == from {
|
|
||||||
lru.buckets[lru.elements[to].bucketPos] = to
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert stores the k/v at pos.
|
|
||||||
// It updates the head to point to this position.
|
|
||||||
func (lru *LRU[K, V]) insert(pos uint32, key K, value V, lifetime time.Duration) {
|
|
||||||
lru.elements[pos].key = key
|
|
||||||
lru.elements[pos].value = value
|
|
||||||
lru.elements[pos].expire = expire(lifetime)
|
|
||||||
|
|
||||||
if lru.len == 0 {
|
|
||||||
lru.elements[pos].prev = pos
|
|
||||||
lru.elements[pos].next = pos
|
|
||||||
lru.head = pos
|
|
||||||
} else if pos != lru.head {
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.len++
|
|
||||||
lru.metrics.Inserts++
|
|
||||||
}
|
|
||||||
|
|
||||||
func now() int64 {
|
|
||||||
return time.Now().UnixMilli()
|
|
||||||
}
|
|
||||||
|
|
||||||
func expire(lifetime time.Duration) int64 {
|
|
||||||
if lifetime == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return now() + lifetime.Milliseconds()
|
|
||||||
}
|
|
||||||
|
|
||||||
// clearKeyAndValue clears stale data to avoid memory leaks
|
|
||||||
func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) {
|
|
||||||
lru.elements[pos].key = lru.emptyKey
|
|
||||||
lru.elements[pos].value = lru.emptyValue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) {
|
|
||||||
_, startPos := lru.hashToPos(hash)
|
|
||||||
if startPos == emptyBucket {
|
|
||||||
return emptyBucket, false
|
|
||||||
}
|
|
||||||
|
|
||||||
pos := startPos
|
|
||||||
for {
|
|
||||||
if key == lru.elements[pos].key {
|
|
||||||
if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() || (lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value)) {
|
|
||||||
lru.removeAt(pos)
|
|
||||||
return emptyBucket, false
|
|
||||||
}
|
|
||||||
return pos, true
|
|
||||||
}
|
|
||||||
|
|
||||||
pos = lru.elements[pos].nextBucket
|
|
||||||
if pos == startPos {
|
|
||||||
// Key not found
|
|
||||||
return emptyBucket, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) findKeyNoExpire(hash uint32, key K) (uint32, bool) {
|
|
||||||
_, startPos := lru.hashToPos(hash)
|
|
||||||
if startPos == emptyBucket {
|
|
||||||
return emptyBucket, false
|
|
||||||
}
|
|
||||||
|
|
||||||
pos := startPos
|
|
||||||
for {
|
|
||||||
if key == lru.elements[pos].key {
|
|
||||||
if lru.healthCheck != nil && !lru.healthCheck(key, lru.elements[pos].value) {
|
|
||||||
lru.removeAt(pos)
|
|
||||||
return emptyBucket, false
|
|
||||||
}
|
|
||||||
return pos, true
|
|
||||||
}
|
|
||||||
|
|
||||||
pos = lru.elements[pos].nextBucket
|
|
||||||
if pos == startPos {
|
|
||||||
// Key not found
|
|
||||||
return emptyBucket, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the number of elements stored in the cache.
|
|
||||||
func (lru *LRU[K, V]) Len() int {
|
|
||||||
return int(lru.len)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddWithLifetime adds a key:value to the cache with a lifetime.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
func (lru *LRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) {
|
|
||||||
return lru.addWithLifetime(lru.hash(key), key, value, lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V,
|
|
||||||
lifetime time.Duration,
|
|
||||||
) (evicted bool) {
|
|
||||||
bucketPos, startPos := lru.hashToPos(hash)
|
|
||||||
if startPos == emptyBucket {
|
|
||||||
pos := lru.len
|
|
||||||
|
|
||||||
if pos == lru.cap {
|
|
||||||
// Capacity reached, evict the oldest entry and
|
|
||||||
// store the new entry at evicted position.
|
|
||||||
pos = lru.elements[lru.head].next
|
|
||||||
lru.evict(pos)
|
|
||||||
lru.metrics.Evictions++
|
|
||||||
evicted = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert new (first) entry into the bucket
|
|
||||||
lru.buckets[bucketPos] = pos
|
|
||||||
lru.elements[pos].bucketPos = bucketPos
|
|
||||||
|
|
||||||
lru.elements[pos].nextBucket = pos
|
|
||||||
lru.elements[pos].prevBucket = pos
|
|
||||||
lru.insert(pos, key, value, lifetime)
|
|
||||||
return evicted
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walk through the bucket list to see whether key already exists.
|
|
||||||
pos := startPos
|
|
||||||
for {
|
|
||||||
if lru.elements[pos].key == key {
|
|
||||||
// Key exists, replace the value and update element to be the head element.
|
|
||||||
lru.elements[pos].value = value
|
|
||||||
lru.elements[pos].expire = expire(lifetime)
|
|
||||||
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
// count as insert, even if it's just an update
|
|
||||||
lru.metrics.Inserts++
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
pos = lru.elements[pos].nextBucket
|
|
||||||
if pos == startPos {
|
|
||||||
// Key not found
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pos = lru.len
|
|
||||||
if pos == lru.cap {
|
|
||||||
// Capacity reached, evict the oldest entry and
|
|
||||||
// store the new entry at evicted position.
|
|
||||||
pos = lru.elements[lru.head].next
|
|
||||||
lru.evict(pos)
|
|
||||||
lru.metrics.Evictions++
|
|
||||||
evicted = true
|
|
||||||
startPos = lru.buckets[bucketPos]
|
|
||||||
if startPos == emptyBucket {
|
|
||||||
startPos = pos
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert new entry into the existing bucket before startPos
|
|
||||||
lru.buckets[bucketPos] = pos
|
|
||||||
lru.elements[pos].bucketPos = bucketPos
|
|
||||||
|
|
||||||
lru.elements[pos].nextBucket = startPos
|
|
||||||
lru.elements[pos].prevBucket = lru.elements[startPos].prevBucket
|
|
||||||
lru.elements[lru.elements[startPos].prevBucket].nextBucket = pos
|
|
||||||
lru.elements[startPos].prevBucket = pos
|
|
||||||
lru.insert(pos, key, value, lifetime)
|
|
||||||
|
|
||||||
if lru.elements[pos].prevBucket != pos {
|
|
||||||
// The bucket now contains more than 1 element.
|
|
||||||
// That means we have a collision.
|
|
||||||
lru.metrics.Collisions++
|
|
||||||
}
|
|
||||||
return evicted
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a key:value to the cache.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
func (lru *LRU[K, V]) Add(key K, value V) (evicted bool) {
|
|
||||||
return lru.addWithLifetime(lru.hash(key), key, value, lru.lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) {
|
|
||||||
return lru.addWithLifetime(hash, key, value, lru.lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// If the found cache item is already expired, the evict function is called
|
|
||||||
// and the return value indicates that the key was not found.
|
|
||||||
func (lru *LRU[K, V]) Get(key K) (value V, ok bool) {
|
|
||||||
return lru.get(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) {
|
|
||||||
if pos, ok := lru.findKey(hash, key); ok {
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.metrics.Hits++
|
|
||||||
return lru.elements[pos].value, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.metrics.Misses++
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
return lru.getWithLifetime(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) getWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
if pos, ok := lru.findKey(hash, key); ok {
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.metrics.Hits++
|
|
||||||
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.metrics.Misses++
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
return lru.getWithLifetimeNoExpire(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) getWithLifetimeNoExpire(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.metrics.Hits++
|
|
||||||
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.metrics.Misses++
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAndRefresh returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// The lifetime of the found cache item is refreshed, even if it was already expired.
|
|
||||||
func (lru *LRU[K, V]) GetAndRefresh(key K) (V, bool) {
|
|
||||||
return lru.getAndRefresh(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) getAndRefresh(hash uint32, key K) (value V, ok bool) {
|
|
||||||
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.metrics.Hits++
|
|
||||||
lru.elements[pos].expire = expire(lru.lifetime)
|
|
||||||
return lru.elements[pos].value, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.metrics.Misses++
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (V, bool, bool) {
|
|
||||||
value, updated, ok := lru.getAndRefreshOrAdd(lru.hash(key), key, constructor)
|
|
||||||
if !updated && ok {
|
|
||||||
lru.PurgeExpired()
|
|
||||||
}
|
|
||||||
return value, updated, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) getAndRefreshOrAdd(hash uint32, key K, constructor func() (V, bool)) (value V, updated bool, ok bool) {
|
|
||||||
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.metrics.Hits++
|
|
||||||
lru.elements[pos].expire = expire(lru.lifetime)
|
|
||||||
return lru.elements[pos].value, true, true
|
|
||||||
}
|
|
||||||
lru.metrics.Misses++
|
|
||||||
value, ok = constructor()
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lru.addWithLifetime(hash, key, value, lru.lifetime)
|
|
||||||
return value, false, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peek looks up a key's value from the cache, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) {
|
|
||||||
return lru.peek(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) {
|
|
||||||
if pos, ok := lru.findKey(hash, key); ok {
|
|
||||||
return lru.elements[pos].value, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
return lru.peekWithLifetime(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) peekWithLifetime(hash uint32, key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
if pos, ok := lru.findKey(hash, key); ok {
|
|
||||||
return lru.elements[pos].value, time.UnixMilli(lru.elements[pos].expire), ok
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) bool {
|
|
||||||
return lru.updateLifetime(lru.hash(key), key, value, lifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) updateLifetime(hash uint32, key K, value V, lifetime time.Duration) bool {
|
|
||||||
_, startPos := lru.hashToPos(hash)
|
|
||||||
if startPos == emptyBucket {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
pos := startPos
|
|
||||||
for {
|
|
||||||
if lru.elements[pos].key == key {
|
|
||||||
if lru.elements[pos].value != value {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.elements[pos].expire = expire(lifetime)
|
|
||||||
|
|
||||||
if pos != lru.head {
|
|
||||||
lru.unlinkElement(pos)
|
|
||||||
lru.setHead(pos)
|
|
||||||
}
|
|
||||||
lru.metrics.Inserts++
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
pos = lru.elements[pos].nextBucket
|
|
||||||
if pos == startPos {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Contains checks for the existence of a key, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
func (lru *LRU[K, V]) Contains(key K) (ok bool) {
|
|
||||||
_, ok = lru.peek(lru.hash(key), key)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) contains(hash uint32, key K) (ok bool) {
|
|
||||||
_, ok = lru.peek(hash, key)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes the key from the cache.
|
|
||||||
// The return value indicates whether the key existed or not.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
func (lru *LRU[K, V]) Remove(key K) (removed bool) {
|
|
||||||
return lru.remove(lru.hash(key), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) {
|
|
||||||
if pos, ok := lru.findKeyNoExpire(hash, key); ok {
|
|
||||||
lru.removeAt(pos)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) removeAt(pos uint32) {
|
|
||||||
lru.evict(pos)
|
|
||||||
lru.move(pos, lru.len)
|
|
||||||
lru.metrics.Removals++
|
|
||||||
|
|
||||||
// remove stale data to avoid memory leaks
|
|
||||||
lru.clearKeyAndValue(lru.len)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOldest removes the oldest entry from the cache.
|
|
||||||
// Key, value and an indicator of whether the entry has been removed is returned.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
func (lru *LRU[K, V]) RemoveOldest() (key K, value V, removed bool) {
|
|
||||||
if lru.len == 0 {
|
|
||||||
return lru.emptyKey, lru.emptyValue, false
|
|
||||||
}
|
|
||||||
pos := lru.elements[lru.head].next
|
|
||||||
key = lru.elements[pos].key
|
|
||||||
value = lru.elements[pos].value
|
|
||||||
lru.removeAt(pos)
|
|
||||||
return key, value, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
|
||||||
// Expired entries are not included.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
func (lru *LRU[K, V]) Keys() []K {
|
|
||||||
lru.PurgeExpired()
|
|
||||||
|
|
||||||
keys := make([]K, 0, lru.len)
|
|
||||||
pos := lru.elements[lru.head].next
|
|
||||||
for i := uint32(0); i < lru.len; i++ {
|
|
||||||
keys = append(keys, lru.elements[pos].key)
|
|
||||||
pos = lru.elements[pos].next
|
|
||||||
}
|
|
||||||
return keys
|
|
||||||
}
|
|
||||||
|
|
||||||
// Purge purges all data (key and value) from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
// The LRU metrics are reset.
|
|
||||||
func (lru *LRU[K, V]) Purge() {
|
|
||||||
l := lru.len
|
|
||||||
for i := uint32(0); i < l; i++ {
|
|
||||||
_, _, _ = lru.RemoveOldest()
|
|
||||||
}
|
|
||||||
|
|
||||||
lru.metrics = Metrics{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PurgeExpired purges all expired items from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
func (lru *LRU[K, V]) PurgeExpired() {
|
|
||||||
n := now()
|
|
||||||
loop:
|
|
||||||
l := lru.len
|
|
||||||
if l == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pos := lru.elements[lru.head].next
|
|
||||||
for i := uint32(0); i < l; i++ {
|
|
||||||
if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= n {
|
|
||||||
lru.removeAt(pos)
|
|
||||||
goto loop
|
|
||||||
}
|
|
||||||
pos = lru.elements[pos].next
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Metrics returns the metrics of the cache.
|
|
||||||
func (lru *LRU[K, V]) Metrics() Metrics {
|
|
||||||
return lru.metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetMetrics resets the metrics of the cache and returns the previous state.
|
|
||||||
func (lru *LRU[K, V]) ResetMetrics() Metrics {
|
|
||||||
metrics := lru.metrics
|
|
||||||
lru.metrics = Metrics{}
|
|
||||||
return metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// just used for debugging
|
|
||||||
func (lru *LRU[K, V]) dump() {
|
|
||||||
fmt.Printf("head %d len %d cap %d size %d mask 0x%X\n",
|
|
||||||
lru.head, lru.len, lru.cap, lru.size, lru.mask)
|
|
||||||
|
|
||||||
for i := range lru.buckets {
|
|
||||||
if lru.buckets[i] == emptyBucket {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fmt.Printf(" bucket[%d] -> %d\n", i, lru.buckets[i])
|
|
||||||
pos := lru.buckets[i]
|
|
||||||
for {
|
|
||||||
e := &lru.elements[pos]
|
|
||||||
fmt.Printf(" pos %d bucketPos %d prevBucket %d nextBucket %d prev %d next %d k %v v %v\n",
|
|
||||||
pos, e.bucketPos, e.prevBucket, e.nextBucket, e.prev, e.next, e.key, e.value)
|
|
||||||
pos = e.nextBucket
|
|
||||||
if pos == lru.buckets[i] {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *LRU[K, V]) PrintStats() {
|
|
||||||
m := &lru.metrics
|
|
||||||
fmt.Printf("Inserts: %d Collisions: %d (%.2f%%) Evictions: %d Removals: %d Hits: %d (%.2f%%) Misses: %d\n",
|
|
||||||
m.Inserts, m.Collisions, float64(m.Collisions)/float64(m.Inserts)*100,
|
|
||||||
m.Evictions, m.Removals,
|
|
||||||
m.Hits, float64(m.Hits)/float64(m.Hits+m.Misses)*100, m.Misses)
|
|
||||||
}
|
|
|
@ -1,100 +0,0 @@
|
||||||
package freelru_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
F "github.com/sagernet/sing/common/format"
|
|
||||||
"github.com/sagernet/sing/contrab/freelru"
|
|
||||||
"github.com/sagernet/sing/contrab/maphash"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUpdateLifetimeOnGet(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lru.AddWithLifetime("hello", "world", 2*time.Second)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
_, ok := lru.GetAndRefresh("hello")
|
|
||||||
require.True(t, ok)
|
|
||||||
time.Sleep(time.Second + time.Millisecond*100)
|
|
||||||
_, ok = lru.Get("hello")
|
|
||||||
require.True(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateLifetimeOnGet1(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lru.AddWithLifetime("hello", "world", 2*time.Second)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
lru.Peek("hello")
|
|
||||||
time.Sleep(time.Second + time.Millisecond*100)
|
|
||||||
_, ok := lru.Get("hello")
|
|
||||||
require.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateLifetime(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lru.Add("hello", "world")
|
|
||||||
require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second))
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
_, ok := lru.Get("hello")
|
|
||||||
require.True(t, ok)
|
|
||||||
time.Sleep(time.Second + time.Millisecond*100)
|
|
||||||
_, ok = lru.Get("hello")
|
|
||||||
require.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateLifetime1(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lru.Add("hello", "world")
|
|
||||||
require.False(t, lru.UpdateLifetime("hello", "not world", 2*time.Second))
|
|
||||||
time.Sleep(2*time.Second + time.Millisecond*100)
|
|
||||||
_, ok := lru.Get("hello")
|
|
||||||
require.True(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateLifetime2(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lru.AddWithLifetime("hello", "world", 2*time.Second)
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
require.True(t, lru.UpdateLifetime("hello", "world", 2*time.Second))
|
|
||||||
time.Sleep(time.Second + time.Millisecond*100)
|
|
||||||
_, ok := lru.Get("hello")
|
|
||||||
require.True(t, ok)
|
|
||||||
time.Sleep(time.Second + time.Millisecond*100)
|
|
||||||
_, ok = lru.Get("hello")
|
|
||||||
require.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPurgeExpired(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
lru, err := freelru.New[string, *string](1024, maphash.NewHasher[string]().Hash32)
|
|
||||||
require.NoError(t, err)
|
|
||||||
lru.SetLifetime(time.Second)
|
|
||||||
lru.SetOnEvict(func(s string, s2 *string) {
|
|
||||||
if s2 == nil {
|
|
||||||
t.Fail()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
lru.AddWithLifetime("hello_"+F.ToString(i), common.Ptr("world_"+F.ToString(i)), time.Duration(rand.Intn(3000))*time.Millisecond)
|
|
||||||
}
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
lru.GetAndRefreshOrAdd("hellox"+F.ToString(i), func() (*string, bool) {
|
|
||||||
return common.Ptr("worldx"), true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,398 +0,0 @@
|
||||||
package freelru
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math/bits"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ShardedLRU is a thread-safe, sharded, fixed size LRU cache.
|
|
||||||
// Sharding is used to reduce lock contention on high concurrency.
|
|
||||||
// The downside is that exact LRU behavior is not given (as for the LRU and SynchedLRU types).
|
|
||||||
type ShardedLRU[K comparable, V comparable] struct {
|
|
||||||
lrus []LRU[K, V]
|
|
||||||
mus []sync.RWMutex
|
|
||||||
hash HashKeyCallback[K]
|
|
||||||
shards uint32
|
|
||||||
mask uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Cache[int, int] = (*ShardedLRU[int, int])(nil)
|
|
||||||
|
|
||||||
// SetLifetime sets the default lifetime of LRU elements.
|
|
||||||
// Lifetime 0 means "forever".
|
|
||||||
func (lru *ShardedLRU[K, V]) SetLifetime(lifetime time.Duration) {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
lru.lrus[shard].SetLifetime(lifetime)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOnEvict sets the OnEvict callback function.
|
|
||||||
// The onEvict function is called for each evicted lru entry.
|
|
||||||
func (lru *ShardedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
lru.lrus[shard].SetOnEvict(onEvict)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
lru.lrus[shard].SetHealthCheck(healthCheck)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func nextPowerOfTwo(val uint32) uint32 {
|
|
||||||
if bits.OnesCount32(val) != 1 {
|
|
||||||
return 1 << bits.Len32(val)
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSharded creates a new thread-safe sharded LRU hashmap with the given capacity.
|
|
||||||
func NewSharded[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*ShardedLRU[K, V],
|
|
||||||
error,
|
|
||||||
) {
|
|
||||||
size := uint32(float64(capacity) * 1.25) // 25% extra space for fewer collisions
|
|
||||||
|
|
||||||
return NewShardedWithSize[K, V](uint32(runtime.GOMAXPROCS(0)*16), capacity, size, hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewShardedWithSize[K comparable, V comparable](shards, capacity, size uint32,
|
|
||||||
hash HashKeyCallback[K]) (
|
|
||||||
*ShardedLRU[K, V], error,
|
|
||||||
) {
|
|
||||||
if capacity == 0 {
|
|
||||||
return nil, errors.New("capacity must be positive")
|
|
||||||
}
|
|
||||||
if size < capacity {
|
|
||||||
return nil, fmt.Errorf("size (%d) is smaller than capacity (%d)", size, capacity)
|
|
||||||
}
|
|
||||||
|
|
||||||
if size < 1<<31 {
|
|
||||||
size = nextPowerOfTwo(size) // next power of 2 so the LRUs can avoid costly divisions
|
|
||||||
} else {
|
|
||||||
size = 1 << 31 // the highest 2^N value that fits in a uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
shards = nextPowerOfTwo(shards) // next power of 2 so we can avoid costly division for sharding
|
|
||||||
|
|
||||||
for shards > size/16 {
|
|
||||||
shards /= 16
|
|
||||||
}
|
|
||||||
if shards == 0 {
|
|
||||||
shards = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
size /= shards // size per LRU
|
|
||||||
if size == 0 {
|
|
||||||
size = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
capacity = (capacity + shards - 1) / shards // size per LRU
|
|
||||||
if capacity == 0 {
|
|
||||||
capacity = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
lrus := make([]LRU[K, V], shards)
|
|
||||||
buckets := make([]uint32, size*shards)
|
|
||||||
elements := make([]element[K, V], size*shards)
|
|
||||||
|
|
||||||
from := 0
|
|
||||||
to := int(size)
|
|
||||||
for i := range lrus {
|
|
||||||
initLRU(&lrus[i], capacity, size, hash, buckets[from:to], elements[from:to])
|
|
||||||
from = to
|
|
||||||
to += int(size)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ShardedLRU[K, V]{
|
|
||||||
lrus: lrus,
|
|
||||||
mus: make([]sync.RWMutex, shards),
|
|
||||||
hash: hash,
|
|
||||||
shards: shards,
|
|
||||||
mask: shards - 1,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the number of elements stored in the cache.
|
|
||||||
func (lru *ShardedLRU[K, V]) Len() (length int) {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].RLock()
|
|
||||||
length += lru.lrus[shard].Len()
|
|
||||||
lru.mus[shard].RUnlock()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddWithLifetime adds a key:value to the cache with a lifetime.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
func (lru *ShardedLRU[K, V]) AddWithLifetime(key K, value V,
|
|
||||||
lifetime time.Duration,
|
|
||||||
) (evicted bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
evicted = lru.lrus[shard].addWithLifetime(hash, key, value, lifetime)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a key:value to the cache.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
func (lru *ShardedLRU[K, V]) Add(key K, value V) (evicted bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
evicted = lru.lrus[shard].add(hash, key, value)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// If the found cache item is already expired, the evict function is called
|
|
||||||
// and the return value indicates that the key was not found.
|
|
||||||
func (lru *ShardedLRU[K, V]) Get(key K) (value V, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
value, ok = lru.lrus[shard].get(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
value, lifetime, ok = lru.lrus[shard].getWithLifetime(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].RLock()
|
|
||||||
value, lifetime, ok = lru.lrus[shard].getWithLifetimeNoExpire(hash, key)
|
|
||||||
lru.mus[shard].RUnlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAndRefresh returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// The lifetime of the found cache item is refreshed, even if it was already expired.
|
|
||||||
func (lru *ShardedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
value, ok = lru.lrus[shard].getAndRefresh(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
value, updated, ok = lru.lrus[shard].getAndRefreshOrAdd(hash, key, constructor)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
if !updated && ok {
|
|
||||||
lru.PurgeExpired()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peek looks up a key's value from the cache, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
func (lru *ShardedLRU[K, V]) Peek(key K) (value V, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
value, ok = lru.lrus[shard].peek(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
value, lifetime, ok = lru.lrus[shard].peekWithLifetime(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
ok = lru.lrus[shard].updateLifetime(hash, key, value, lifetime)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Contains checks for the existence of a key, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
func (lru *ShardedLRU[K, V]) Contains(key K) (ok bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
ok = lru.lrus[shard].contains(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes the key from the cache.
|
|
||||||
// The return value indicates whether the key existed or not.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
func (lru *ShardedLRU[K, V]) Remove(key K) (removed bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
removed = lru.lrus[shard].remove(hash, key)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOldest removes the oldest entry from the cache.
|
|
||||||
// Key, value and an indicator of whether the entry has been removed is returned.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
func (lru *ShardedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) {
|
|
||||||
hash := lru.hash(key)
|
|
||||||
shard := (hash >> 16) & lru.mask
|
|
||||||
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
key, value, removed = lru.lrus[shard].RemoveOldest()
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
|
||||||
// Expired entries are not included.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
func (lru *ShardedLRU[K, V]) Keys() []K {
|
|
||||||
keys := make([]K, 0, lru.shards*lru.lrus[0].cap)
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
keys = append(keys, lru.lrus[shard].Keys()...)
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return keys
|
|
||||||
}
|
|
||||||
|
|
||||||
// Purge purges all data (key and value) from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
// The LRU metrics are reset.
|
|
||||||
func (lru *ShardedLRU[K, V]) Purge() {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
lru.lrus[shard].Purge()
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PurgeExpired purges all expired items from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
func (lru *ShardedLRU[K, V]) PurgeExpired() {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
lru.lrus[shard].PurgeExpired()
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Metrics returns the metrics of the cache.
|
|
||||||
func (lru *ShardedLRU[K, V]) Metrics() Metrics {
|
|
||||||
metrics := Metrics{}
|
|
||||||
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
m := lru.lrus[shard].Metrics()
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
addMetrics(&metrics, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetMetrics resets the metrics of the cache and returns the previous state.
|
|
||||||
func (lru *ShardedLRU[K, V]) ResetMetrics() Metrics {
|
|
||||||
metrics := Metrics{}
|
|
||||||
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
lru.mus[shard].Lock()
|
|
||||||
m := lru.lrus[shard].ResetMetrics()
|
|
||||||
lru.mus[shard].Unlock()
|
|
||||||
|
|
||||||
addMetrics(&metrics, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
func addMetrics(dst *Metrics, src Metrics) {
|
|
||||||
dst.Inserts += src.Inserts
|
|
||||||
dst.Collisions += src.Collisions
|
|
||||||
dst.Evictions += src.Evictions
|
|
||||||
dst.Removals += src.Removals
|
|
||||||
dst.Hits += src.Hits
|
|
||||||
dst.Misses += src.Misses
|
|
||||||
}
|
|
||||||
|
|
||||||
// just used for debugging
|
|
||||||
func (lru *ShardedLRU[K, V]) dump() {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
fmt.Printf("Shard %d:\n", shard)
|
|
||||||
lru.mus[shard].RLock()
|
|
||||||
lru.lrus[shard].dump()
|
|
||||||
lru.mus[shard].RUnlock()
|
|
||||||
fmt.Println("")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *ShardedLRU[K, V]) PrintStats() {
|
|
||||||
for shard := range lru.lrus {
|
|
||||||
fmt.Printf("Shard %d:\n", shard)
|
|
||||||
lru.mus[shard].RLock()
|
|
||||||
lru.lrus[shard].PrintStats()
|
|
||||||
lru.mus[shard].RUnlock()
|
|
||||||
fmt.Println("")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,270 +0,0 @@
|
||||||
package freelru
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SyncedLRU[K comparable, V comparable] struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
lru *LRU[K, V]
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Cache[int, int] = (*SyncedLRU[int, int])(nil)
|
|
||||||
|
|
||||||
// SetLifetime sets the default lifetime of LRU elements.
|
|
||||||
// Lifetime 0 means "forever".
|
|
||||||
func (lru *SyncedLRU[K, V]) SetLifetime(lifetime time.Duration) {
|
|
||||||
lru.mu.Lock()
|
|
||||||
lru.lru.SetLifetime(lifetime)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOnEvict sets the OnEvict callback function.
|
|
||||||
// The onEvict function is called for each evicted lru entry.
|
|
||||||
func (lru *SyncedLRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) {
|
|
||||||
lru.mu.Lock()
|
|
||||||
lru.lru.SetOnEvict(onEvict)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) {
|
|
||||||
lru.mu.Lock()
|
|
||||||
lru.lru.SetHealthCheck(healthCheck)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSynced creates a new thread-safe LRU hashmap with the given capacity.
|
|
||||||
func NewSynced[K comparable, V comparable](capacity uint32, hash HashKeyCallback[K]) (*SyncedLRU[K, V],
|
|
||||||
error,
|
|
||||||
) {
|
|
||||||
return NewSyncedWithSize[K, V](capacity, capacity, hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSyncedWithSize[K comparable, V comparable](capacity, size uint32,
|
|
||||||
hash HashKeyCallback[K],
|
|
||||||
) (*SyncedLRU[K, V], error) {
|
|
||||||
lru, err := NewWithSize[K, V](capacity, size, hash)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &SyncedLRU[K, V]{lru: lru}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the number of elements stored in the cache.
|
|
||||||
func (lru *SyncedLRU[K, V]) Len() (length int) {
|
|
||||||
lru.mu.RLock()
|
|
||||||
length = lru.lru.Len()
|
|
||||||
lru.mu.RUnlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddWithLifetime adds a key:value to the cache with a lifetime.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
func (lru *SyncedLRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (evicted bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
evicted = lru.lru.addWithLifetime(hash, key, value, lifetime)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a key:value to the cache.
|
|
||||||
// Returns true, true if key was updated and eviction occurred.
|
|
||||||
func (lru *SyncedLRU[K, V]) Add(key K, value V) (evicted bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
evicted = lru.lru.add(hash, key, value)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// If the found cache item is already expired, the evict function is called
|
|
||||||
// and the return value indicates that the key was not found.
|
|
||||||
func (lru *SyncedLRU[K, V]) Get(key K) (value V, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, ok = lru.lru.get(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) GetWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, lifetime, ok = lru.lru.getWithLifetime(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) GetWithLifetimeNoExpire(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, lifetime, ok = lru.lru.getWithLifetimeNoExpire(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAndRefresh returns the value associated with the key, setting it as the most
|
|
||||||
// recently used item.
|
|
||||||
// The lifetime of the found cache item is refreshed, even if it was already expired.
|
|
||||||
func (lru *SyncedLRU[K, V]) GetAndRefresh(key K) (value V, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, ok = lru.lru.getAndRefresh(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) GetAndRefreshOrAdd(key K, constructor func() (V, bool)) (value V, updated bool, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, updated, ok = lru.lru.getAndRefreshOrAdd(hash, key, constructor)
|
|
||||||
if !updated && ok {
|
|
||||||
lru.lru.PurgeExpired()
|
|
||||||
}
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peek looks up a key's value from the cache, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
func (lru *SyncedLRU[K, V]) Peek(key K) (value V, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, ok = lru.lru.peek(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) PeekWithLifetime(key K) (value V, lifetime time.Time, ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
value, lifetime, ok = lru.lru.peekWithLifetime(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) UpdateLifetime(key K, value V, lifetime time.Duration) (ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
ok = lru.lru.updateLifetime(hash, key, value, lifetime)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Contains checks for the existence of a key, without changing its recent-ness.
|
|
||||||
// If the found entry is already expired, the evict function is called.
|
|
||||||
func (lru *SyncedLRU[K, V]) Contains(key K) (ok bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
ok = lru.lru.contains(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes the key from the cache.
|
|
||||||
// The return value indicates whether the key existed or not.
|
|
||||||
// The evict function is being called if the key existed.
|
|
||||||
func (lru *SyncedLRU[K, V]) Remove(key K) (removed bool) {
|
|
||||||
hash := lru.lru.hash(key)
|
|
||||||
|
|
||||||
lru.mu.Lock()
|
|
||||||
removed = lru.lru.remove(hash, key)
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOldest removes the oldest entry from the cache.
|
|
||||||
// Key, value and an indicator of whether the entry has been removed is returned.
|
|
||||||
// The evict function is called for the removed entry.
|
|
||||||
func (lru *SyncedLRU[K, V]) RemoveOldest() (key K, value V, removed bool) {
|
|
||||||
lru.mu.Lock()
|
|
||||||
key, value, removed = lru.lru.RemoveOldest()
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
|
||||||
// Expired entries are not included.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
func (lru *SyncedLRU[K, V]) Keys() (keys []K) {
|
|
||||||
lru.mu.Lock()
|
|
||||||
keys = lru.lru.Keys()
|
|
||||||
lru.mu.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Purge purges all data (key and value) from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
// The LRU metrics are reset.
|
|
||||||
func (lru *SyncedLRU[K, V]) Purge() {
|
|
||||||
lru.mu.Lock()
|
|
||||||
lru.lru.Purge()
|
|
||||||
lru.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PurgeExpired purges all expired items from the LRU.
|
|
||||||
// The evict function is called for each expired item.
|
|
||||||
func (lru *SyncedLRU[K, V]) PurgeExpired() {
|
|
||||||
lru.mu.Lock()
|
|
||||||
lru.lru.PurgeExpired()
|
|
||||||
lru.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Metrics returns the metrics of the cache.
|
|
||||||
func (lru *SyncedLRU[K, V]) Metrics() Metrics {
|
|
||||||
lru.mu.Lock()
|
|
||||||
metrics := lru.lru.Metrics()
|
|
||||||
lru.mu.Unlock()
|
|
||||||
return metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetMetrics resets the metrics of the cache and returns the previous state.
|
|
||||||
func (lru *SyncedLRU[K, V]) ResetMetrics() Metrics {
|
|
||||||
lru.mu.Lock()
|
|
||||||
metrics := lru.lru.ResetMetrics()
|
|
||||||
lru.mu.Unlock()
|
|
||||||
return metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// just used for debugging
|
|
||||||
func (lru *SyncedLRU[K, V]) dump() {
|
|
||||||
lru.mu.RLock()
|
|
||||||
lru.lru.dump()
|
|
||||||
lru.mu.RUnlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lru *SyncedLRU[K, V]) PrintStats() {
|
|
||||||
lru.mu.RLock()
|
|
||||||
lru.lru.PrintStats()
|
|
||||||
lru.mu.RUnlock()
|
|
||||||
}
|
|
|
@ -1,3 +0,0 @@
|
||||||
# maphash
|
|
||||||
|
|
||||||
kanged from github.com/dolthub/maphash@v0.1.0
|
|
|
@ -1,53 +0,0 @@
|
||||||
// Copyright 2022 Dolthub, Inc.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package maphash
|
|
||||||
|
|
||||||
import "unsafe"
|
|
||||||
|
|
||||||
// Hasher hashes values of type K.
|
|
||||||
// Uses runtime AES-based hashing.
|
|
||||||
type Hasher[K comparable] struct {
|
|
||||||
hash hashfn
|
|
||||||
seed uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHasher creates a new Hasher[K] with a random seed.
|
|
||||||
func NewHasher[K comparable]() Hasher[K] {
|
|
||||||
return Hasher[K]{
|
|
||||||
hash: getRuntimeHasher[K](),
|
|
||||||
seed: newHashSeed(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSeed returns a copy of |h| with a new hash seed.
|
|
||||||
func NewSeed[K comparable](h Hasher[K]) Hasher[K] {
|
|
||||||
return Hasher[K]{
|
|
||||||
hash: h.hash,
|
|
||||||
seed: newHashSeed(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash hashes |key|.
|
|
||||||
func (h Hasher[K]) Hash(key K) uint64 {
|
|
||||||
// promise to the compiler that pointer
|
|
||||||
// |p| does not escape the stack.
|
|
||||||
p := noescape(unsafe.Pointer(&key))
|
|
||||||
return uint64(h.hash(p, h.seed))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h Hasher[K]) Hash32(key K) uint32 {
|
|
||||||
p := noescape(unsafe.Pointer(&key))
|
|
||||||
return uint32(h.hash(p, h.seed))
|
|
||||||
}
|
|
|
@ -1,114 +0,0 @@
|
||||||
// Copyright 2022 Dolthub, Inc.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
//
|
|
||||||
// This file incorporates work covered by the following copyright and
|
|
||||||
// permission notice:
|
|
||||||
//
|
|
||||||
// Copyright 2022 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build go1.18 || go1.19
|
|
||||||
// +build go1.18 go1.19
|
|
||||||
|
|
||||||
package maphash
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type hashfn func(unsafe.Pointer, uintptr) uintptr
|
|
||||||
|
|
||||||
func getRuntimeHasher[K comparable]() (h hashfn) {
|
|
||||||
a := any(make(map[K]struct{}))
|
|
||||||
i := (*mapiface)(unsafe.Pointer(&a))
|
|
||||||
h = i.typ.hasher
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHashSeed() uintptr {
|
|
||||||
return uintptr(rand.Int())
|
|
||||||
}
|
|
||||||
|
|
||||||
// noescape hides a pointer from escape analysis. It is the identity function
|
|
||||||
// but escape analysis doesn't think the output depends on the input.
|
|
||||||
// noescape is inlined and currently compiles down to zero instructions.
|
|
||||||
// USE CAREFULLY!
|
|
||||||
// This was copied from the runtime (via pkg "strings"); see issues 23382 and 7921.
|
|
||||||
//
|
|
||||||
//go:nosplit
|
|
||||||
//go:nocheckptr
|
|
||||||
func noescape(p unsafe.Pointer) unsafe.Pointer {
|
|
||||||
x := uintptr(p)
|
|
||||||
//nolint:staticcheck
|
|
||||||
return unsafe.Pointer(x ^ 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mapiface struct {
|
|
||||||
typ *maptype
|
|
||||||
val *hmap
|
|
||||||
}
|
|
||||||
|
|
||||||
// go/src/runtime/type.go
|
|
||||||
type maptype struct {
|
|
||||||
typ _type
|
|
||||||
key *_type
|
|
||||||
elem *_type
|
|
||||||
bucket *_type
|
|
||||||
// function for hashing keys (ptr to key, seed) -> hash
|
|
||||||
hasher func(unsafe.Pointer, uintptr) uintptr
|
|
||||||
keysize uint8
|
|
||||||
elemsize uint8
|
|
||||||
bucketsize uint16
|
|
||||||
flags uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// go/src/runtime/map.go
|
|
||||||
type hmap struct {
|
|
||||||
count int
|
|
||||||
flags uint8
|
|
||||||
B uint8
|
|
||||||
noverflow uint16
|
|
||||||
// hash seed
|
|
||||||
hash0 uint32
|
|
||||||
buckets unsafe.Pointer
|
|
||||||
oldbuckets unsafe.Pointer
|
|
||||||
nevacuate uintptr
|
|
||||||
// true type is *mapextra
|
|
||||||
// but we don't need this data
|
|
||||||
extra unsafe.Pointer
|
|
||||||
}
|
|
||||||
|
|
||||||
// go/src/runtime/type.go
|
|
||||||
type (
|
|
||||||
tflag uint8
|
|
||||||
nameOff int32
|
|
||||||
typeOff int32
|
|
||||||
)
|
|
||||||
|
|
||||||
// go/src/runtime/type.go
|
|
||||||
type _type struct {
|
|
||||||
size uintptr
|
|
||||||
ptrdata uintptr
|
|
||||||
hash uint32
|
|
||||||
tflag tflag
|
|
||||||
align uint8
|
|
||||||
fieldAlign uint8
|
|
||||||
kind uint8
|
|
||||||
equal func(unsafe.Pointer, unsafe.Pointer) bool
|
|
||||||
gcdata *byte
|
|
||||||
str nameOff
|
|
||||||
ptrToThis typeOff
|
|
||||||
}
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
std_bufio "bufio"
|
std_bufio "bufio"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -21,20 +20,15 @@ import (
|
||||||
"github.com/sagernet/sing/common/pipe"
|
"github.com/sagernet/sing/common/pipe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleConnectionEx(
|
type Handler = N.TCPConnectionHandler
|
||||||
ctx context.Context,
|
|
||||||
conn net.Conn,
|
func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
|
||||||
reader *std_bufio.Reader,
|
|
||||||
authenticator *auth.Authenticator,
|
|
||||||
handler N.TCPConnectionHandlerEx,
|
|
||||||
source M.Socksaddr,
|
|
||||||
onClose N.CloseHandlerFunc,
|
|
||||||
) error {
|
|
||||||
for {
|
for {
|
||||||
request, err := ReadRequest(reader)
|
request, err := ReadRequest(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "read http request")
|
return E.Cause(err, "read http request")
|
||||||
}
|
}
|
||||||
|
|
||||||
if authenticator != nil {
|
if authenticator != nil {
|
||||||
var (
|
var (
|
||||||
username string
|
username string
|
||||||
|
@ -74,23 +68,22 @@ func HandleConnectionEx(
|
||||||
}
|
}
|
||||||
|
|
||||||
if sourceAddress := SourceAddress(request); sourceAddress.IsValid() {
|
if sourceAddress := SourceAddress(request); sourceAddress.IsValid() {
|
||||||
source = sourceAddress
|
metadata.Source = sourceAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.Method == "CONNECT" {
|
if request.Method == "CONNECT" {
|
||||||
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
|
portStr := request.URL.Port()
|
||||||
if destination.Port == 0 {
|
if portStr == "" {
|
||||||
switch request.URL.Scheme {
|
portStr = "80"
|
||||||
case "https", "wss":
|
|
||||||
destination.Port = 443
|
|
||||||
default:
|
|
||||||
destination.Port = 80
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
|
||||||
_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
|
_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "write http response")
|
return E.Cause(err, "write http response")
|
||||||
}
|
}
|
||||||
|
metadata.Protocol = "http"
|
||||||
|
metadata.Destination = destination
|
||||||
|
|
||||||
var requestConn net.Conn
|
var requestConn net.Conn
|
||||||
if reader.Buffered() > 0 {
|
if reader.Buffered() > 0 {
|
||||||
buffer := buf.NewSize(reader.Buffered())
|
buffer := buf.NewSize(reader.Buffered())
|
||||||
|
@ -102,115 +95,75 @@ func HandleConnectionEx(
|
||||||
} else {
|
} else {
|
||||||
requestConn = conn
|
requestConn = conn
|
||||||
}
|
}
|
||||||
handler.NewConnectionEx(ctx, requestConn, source, destination, onClose)
|
return handler.NewConnection(ctx, requestConn, metadata)
|
||||||
return nil
|
}
|
||||||
} else if strings.ToLower(request.Header.Get("Connection")) == "upgrade" {
|
|
||||||
destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), request.URL.Port())
|
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
||||||
if destination.Port == 0 {
|
request.RequestURI = ""
|
||||||
switch request.URL.Scheme {
|
|
||||||
case "https", "wss":
|
removeHopByHopHeaders(request.Header)
|
||||||
destination.Port = 443
|
removeExtraHTTPHostPort(request)
|
||||||
default:
|
|
||||||
destination.Port = 80
|
if hostStr := request.Header.Get("Host"); hostStr != "" {
|
||||||
}
|
if hostStr != request.URL.Host {
|
||||||
}
|
request.Host = hostStr
|
||||||
serverConn, clientConn := pipe.Pipe()
|
|
||||||
go func() {
|
|
||||||
handler.NewConnectionEx(ctx, clientConn, source, destination, func(it error) {
|
|
||||||
if it != nil {
|
|
||||||
common.Close(serverConn, clientConn)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
err = request.Write(serverConn)
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "http: write upgrade request")
|
|
||||||
}
|
|
||||||
if reader.Buffered() > 0 {
|
|
||||||
_, err = io.CopyN(serverConn, reader, int64(reader.Buffered()))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return bufio.CopyConn(ctx, conn, serverConn)
|
|
||||||
} else {
|
|
||||||
err = handleHTTPConnection(ctx, handler, conn, request, source)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleHTTPConnection(
|
if request.URL.Scheme == "" || request.URL.Host == "" {
|
||||||
ctx context.Context,
|
return responseWith(request, http.StatusBadRequest).Write(conn)
|
||||||
handler N.TCPConnectionHandlerEx,
|
|
||||||
conn net.Conn,
|
|
||||||
request *http.Request, source M.Socksaddr,
|
|
||||||
) error {
|
|
||||||
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
|
|
||||||
request.RequestURI = ""
|
|
||||||
|
|
||||||
removeHopByHopHeaders(request.Header)
|
|
||||||
removeExtraHTTPHostPort(request)
|
|
||||||
|
|
||||||
if hostStr := request.Header.Get("Host"); hostStr != "" {
|
|
||||||
if hostStr != request.URL.Host {
|
|
||||||
request.Host = hostStr
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if request.URL.Scheme == "" || request.URL.Host == "" {
|
var innerErr atomic.TypedValue[error]
|
||||||
return responseWith(request, http.StatusBadRequest).Write(conn)
|
httpClient := &http.Client{
|
||||||
}
|
Transport: &http.Transport{
|
||||||
|
DisableCompression: true,
|
||||||
var innerErr atomic.TypedValue[error]
|
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
httpClient := &http.Client{
|
metadata.Destination = M.ParseSocksaddr(address)
|
||||||
Transport: &http.Transport{
|
metadata.Protocol = "http"
|
||||||
DisableCompression: true,
|
input, output := pipe.Pipe()
|
||||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
go func() {
|
||||||
input, output := pipe.Pipe()
|
hErr := handler.NewConnection(ctx, output, metadata)
|
||||||
go handler.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) {
|
if hErr != nil {
|
||||||
innerErr.Store(it)
|
innerErr.Store(hErr)
|
||||||
common.Close(input, output)
|
common.Close(input, output)
|
||||||
})
|
}
|
||||||
return input, nil
|
}()
|
||||||
|
return input, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
return http.ErrUseLastResponse
|
||||||
return http.ErrUseLastResponse
|
},
|
||||||
},
|
}
|
||||||
}
|
requestCtx, cancel := context.WithCancel(ctx)
|
||||||
defer httpClient.CloseIdleConnections()
|
response, err := httpClient.Do(request.WithContext(requestCtx))
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
|
||||||
|
}
|
||||||
|
|
||||||
|
removeHopByHopHeaders(response.Header)
|
||||||
|
|
||||||
|
if keepAlive {
|
||||||
|
response.Header.Set("Proxy-Connection", "keep-alive")
|
||||||
|
response.Header.Set("Connection", "keep-alive")
|
||||||
|
response.Header.Set("Keep-Alive", "timeout=4")
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Close = !keepAlive
|
||||||
|
|
||||||
|
err = response.Write(conn)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
return E.Errors(innerErr.Load(), err)
|
||||||
|
}
|
||||||
|
|
||||||
requestCtx, cancel := context.WithCancel(ctx)
|
|
||||||
response, err := httpClient.Do(request.WithContext(requestCtx))
|
|
||||||
if err != nil {
|
|
||||||
cancel()
|
cancel()
|
||||||
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
|
if !keepAlive {
|
||||||
|
return conn.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
removeHopByHopHeaders(response.Header)
|
|
||||||
|
|
||||||
if keepAlive {
|
|
||||||
response.Header.Set("Proxy-Connection", "keep-alive")
|
|
||||||
response.Header.Set("Connection", "keep-alive")
|
|
||||||
response.Header.Set("Keep-Alive", "timeout=4")
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Close = !keepAlive
|
|
||||||
|
|
||||||
err = response.Write(conn)
|
|
||||||
if err != nil {
|
|
||||||
cancel()
|
|
||||||
return E.Errors(innerErr.Load(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
if !keepAlive {
|
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeHopByHopHeaders(header http.Header) {
|
func removeHopByHopHeaders(header http.Header) {
|
||||||
|
|
|
@ -10,7 +10,6 @@ import (
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/auth"
|
"github.com/sagernet/sing/common/auth"
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
@ -20,13 +19,9 @@ import (
|
||||||
"github.com/sagernet/sing/protocol/socks/socks5"
|
"github.com/sagernet/sing/protocol/socks/socks5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HandlerEx interface {
|
type Handler interface {
|
||||||
N.TCPConnectionHandlerEx
|
N.TCPConnectionHandler
|
||||||
N.UDPConnectionHandlerEx
|
N.UDPConnectionHandler
|
||||||
}
|
|
||||||
|
|
||||||
type PacketListener interface {
|
|
||||||
ListenPacket(listenConfig net.ListenConfig, ctx context.Context, network string, address string) (net.PacketConn, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
|
func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
|
||||||
|
@ -84,26 +79,6 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
|
||||||
} else if authResponse.Method != socks5.AuthTypeNotRequired {
|
} else if authResponse.Method != socks5.AuthTypeNotRequired {
|
||||||
return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method)
|
return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method)
|
||||||
}
|
}
|
||||||
|
|
||||||
if command == socks5.CommandUDPAssociate {
|
|
||||||
if destination.Addr.IsPrivate() {
|
|
||||||
if destination.Addr.Is6() {
|
|
||||||
destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
|
||||||
} else {
|
|
||||||
destination.Addr = netip.IPv6Loopback()
|
|
||||||
}
|
|
||||||
} else if destination.Addr.IsGlobalUnicast() {
|
|
||||||
if destination.Addr.Is6() {
|
|
||||||
destination.Addr = netip.IPv6Unspecified()
|
|
||||||
} else {
|
|
||||||
destination.Addr = netip.IPv4Unspecified()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
destination.Addr = netip.IPv6Unspecified()
|
|
||||||
}
|
|
||||||
destination.Port = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
err = socks5.WriteRequest(conn, socks5.Request{
|
err = socks5.WriteRequest(conn, socks5.Request{
|
||||||
Command: command,
|
Command: command,
|
||||||
Destination: destination,
|
Destination: destination,
|
||||||
|
@ -121,23 +96,18 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleConnectionEx(
|
func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
|
||||||
ctx context.Context, conn net.Conn, reader *std_bufio.Reader,
|
return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata)
|
||||||
authenticator *auth.Authenticator,
|
}
|
||||||
handler HandlerEx,
|
|
||||||
packetListener PacketListener,
|
func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
|
||||||
// resolver TorResolver,
|
|
||||||
source M.Socksaddr,
|
|
||||||
onClose N.CloseHandlerFunc,
|
|
||||||
) error {
|
|
||||||
version, err := reader.ReadByte()
|
version, err := reader.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
switch version {
|
switch version {
|
||||||
case socks4.Version:
|
case socks4.Version:
|
||||||
var request socks4.Request
|
request, err := socks4.ReadRequest0(reader)
|
||||||
request, err = socks4.ReadRequest0(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -145,23 +115,28 @@ func HandleConnectionEx(
|
||||||
case socks4.CommandConnect:
|
case socks4.CommandConnect:
|
||||||
if authenticator != nil && !authenticator.Verify(request.Username, "") {
|
if authenticator != nil && !authenticator.Verify(request.Username, "") {
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
err = socks4.WriteResponse(conn, socks4.Response{
|
||||||
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
||||||
|
Destination: request.Destination,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return E.New("socks4: authentication failed, username=", request.Username)
|
return E.New("socks4: authentication failed, username=", request.Username)
|
||||||
}
|
}
|
||||||
handler.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, request.Destination, onClose)
|
err = socks4.WriteResponse(conn, socks4.Response{
|
||||||
return nil
|
ReplyCode: socks4.ReplyCodeGranted,
|
||||||
/*case CommandTorResolve, CommandTorResolvePTR:
|
Destination: M.SocksaddrFromNet(conn.LocalAddr()),
|
||||||
if resolver == nil {
|
})
|
||||||
return E.New("socks4: torsocks: commands not implemented")
|
if err != nil {
|
||||||
}
|
return err
|
||||||
return handleTorSocks4(ctx, conn, request, resolver)*/
|
}
|
||||||
|
metadata.Protocol = "socks4"
|
||||||
|
metadata.Destination = request.Destination
|
||||||
|
return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, metadata)
|
||||||
default:
|
default:
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
err = socks4.WriteResponse(conn, socks4.Response{
|
||||||
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
||||||
|
Destination: request.Destination,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -169,8 +144,7 @@ func HandleConnectionEx(
|
||||||
return E.New("socks4: unsupported command ", request.Command)
|
return E.New("socks4: unsupported command ", request.Command)
|
||||||
}
|
}
|
||||||
case socks5.Version:
|
case socks5.Version:
|
||||||
var authRequest socks5.AuthRequest
|
authRequest, err := socks5.ReadAuthRequest0(reader)
|
||||||
authRequest, err = socks5.ReadAuthRequest0(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -195,8 +169,7 @@ func HandleConnectionEx(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if authMethod == socks5.AuthTypeUsernamePassword {
|
if authMethod == socks5.AuthTypeUsernamePassword {
|
||||||
var usernamePasswordAuthRequest socks5.UsernamePasswordAuthRequest
|
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader)
|
||||||
usernamePasswordAuthRequest, err = socks5.ReadUsernamePasswordAuthRequest(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -215,50 +188,49 @@ func HandleConnectionEx(
|
||||||
return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password)
|
return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var request socks5.Request
|
request, err := socks5.ReadRequest(reader)
|
||||||
request, err = socks5.ReadRequest(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
switch request.Command {
|
switch request.Command {
|
||||||
case socks5.CommandConnect:
|
case socks5.CommandConnect:
|
||||||
handler.NewConnectionEx(ctx, NewLazyConn(conn, version), source, request.Destination, onClose)
|
err = socks5.WriteResponse(conn, socks5.Response{
|
||||||
return nil
|
ReplyCode: socks5.ReplyCodeSuccess,
|
||||||
case socks5.CommandUDPAssociate:
|
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
|
||||||
var (
|
})
|
||||||
listenConfig net.ListenConfig
|
|
||||||
udpConn net.PacketConn
|
|
||||||
)
|
|
||||||
if packetListener != nil {
|
|
||||||
udpConn, err = packetListener.ListenPacket(listenConfig, ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String())
|
|
||||||
} else {
|
|
||||||
udpConn, err = listenConfig.ListenPacket(ctx, M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), M.SocksaddrFrom(M.AddrFromNet(conn.LocalAddr()), 0).String())
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "socks5: listen udp")
|
return err
|
||||||
}
|
}
|
||||||
|
metadata.Protocol = "socks5"
|
||||||
|
metadata.Destination = request.Destination
|
||||||
|
return handler.NewConnection(ctx, conn, metadata)
|
||||||
|
case socks5.CommandUDPAssociate:
|
||||||
|
var udpConn *net.UDPConn
|
||||||
|
udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer udpConn.Close()
|
||||||
err = socks5.WriteResponse(conn, socks5.Response{
|
err = socks5.WriteResponse(conn, socks5.Response{
|
||||||
ReplyCode: socks5.ReplyCodeSuccess,
|
ReplyCode: socks5.ReplyCodeSuccess,
|
||||||
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
|
Bind: M.SocksaddrFromNet(udpConn.LocalAddr()),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Cause(err, "socks5: write response")
|
return err
|
||||||
}
|
}
|
||||||
var socksPacketConn N.PacketConn = NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), M.Socksaddr{}, conn)
|
metadata.Protocol = "socks5"
|
||||||
firstPacket := buf.NewPacket()
|
metadata.Destination = request.Destination
|
||||||
var destination M.Socksaddr
|
var innerError error
|
||||||
destination, err = socksPacketConn.ReadPacket(firstPacket)
|
done := make(chan struct{})
|
||||||
if err != nil {
|
associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn)
|
||||||
return E.Cause(err, "socks5: read first packet")
|
go func() {
|
||||||
}
|
innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata)
|
||||||
socksPacketConn = bufio.NewCachedPacketConn(socksPacketConn, firstPacket, destination)
|
close(done)
|
||||||
handler.NewPacketConnectionEx(ctx, socksPacketConn, source, destination, onClose)
|
}()
|
||||||
return nil
|
err = common.Error(io.Copy(io.Discard, conn))
|
||||||
/*case CommandTorResolve, CommandTorResolvePTR:
|
associatePacketConn.Close()
|
||||||
if resolver == nil {
|
<-done
|
||||||
return E.New("socks4: torsocks: commands not implemented")
|
return E.Errors(innerError, err)
|
||||||
}
|
|
||||||
return handleTorSocks5(ctx, conn, request, resolver)*/
|
|
||||||
default:
|
default:
|
||||||
err = socks5.WriteResponse(conn, socks5.Response{
|
err = socks5.WriteResponse(conn, socks5.Response{
|
||||||
ReplyCode: socks5.ReplyCodeUnsupported,
|
ReplyCode: socks5.ReplyCodeUnsupported,
|
||||||
|
|
|
@ -1,146 +0,0 @@
|
||||||
package socks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
"github.com/sagernet/sing/protocol/socks/socks4"
|
|
||||||
"github.com/sagernet/sing/protocol/socks/socks5"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
CommandTorResolve byte = 0xF0
|
|
||||||
CommandTorResolvePTR byte = 0xF1
|
|
||||||
)
|
|
||||||
|
|
||||||
type TorResolver interface {
|
|
||||||
LookupIP(ctx context.Context, host string) (netip.Addr, error)
|
|
||||||
LookupPTR(ctx context.Context, addr netip.Addr) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleTorSocks4(ctx context.Context, conn net.Conn, request socks4.Request, resolver TorResolver) error {
|
|
||||||
switch request.Command {
|
|
||||||
case CommandTorResolve:
|
|
||||||
if !request.Destination.IsFqdn() {
|
|
||||||
return E.New("socks4: torsocks: invalid destination")
|
|
||||||
}
|
|
||||||
ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn)
|
|
||||||
if err != nil {
|
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
|
||||||
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return E.Cause(err, "socks4: torsocks: lookup failed for domain: ", request.Destination.Fqdn)
|
|
||||||
}
|
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
|
||||||
ReplyCode: socks4.ReplyCodeGranted,
|
|
||||||
Destination: M.SocksaddrFrom(ipAddr, 0),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "socks4: torsocks: write response")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case CommandTorResolvePTR:
|
|
||||||
var ipAddr netip.Addr
|
|
||||||
if request.Destination.IsIP() {
|
|
||||||
ipAddr = request.Destination.Addr
|
|
||||||
} else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") {
|
|
||||||
ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")])
|
|
||||||
} else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") {
|
|
||||||
ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":"))
|
|
||||||
}
|
|
||||||
if !ipAddr.IsValid() {
|
|
||||||
return E.New("socks4: torsocks: invalid destination")
|
|
||||||
}
|
|
||||||
host, err := resolver.LookupPTR(ctx, ipAddr)
|
|
||||||
if err != nil {
|
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
|
||||||
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return E.Cause(err, "socks4: torsocks: lookup PTR failed for ip: ", ipAddr)
|
|
||||||
}
|
|
||||||
err = socks4.WriteResponse(conn, socks4.Response{
|
|
||||||
ReplyCode: socks4.ReplyCodeGranted,
|
|
||||||
Destination: M.Socksaddr{
|
|
||||||
Fqdn: host,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "socks4: torsocks: write response")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleTorSocks5(ctx context.Context, conn net.Conn, request socks5.Request, resolver TorResolver) error {
|
|
||||||
switch request.Command {
|
|
||||||
case CommandTorResolve:
|
|
||||||
if !request.Destination.IsFqdn() {
|
|
||||||
return E.New("socks5: torsocks: invalid destination")
|
|
||||||
}
|
|
||||||
ipAddr, err := resolver.LookupIP(ctx, request.Destination.Fqdn)
|
|
||||||
if err != nil {
|
|
||||||
err = socks5.WriteResponse(conn, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeFailure,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return E.Cause(err, "socks5: torsocks: lookup failed for domain: ", request.Destination.Fqdn)
|
|
||||||
}
|
|
||||||
err = socks5.WriteResponse(conn, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeSuccess,
|
|
||||||
Bind: M.SocksaddrFrom(ipAddr, 0),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "socks5: torsocks: write response")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case CommandTorResolvePTR:
|
|
||||||
var ipAddr netip.Addr
|
|
||||||
if request.Destination.IsIP() {
|
|
||||||
ipAddr = request.Destination.Addr
|
|
||||||
} else if strings.HasSuffix(request.Destination.Fqdn, ".in-addr.arpa") {
|
|
||||||
ipAddr, _ = netip.ParseAddr(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".in-addr.arpa")])
|
|
||||||
} else if strings.HasSuffix(request.Destination.Fqdn, ".ip6.arpa") {
|
|
||||||
ipAddr, _ = netip.ParseAddr(strings.ReplaceAll(request.Destination.Fqdn[:len(request.Destination.Fqdn)-len(".ip6.arpa")], ".", ":"))
|
|
||||||
}
|
|
||||||
if !ipAddr.IsValid() {
|
|
||||||
return E.New("socks5: torsocks: invalid destination")
|
|
||||||
}
|
|
||||||
host, err := resolver.LookupPTR(ctx, ipAddr)
|
|
||||||
if err != nil {
|
|
||||||
err = socks5.WriteResponse(conn, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeFailure,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return E.Cause(err, "socks5: torsocks: lookup PTR failed for ip: ", ipAddr)
|
|
||||||
}
|
|
||||||
err = socks5.WriteResponse(conn, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeSuccess,
|
|
||||||
Bind: M.Socksaddr{
|
|
||||||
Fqdn: host,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return E.Cause(err, "socks5: torsocks: write response")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,215 +0,0 @@
|
||||||
package socks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
"github.com/sagernet/sing/common/bufio"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
"github.com/sagernet/sing/protocol/socks/socks4"
|
|
||||||
"github.com/sagernet/sing/protocol/socks/socks5"
|
|
||||||
)
|
|
||||||
|
|
||||||
type LazyConn struct {
|
|
||||||
net.Conn
|
|
||||||
socksVersion byte
|
|
||||||
responseWritten bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLazyConn(conn net.Conn, socksVersion byte) *LazyConn {
|
|
||||||
return &LazyConn{
|
|
||||||
Conn: conn,
|
|
||||||
socksVersion: socksVersion,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
c.responseWritten = true
|
|
||||||
}()
|
|
||||||
switch c.socksVersion {
|
|
||||||
case socks4.Version:
|
|
||||||
return socks4.WriteResponse(c.Conn, socks4.Response{
|
|
||||||
ReplyCode: socks4.ReplyCodeGranted,
|
|
||||||
Destination: M.SocksaddrFromNet(conn.LocalAddr()),
|
|
||||||
})
|
|
||||||
case socks5.Version:
|
|
||||||
return socks5.WriteResponse(c.Conn, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeSuccess,
|
|
||||||
Bind: M.SocksaddrFromNet(conn.LocalAddr()),
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
panic("unknown socks version")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) HandshakeFailure(err error) error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
c.responseWritten = true
|
|
||||||
}()
|
|
||||||
switch c.socksVersion {
|
|
||||||
case socks4.Version:
|
|
||||||
return socks4.WriteResponse(c.Conn, socks4.Response{
|
|
||||||
ReplyCode: socks4.ReplyCodeRejectedOrFailed,
|
|
||||||
})
|
|
||||||
case socks5.Version:
|
|
||||||
return socks5.WriteResponse(c.Conn, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeForError(err),
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
panic("unknown socks version")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) Read(p []byte) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.ConnHandshakeSuccess(c.Conn)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.Conn.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) Write(p []byte) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.ConnHandshakeSuccess(c.Conn)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.Conn.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) ReaderReplaceable() bool {
|
|
||||||
return c.responseWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) WriterReplaceable() bool {
|
|
||||||
return c.responseWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyConn) Upstream() any {
|
|
||||||
return c.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
type LazyAssociatePacketConn struct {
|
|
||||||
AssociatePacketConn
|
|
||||||
responseWritten bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLazyAssociatePacketConn(conn net.Conn, underlying net.Conn) *LazyAssociatePacketConn {
|
|
||||||
return &LazyAssociatePacketConn{
|
|
||||||
AssociatePacketConn: AssociatePacketConn{
|
|
||||||
AbstractConn: conn,
|
|
||||||
conn: bufio.NewExtendedConn(conn),
|
|
||||||
underlying: underlying,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) HandshakeSuccess() error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
c.responseWritten = true
|
|
||||||
}()
|
|
||||||
return socks5.WriteResponse(c.underlying, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeSuccess,
|
|
||||||
Bind: M.SocksaddrFromNet(c.conn.LocalAddr()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
c.responseWritten = true
|
|
||||||
c.conn.Close()
|
|
||||||
c.underlying.Close()
|
|
||||||
}()
|
|
||||||
return socks5.WriteResponse(c.underlying, socks5.Response{
|
|
||||||
ReplyCode: socks5.ReplyCodeForError(err),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.HandshakeSuccess()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.AssociatePacketConn.ReadFrom(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.HandshakeSuccess()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.AssociatePacketConn.WriteTo(p, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.HandshakeSuccess()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.AssociatePacketConn.ReadPacket(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err := c.HandshakeSuccess()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.AssociatePacketConn.WritePacket(buffer, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) Read(p []byte) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.HandshakeSuccess()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.AssociatePacketConn.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) Write(p []byte) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = c.HandshakeSuccess()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.AssociatePacketConn.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) ReaderReplaceable() bool {
|
|
||||||
return c.responseWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) WriterReplaceable() bool {
|
|
||||||
return c.responseWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LazyAssociatePacketConn) Upstream() any {
|
|
||||||
return &c.AssociatePacketConn
|
|
||||||
}
|
|
|
@ -1,10 +1,8 @@
|
||||||
package socks5
|
package socks5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
@ -39,20 +37,6 @@ const (
|
||||||
ReplyCodeAddressTypeUnsupported byte = 8
|
ReplyCodeAddressTypeUnsupported byte = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReplyCodeForError(err error) byte {
|
|
||||||
if errors.Is(err, syscall.ENETUNREACH) {
|
|
||||||
return ReplyCodeNetworkUnreachable
|
|
||||||
} else if errors.Is(err, syscall.EHOSTUNREACH) {
|
|
||||||
return ReplyCodeHostUnreachable
|
|
||||||
} else if errors.Is(err, syscall.ECONNREFUSED) {
|
|
||||||
return ReplyCodeConnectionRefused
|
|
||||||
} else if errors.Is(err, syscall.EPERM) {
|
|
||||||
return ReplyCodeNotAllowed
|
|
||||||
} else {
|
|
||||||
return ReplyCodeFailure
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// +----+----------+----------+
|
// +----+----------+----------+
|
||||||
// |VER | NMETHODS | METHODS |
|
// |VER | NMETHODS | METHODS |
|
||||||
// +----+----------+----------+
|
// +----+----------+----------+
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue