From e2392d8d402352d1e71bdab609691af3b24a44a3 Mon Sep 17 00:00:00 2001 From: shij Date: Fri, 16 Jun 2023 16:35:05 +0800 Subject: [PATCH 001/141] Fix isConnect logic mistake Co-authored-by: jevin.shi --- common/uot/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/uot/server.go b/common/uot/server.go index 6457878..b620cce 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -60,7 +60,7 @@ func (c *ServerConn) loopInput() { for { var destination M.Socksaddr var err error - if !c.isConnect { + if c.isConnect { destination = c.destination } else { destination, err = AddrParser.ReadAddrPort(c.inputReader) From dc27334e9a7deb7ab350d3ddf6045e03fbf2819d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 19 Jun 2023 13:18:48 +0800 Subject: [PATCH 002/141] Unwrap 4in6 address in socks packet conn --- protocol/socks/packet.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index 8bfba10..08028b0 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -89,15 +89,18 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { return c.WriteTo(b, c.remoteAddr) } -func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { - destination, err := c.NetPacketConn.ReadPacket(buffer) +func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + destination, err = c.NetPacketConn.ReadPacket(buffer) if err != nil { return M.Socksaddr{}, err } c.remoteAddr = destination buffer.Advance(3) - dest, err := M.SocksaddrSerializer.ReadAddrPort(buffer) - return dest, err + destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return + } + return destination.Unwrap(), nil } func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { From d852e9c03d073caee30675bd521e2526b39355ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 22 Jun 2023 14:55:16 +0800 Subject: [PATCH 003/141] Fix build on go1.21 --- common/buf/ptr.go | 4 ++-- common/buf/ptr_go120.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 common/buf/ptr_go120.go diff --git a/common/buf/ptr.go b/common/buf/ptr.go index 901c9e3..60e8fcf 100644 --- a/common/buf/ptr.go +++ b/common/buf/ptr.go @@ -1,4 +1,4 @@ -//go:build !disable_unsafe +//go:build !disable_unsafe && go1.21 package buf @@ -23,7 +23,7 @@ func init() { if !common.UnsafeBuffer { return } - debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars)) + debugVars := *(*[]*dbgVar)(unsafe.Pointer(&dbgvars)) for _, v := range debugVars { if v.name == "invalidptr" { *v.value = 0 diff --git a/common/buf/ptr_go120.go b/common/buf/ptr_go120.go new file mode 100644 index 0000000..575c723 --- /dev/null +++ b/common/buf/ptr_go120.go @@ -0,0 +1,34 @@ +//go:build !disable_unsafe && !go1.21 + +package buf + +import ( + "unsafe" + + "github.com/sagernet/sing/common" +) + +type dbgVar struct { + name string + value *int32 +} + +//go:linkname dbgvars runtime.dbgvars +var dbgvars any + +// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined +// var dbgvars []dbgVar + +func init() { + if !common.UnsafeBuffer { + return + } + debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars)) + for _, v := range debugVars { + if v.name == "invalidptr" { + *v.value = 0 + return + } + } + panic("can't disable invalidptr") +} From c68251b6d0592efd8b0d9e947a218fb2761ee7d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 3 Jul 2023 08:21:04 +0800 Subject: [PATCH 004/141] Deprecate stack buffer --- common/buf/buffer.go | 34 ++------ common/buf/hex.go | 9 -- common/buf/pool.go | 43 +--------- common/buf/ptr.go | 34 -------- common/buf/ptr_go120.go | 34 -------- common/bufio/copy.go | 59 +------------ common/bufio/once.go | 127 ---------------------------- common/bufio/vectorised.go | 4 +- common/cond.go | 16 ++-- common/exceptions/cause.go | 3 - common/unsafe_default.go | 10 --- common/unsafe_disable.go | 5 -- common/uot/conn.go | 8 +- common/uot/server.go | 9 +- protocol/socks/packet.go | 4 +- protocol/socks/packet_vectorised.go | 4 +- protocol/socks/socks4/protocol.go | 8 +- protocol/socks/socks5/protocol.go | 16 +--- 18 files changed, 28 insertions(+), 399 deletions(-) delete mode 100644 common/buf/hex.go delete mode 100644 common/buf/ptr.go delete mode 100644 common/buf/ptr_go120.go delete mode 100644 common/bufio/once.go delete mode 100644 common/unsafe_default.go delete mode 100644 common/unsafe_disable.go diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 10fa709..c374147 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -54,41 +54,19 @@ func NewSize(size int) *Buffer { } } +// Deprecated: use New instead. func StackNew() *Buffer { - if common.UnsafeBuffer { - return &Buffer{ - data: make([]byte, BufferSize), - start: ReversedHeader, - end: ReversedHeader, - } - } else { - return New() - } + return New() } +// Deprecated: use NewPacket instead. func StackNewPacket() *Buffer { - if common.UnsafeBuffer { - return &Buffer{ - data: make([]byte, UDPBufferSize), - start: ReversedHeader, - end: ReversedHeader, - } - } else { - return NewPacket() - } + return NewPacket() } +// Deprecated: use NewSize instead. func StackNewSize(size int) *Buffer { - if size == 0 { - return &Buffer{} - } - if common.UnsafeBuffer { - return &Buffer{ - data: Make(size), - } - } else { - return NewSize(size) - } + return NewSize(size) } func As(data []byte) *Buffer { diff --git a/common/buf/hex.go b/common/buf/hex.go deleted file mode 100644 index ca54f67..0000000 --- a/common/buf/hex.go +++ /dev/null @@ -1,9 +0,0 @@ -package buf - -import "encoding/hex" - -func EncodeHexString(src []byte) string { - dst := Make(hex.EncodedLen(len(src))) - hex.Encode(dst, src) - return string(dst) -} diff --git a/common/buf/pool.go b/common/buf/pool.go index a729989..37f1232 100644 --- a/common/buf/pool.go +++ b/common/buf/pool.go @@ -11,46 +11,7 @@ func Put(buf []byte) error { return DefaultAllocator.Put(buf) } +// Deprecated: use array instead. func Make(size int) []byte { - if size == 0 { - return nil - } - var buffer []byte - switch { - case size <= 2: - buffer = make([]byte, 2) - case size <= 4: - buffer = make([]byte, 4) - case size <= 8: - buffer = make([]byte, 8) - case size <= 16: - buffer = make([]byte, 16) - case size <= 32: - buffer = make([]byte, 32) - case size <= 64: - buffer = make([]byte, 64) - case size <= 128: - buffer = make([]byte, 128) - case size <= 256: - buffer = make([]byte, 256) - case size <= 512: - buffer = make([]byte, 512) - case size <= 1024: - buffer = make([]byte, 1024) - case size <= 2048: - buffer = make([]byte, 2048) - case size <= 4096: - buffer = make([]byte, 4096) - case size <= 8192: - buffer = make([]byte, 8192) - case size <= 16384: - buffer = make([]byte, 16384) - case size <= 32768: - buffer = make([]byte, 32768) - case size <= 65535: - buffer = make([]byte, 65535) - default: - return make([]byte, size) - } - return buffer[:size] + return make([]byte, size) } diff --git a/common/buf/ptr.go b/common/buf/ptr.go deleted file mode 100644 index 60e8fcf..0000000 --- a/common/buf/ptr.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build !disable_unsafe && go1.21 - -package buf - -import ( - "unsafe" - - "github.com/sagernet/sing/common" -) - -type dbgVar struct { - name string - value *int32 -} - -//go:linkname dbgvars runtime.dbgvars -var dbgvars any - -// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined -// var dbgvars []dbgVar - -func init() { - if !common.UnsafeBuffer { - return - } - debugVars := *(*[]*dbgVar)(unsafe.Pointer(&dbgvars)) - for _, v := range debugVars { - if v.name == "invalidptr" { - *v.value = 0 - return - } - } - panic("can't disable invalidptr") -} diff --git a/common/buf/ptr_go120.go b/common/buf/ptr_go120.go deleted file mode 100644 index 575c723..0000000 --- a/common/buf/ptr_go120.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build !disable_unsafe && !go1.21 - -package buf - -import ( - "unsafe" - - "github.com/sagernet/sing/common" -) - -type dbgVar struct { - name string - value *int32 -} - -//go:linkname dbgvars runtime.dbgvars -var dbgvars any - -// go.info.runtime.dbgvars: relocation target go.info.[]github.com/sagernet/sing/common/buf.dbgVar not defined -// var dbgvars []dbgVar - -func init() { - if !common.UnsafeBuffer { - return - } - debugVars := *(*[]dbgVar)(unsafe.Pointer(&dbgvars)) - for _, v := range debugVars { - if v.name == "invalidptr" { - *v.value = 0 - return - } - } - panic("can't disable invalidptr") -} diff --git a/common/bufio/copy.go b/common/bufio/copy.go index c0ff6dd..ea279eb 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -71,20 +71,7 @@ func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, sou return } } - if !common.UnsafeBuffer || N.IsUnsafeWriter(destination) { - return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) - } - bufferSize := N.CalculateMTU(source, destination) - if bufferSize > 0 { - bufferSize += headroom - } else { - bufferSize = buf.BufferSize - } - _buffer := buf.StackNewSize(bufferSize) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters) + return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) } func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { @@ -291,49 +278,7 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, return } } - if N.IsUnsafeWriter(destinationConn) { - return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) - } - bufferSize := N.CalculateMTU(source, destinationConn) - if bufferSize > 0 { - bufferSize += headroom - } else { - bufferSize = buf.UDPBufferSize - } - _buffer := buf.StackNewSize(bufferSize) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - var destination M.Socksaddr - var notFirstTime bool - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) - for { - readBuffer.Resize(frontHeadroom, 0) - destination, err = source.ReadPacket(readBuffer) - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = destinationConn.WritePacket(buffer, destination) - if err != nil { - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } + return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) } func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { diff --git a/common/bufio/once.go b/common/bufio/once.go deleted file mode 100644 index 5bfd0aa..0000000 --- a/common/bufio/once.go +++ /dev/null @@ -1,127 +0,0 @@ -package bufio - -import ( - "io" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - N "github.com/sagernet/sing/common/network" -) - -func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) { - return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times) -} - -func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) { - frontHeadroom := N.CalculateFrontHeadroom(dst) - rearHeadroom := N.CalculateRearHeadroom(dst) - bufferSize := N.CalculateMTU(src, dst) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - dstUnsafe := N.IsUnsafeWriter(dst) - var buffer *buf.Buffer - if !dstUnsafe { - _buffer := buf.StackNewSize(bufferSize) - defer common.KeepAlive(_buffer) - buffer = common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - } - notFirstTime := true - for i := 0; i < times; i++ { - if dstUnsafe { - buffer = buf.NewSize(bufferSize) - } - readBufferRaw := buffer.Slice() - readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) - readBuffer.Resize(frontHeadroom, 0) - err = src.ReadBuffer(readBuffer) - if err != nil { - buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - dataLen := readBuffer.Len() - buffer.Resize(readBuffer.Start(), dataLen) - err = dst.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - notFirstTime = true - } - return -} - -type ReadFromWriter interface { - io.ReaderFrom - io.Writer -} - -func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) { - n, err = CopyTimes(readerFrom, reader, 1) - if err != nil { - return - } - var rn int64 - rn, err = readerFrom.ReadFrom(reader) - if err != nil { - return - } - n += rn - return -} - -func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) { - n, err = CopyTimes(readerFrom, reader, times) - if err != nil { - return - } - var rn int64 - rn, err = readerFrom.ReadFrom(reader) - if err != nil { - return - } - n += rn - return -} - -type WriteToReader interface { - io.WriterTo - io.Reader -} - -func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) { - n, err = CopyTimes(writer, writerTo, 1) - if err != nil { - return - } - var wn int64 - wn, err = writerTo.WriteTo(writer) - if err != nil { - return - } - n += wn - return -} - -func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) { - n, err = CopyTimes(writer, writerTo, times) - if err != nil { - return - } - var wn int64 - wn, err = writerTo.WriteTo(writer) - if err != nil { - return - } - n += wn - return -} diff --git a/common/bufio/vectorised.go b/common/bufio/vectorised.go index ef875fd..bc6c623 100644 --- a/common/bufio/vectorised.go +++ b/common/bufio/vectorised.go @@ -74,9 +74,7 @@ func (w *BufferedVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error if bufferLen > 65535 { bufferBytes = make([]byte, bufferLen) } else { - _buffer := buf.StackNewSize(bufferLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(bufferLen) defer buffer.Release() bufferBytes = buffer.FreeBytes() } diff --git a/common/cond.go b/common/cond.go index 24458a5..81843dd 100644 --- a/common/cond.go +++ b/common/cond.go @@ -159,20 +159,14 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int { //go:norace func Dup[T any](obj T) T { - if UnsafeBuffer { - pointer := uintptr(unsafe.Pointer(&obj)) - //nolint:staticcheck - //goland:noinspection GoVetUnsafePointer - return *(*T)(unsafe.Pointer(pointer)) - } else { - return obj - } + pointer := uintptr(unsafe.Pointer(&obj)) + //nolint:staticcheck + //goland:noinspection GoVetUnsafePointer + return *(*T)(unsafe.Pointer(pointer)) } func KeepAlive(obj any) { - if UnsafeBuffer { - runtime.KeepAlive(obj) - } + runtime.KeepAlive(obj) } func Uniq[T comparable](arr []T) []T { diff --git a/common/exceptions/cause.go b/common/exceptions/cause.go index 27211f2..fe7adf3 100644 --- a/common/exceptions/cause.go +++ b/common/exceptions/cause.go @@ -6,9 +6,6 @@ type causeError struct { } func (e *causeError) Error() string { - if e.cause == nil { - return e.message - } return e.message + ": " + e.cause.Error() } diff --git a/common/unsafe_default.go b/common/unsafe_default.go deleted file mode 100644 index 3f3952d..0000000 --- a/common/unsafe_default.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !unsafe_buffer && !disable_unsafe_buffer - -package common - -import "runtime" - -// net/*Conn in windows keeps the buffer pointer passed in during io operations, so we disable it by default. -// https://github.com/golang/go/blob/4068be56ce7721a3d75606ea986d11e9ca27077a/src/internal/poll/fd_windows.go#L876 - -const UnsafeBuffer = runtime.GOOS != "windows" diff --git a/common/unsafe_disable.go b/common/unsafe_disable.go deleted file mode 100644 index 5795b0f..0000000 --- a/common/unsafe_disable.go +++ /dev/null @@ -1,5 +0,0 @@ -//go:build disable_unsafe_buffer - -package common - -const UnsafeBuffer = false diff --git a/common/uot/conn.go b/common/uot/conn.go index 7bc3fb8..81382a4 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -75,9 +75,7 @@ func (c *Conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if c.writer == nil { bufferLen += len(p) } - _buffer := buf.StackNewSize(bufferLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(bufferLen) defer buffer.Release() if !c.isConnect { common.Must(AddrParser.WriteAddrPort(buffer, destination)) @@ -124,9 +122,7 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { if c.writer == nil { headerLen += buffer.Len() } - _header := buf.StackNewSize(headerLen) - defer common.KeepAlive(_header) - header := common.Dup(_header) + header := buf.NewSize(headerLen) defer header.Release() if !c.isConnect { common.Must(AddrParser.WriteAddrPort(header, destination)) diff --git a/common/uot/server.go b/common/uot/server.go index b620cce..78cfa6d 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -5,7 +5,6 @@ import ( "io" "net" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" ) @@ -53,9 +52,7 @@ func (c *ServerConn) loopInput() { c.isConnect = request.IsConnect c.destination = request.Destination } - _buffer := buf.StackNew() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewPacket() defer buffer.Release() for { var destination M.Socksaddr @@ -95,9 +92,7 @@ func (c *ServerConn) loopInput() { //warn:unsafe func (c *ServerConn) loopOutput() { - _buffer := buf.StackNew() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewPacket() defer buffer.Release() for { buffer.FullReset() diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index 08028b0..555ee79 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -64,9 +64,7 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro //warn:unsafe func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination := M.SocksaddrFromNet(addr) - _buffer := buf.StackNewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) defer buffer.Release() common.Must(buffer.WriteZeroN(3)) err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) diff --git a/protocol/socks/packet_vectorised.go b/protocol/socks/packet_vectorised.go index 78bf918..12684cc 100644 --- a/protocol/socks/packet_vectorised.go +++ b/protocol/socks/packet_vectorised.go @@ -40,9 +40,7 @@ func NewVectorisedAssociateConn(conn net.Conn, writer N.VectorisedWriter, remote } func (v *VectorisedAssociatePacketConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { - _header := buf.StackNewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination)) - defer common.KeepAlive(_header) - header := common.Dup(_header) + header := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination)) defer header.Release() common.Must(header.WriteZeroN(3)) common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) diff --git a/protocol/socks/socks4/protocol.go b/protocol/socks/socks4/protocol.go index f5c8f7e..8b3879d 100644 --- a/protocol/socks/socks4/protocol.go +++ b/protocol/socks/socks4/protocol.go @@ -85,9 +85,7 @@ func WriteRequest(writer io.Writer, request Request) error { requestLen += len(request.Username) } - _buffer := buf.StackNewSize(requestLen) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(requestLen) defer buffer.Release() common.Must( @@ -145,9 +143,7 @@ func ReadResponse(reader io.Reader) (response Response, err error) { } func WriteResponse(writer io.Writer, response Response) error { - _buffer := buf.StackNewSize(8) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(8) defer buffer.Release() common.Must( buffer.WriteByte(0), diff --git a/protocol/socks/socks5/protocol.go b/protocol/socks/socks5/protocol.go index 79ab96f..bce361b 100644 --- a/protocol/socks/socks5/protocol.go +++ b/protocol/socks/socks5/protocol.go @@ -48,9 +48,7 @@ type AuthRequest struct { } func WriteAuthRequest(writer io.Writer, request AuthRequest) error { - _buffer := buf.StackNewSize(len(request.Methods) + 2) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(len(request.Methods) + 2) defer buffer.Release() common.Must( buffer.WriteByte(Version), @@ -120,9 +118,7 @@ type UsernamePasswordAuthRequest struct { } func WriteUsernamePasswordAuthRequest(writer io.Writer, request UsernamePasswordAuthRequest) error { - _buffer := buf.StackNewSize(3 + len(request.Username) + len(request.Password)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(3 + len(request.Username) + len(request.Password)) defer buffer.Release() common.Must( buffer.WriteByte(1), @@ -191,9 +187,7 @@ type Request struct { } func WriteRequest(writer io.Writer, request Request) error { - _buffer := buf.StackNewSize(3 + M.SocksaddrSerializer.AddrPortLen(request.Destination)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(request.Destination)) defer buffer.Release() common.Must( buffer.WriteByte(Version), @@ -244,9 +238,7 @@ func WriteResponse(writer io.Writer, response Response) error { bind.Addr = netip.IPv4Unspecified() } - _buffer := buf.StackNewSize(3 + M.SocksaddrSerializer.AddrPortLen(bind)) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) + buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(bind)) defer buffer.Release() common.Must( buffer.WriteByte(Version), From f8874e3e1c7ff838527e06c2527ed8a9b4a11a00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 7 Jul 2023 13:39:24 +0800 Subject: [PATCH 005/141] Fix direct copy --- common/bufio/copy_direct_posix.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 4479988..d682558 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -223,7 +223,7 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf case *syscall.SockaddrInet4: w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) case *syscall.SockaddrInet6: - w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)) + w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap() } } return true From f494f694c70576b4762e0e6f0cbf0941e7123cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 7 Jul 2023 13:56:57 +0800 Subject: [PATCH 006/141] Fix multi error --- common/exceptions/multi.go | 1 + 1 file changed, 1 insertion(+) diff --git a/common/exceptions/multi.go b/common/exceptions/multi.go index a42f00c..a11a901 100644 --- a/common/exceptions/multi.go +++ b/common/exceptions/multi.go @@ -23,6 +23,7 @@ func (e *multiError) Unwrap() []error { func Errors(errors ...error) error { errors = common.FilterNotNil(errors) errors = ExpandAll(errors) + errors = common.FilterNotNil(errors) errors = common.UniqBy(errors, error.Error) switch len(errors) { case 0: From 37622ea16f5017f726c9c8aa756d34c71d0de555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 10 Jul 2023 12:18:18 +0800 Subject: [PATCH 007/141] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f7f8ac3..f1298ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea/ /vendor/ +.DS_Store From 8807070904b24b063f853da5248126d86f1ade0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 15 Jul 2023 14:41:05 +0800 Subject: [PATCH 008/141] abx: Accept independent attr --- common/abx/reader.go | 60 +++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/common/abx/reader.go b/common/abx/reader.go index 1629572..16f839c 100644 --- a/common/abx/reader.go +++ b/common/abx/reader.go @@ -18,7 +18,6 @@ var _ xml.TokenReader = (*Reader)(nil) type Reader struct { reader *bytes.Reader stringRefs []string - attrs []xml.Attr } func NewReader(content []byte) (xml.TokenReader, bool) { @@ -47,7 +46,7 @@ func (r *Reader) Token() (token xml.Token, err error) { return } var attrs []xml.Attr - attrs, err = r.pullAttributes() + attrs, err = r.readAttributes() if err != nil { return } @@ -93,35 +92,41 @@ func (r *Reader) Token() (token xml.Token, err error) { _, err = r.readUTF() return case ATTRIBUTE: - return nil, E.New("unexpected attribute") + _, err = r.readAttribute() + return } return nil, E.New("unknown token type ", tokenType, " with type ", eventType) } -func (r *Reader) pullAttributes() ([]xml.Attr, error) { - err := r.pullAttribute() - if err != nil { - return nil, err +func (r *Reader) readAttributes() ([]xml.Attr, error) { + var attrs []xml.Attr + for { + attr, err := r.readAttribute() + if err == io.EOF { + break + } + attrs = append(attrs, attr) } - attrs := r.attrs - r.attrs = nil return attrs, nil } -func (r *Reader) pullAttribute() error { +func (r *Reader) readAttribute() (xml.Attr, error) { event, err := r.reader.ReadByte() if err != nil { - return nil + return xml.Attr{}, nil } tokenType := event & 0x0f eventType := event & 0xf0 if tokenType != ATTRIBUTE { - return r.reader.UnreadByte() + err = r.reader.UnreadByte() + if err != nil { + return xml.Attr{}, nil + } + return xml.Attr{}, io.EOF } - var name string - name, err = r.readInternedUTF() + name, err := r.readInternedUTF() if err != nil { - return err + return xml.Attr{}, err } var value string switch eventType { @@ -134,74 +139,73 @@ func (r *Reader) pullAttribute() error { case TypeString: value, err = r.readUTF() if err != nil { - return err + return xml.Attr{}, err } case TypeStringInterned: value, err = r.readInternedUTF() if err != nil { - return err + return xml.Attr{}, err } case TypeBytesHex: var data []byte data, err = r.readBytes() if err != nil { - return err + return xml.Attr{}, err } value = hex.EncodeToString(data) case TypeBytesBase64: var data []byte data, err = r.readBytes() if err != nil { - return err + return xml.Attr{}, err } value = base64.StdEncoding.EncodeToString(data) case TypeInt: var data int32 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatInt(int64(data), 10) case TypeIntHex: var data int32 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = "0x" + strconv.FormatInt(int64(data), 16) case TypeLong: var data int64 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatInt(data, 10) case TypeLongHex: var data int64 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = "0x" + strconv.FormatInt(data, 16) case TypeFloat: var data float32 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatFloat(float64(data), 'g', -1, 32) case TypeDouble: var data float64 err = binary.Read(r.reader, binary.BigEndian, &data) if err != nil { - return err + return xml.Attr{}, err } value = strconv.FormatFloat(data, 'g', -1, 64) default: - return E.New("unexpected attribute type, ", eventType) + return xml.Attr{}, E.New("unexpected attribute type, ", eventType) } - r.attrs = append(r.attrs, xml.Attr{Name: xml.Name{Local: name}, Value: value}) - return r.pullAttribute() + return xml.Attr{Name: xml.Name{Local: name}, Value: value}, nil } func (r *Reader) readUnsignedShort() (uint16, error) { From 32f9f628a030c59e3f03104e7b3d6c422abc7347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 23 Jul 2023 12:43:02 +0800 Subject: [PATCH 009/141] Fix handshake conn interface --- common/bufio/copy.go | 81 ++++++++++++++++------------- common/bufio/copy_direct_posix.go | 23 ++++---- common/bufio/copy_direct_windows.go | 4 +- 3 files changed, 57 insertions(+), 51 deletions(-) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index ea279eb..80cc823 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -22,7 +22,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } else if destination == nil { return 0, E.New("nil writer") } - originDestination := destination + originSource := source var readCounters, writeCounters []N.CountFunc for { source, readCounters = N.UnwrapCountReader(source, readCounters) @@ -52,29 +52,29 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { } break } - return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) + return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) } -func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { safeSrc := N.IsSafeReader(source) headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) if safeSrc != nil { if headroom == 0 { - return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters) + return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters) } } readWaiter, isReadWaiter := CreateReadWaiter(source) if isReadWaiter { var handled bool - handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters) + handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) if handled { return } } - return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) + return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) } -func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { buffer.IncRef() defer buffer.DecRef() frontHeadroom := N.CalculateFrontHeadroom(destination) @@ -90,15 +90,15 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := readBuffer.Len() buffer.Resize(readBuffer.Start(), dataLen) err = destination.WriteBuffer(buffer) if err != nil { + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -112,7 +112,7 @@ func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWrite } } -func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { var notFirstTime bool for { var buffer *buf.Buffer @@ -122,15 +122,15 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := buffer.Len() err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -144,7 +144,7 @@ func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.Extend } } -func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) bufferSize := N.CalculateMTU(source, destination) @@ -166,9 +166,6 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := readBuffer.Len() @@ -176,6 +173,9 @@ func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWri err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -236,6 +236,7 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { var readCounters, writeCounters []N.CountFunc var cachedPackets []*N.PacketBuffer + originSource := source for { source, readCounters = N.UnwrapCountPacketReader(source, readCounters) destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters) @@ -249,8 +250,9 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, break } if cachedPackets != nil { - n, err = WritePacketWithPool(destinationConn, cachedPackets) + n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) if err != nil { + println("err in write cached packets") return } } @@ -261,36 +263,34 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, if safeSrc != nil { if headroom == 0 { var copyN int64 - copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters) + copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0) n += copyN return } } + var ( + handled bool + copeN int64 + ) readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { - var ( - handled bool - copeN int64 - ) - handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters) + handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) if handled { n += copeN return } } - return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) + copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) + n += copeN + return } -func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { var buffer *buf.Buffer var destination M.Socksaddr - var notFirstTime bool for { buffer, destination, err = source.ReadPacketThreadSafe() if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } dataLen := buffer.Len() @@ -300,6 +300,9 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -313,7 +316,7 @@ func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafe } } -func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { +func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) bufferSize := N.CalculateMTU(source, destinationConn) @@ -323,7 +326,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r bufferSize = buf.UDPBufferSize } var destination M.Socksaddr - var notFirstTime bool for { buffer := buf.NewSize(bufferSize) readBufferRaw := buffer.Slice() @@ -332,9 +334,6 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r destination, err = source.ReadPacket(readBuffer) if err != nil { buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } dataLen := readBuffer.Len() @@ -342,6 +341,9 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -355,9 +357,10 @@ func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, r } } -func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { +func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) + var notFirstTime bool for _, packetBuffer := range packetBuffers { buffer := buf.NewPacket() readBufferRaw := buffer.Slice() @@ -366,6 +369,7 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack _, err = readBuffer.Write(packetBuffer.Buffer.Bytes()) packetBuffer.Buffer.Release() if err != nil { + buffer.Release() continue } dataLen := readBuffer.Len() @@ -373,6 +377,9 @@ func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.Pack err = destinationConn.WritePacket(buffer, packetBuffer.Destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index d682558..63643a1 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -15,7 +15,7 @@ import ( N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { handled = true frontHeadroom := N.CalculateFrontHeadroom(destination) rearHeadroom := N.CalculateRearHeadroom(destination) @@ -45,9 +45,6 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, err = nil return } - if !notFirstTime { - err = N.HandshakeFailure(originDestination, err) - } return } dataLen := readBuffer.Len() @@ -55,6 +52,9 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, err = destination.WriteBuffer(buffer) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) @@ -68,7 +68,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, } } -func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { handled = true frontHeadroom := N.CalculateFrontHeadroom(destinationConn) rearHeadroom := N.CalculateRearHeadroom(destinationConn) @@ -79,10 +79,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW bufferSize = buf.UDPBufferSize } var ( - buffer *buf.Buffer - readBuffer *buf.Buffer - destination M.Socksaddr - notFirstTime bool + buffer *buf.Buffer + readBuffer *buf.Buffer + destination M.Socksaddr ) source.InitializeReadWaiter(func() *buf.Buffer { buffer = buf.NewSize(bufferSize) @@ -95,9 +94,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW for { destination, err = source.WaitReadPacket() if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(destinationConn, err) - } return } dataLen := readBuffer.Len() @@ -105,6 +101,9 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW err = destinationConn.WritePacket(buffer, destination) if err != nil { buffer.Release() + if !notFirstTime { + err = N.HandshakeFailure(originSource, err) + } return } n += int64(dataLen) diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 9c0743f..22a2de0 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -6,11 +6,11 @@ import ( N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { return } -func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { +func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { return } From 221477cf17707fad9dad191033810884353d34f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 31 Jul 2023 08:57:58 +0800 Subject: [PATCH 010/141] Fix http proxy server --- protocol/http/handshake.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 280b934..78b2479 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -28,6 +28,11 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read return E.Cause(err, "read http request") } + if hostStr := request.Header.Get("Host"); hostStr != "" { + request.Host = hostStr + request.URL.Host = hostStr + } + if authenticator != nil { var authOk bool authorization := request.Header.Get("Proxy-Authorization") From 26d3f3d91bdadb6600c50cdeed3d6c03629757c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 31 Jul 2023 09:01:40 +0800 Subject: [PATCH 011/141] Update dependencies --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index b6c5874..edf21aa 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.7.0 +require golang.org/x/sys v0.10.0 diff --git a/go.sum b/go.sum index 7c22a2b..55a3ff2 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 83ce0be4d4aafc78d7cfd70357104a6dcdc5456e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 1 Aug 2023 15:46:12 +0800 Subject: [PATCH 012/141] Remove debug log --- common/bufio/copy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 80cc823..b5cb412 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -252,7 +252,6 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, if cachedPackets != nil { n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) if err != nil { - println("err in write cached packets") return } } From c6a69b4912ee817c997e43fd45eb7018d99315f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 2 Aug 2023 18:59:22 +0800 Subject: [PATCH 013/141] Fix "Fix http proxy server" --- protocol/http/handshake.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 78b2479..455de7c 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -28,11 +28,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read return E.Cause(err, "read http request") } - if hostStr := request.Header.Get("Host"); hostStr != "" { - request.Host = hostStr - request.URL.Host = hostStr - } - if authenticator != nil { var authOk bool authorization := request.Header.Get("Proxy-Authorization") @@ -89,6 +84,12 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read removeHopByHopHeaders(request.Header) removeExtraHTTPHostPort(request) + if hostStr := request.Header.Get("Host"); hostStr != "" { + if hostStr != request.URL.Host { + request.Host = hostStr + } + } + if request.URL.Scheme == "" || request.URL.Host == "" { return responseWith(request, http.StatusBadRequest).Write(conn) } From a755de3bbd49cf23f65b969f3bc188b4f117fb23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 2 Aug 2023 19:39:18 +0800 Subject: [PATCH 014/141] Improve serializer --- common/metadata/family.go | 7 +++-- common/metadata/serializer.go | 55 +++++++++++++++++++++-------------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/common/metadata/family.go b/common/metadata/family.go index c6551c2..c2a5bed 100644 --- a/common/metadata/family.go +++ b/common/metadata/family.go @@ -3,7 +3,8 @@ package metadata type Family = byte const ( - AddressFamilyIPv4 Family = 0x01 - AddressFamilyIPv6 Family = 0x04 - AddressFamilyFqdn Family = 0x03 + AddressFamilyIPv4 Family = 0x01 + AddressFamilyIPv6 Family = 0x04 + AddressFamilyFqdn Family = 0x03 + AddressFamilyEmpty Family = 0xff ) diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index 62d31c1..9dec47d 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -50,14 +50,20 @@ func NewSerializer(options ...SerializerOption) *Serializer { func (s *Serializer) WriteAddress(buffer *buf.Buffer, addr Socksaddr) error { var af Family - if addr.IsIPv4() { + if !addr.IsValid() { + af = AddressFamilyFqdn + } else if addr.IsIPv4() { af = AddressFamilyIPv4 } else if addr.IsIPv6() { af = AddressFamilyIPv6 } else { af = AddressFamilyFqdn } - err := buffer.WriteByte(s.familyByteMap[af]) + afByte, loaded := s.familyByteMap[af] + if !loaded { + return E.New("unsupported address") + } + err := buffer.WriteByte(afByte) if err != nil { return err } @@ -70,7 +76,14 @@ func (s *Serializer) WriteAddress(buffer *buf.Buffer, addr Socksaddr) error { } func (s *Serializer) AddressLen(addr Socksaddr) int { - if addr.IsIPv4() { + if !addr.IsValid() { + _, supportEmpty := s.familyByteMap[AddressFamilyEmpty] + if supportEmpty { + return 1 + } else { + return 2 + } + } else if addr.IsIPv4() { return 5 } else if addr.IsIPv6() { return 17 @@ -129,26 +142,24 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) { return Socksaddr{}, E.Cause(err, "read fqdn") } return ParseSocksaddrHostPort(fqdn, 0), nil - default: - switch family { - case AddressFamilyIPv4: - var addr [4]byte - err = common.Error(reader.Read(addr[:])) - if err != nil { - return Socksaddr{}, E.Cause(err, "read ipv4 address") - } - return Socksaddr{Addr: netip.AddrFrom4(addr)}, nil - case AddressFamilyIPv6: - var addr [16]byte - err = common.Error(reader.Read(addr[:])) - if err != nil { - return Socksaddr{}, E.Cause(err, "read ipv6 address") - } - - return Socksaddr{Addr: netip.AddrFrom16(addr)}.Unwrap(), nil - default: - return Socksaddr{}, E.New("unknown address family: ", af) + case AddressFamilyIPv4: + var addr [4]byte + _, err = io.ReadFull(reader, addr[:]) + if err != nil { + return Socksaddr{}, E.Cause(err, "read ipv4 address") } + return Socksaddr{Addr: netip.AddrFrom4(addr)}, nil + case AddressFamilyIPv6: + var addr [16]byte + _, err = io.ReadFull(reader, addr[:]) + if err != nil { + return Socksaddr{}, E.Cause(err, "read ipv6 address") + } + return Socksaddr{Addr: netip.AddrFrom16(addr)}.Unwrap(), nil + case AddressFamilyEmpty: + return Socksaddr{}, nil + default: + return Socksaddr{}, E.New("unknown address family: ", af) } } From 4db0062caa0af2636330b2571458867e8381f746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 7 Aug 2023 16:02:48 +0800 Subject: [PATCH 015/141] Add pause manager --- service/pause/context.go | 22 +++++++++ service/pause/default.go | 101 +++++++++++++++++++++++++++++++++++++++ service/pause/manager.go | 12 +++++ 3 files changed, 135 insertions(+) create mode 100644 service/pause/context.go create mode 100644 service/pause/default.go create mode 100644 service/pause/manager.go diff --git a/service/pause/context.go b/service/pause/context.go new file mode 100644 index 0000000..9baeded --- /dev/null +++ b/service/pause/context.go @@ -0,0 +1,22 @@ +package pause + +import ( + "context" + + "github.com/sagernet/sing/service" +) + +func ManagerFromContext(ctx context.Context) Manager { + return service.FromContext[Manager](ctx) +} + +func ContextWithManager(ctx context.Context, manager Manager) context.Context { + return service.ContextWith[Manager](ctx, manager) +} + +func ContextWithDefaultManager(ctx context.Context) context.Context { + if service.FromContext[Manager](ctx) != nil { + return ctx + } + return service.ContextWith[Manager](ctx, NewDefaultManager(ctx)) +} diff --git a/service/pause/default.go b/service/pause/default.go new file mode 100644 index 0000000..beb7e0f --- /dev/null +++ b/service/pause/default.go @@ -0,0 +1,101 @@ +package pause + +import ( + "context" + "sync" +) + +type defaultManager struct { + ctx context.Context + access sync.Mutex + devicePause chan struct{} + networkPause chan struct{} +} + +func NewDefaultManager(ctx context.Context) Manager { + devicePauseChan := make(chan struct{}) + networkPauseChan := make(chan struct{}) + close(devicePauseChan) + close(networkPauseChan) + return &defaultManager{ + ctx: ctx, + devicePause: devicePauseChan, + networkPause: networkPauseChan, + } +} + +func (d *defaultManager) DevicePause() { + d.access.Lock() + defer d.access.Unlock() + select { + case <-d.devicePause: + d.devicePause = make(chan struct{}) + default: + } +} + +func (d *defaultManager) DeviceWake() { + d.access.Lock() + defer d.access.Unlock() + select { + case <-d.devicePause: + default: + close(d.devicePause) + } +} + +func (d *defaultManager) DevicePauseChan() <-chan struct{} { + return d.devicePause +} + +func (d *defaultManager) NetworkPause() { + d.access.Lock() + defer d.access.Unlock() + select { + case <-d.networkPause: + d.networkPause = make(chan struct{}) + default: + } +} + +func (d *defaultManager) NetworkWake() { + d.access.Lock() + defer d.access.Unlock() + select { + case <-d.networkPause: + default: + close(d.networkPause) + } +} + +func (d *defaultManager) NetworkPauseChan() <-chan struct{} { + return d.networkPause +} + +func (d *defaultManager) IsPaused() bool { + select { + case <-d.devicePause: + default: + return true + } + + select { + case <-d.networkPause: + default: + return true + } + + return false +} + +func (d *defaultManager) WaitActive() { + select { + case <-d.devicePause: + case <-d.ctx.Done(): + } + + select { + case <-d.networkPause: + case <-d.ctx.Done(): + } +} diff --git a/service/pause/manager.go b/service/pause/manager.go new file mode 100644 index 0000000..9f7df65 --- /dev/null +++ b/service/pause/manager.go @@ -0,0 +1,12 @@ +package pause + +type Manager interface { + DevicePause() + DeviceWake() + DevicePauseChan() <-chan struct{} + NetworkPause() + NetworkWake() + NetworkPauseChan() <-chan struct{} + IsPaused() bool + WaitActive() +} From 620f3a3b882d3c6590aa523fba2cf29ca112a88b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 20 Aug 2023 13:05:20 +0800 Subject: [PATCH 016/141] Fix serializer --- common/metadata/serializer.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/common/metadata/serializer.go b/common/metadata/serializer.go index 9dec47d..b858f5c 100644 --- a/common/metadata/serializer.go +++ b/common/metadata/serializer.go @@ -51,7 +51,7 @@ func NewSerializer(options ...SerializerOption) *Serializer { func (s *Serializer) WriteAddress(buffer *buf.Buffer, addr Socksaddr) error { var af Family if !addr.IsValid() { - af = AddressFamilyFqdn + af = AddressFamilyEmpty } else if addr.IsIPv4() { af = AddressFamilyIPv4 } else if addr.IsIPv6() { @@ -67,9 +67,10 @@ func (s *Serializer) WriteAddress(buffer *buf.Buffer, addr Socksaddr) error { if err != nil { return err } - if addr.Addr.IsValid() { + switch af { + case AddressFamilyIPv4, AddressFamilyIPv6: _, err = buffer.Write(addr.Addr.AsSlice()) - } else { + case AddressFamilyFqdn: err = WriteSocksString(buffer, addr.Fqdn) } return err @@ -77,12 +78,7 @@ func (s *Serializer) WriteAddress(buffer *buf.Buffer, addr Socksaddr) error { func (s *Serializer) AddressLen(addr Socksaddr) int { if !addr.IsValid() { - _, supportEmpty := s.familyByteMap[AddressFamilyEmpty] - if supportEmpty { - return 1 - } else { - return 2 - } + return 1 } else if addr.IsIPv4() { return 5 } else if addr.IsIPv6() { @@ -113,7 +109,7 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro } if s.portFirst { err = s.WriteAddress(buffer, destination) - } else { + } else if destination.IsValid() { err = s.WritePort(buffer, destination.Port) } if err != nil { @@ -126,7 +122,11 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro } func (s *Serializer) AddrPortLen(destination Socksaddr) int { - return s.AddressLen(destination) + 2 + if destination.IsValid() { + return s.AddressLen(destination) + 2 + } else { + return s.AddressLen(destination) + } } func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) { @@ -184,7 +184,7 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err } if s.portFirst { addr, err = s.ReadAddress(reader) - } else { + } else if addr.IsValid() { port, err = s.ReadPort(reader) } if err != nil { From 8d731e68853a5674cc1316f68ade719a595f5d2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 24 Aug 2023 19:58:37 +0800 Subject: [PATCH 017/141] Add LRUCache.Clear --- common/cache/lrucache.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 37fa26b..a8912e8 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -258,6 +258,14 @@ func (c *LruCache[K, V]) Delete(key K) { c.mu.Unlock() } +func (c *LruCache[K, V]) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + for element := c.lru.Front(); element != nil; element = element.Next() { + c.deleteElement(element) + } +} + func (c *LruCache[K, V]) maybeDeleteOldest() { if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix() From 30bf19f2833c4a73bab97a06a6cb0b2b81fe9896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 30 Aug 2023 21:26:30 +0800 Subject: [PATCH 018/141] Check nil buffer in CopyPacketWithSrcBuffer --- common/bufio/copy.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index b5cb412..926211e 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + "reflect" "syscall" "github.com/sagernet/sing/common" @@ -292,6 +293,9 @@ func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.Pack if err != nil { return } + if buffer == nil { + panic("nil buffer returned from " + reflect.TypeOf(source).String()) + } dataLen := buffer.Len() if dataLen == 0 { continue From 0eec7bbe1934bc13adb954e77c882ac0327d63b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 3 Sep 2023 19:41:21 +0800 Subject: [PATCH 019/141] Add must register func for service --- service/context.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/service/context.go b/service/context.go index 327d8ba..6817582 100644 --- a/service/context.go +++ b/service/context.go @@ -68,3 +68,19 @@ func ContextWithPtr[T any](ctx context.Context, servicePtr *T) context.Context { registry.Register(common.DefaultValue[*T](), servicePtr) return ctx } + +func MustRegister[T any](ctx context.Context, service T) { + registry := RegistryFromContext(ctx) + if registry == nil { + panic("missing service registry in context") + } + registry.Register(common.DefaultValue[*T](), service) +} + +func MustRegisterPtr[T any](ctx context.Context, servicePtr *T) { + registry := RegistryFromContext(ctx) + if registry == nil { + panic("missing service registry in context") + } + registry.Register(common.DefaultValue[*T](), servicePtr) +} From 03c21c0a1205daedc7c4f701b44f5eba8319ddca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Sep 2023 09:06:46 +0800 Subject: [PATCH 020/141] Fix WriteAddrPort usage --- common/uot/lazy.go | 14 +++++++++++--- common/uot/protocol.go | 19 ++++++++++++------- protocol/socks/packet.go | 5 ++++- protocol/socks/packet_vectorised.go | 5 ++++- protocol/socks/socks5/protocol.go | 10 ++++++++-- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/common/uot/lazy.go b/common/uot/lazy.go index 1b1b54a..8853564 100644 --- a/common/uot/lazy.go +++ b/common/uot/lazy.go @@ -30,7 +30,11 @@ func NewLazyConn(conn net.Conn, request Request) *Conn { func (c *LazyClientConn) Write(p []byte) (n int, err error) { if !c.requestWritten { - request := EncodeRequest(c.request) + var request *buf.Buffer + request, err = EncodeRequest(c.request) + if err != nil { + return + } err = c.writer.WriteVectorised([]*buf.Buffer{request, buf.As(p)}) if err != nil { return @@ -43,8 +47,12 @@ func (c *LazyClientConn) Write(p []byte) (n int, err error) { func (c *LazyClientConn) WriteVectorised(buffers []*buf.Buffer) error { if !c.requestWritten { - request := EncodeRequest(c.request) - err := c.writer.WriteVectorised(append([]*buf.Buffer{request}, buffers...)) + request, err := EncodeRequest(c.request) + if err != nil { + return err + } + + err = c.writer.WriteVectorised(append([]*buf.Buffer{request}, buffers...)) c.requestWritten = true return err } diff --git a/common/uot/protocol.go b/common/uot/protocol.go index 6ac2d84..4f968c7 100644 --- a/common/uot/protocol.go +++ b/common/uot/protocol.go @@ -51,20 +51,25 @@ func ReadRequest(reader io.Reader) (*Request, error) { return &request, nil } -func EncodeRequest(request Request) *buf.Buffer { +func EncodeRequest(request Request) (*buf.Buffer, error) { var bufferLen int bufferLen += 1 // isConnect bufferLen += M.SocksaddrSerializer.AddrPortLen(request.Destination) buffer := buf.NewSize(bufferLen) - common.Must( - binary.Write(buffer, binary.BigEndian, request.IsConnect), - M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination), - ) - return buffer + common.Must(binary.Write(buffer, binary.BigEndian, request.IsConnect)) + err := M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination) + if err != nil { + buffer.Release() + return nil, err + } + return buffer, nil } func WriteRequest(writer io.Writer, request Request) error { - buffer := EncodeRequest(request) + buffer, err := EncodeRequest(request) + if err != nil { + return err + } defer buffer.Release() return common.Error(writer.Write(buffer.Bytes())) } diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index 555ee79..4df672c 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -104,7 +104,10 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Sock func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination))) common.Must(header.WriteZeroN(3)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)) } diff --git a/protocol/socks/packet_vectorised.go b/protocol/socks/packet_vectorised.go index 12684cc..c73286b 100644 --- a/protocol/socks/packet_vectorised.go +++ b/protocol/socks/packet_vectorised.go @@ -43,7 +43,10 @@ func (v *VectorisedAssociatePacketConn) WriteVectorisedPacket(buffers []*buf.Buf header := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination)) defer header.Release() common.Must(header.WriteZeroN(3)) - common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } return v.VectorisedPacketWriter.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination) } diff --git a/protocol/socks/socks5/protocol.go b/protocol/socks/socks5/protocol.go index bce361b..67d9797 100644 --- a/protocol/socks/socks5/protocol.go +++ b/protocol/socks/socks5/protocol.go @@ -193,8 +193,11 @@ func WriteRequest(writer io.Writer, request Request) error { buffer.WriteByte(Version), buffer.WriteByte(request.Command), buffer.WriteZero(), - M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination), ) + err := M.SocksaddrSerializer.WriteAddrPort(buffer, request.Destination) + if err != nil { + return err + } return rw.WriteBytes(writer, buffer.Bytes()) } @@ -244,8 +247,11 @@ func WriteResponse(writer io.Writer, response Response) error { buffer.WriteByte(Version), buffer.WriteByte(response.ReplyCode), buffer.WriteZero(), - M.SocksaddrSerializer.WriteAddrPort(buffer, bind), ) + err := M.SocksaddrSerializer.WriteAddrPort(buffer, bind) + if err != nil { + return err + } return rw.WriteBytes(writer, buffer.Bytes()) } From b0849c43a60063f6229bcd477cdb081bb22c9aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Sep 2023 11:22:32 +0800 Subject: [PATCH 021/141] Add HandshakeSuccess interface --- common/bufio/copy.go | 12 ++++++------ common/bufio/copy_direct_posix.go | 4 ++-- common/network/handshake.go | 17 ++++++++++++++--- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 926211e..3bdb164 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -98,7 +98,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so err = destination.WriteBuffer(buffer) if err != nil { if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } @@ -130,7 +130,7 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } @@ -175,7 +175,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } @@ -304,7 +304,7 @@ func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.Pack if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } @@ -345,7 +345,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } @@ -381,7 +381,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index 63643a1..3501e66 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -53,7 +53,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } @@ -102,7 +102,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe if err != nil { buffer.Release() if !notFirstTime { - err = N.HandshakeFailure(originSource, err) + err = N.ReportHandshakeFailure(originSource, err) } return } diff --git a/common/network/handshake.go b/common/network/handshake.go index 1d7ede2..674211d 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -5,15 +5,26 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -type HandshakeConn interface { +type HandshakeFailure interface { HandshakeFailure(err error) error } -func HandshakeFailure(conn any, err error) error { - if handshakeConn, isHandshakeConn := common.Cast[HandshakeConn](conn); isHandshakeConn { +type HandshakeSuccess interface { + HandshakeSuccess() error +} + +func ReportHandshakeFailure(conn any, err error) error { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn { return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { return E.Cause(err, "write handshake failure") }) } return err } + +func ReportHandshakeSuccess(conn any) error { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn { + return handshakeConn.HandshakeSuccess() + } + return nil +} From 1453c7c8c20d0be7c59187a3a0155c773353fe2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 12 Sep 2023 13:08:51 +0800 Subject: [PATCH 022/141] Fix quic error wrapper --- common/baderror/baderror.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go index 74e37a2..952dac8 100644 --- a/common/baderror/baderror.go +++ b/common/baderror/baderror.go @@ -55,7 +55,7 @@ func WrapQUIC(err error) error { if err == nil { return nil } - if Contains(err, "canceled with error code 0") { + if Contains(err, "canceled by local with error code 0") { return net.ErrClosed } return err From b1cca65a05285854e7ec4f327ca90b04988124ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 20 Sep 2023 13:49:54 +0800 Subject: [PATCH 023/141] Add memory package --- common/memory/memory.go | 16 ++++++++++++++++ common/memory/memory_darwin.go | 18 ++++++++++++++++++ common/memory/memory_stub.go | 9 +++++++++ 3 files changed, 43 insertions(+) create mode 100644 common/memory/memory.go create mode 100644 common/memory/memory_darwin.go create mode 100644 common/memory/memory_stub.go diff --git a/common/memory/memory.go b/common/memory/memory.go new file mode 100644 index 0000000..63b8d35 --- /dev/null +++ b/common/memory/memory.go @@ -0,0 +1,16 @@ +package memory + +import "runtime" + +func Total() uint64 { + if nativeAvailable { + return usageNative() + } + return Inuse() +} + +func Inuse() uint64 { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + return memStats.StackInuse + memStats.HeapInuse + memStats.HeapIdle - memStats.HeapReleased +} diff --git a/common/memory/memory_darwin.go b/common/memory/memory_darwin.go new file mode 100644 index 0000000..4fe2e50 --- /dev/null +++ b/common/memory/memory_darwin.go @@ -0,0 +1,18 @@ +package memory + +// #include +import "C" +import "unsafe" + +const nativeAvailable = true + +func usageNative() uint64 { + var memoryUsageInByte uint64 + var vmInfo C.task_vm_info_data_t + var count C.mach_msg_type_number_t = C.TASK_VM_INFO_COUNT + var kernelReturn C.kern_return_t = C.task_info(C.vm_map_t(C.mach_task_self_), C.TASK_VM_INFO, (*C.integer_t)(unsafe.Pointer(&vmInfo)), &count) + if kernelReturn == C.KERN_SUCCESS { + memoryUsageInByte = uint64(vmInfo.phys_footprint) + } + return memoryUsageInByte +} diff --git a/common/memory/memory_stub.go b/common/memory/memory_stub.go new file mode 100644 index 0000000..3781f58 --- /dev/null +++ b/common/memory/memory_stub.go @@ -0,0 +1,9 @@ +//go:build (darwin && !cgo) || !darwin + +package memory + +const nativeAvailable = false + +func usageNative() uint64 { + return 0 +} From bc044ee31d3a5031f58a70f3b8b2f0d79f64c892 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 20 Sep 2023 14:03:51 +0800 Subject: [PATCH 024/141] [dependencies] Update actions/checkout action to v4 Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- .github/workflows/debug.yml | 2 +- .github/workflows/lint.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml index 93d7659..cac1570 100644 --- a/.github/workflows/debug.yml +++ b/.github/workflows/debug.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Get latest go version diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0a1c8f2..2cb57d9 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Get latest go version From 3c4a2b06a988c24d9e782cdb087bd24f1a2ba401 Mon Sep 17 00:00:00 2001 From: stT-e5gna2z5MBS <143945532+stT-e5gna2z5MBS@users.noreply.github.com> Date: Wed, 20 Sep 2023 02:05:54 -0400 Subject: [PATCH 025/141] BindToInterfaceFunc: allow block()/AutoDetectInterfaceFunc() to return error to avoid Tun traffic loopback, AutoDetectInterfaceFunc() should return error when no valid interface is found --- common/control/bind.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/common/control/bind.go b/common/control/bind.go index 94c621c..4a79185 100644 --- a/common/control/bind.go +++ b/common/control/bind.go @@ -15,9 +15,12 @@ func BindToInterface(finder InterfaceFinder, interfaceName string, interfaceInde } } -func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int)) Func { +func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, address string) (interfaceName string, interfaceIndex int, err error)) Func { return func(network, address string, conn syscall.RawConn) error { - interfaceName, interfaceIndex := block(network, address) + interfaceName, interfaceIndex, err := block(network, address) + if err != nil { + return err + } return BindToInterface0(finder, conn, network, address, interfaceName, interfaceIndex) } } @@ -25,10 +28,10 @@ func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, addr const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android" func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { - if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) { + if interfaceName == "" && interfaceIndex == -1 { return nil } - if interfaceName == "" && interfaceIndex == -1 { + if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) { return nil } if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName { From 57f342a8470a5151a213edda2becee54cb292f14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 20 Sep 2023 22:12:20 +0800 Subject: [PATCH 026/141] Update dependencies --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index edf21aa..2fbe8c9 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.10.0 +require golang.org/x/sys v0.12.0 diff --git a/go.sum b/go.sum index 55a3ff2..63a0250 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 494f88c9b8bffe79818d1c4d4fd9e09251a91489 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 22 Sep 2023 00:04:40 +0800 Subject: [PATCH 027/141] using io.ReadFull in uot's ReadFrom --- common/uot/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/uot/conn.go b/common/uot/conn.go index 81382a4..0eb4739 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -58,7 +58,7 @@ func (c *Conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { err = E.Cause(io.ErrShortBuffer, "UoT read") return } - n, err = c.Conn.Read(p[:length]) + n, err = io.ReadFull(c.Conn, p[:length]) if err == nil { addr = destination.UDPAddr() } From 5b05b5c147d9650e8accb4441e216c72a61f4859 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 25 Sep 2023 17:28:53 +0800 Subject: [PATCH 028/141] Fix socks5 handshake --- protocol/socks/handshake.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index c355697..cec1bf8 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -171,6 +171,9 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent if err != nil { return err } + if response.Status != socks5.UsernamePasswordStatusSuccess { + return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password) + } } request, err := socks5.ReadRequest(conn) if err != nil { From e781e86e32ce290bfab79e6fd49670a5865239d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 25 Sep 2023 20:43:42 +0800 Subject: [PATCH 029/141] Reject socks4 unauthenticated request --- protocol/socks/handshake.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index cec1bf8..39f4cb7 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -110,6 +110,16 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent } switch request.Command { case socks4.CommandConnect: + if authenticator != nil && !authenticator.Verify(request.Username, "") { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + Destination: request.Destination, + }) + if err != nil { + return err + } + return E.New("socks4: authentication failed, username=", request.Username) + } err = socks4.WriteResponse(conn, socks4.Response{ ReplyCode: socks4.ReplyCodeGranted, Destination: M.SocksaddrFromNet(conn.LocalAddr()), From 63b82af61fdd162cbff5e783901f9325ca752c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 27 Sep 2023 14:53:46 +0800 Subject: [PATCH 030/141] Panic on bad error usage --- common/exceptions/error.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/exceptions/error.go b/common/exceptions/error.go index cf5a3da..5d056e6 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -26,14 +26,14 @@ func New(message ...any) error { func Cause(cause error, message ...any) error { if cause == nil { - return nil + panic("cause on an nil error") } return &causeError{F.ToString(message...), cause} } func Extend(cause error, message ...any) error { if cause == nil { - return nil + panic("extend on an nil error") } return &extendedError{F.ToString(message...), cause} } From e727641a9831c330d33b5c187e5988dafaf7b38f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 30 Sep 2023 21:37:12 +0800 Subject: [PATCH 031/141] Fix concurrent access on task returnError --- common/task/task.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/common/task/task.go b/common/task/task.go index dfc8cde..5403481 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -87,16 +87,18 @@ func (g *Group) RunContextList(contextList []context.Context) error { } selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...)) - if selectedContext != 0 { - returnError = E.Append(returnError, upstreamErr, func(err error) error { - return E.Cause(err, "upstream") - }) - } if g.cleanup != nil { g.cleanup() } <-taskContext.Done() + + if selectedContext != 0 { + returnError = E.Append(returnError, upstreamErr, func(err error) error { + return E.Cause(err, "upstream") + }) + } + return returnError } From e0ec961fb1abbe66165945f4f2d900754a65ce60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 1 Oct 2023 14:36:31 +0800 Subject: [PATCH 032/141] Fix HTTP server leak --- protocol/http/handshake.go | 43 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 455de7c..3d4b024 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -21,7 +21,6 @@ import ( type Handler = N.TCPConnectionHandler func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error { - var httpClient *http.Client for { request, err := ReadRequest(reader) if err != nil { @@ -95,28 +94,26 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } var innerErr error - if httpClient == nil { - httpClient = &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(context context.Context, network, address string) (net.Conn, error) { - metadata.Destination = M.ParseSocksaddr(address) - metadata.Protocol = "http" - input, output := net.Pipe() - go func() { - hErr := handler.NewConnection(ctx, output, metadata) - if hErr != nil { - innerErr = hErr - common.Close(input, output) - } - }() - return input, nil - }, + httpClient := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(context context.Context, network, address string) (net.Conn, error) { + metadata.Destination = M.ParseSocksaddr(address) + metadata.Protocol = "http" + input, output := net.Pipe() + go func() { + hErr := handler.NewConnection(ctx, output, metadata) + if hErr != nil { + innerErr = hErr + common.Close(input, output) + } + }() + return input, nil }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } response, err := httpClient.Do(request) @@ -139,6 +136,8 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read return E.Errors(innerErr, err) } + httpClient.CloseIdleConnections() + if !keepAlive { return conn.Close() } From d16ad133622f75a5a5e22cd4b1504023980f852c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Oct 2023 17:01:01 +0800 Subject: [PATCH 033/141] Update dependencies --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2fbe8c9..e0ec5f9 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.12.0 +require golang.org/x/sys v0.13.0 diff --git a/go.sum b/go.sum index 63a0250..d4673ec 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 96a05f9afefe9f990452eb33f24b9e44f4970dfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 10 Oct 2023 15:06:28 +0800 Subject: [PATCH 034/141] Fix task cancel context --- common/task/task.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/task/task.go b/common/task/task.go index 5403481..b2bb7cf 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -88,6 +88,10 @@ func (g *Group) RunContextList(contextList []context.Context) error { selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...)) + if selectedContext == 0 { + taskCancel(upstreamErr) + } + if g.cleanup != nil { g.cleanup() } From 49f5dfd767e1a7f55bf4ccf11d56eb2752546379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 11 Oct 2023 12:04:11 +0800 Subject: [PATCH 035/141] Fix "Fix task cancel context" --- common/task/task.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/task/task.go b/common/task/task.go index b2bb7cf..d37d9e7 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -88,7 +88,7 @@ func (g *Group) RunContextList(contextList []context.Context) error { selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList...)) - if selectedContext == 0 { + if selectedContext != 0 { taskCancel(upstreamErr) } From 570295cd12f535b7a4c41f0e4f77fc704647afe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Oct 2023 12:00:00 +0800 Subject: [PATCH 036/141] Fix invalid address check in UoT conn --- common/uot/conn.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/common/uot/conn.go b/common/uot/conn.go index 0eb4739..fd7d898 100644 --- a/common/uot/conn.go +++ b/common/uot/conn.go @@ -78,7 +78,10 @@ func (c *Conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { buffer := buf.NewSize(bufferLen) defer buffer.Release() if !c.isConnect { - common.Must(AddrParser.WriteAddrPort(buffer, destination)) + err = AddrParser.WriteAddrPort(buffer, destination) + if err != nil { + return + } } common.Must(binary.Write(buffer, binary.BigEndian, uint16(len(p)))) if c.writer == nil { @@ -125,7 +128,10 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { header := buf.NewSize(headerLen) defer header.Release() if !c.isConnect { - common.Must(AddrParser.WriteAddrPort(header, destination)) + err := AddrParser.WriteAddrPort(header, destination) + if err != nil { + return err + } } common.Must(binary.Write(header, binary.BigEndian, uint16(buffer.Len()))) if c.writer == nil { From 27518fdf125489a43364d80bf3c758b59fe49522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Oct 2023 12:00:00 +0800 Subject: [PATCH 037/141] Improve linux bind interface --- common/control/bind.go | 33 +++--------------------------- common/control/bind_darwin.go | 16 +++++++++++---- common/control/bind_linux.go | 37 +++++++++++++++++++++++++++++++++- common/control/bind_other.go | 2 +- common/control/bind_windows.go | 15 ++++++++++++-- 5 files changed, 65 insertions(+), 38 deletions(-) diff --git a/common/control/bind.go b/common/control/bind.go index 4a79185..b8451db 100644 --- a/common/control/bind.go +++ b/common/control/bind.go @@ -1,10 +1,9 @@ package control import ( - "os" - "runtime" "syscall" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -25,38 +24,12 @@ func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, addr } } -const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android" - func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { if interfaceName == "" && interfaceIndex == -1 { - return nil + return E.New("interface not found: ", interfaceName) } if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) { return nil } - if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName { - return bindToInterface(conn, network, address, interfaceName, interfaceIndex) - } - if finder == nil { - return os.ErrInvalid - } - var err error - if useInterfaceName { - interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex) - } else { - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) - } - if err != nil { - return err - } - if useInterfaceName { - if interfaceName == "" { - return nil - } - } else { - if interfaceIndex == -1 { - return nil - } - } - return bindToInterface(conn, network, address, interfaceName, interfaceIndex) + return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex) } diff --git a/common/control/bind_darwin.go b/common/control/bind_darwin.go index 8262ac7..f5be42d 100644 --- a/common/control/bind_darwin.go +++ b/common/control/bind_darwin.go @@ -1,16 +1,24 @@ package control import ( + "os" "syscall" "golang.org/x/sys/unix" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { - if interfaceIndex == -1 { - return nil - } +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { + var err error + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } switch network { case "tcp6", "udp6": return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex) diff --git a/common/control/bind_linux.go b/common/control/bind_linux.go index 6ebca49..51529a0 100644 --- a/common/control/bind_linux.go +++ b/common/control/bind_linux.go @@ -1,13 +1,48 @@ package control import ( + "os" "syscall" + "github.com/sagernet/sing/common/atomic" + E "github.com/sagernet/sing/common/exceptions" + "golang.org/x/sys/unix" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +var ifIndexDisabled atomic.Bool + +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { + var err error + if !ifIndexDisabled.Load() { + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex) + if err == nil { + return nil + } else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) { + ifIndexDisabled.Store(true) + } else { + return err + } + } + if interfaceName == "" { + if finder == nil { + return os.ErrInvalid + } + interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex) + if err != nil { + return err + } + } return unix.BindToDevice(int(fd), interfaceName) }) } diff --git a/common/control/bind_other.go b/common/control/bind_other.go index 27d0497..539ef1c 100644 --- a/common/control/bind_other.go +++ b/common/control/bind_other.go @@ -4,6 +4,6 @@ package control import "syscall" -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return nil } diff --git a/common/control/bind_windows.go b/common/control/bind_windows.go index 5e23bf1..7029c80 100644 --- a/common/control/bind_windows.go +++ b/common/control/bind_windows.go @@ -2,17 +2,28 @@ package control import ( "encoding/binary" + "os" "syscall" "unsafe" M "github.com/sagernet/sing/common/metadata" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { + var err error + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } handle := syscall.Handle(fd) if M.ParseSocksaddr(address).AddrString() == "" { - err := bind4(handle, interfaceIndex) + err = bind4(handle, interfaceIndex) if err != nil { return err } From 8002db54c028ca0ee47d5ebf7a888f62e4a0f848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 21 Oct 2023 12:00:00 +0800 Subject: [PATCH 038/141] Add http.parseBasicAuth func stub --- protocol/http/link.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/protocol/http/link.go b/protocol/http/link.go index 19cb6cf..554c9c9 100644 --- a/protocol/http/link.go +++ b/protocol/http/link.go @@ -12,3 +12,6 @@ func ReadRequest(b *bufio.Reader) (req *http.Request, err error) //go:linkname URLSetPath net/url.(*URL).setPath func URLSetPath(u *url.URL, p string) error + +//go:linkname ParseBasicAuth net/http.parseBasicAuth +func ParseBasicAuth(auth string) (username, password string, ok bool) From 38cdffccc56de7f35594cc83e93e65ec18798a6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 5 Nov 2023 15:35:18 +0800 Subject: [PATCH 039/141] Update dependencies --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index e0ec5f9..849732b 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.13.0 +require golang.org/x/sys v0.14.0 diff --git a/go.sum b/go.sum index d4673ec..5867dc0 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= From 81c1436b69689dcce452802f05b6c41fe4a5d3d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 5 Nov 2023 15:28:49 +0800 Subject: [PATCH 040/141] Add remove for filemanager --- service/filemanager/default.go | 10 ++++++++++ service/filemanager/manager.go | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/service/filemanager/default.go b/service/filemanager/default.go index 96e7c4c..f7b984e 100644 --- a/service/filemanager/default.go +++ b/service/filemanager/default.go @@ -154,6 +154,16 @@ func (m *defaultManager) MkdirAll(path string, perm os.FileMode) error { return nil } +func (m *defaultManager) Remove(path string) error { + path = m.BasePath(path) + return os.Remove(path) +} + +func (m *defaultManager) RemoveAll(path string) error { + path = m.BasePath(path) + return os.RemoveAll(path) +} + func fixRootDirectory(p string) string { if len(p) == len(`\\?\c:`) { if os.IsPathSeparator(p[0]) && os.IsPathSeparator(p[1]) && p[2] == '?' && os.IsPathSeparator(p[3]) && p[5] == ':' { diff --git a/service/filemanager/manager.go b/service/filemanager/manager.go index f4af5b9..6367146 100644 --- a/service/filemanager/manager.go +++ b/service/filemanager/manager.go @@ -14,6 +14,8 @@ type Manager interface { CreateTemp(pattern string) (*os.File, error) Mkdir(path string, perm os.FileMode) error MkdirAll(path string, perm os.FileMode) error + Remove(path string) error + RemoveAll(path string) error } func BasePath(ctx context.Context, name string) string { @@ -64,6 +66,22 @@ func MkdirAll(ctx context.Context, path string, perm os.FileMode) error { return manager.MkdirAll(path, perm) } +func Remove(ctx context.Context, path string) error { + manager := service.FromContext[Manager](ctx) + if manager == nil { + return os.Remove(path) + } + return manager.Remove(path) +} + +func RemoveAll(ctx context.Context, path string) error { + manager := service.FromContext[Manager](ctx) + if manager == nil { + return os.RemoveAll(path) + } + return manager.RemoveAll(path) +} + func WriteFile(ctx context.Context, name string, data []byte, perm os.FileMode) error { manager := service.FromContext[Manager](ctx) if manager == nil { From d6fe25153cdefa9cd3f87f4bc7aa39af3764c4df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 8 Oct 2023 12:00:00 +0800 Subject: [PATCH 041/141] Add `NetConn() net.Conn` support for cast --- common/upstream.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/common/upstream.go b/common/upstream.go index 30070cc..0bf7a5b 100644 --- a/common/upstream.go +++ b/common/upstream.go @@ -1,9 +1,15 @@ package common +import "net" + type WithUpstream interface { Upstream() any } +type stdWithUpstreamNetConn interface { + NetConn() net.Conn +} + func Cast[T any](obj any) (T, bool) { if c, ok := obj.(T); ok { return c, true @@ -11,6 +17,9 @@ func Cast[T any](obj any) (T, bool) { if u, ok := obj.(WithUpstream); ok { return Cast[T](u.Upstream()) } + if u, ok := obj.(stdWithUpstreamNetConn); ok { + return Cast[T](u.NetConn()) + } return DefaultValue[T](), false } From 5b9d6eba3885289e2877072bc4d224e13e19add7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 17 Nov 2023 23:08:00 +0800 Subject: [PATCH 042/141] Add filemanager.chown --- service/filemanager/default.go | 7 +++++++ service/filemanager/manager.go | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/service/filemanager/default.go b/service/filemanager/default.go index f7b984e..b77c086 100644 --- a/service/filemanager/default.go +++ b/service/filemanager/default.go @@ -93,6 +93,13 @@ func (m *defaultManager) CreateTemp(pattern string) (*os.File, error) { return file, nil } +func (m *defaultManager) Chown(path string) error { + if m.chown { + return os.Chown(path, m.userID, m.groupID) + } + return nil +} + func (m *defaultManager) Mkdir(path string, perm os.FileMode) error { path = m.BasePath(path) err := os.Mkdir(path, perm) diff --git a/service/filemanager/manager.go b/service/filemanager/manager.go index 6367146..2de9a6c 100644 --- a/service/filemanager/manager.go +++ b/service/filemanager/manager.go @@ -12,6 +12,7 @@ type Manager interface { OpenFile(name string, flag int, perm os.FileMode) (*os.File, error) Create(name string) (*os.File, error) CreateTemp(pattern string) (*os.File, error) + Chown(name string) error Mkdir(path string, perm os.FileMode) error MkdirAll(path string, perm os.FileMode) error Remove(path string) error @@ -50,6 +51,14 @@ func CreateTemp(ctx context.Context, pattern string) (*os.File, error) { return manager.CreateTemp(pattern) } +func Chown(ctx context.Context, name string) error { + manager := service.FromContext[Manager](ctx) + if manager == nil { + return nil + } + return manager.Chown(name) +} + func Mkdir(ctx context.Context, path string, perm os.FileMode) error { manager := service.FromContext[Manager](ctx) if manager == nil { From e50e7ae2d3e4437f8351d30712476f072ae1a847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 24 Nov 2023 19:45:43 +0800 Subject: [PATCH 043/141] Fix "Fix HTTP server leak" --- protocol/http/handshake.go | 49 +++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 3d4b024..7d90e05 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -21,6 +21,7 @@ import ( type Handler = N.TCPConnectionHandler func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error { + var httpClient *http.Client for { request, err := ReadRequest(reader) if err != nil { @@ -94,30 +95,33 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } var innerErr error - httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(context context.Context, network, address string) (net.Conn, error) { - metadata.Destination = M.ParseSocksaddr(address) - metadata.Protocol = "http" - input, output := net.Pipe() - go func() { - hErr := handler.NewConnection(ctx, output, metadata) - if hErr != nil { - innerErr = hErr - common.Close(input, output) - } - }() - return input, nil + if httpClient == nil { + httpClient = &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(context context.Context, network, address string) (net.Conn, error) { + metadata.Destination = M.ParseSocksaddr(address) + metadata.Protocol = "http" + input, output := net.Pipe() + go func() { + hErr := handler.NewConnection(ctx, output, metadata) + if hErr != nil { + innerErr = hErr + common.Close(input, output) + } + }() + return input, nil + }, }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } } - - response, err := httpClient.Do(request) + requestCtx, cancel := context.WithCancel(ctx) + response, err := httpClient.Do(request.WithContext(requestCtx)) if err != nil { + cancel() return E.Errors(innerErr, err, responseWith(request, http.StatusBadGateway).Write(conn)) } @@ -133,10 +137,11 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read err = response.Write(conn) if err != nil { + cancel() return E.Errors(innerErr, err) } - httpClient.CloseIdleConnections() + cancel() if !keepAlive { return conn.Close() From 0d98e82146cb6cd93530a2c0834d8e41830fdbc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Oct 2023 20:59:48 +0800 Subject: [PATCH 044/141] Add unidirectional NATPacketConn --- common/bufio/nat.go | 63 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/common/bufio/nat.go b/common/bufio/nat.go index d652094..43e8d40 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -9,21 +9,62 @@ import ( N "github.com/sagernet/sing/common/network" ) -type NATPacketConn struct { +type NATPacketConn interface { N.NetPacketConn - origin M.Socksaddr - destination M.Socksaddr + UpdateDestination(destinationAddress netip.Addr) } -func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) *NATPacketConn { - return &NATPacketConn{ +func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &unidirectionalNATPacketConn{ NetPacketConn: conn, origin: origin, destination: destination, } } -func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn { + return &bidirectionalNATPacketConn{ + NetPacketConn: conn, + origin: origin, + destination: destination, + } +} + +type unidirectionalNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if M.SocksaddrFromNet(addr) == c.destination { + addr = c.origin.UDPAddr() + } + return c.NetPacketConn.WriteTo(p, addr) +} + +func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if destination == c.destination { + destination = c.origin + } + return c.NetPacketConn.WritePacket(buffer, destination) +} + +func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { + c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) +} + +func (c *unidirectionalNATPacketConn) Upstream() any { + return c.NetPacketConn +} + +type bidirectionalNATPacketConn struct { + N.NetPacketConn + origin M.Socksaddr + destination M.Socksaddr +} + +func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.NetPacketConn.ReadFrom(p) if err == nil && M.SocksaddrFromNet(addr) == c.origin { addr = c.destination.UDPAddr() @@ -31,14 +72,14 @@ func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return } -func (c *NATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { +func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if M.SocksaddrFromNet(addr) == c.destination { addr = c.origin.UDPAddr() } return c.NetPacketConn.WriteTo(p, addr) } -func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { +func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { destination, err = c.NetPacketConn.ReadPacket(buffer) if destination == c.origin { destination = c.destination @@ -46,17 +87,17 @@ func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, return } -func (c *NATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { +func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { if destination == c.destination { destination = c.origin } return c.NetPacketConn.WritePacket(buffer, destination) } -func (c *NATPacketConn) UpdateDestination(destinationAddress netip.Addr) { +func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } -func (c *NATPacketConn) Upstream() any { +func (c *bidirectionalNATPacketConn) Upstream() any { return c.NetPacketConn } From 7c05b33b2d479af3be00fb471c9b85649b0291d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 8 Nov 2023 11:29:54 +0800 Subject: [PATCH 045/141] Add common.Top func --- common/upstream.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/common/upstream.go b/common/upstream.go index 0bf7a5b..9af3604 100644 --- a/common/upstream.go +++ b/common/upstream.go @@ -31,3 +31,13 @@ func MustCast[T any](obj any) T { } return value } + +func Top(obj any) any { + if u, ok := obj.(WithUpstream); ok { + return Top(u.Upstream()) + } + if u, ok := obj.(stdWithUpstreamNetConn); ok { + return Top(u.NetConn()) + } + return obj +} From 2dcabf4bfcbcd3d0522d1d24d70e6e0a94d7099f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 19 Nov 2023 11:24:32 +0800 Subject: [PATCH 046/141] Add binary.NativeEndian wrapper --- common/native_endian_big.go | 12 ++++++++++++ common/native_endian_little.go | 12 ++++++++++++ common/native_endian_std.go | 7 +++++++ 3 files changed, 31 insertions(+) create mode 100644 common/native_endian_big.go create mode 100644 common/native_endian_little.go create mode 100644 common/native_endian_std.go diff --git a/common/native_endian_big.go b/common/native_endian_big.go new file mode 100644 index 0000000..f4030d6 --- /dev/null +++ b/common/native_endian_big.go @@ -0,0 +1,12 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.21 && (armbe || arm64be || m68k || mips || mips64 || mips64p32 || ppc || ppc64 || s390 || s390x || shbe || sparc || sparc64) + +package common + +import "encoding/binary" + +// NativeEndian is the native-endian implementation of ByteOrder and AppendByteOrder. +var NativeEndian = binary.BigEndian diff --git a/common/native_endian_little.go b/common/native_endian_little.go new file mode 100644 index 0000000..a9d0b9b --- /dev/null +++ b/common/native_endian_little.go @@ -0,0 +1,12 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.21 && (386 || amd64 || amd64p32 || alpha || arm || arm64 || loong64 || mipsle || mips64le || mips64p32le || nios2 || ppc64le || riscv || riscv64 || sh || wasm) + +package common + +import "encoding/binary" + +// NativeEndian is the native-endian implementation of ByteOrder and AppendByteOrder. +var NativeEndian = binary.LittleEndian diff --git a/common/native_endian_std.go b/common/native_endian_std.go new file mode 100644 index 0000000..071da51 --- /dev/null +++ b/common/native_endian_std.go @@ -0,0 +1,7 @@ +//go:build go1.21 + +package common + +import "encoding/binary" + +var NativeEndian = binary.NativeEndian From bca74039ead55337b4656ca11e98948cb871273f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 1 Dec 2023 12:21:23 +0800 Subject: [PATCH 047/141] Fix deadline reader --- common/bufio/deadline/conn.go | 8 +-- common/bufio/deadline/packet_conn.go | 8 +-- common/bufio/deadline/packet_reader.go | 31 +++----- .../bufio/deadline/packet_reader_fallback.go | 41 +++++++---- common/bufio/deadline/reader.go | 30 +++----- common/bufio/deadline/reader_fallback.go | 41 +++++++---- common/bufio/deadline/serial.go | 71 +++++++++++++++++++ 7 files changed, 153 insertions(+), 77 deletions(-) create mode 100644 common/bufio/deadline/serial.go diff --git a/common/bufio/deadline/conn.go b/common/bufio/deadline/conn.go index 7ad1a9e..484d297 100644 --- a/common/bufio/deadline/conn.go +++ b/common/bufio/deadline/conn.go @@ -14,18 +14,18 @@ type Conn struct { reader Reader } -func NewConn(conn net.Conn) *Conn { +func NewConn(conn net.Conn) N.ExtendedConn { if deadlineConn, isDeadline := conn.(*Conn); isDeadline { return deadlineConn } - return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)} + return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}) } -func NewFallbackConn(conn net.Conn) *Conn { +func NewFallbackConn(conn net.Conn) N.ExtendedConn { if deadlineConn, isDeadline := conn.(*Conn); isDeadline { return deadlineConn } - return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)} + return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)}) } func (c *Conn) Read(p []byte) (n int, err error) { diff --git a/common/bufio/deadline/packet_conn.go b/common/bufio/deadline/packet_conn.go index 7c92845..a0e9808 100644 --- a/common/bufio/deadline/packet_conn.go +++ b/common/bufio/deadline/packet_conn.go @@ -14,18 +14,18 @@ type PacketConn struct { reader PacketReader } -func NewPacketConn(conn N.NetPacketConn) *PacketConn { +func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn { if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline { return deadlineConn } - return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)} + return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}) } -func NewFallbackPacketConn(conn N.NetPacketConn) *PacketConn { +func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn { if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline { return deadlineConn } - return &PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)} + return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)}) } func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { diff --git a/common/bufio/deadline/packet_reader.go b/common/bufio/deadline/packet_reader.go index 36b4e87..088a811 100644 --- a/common/bufio/deadline/packet_reader.go +++ b/common/bufio/deadline/packet_reader.go @@ -52,14 +52,13 @@ func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) { default: } select { + case result := <-r.result: + return r.pipeReturnFrom(result, p) + case <-r.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded case <-r.done: go r.pipeReadFrom(len(p)) - default: } - return r.readFrom(p) -} - -func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) { select { case result := <-r.result: return r.pipeReturnFrom(result, p) @@ -106,14 +105,13 @@ func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, default: } select { + case result := <-r.result: + return r.pipeReturnFromBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded case <-r.done: - go r.pipeReadFromBuffer(buffer.FreeLen()) - default: + go r.pipeReadFrom(buffer.FreeLen()) } - return r.readPacket(buffer) -} - -func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { select { case result := <-r.result: return r.pipeReturnFromBuffer(result, buffer) @@ -134,17 +132,6 @@ func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *bu } } -func (r *packetReader) pipeReadFromBuffer(pLen int) { - buffer := buf.NewSize(pLen) - destination, err := r.TimeoutPacketReader.ReadPacket(buffer) - r.result <- &packetReadResult{ - buffer: buffer, - destination: destination, - err: err, - } - r.done <- struct{}{} -} - func (r *packetReader) SetReadDeadline(t time.Time) error { r.deadline.Store(t) r.pipeDeadline.set(t) diff --git a/common/bufio/deadline/packet_reader_fallback.go b/common/bufio/deadline/packet_reader_fallback.go index 276b784..c20f568 100644 --- a/common/bufio/deadline/packet_reader_fallback.go +++ b/common/bufio/deadline/packet_reader_fallback.go @@ -2,6 +2,7 @@ package deadline import ( "net" + "os" "time" "github.com/sagernet/sing/common/atomic" @@ -25,12 +26,15 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err return r.pipeReturnFrom(result, p) default: } - if r.disablePipe.Load() { - return r.TimeoutPacketReader.ReadFrom(p) - } select { + case result := <-r.result: + return r.pipeReturnFrom(result, p) + case <-r.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.TimeoutPacketReader.ReadFrom(p) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) @@ -38,9 +42,13 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err return } go r.pipeReadFrom(len(p)) - default: } - return r.readFrom(p) + select { + case result := <-r.result: + return r.pipeReturnFrom(result, p) + case <-r.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded + } } func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { @@ -49,22 +57,29 @@ func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Soc return r.pipeReturnFromBuffer(result, buffer) default: } - if r.disablePipe.Load() { - return r.TimeoutPacketReader.ReadPacket(buffer) - } select { + case result := <-r.result: + return r.pipeReturnFromBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.TimeoutPacketReader.ReadPacket(buffer) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) destination, err = r.TimeoutPacketReader.ReadPacket(buffer) return } - go r.pipeReadFromBuffer(buffer.FreeLen()) - default: + go r.pipeReadFrom(buffer.FreeLen()) + } + select { + case result := <-r.result: + return r.pipeReturnFromBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded } - return r.readPacket(buffer) } func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error { diff --git a/common/bufio/deadline/reader.go b/common/bufio/deadline/reader.go index b6d3c7d..a7a6252 100644 --- a/common/bufio/deadline/reader.go +++ b/common/bufio/deadline/reader.go @@ -54,14 +54,13 @@ func (r *reader) Read(p []byte) (n int, err error) { default: } select { + case result := <-r.result: + return r.pipeReturn(result, p) + case <-r.pipeDeadline.wait(): + return 0, os.ErrDeadlineExceeded case <-r.done: go r.pipeRead(len(p)) - default: } - return r.read(p) -} - -func (r *reader) read(p []byte) (n int, err error) { select { case result := <-r.result: return r.pipeReturn(result, p) @@ -99,14 +98,13 @@ func (r *reader) ReadBuffer(buffer *buf.Buffer) error { default: } select { + case result := <-r.result: + return r.pipeReturnBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return os.ErrDeadlineExceeded case <-r.done: - go r.pipeReadBuffer(buffer.FreeLen()) - default: + go r.pipeRead(buffer.FreeLen()) } - return r.readBuffer(buffer) -} - -func (r *reader) readBuffer(buffer *buf.Buffer) error { select { case result := <-r.result: return r.pipeReturnBuffer(result, buffer) @@ -127,16 +125,6 @@ func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error } } -func (r *reader) pipeReadBuffer(pLen int) { - cacheBuffer := buf.NewSize(pLen) - err := r.ExtendedReader.ReadBuffer(cacheBuffer) - r.result <- &readResult{ - buffer: cacheBuffer, - err: err, - } - r.done <- struct{}{} -} - func (r *reader) SetReadDeadline(t time.Time) error { r.deadline.Store(t) r.pipeDeadline.set(t) diff --git a/common/bufio/deadline/reader_fallback.go b/common/bufio/deadline/reader_fallback.go index 182ab40..a28b315 100644 --- a/common/bufio/deadline/reader_fallback.go +++ b/common/bufio/deadline/reader_fallback.go @@ -1,6 +1,7 @@ package deadline import ( + "os" "time" "github.com/sagernet/sing/common/atomic" @@ -23,12 +24,15 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) { return r.pipeReturn(result, p) default: } - if r.disablePipe.Load() { - return r.ExtendedReader.Read(p) - } select { + case result := <-r.result: + return r.pipeReturn(result, p) + case <-r.pipeDeadline.wait(): + return 0, os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.ExtendedReader.Read(p) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) @@ -36,9 +40,13 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) { return } go r.pipeRead(len(p)) - default: } - return r.reader.read(p) + select { + case result := <-r.result: + return r.pipeReturn(result, p) + case <-r.pipeDeadline.wait(): + return 0, os.ErrDeadlineExceeded + } } func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error { @@ -47,21 +55,28 @@ func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error { return r.pipeReturnBuffer(result, buffer) default: } - if r.disablePipe.Load() { - return r.ExtendedReader.ReadBuffer(buffer) - } select { + case result := <-r.result: + return r.pipeReturnBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return os.ErrDeadlineExceeded case <-r.done: - if r.deadline.Load().IsZero() { + if r.disablePipe.Load() { + return r.ExtendedReader.ReadBuffer(buffer) + } else if r.deadline.Load().IsZero() { r.done <- struct{}{} r.inRead.Store(true) defer r.inRead.Store(false) return r.ExtendedReader.ReadBuffer(buffer) } - go r.pipeReadBuffer(buffer.FreeLen()) - default: + go r.pipeRead(buffer.FreeLen()) + } + select { + case result := <-r.result: + return r.pipeReturnBuffer(result, buffer) + case <-r.pipeDeadline.wait(): + return os.ErrDeadlineExceeded } - return r.readBuffer(buffer) } func (r *fallbackReader) SetReadDeadline(t time.Time) error { diff --git a/common/bufio/deadline/serial.go b/common/bufio/deadline/serial.go new file mode 100644 index 0000000..64b3250 --- /dev/null +++ b/common/bufio/deadline/serial.go @@ -0,0 +1,71 @@ +package deadline + +import ( + "net" + "sync" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/debug" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type SerialConn struct { + N.ExtendedConn + access sync.Mutex +} + +func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn { + if !debug.Enabled { + return conn + } + return &SerialConn{ExtendedConn: conn} +} + +func (c *SerialConn) Read(p []byte) (n int, err error) { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.ExtendedConn.Read(p) +} + +func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.ExtendedConn.ReadBuffer(buffer) +} + +type SerialPacketConn struct { + N.NetPacketConn + access sync.Mutex +} + +func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn { + if !debug.Enabled { + return conn + } + return &SerialPacketConn{NetPacketConn: conn} +} + +func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.NetPacketConn.ReadFrom(p) +} + +func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + if !c.access.TryLock() { + panic("concurrent read on deadline conn") + } + defer c.access.Unlock() + return c.NetPacketConn.ReadPacket(buffer) +} + +func (c *SerialPacketConn) Upstream() any { + return c.NetPacketConn +} From 0ba5576c7be82790bbfb6fac24afaaebb89b6a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 3 Dec 2023 16:52:53 +0800 Subject: [PATCH 048/141] Fix not set Host header for HTTP proxy client --- protocol/http/client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/protocol/http/client.go b/protocol/http/client.go index 25351fd..013da4a 100644 --- a/protocol/http/client.go +++ b/protocol/http/client.go @@ -81,6 +81,10 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. } } for key, valueList := range c.headers { + if key == "Host" { + request.Host = valueList[0] + continue + } request.Header.Set(key, valueList[0]) for _, value := range valueList[1:] { request.Header.Add(key, value) From 1ee2a5bd0ef9b79b255ffc1f6355d1425e115028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 18:28:06 +0800 Subject: [PATCH 049/141] Update dependencies --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 849732b..25deb68 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/sagernet/sing go 1.18 -require golang.org/x/sys v0.14.0 +require golang.org/x/sys v0.15.0 diff --git a/go.sum b/go.sum index 5867dc0..063d2d3 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= From 6b69046063f3629f2d20c5795a3e5ff3cecc44df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 18:18:03 +0800 Subject: [PATCH 050/141] Fix fallback packet conn --- common/bufio/fallback.go | 60 ++++++++++++++++++++++++++++++++++++---- common/bufio/io.go | 16 ++--------- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/common/bufio/fallback.go b/common/bufio/fallback.go index 4ea87cf..bd4ab46 100644 --- a/common/bufio/fallback.go +++ b/common/bufio/fallback.go @@ -3,6 +3,7 @@ package bufio import ( "net" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -12,13 +13,17 @@ var _ N.NetPacketConn = (*FallbackPacketConn)(nil) type FallbackPacketConn struct { N.PacketConn + writer N.NetPacketWriter } func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn { if packetConn, loaded := conn.(N.NetPacketConn); loaded { return packetConn } - return &FallbackPacketConn{PacketConn: conn} + return &FallbackPacketConn{ + PacketConn: conn, + writer: NewNetPacketWriter(conn), + } } func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { @@ -36,11 +41,7 @@ func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error } func (c *FallbackPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) - if err == nil { - n = len(p) - } - return + return c.writer.WriteTo(p, addr) } func (c *FallbackPacketConn) ReaderReplaceable() bool { @@ -54,3 +55,50 @@ func (c *FallbackPacketConn) WriterReplaceable() bool { func (c *FallbackPacketConn) Upstream() any { return c.PacketConn } + +func (c *FallbackPacketConn) UpstreamWriter() any { + return c.writer +} + +var _ N.NetPacketWriter = (*FallbackPacketWriter)(nil) + +type FallbackPacketWriter struct { + N.PacketWriter + frontHeadroom int + rearHeadroom int +} + +func NewNetPacketWriter(writer N.PacketWriter) N.NetPacketWriter { + if packetWriter, loaded := writer.(N.NetPacketWriter); loaded { + return packetWriter + } + return &FallbackPacketWriter{ + PacketWriter: writer, + frontHeadroom: N.CalculateFrontHeadroom(writer), + rearHeadroom: N.CalculateRearHeadroom(writer), + } +} + +func (c *FallbackPacketWriter) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.frontHeadroom > 0 || c.rearHeadroom > 0 { + buffer := buf.NewSize(len(p) + c.frontHeadroom + c.rearHeadroom) + buffer.Resize(c.frontHeadroom, 0) + common.Must1(buffer.Write(p)) + err = c.PacketWriter.WritePacket(buffer, M.SocksaddrFromNet(addr)) + } else { + err = c.PacketWriter.WritePacket(buf.As(p), M.SocksaddrFromNet(addr)) + } + if err != nil { + return + } + n = len(p) + return +} + +func (c *FallbackPacketWriter) WriterReplaceable() bool { + return true +} + +func (c *FallbackPacketWriter) Upstream() any { + return c.PacketWriter +} diff --git a/common/bufio/io.go b/common/bufio/io.go index 1e5d89b..a25a7cc 100644 --- a/common/bufio/io.go +++ b/common/bufio/io.go @@ -37,13 +37,7 @@ func WriteBuffer(writer N.ExtendedWriter, buffer *buf.Buffer) (n int, err error) frontHeadroom := N.CalculateFrontHeadroom(writer) rearHeadroom := N.CalculateRearHeadroom(writer) if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() { - bufferSize := N.CalculateMTU(nil, writer) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - newBuffer := buf.NewSize(bufferSize) + newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom) newBuffer.Resize(frontHeadroom, 0) common.Must1(newBuffer.Write(buffer.Bytes())) buffer.Release() @@ -69,13 +63,7 @@ func WritePacketBuffer(writer N.PacketWriter, buffer *buf.Buffer, destination M. frontHeadroom := N.CalculateFrontHeadroom(writer) rearHeadroom := N.CalculateRearHeadroom(writer) if frontHeadroom > buffer.Start() || rearHeadroom > buffer.FreeLen() { - bufferSize := N.CalculateMTU(nil, writer) - if bufferSize > 0 { - bufferSize += frontHeadroom + rearHeadroom - } else { - bufferSize = buf.BufferSize - } - newBuffer := buf.NewSize(bufferSize) + newBuffer := buf.NewSize(buffer.Len() + frontHeadroom + rearHeadroom) newBuffer.Resize(frontHeadroom, 0) common.Must1(newBuffer.Write(buffer.Bytes())) buffer.Release() From 01c915e1e4389180c81c42fca04b2901dbb74a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Dec 2023 11:01:19 +0800 Subject: [PATCH 051/141] Fix "Fix not set Host header for HTTP proxy client" --- protocol/http/client.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/protocol/http/client.go b/protocol/http/client.go index 013da4a..269a288 100644 --- a/protocol/http/client.go +++ b/protocol/http/client.go @@ -67,13 +67,24 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. } request := &http.Request{ Method: http.MethodConnect, - URL: &url.URL{ - Host: destination.String(), - }, Header: http.Header{ "Proxy-Connection": []string{"Keep-Alive"}, }, } + var host string + if c.headers != nil { + host = c.headers.Get("Host") + c.headers.Del("Host") + } + if host != "" && host != destination.Fqdn { + if c.path != "" { + return nil, E.New("Host header and path are not allowed at the same time") + } + request.Host = host + request.URL = &url.URL{Opaque: destination.String()} + } else { + request.URL = &url.URL{Host: destination.String()} + } if c.path != "" { err = URLSetPath(request.URL, c.path) if err != nil { @@ -81,10 +92,6 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. } } for key, valueList := range c.headers { - if key == "Host" { - request.Host = valueList[0] - continue - } request.Header.Set(key, valueList[0]) for _, value := range valueList[1:] { request.Header.Add(key, value) From 0d701cfff0dfea4756d12107a9ac2d31a8a9e182 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Dec 2023 14:51:14 +0800 Subject: [PATCH 052/141] Fix buffer WriteZeroN --- common/buf/buffer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/buf/buffer.go b/common/buf/buffer.go index c374147..fe8e0cb 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -230,7 +230,7 @@ func (b *Buffer) WriteZeroN(n int) error { if b.end+n > b.Cap() { return io.ErrShortBuffer } - for i := b.end; i <= b.end+n; i++ { + for i := b.end; i < b.end+n; i++ { b.data[i] = 0 } b.end += n From 5a3d0edd1cbe7b955933b273473bd00ec9311213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Dec 2023 14:54:42 +0800 Subject: [PATCH 053/141] Update quic bad error list --- common/baderror/baderror.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/common/baderror/baderror.go b/common/baderror/baderror.go index 952dac8..c5ab530 100644 --- a/common/baderror/baderror.go +++ b/common/baderror/baderror.go @@ -55,7 +55,10 @@ func WrapQUIC(err error) error { if err == nil { return nil } - if Contains(err, "canceled by local with error code 0") { + if Contains(err, + "canceled by remote with error code 0", + "canceled by local with error code 0", + ) { return net.ErrClosed } return err From 544863e3f490ba37426808523740160301776762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 21 Dec 2023 10:29:16 +0800 Subject: [PATCH 054/141] Try to fix HTTP server leak again --- protocol/http/handshake.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 7d90e05..d179ff1 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -99,7 +99,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read httpClient = &http.Client{ Transport: &http.Transport{ DisableCompression: true, - DialContext: func(context context.Context, network, address string) (net.Conn, error) { + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { metadata.Destination = M.ParseSocksaddr(address) metadata.Protocol = "http" input, output := net.Pipe() From 349d7d31b3c5c1858151092853ce964ccd5127f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 24 Dec 2023 08:01:43 +0800 Subject: [PATCH 055/141] Fix calculate host for HTTP connect client --- protocol/http/client.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/protocol/http/client.go b/protocol/http/client.go index 269a288..727d4d9 100644 --- a/protocol/http/client.go +++ b/protocol/http/client.go @@ -23,6 +23,7 @@ type Client struct { serverAddr M.Socksaddr username string password string + host string path string headers http.Header } @@ -48,6 +49,12 @@ func NewClient(options Options) *Client { if options.Dialer == nil { client.dialer = N.SystemDialer } + var host string + if client.headers != nil { + host = client.headers.Get("Host") + client.headers.Del("Host") + client.host = host + } return client } @@ -71,16 +78,11 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. "Proxy-Connection": []string{"Keep-Alive"}, }, } - var host string - if c.headers != nil { - host = c.headers.Get("Host") - c.headers.Del("Host") - } - if host != "" && host != destination.Fqdn { + if c.host != "" && c.host != destination.Fqdn { if c.path != "" { return nil, E.New("Host header and path are not allowed at the same time") } - request.Host = host + request.Host = c.host request.URL = &url.URL{Opaque: destination.String()} } else { request.URL = &url.URL{Host: destination.String()} From 028dcd722c1013264b9d440e036f27162663a845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 29 Nov 2023 11:53:09 +0800 Subject: [PATCH 056/141] Add serialize support for domain matcher --- common/domain/matcher.go | 80 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/common/domain/matcher.go b/common/domain/matcher.go index 95dc0dc..258f3f9 100644 --- a/common/domain/matcher.go +++ b/common/domain/matcher.go @@ -1,8 +1,12 @@ package domain import ( + "encoding/binary" + "io" "sort" "unicode/utf8" + + "github.com/sagernet/sing/common/rw" ) type Matcher struct { @@ -27,15 +31,87 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher { domainList = append(domainList, reverseDomain(domain)) } sort.Strings(domainList) - return &Matcher{ - newSuccinctSet(domainList), + return &Matcher{newSuccinctSet(domainList)} +} + +func ReadMatcher(reader io.Reader) (*Matcher, error) { + var version uint8 + err := binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return nil, err } + leavesLength, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + leaves := make([]uint64, leavesLength) + err = binary.Read(reader, binary.BigEndian, leaves) + if err != nil { + return nil, err + } + labelBitmapLength, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + labelBitmap := make([]uint64, labelBitmapLength) + err = binary.Read(reader, binary.BigEndian, labelBitmap) + if err != nil { + return nil, err + } + labelsLength, err := rw.ReadUVariant(reader) + if err != nil { + return nil, err + } + labels := make([]byte, labelsLength) + _, err = io.ReadFull(reader, labels) + if err != nil { + return nil, err + } + set := &succinctSet{ + leaves: leaves, + labelBitmap: labelBitmap, + labels: labels, + } + set.init() + return &Matcher{set}, nil } func (m *Matcher) Match(domain string) bool { return m.set.Has(reverseDomain(domain)) } +func (m *Matcher) Write(writer io.Writer) error { + err := binary.Write(writer, binary.BigEndian, byte(1)) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(m.set.leaves))) + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, m.set.leaves) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(m.set.labelBitmap))) + if err != nil { + return err + } + err = binary.Write(writer, binary.BigEndian, m.set.labelBitmap) + if err != nil { + return err + } + err = rw.WriteUVariant(writer, uint64(len(m.set.labels))) + if err != nil { + return err + } + _, err = writer.Write(m.set.labels) + if err != nil { + return err + } + return nil +} + func reverseDomain(domain string) string { l := len(domain) b := make([]byte, l) From 99d07d6e5a9638ad9b841cb408269553f7179400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 30 Nov 2023 16:08:57 +0800 Subject: [PATCH 057/141] Add concurrency limit for task --- common/task/task.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/common/task/task.go b/common/task/task.go index d37d9e7..6b266d2 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -23,6 +23,7 @@ type Group struct { tasks []taskItem cleanup func() fastFail bool + queue chan struct{} } func (g *Group) Append(name string, f func(ctx context.Context) error) { @@ -46,6 +47,13 @@ func (g *Group) FastFail() { g.fastFail = true } +func (g *Group) Concurrency(n int) { + g.queue = make(chan struct{}, n) + for i := 0; i < n; i++ { + g.queue <- struct{}{} + } +} + func (g *Group) Run(contextList ...context.Context) error { return g.RunContextList(contextList) } @@ -65,6 +73,14 @@ func (g *Group) RunContextList(contextList []context.Context) error { for _, task := range g.tasks { currentTask := task go func() { + if g.queue != nil { + <-g.queue + select { + case <-taskCancelContext.Done(): + return + default: + } + } err := currentTask.Run(taskCancelContext) errorAccess.Lock() if err != nil { @@ -83,6 +99,9 @@ func (g *Group) RunContextList(contextList []context.Context) error { taskCancel(errTaskSucceed{}) taskFinish(errTaskSucceed{}) } + if g.queue != nil { + g.queue <- struct{}{} + } }() } From d7ce998e7ea6f7f59f542ebdadba7100406057a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 1 Dec 2023 12:51:05 +0800 Subject: [PATCH 058/141] Remove legacy buffer header --- common/buf/buffer.go | 16 +++++----------- common/bufio/buffer.go | 2 +- common/bufio/chunk.go | 6 +++--- common/uot/server.go | 4 ++-- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/common/buf/buffer.go b/common/buf/buffer.go index fe8e0cb..d0f4ba4 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -11,8 +11,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -const ReversedHeader = 1024 - type Buffer struct { data []byte start int @@ -25,8 +23,6 @@ type Buffer struct { func New() *Buffer { return &Buffer{ data: Get(BufferSize), - start: ReversedHeader, - end: ReversedHeader, managed: true, } } @@ -34,8 +30,6 @@ func New() *Buffer { func NewPacket() *Buffer { return &Buffer{ data: Get(UDPBufferSize), - start: ReversedHeader, - end: ReversedHeader, managed: true, } } @@ -277,15 +271,15 @@ func (b *Buffer) Resize(start, end int) { } func (b *Buffer) Reset() { - b.start = ReversedHeader - b.end = ReversedHeader -} - -func (b *Buffer) FullReset() { b.start = 0 b.end = 0 } +// Deprecated: use Reset instead. +func (b *Buffer) FullReset() { + b.Reset() +} + func (b *Buffer) IncRef() { atomic.AddInt32(&b.refs, 1) } diff --git a/common/bufio/buffer.go b/common/bufio/buffer.go index 47c35a9..cdd2896 100644 --- a/common/bufio/buffer.go +++ b/common/bufio/buffer.go @@ -37,7 +37,7 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) { if err != nil { return } - w.buffer.FullReset() + w.buffer.Reset() } } diff --git a/common/bufio/chunk.go b/common/bufio/chunk.go index 56a733a..cd11f63 100644 --- a/common/bufio/chunk.go +++ b/common/bufio/chunk.go @@ -30,7 +30,7 @@ func (c *ChunkReader) ReadBuffer(buffer *buf.Buffer) error { } else if !c.cache.IsEmpty() { return common.Error(buffer.ReadFrom(c.cache)) } - c.cache.FullReset() + c.cache.Reset() err := c.upstream.ReadBuffer(c.cache) if err != nil { c.cache.Release() @@ -46,7 +46,7 @@ func (c *ChunkReader) Read(p []byte) (n int, err error) { } else if !c.cache.IsEmpty() { return c.cache.Read(p) } - c.cache.FullReset() + c.cache.Reset() err = c.upstream.ReadBuffer(c.cache) if err != nil { c.cache.Release() @@ -70,7 +70,7 @@ func (c *ChunkReader) ReadChunk() (*buf.Buffer, error) { } else if !c.cache.IsEmpty() { return c.cache, nil } - c.cache.FullReset() + c.cache.Reset() err := c.upstream.ReadBuffer(c.cache) if err != nil { c.cache.Release() diff --git a/common/uot/server.go b/common/uot/server.go index 78cfa6d..57dfc38 100644 --- a/common/uot/server.go +++ b/common/uot/server.go @@ -77,7 +77,7 @@ func (c *ServerConn) loopInput() { if err != nil { break } - buffer.FullReset() + buffer.Reset() _, err = buffer.ReadFullFrom(c.inputReader, int(length)) if err != nil { break @@ -95,7 +95,7 @@ func (c *ServerConn) loopOutput() { buffer := buf.NewPacket() defer buffer.Release() for { - buffer.FullReset() + buffer.Reset() n, addr, err := buffer.ReadPacketFrom(c) if err != nil { break From f23499eaea74cdc76686a6b6b2401de043e6138e Mon Sep 17 00:00:00 2001 From: H1JK Date: Sat, 18 Nov 2023 13:22:35 +0800 Subject: [PATCH 059/141] Pool allocate arrays instead of slices This is inspired by https://go-review.googlesource.com/c/net/+/539915 --- common/buf/alloc.go | 120 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 104 insertions(+), 16 deletions(-) diff --git a/common/buf/alloc.go b/common/buf/alloc.go index 5d0b248..69048c4 100644 --- a/common/buf/alloc.go +++ b/common/buf/alloc.go @@ -8,7 +8,7 @@ import ( "sync" ) -var DefaultAllocator = newDefaultAllocer() +var DefaultAllocator = newDefaultAllocator() type Allocator interface { Get(size int) []byte @@ -17,22 +17,34 @@ type Allocator interface { // defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing type defaultAllocator struct { - buffers []sync.Pool + buffers [17]sync.Pool } // NewAllocator initiates a []byte allocator for frames less than 65536 bytes, // the waste(memory fragmentation) of space allocation is guaranteed to be // no more than 50%. -func newDefaultAllocer() Allocator { - alloc := new(defaultAllocator) - alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K - for k := range alloc.buffers { - i := k - alloc.buffers[k].New = func() any { - return make([]byte, 1< 64K + {New: func() any { return new([1]byte) }}, + {New: func() any { return new([1 << 1]byte) }}, + {New: func() any { return new([1 << 2]byte) }}, + {New: func() any { return new([1 << 3]byte) }}, + {New: func() any { return new([1 << 4]byte) }}, + {New: func() any { return new([1 << 5]byte) }}, + {New: func() any { return new([1 << 6]byte) }}, + {New: func() any { return new([1 << 7]byte) }}, + {New: func() any { return new([1 << 8]byte) }}, + {New: func() any { return new([1 << 9]byte) }}, + {New: func() any { return new([1 << 10]byte) }}, + {New: func() any { return new([1 << 11]byte) }}, + {New: func() any { return new([1 << 12]byte) }}, + {New: func() any { return new([1 << 13]byte) }}, + {New: func() any { return new([1 << 14]byte) }}, + {New: func() any { return new([1 << 15]byte) }}, + {New: func() any { return new([1 << 16]byte) }}, + }, } - return alloc } // Get a []byte from pool with most appropriate cap @@ -41,12 +53,50 @@ func (alloc *defaultAllocator) Get(size int) []byte { return nil } - bits := msb(size) - if size == 1< 65536 || cap(buf) != 1< Date: Sat, 18 Nov 2023 21:37:58 +0800 Subject: [PATCH 060/141] Shrink buf pool range --- common/buf/alloc.go | 85 ++++++++++++++++----------------------------- 1 file changed, 30 insertions(+), 55 deletions(-) diff --git a/common/buf/alloc.go b/common/buf/alloc.go index 69048c4..b556d93 100644 --- a/common/buf/alloc.go +++ b/common/buf/alloc.go @@ -17,7 +17,7 @@ type Allocator interface { // defaultAllocator for incoming frames, optimized to prevent overwriting after zeroing type defaultAllocator struct { - buffers [17]sync.Pool + buffers [11]sync.Pool } // NewAllocator initiates a []byte allocator for frames less than 65536 bytes, @@ -25,13 +25,7 @@ type defaultAllocator struct { // no more than 50%. func newDefaultAllocator() Allocator { return &defaultAllocator{ - buffers: [...]sync.Pool{ // 1B -> 64K - {New: func() any { return new([1]byte) }}, - {New: func() any { return new([1 << 1]byte) }}, - {New: func() any { return new([1 << 2]byte) }}, - {New: func() any { return new([1 << 3]byte) }}, - {New: func() any { return new([1 << 4]byte) }}, - {New: func() any { return new([1 << 5]byte) }}, + buffers: [...]sync.Pool{ // 64B -> 64K {New: func() any { return new([1 << 6]byte) }}, {New: func() any { return new([1 << 7]byte) }}, {New: func() any { return new([1 << 8]byte) }}, @@ -53,46 +47,38 @@ func (alloc *defaultAllocator) Get(size int) []byte { return nil } - index := msb(size) - if size != 1< 64 { + index = msb(size) + if size != 1< 65536 || cap(buf) != 1< Date: Tue, 5 Dec 2023 14:18:53 +0800 Subject: [PATCH 061/141] contentjson: Import from go1.21.4 --- common/json/internal/contextjson/decode.go | 1293 +++++++++++++++++++ common/json/internal/contextjson/encode.go | 1283 ++++++++++++++++++ common/json/internal/contextjson/fold.go | 48 + common/json/internal/contextjson/indent.go | 174 +++ common/json/internal/contextjson/scanner.go | 610 +++++++++ common/json/internal/contextjson/stream.go | 513 ++++++++ common/json/internal/contextjson/tables.go | 218 ++++ common/json/internal/contextjson/tags.go | 38 + 8 files changed, 4177 insertions(+) create mode 100644 common/json/internal/contextjson/decode.go create mode 100644 common/json/internal/contextjson/encode.go create mode 100644 common/json/internal/contextjson/fold.go create mode 100644 common/json/internal/contextjson/indent.go create mode 100644 common/json/internal/contextjson/scanner.go create mode 100644 common/json/internal/contextjson/stream.go create mode 100644 common/json/internal/contextjson/tables.go create mode 100644 common/json/internal/contextjson/tags.go diff --git a/common/json/internal/contextjson/decode.go b/common/json/internal/contextjson/decode.go new file mode 100644 index 0000000..c222bec --- /dev/null +++ b/common/json/internal/contextjson/decode.go @@ -0,0 +1,1293 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "encoding" + "encoding/base64" + "fmt" + "reflect" + "strconv" + "strings" + "unicode" + "unicode/utf16" + "unicode/utf8" +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an InvalidUnmarshalError. +// +// Unmarshal uses the inverse of the encodings that +// Marshal uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a value implementing the Unmarshaler interface, +// Unmarshal calls that value's UnmarshalJSON method, including +// when the input is a JSON null. +// Otherwise, if the value implements encoding.TextUnmarshaler +// and the input is a JSON quoted string, Unmarshal calls that value's +// UnmarshalText method with the unquoted form of the string. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. By +// default, object keys which don't have a corresponding struct field are +// ignored (see Decoder.DisallowUnknownFields for an alternative). +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// bool, for JSON booleans +// float64, for JSON numbers +// string, for JSON strings +// []interface{}, for JSON arrays +// map[string]interface{}, for JSON objects +// nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a map, Unmarshal first establishes a map to +// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the JSON object into the map. The map's key type must +// either be any string type, an integer, implement json.Unmarshaler, or +// implement encoding.TextUnmarshaler. +// +// If the JSON-encoded data contain a syntax error, Unmarshal returns a SyntaxError. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an UnmarshalTypeError describing the earliest such error. In any +// case, it's not guaranteed that all the remaining fields following +// the problematic one will be unmarshaled into the target object. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// “not present,” unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +func Unmarshal(data []byte, v any) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +// +// By convention, to approximate the behavior of Unmarshal itself, +// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes + Struct string // name of the struct type containing the field + Field string // the full path from root node to the field +} + +func (e *UnmarshalTypeError) Error() string { + if e.Struct != "" || e.Field != "" { + return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String() + } + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// +// Deprecated: No longer used; kept for compatibility. +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. +// (The argument to Unmarshal must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + d.scanWhile(scanSkipSpace) + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err := d.value(rv) + if err != nil { + return d.addErrorContext(err) + } + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// An errorContext provides context for type errors during decoding. +type errorContext struct { + Struct reflect.Type + FieldStack []string +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // next read offset in data + opcode int // last read result + scan scanner + errorContext *errorContext + savedError error + useNumber bool + disallowUnknownFields bool +} + +// readIndex returns the position of the last byte read. +func (d *decodeState) readIndex() int { + return d.off - 1 +} + +// phasePanicMsg is used as a panic message when we end up with something that +// shouldn't happen. It can indicate a bug in the JSON decoder, or that +// something is editing the data slice while the decoder executes. +const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?" + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + if d.errorContext != nil { + d.errorContext.Struct = nil + // Reuse the allocated space for the FieldStack slice. + d.errorContext.FieldStack = d.errorContext.FieldStack[:0] + } + return d +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = d.addErrorContext(err) + } +} + +// addErrorContext returns a new error enhanced with information from d.errorContext +func (d *decodeState) addErrorContext(err error) error { + if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) { + switch err := err.(type) { + case *UnmarshalTypeError: + err.Struct = d.errorContext.Struct.Name() + err.Field = strings.Join(d.errorContext.FieldStack, ".") + } + } + return err +} + +// skip scans to the end of what was started. +func (d *decodeState) skip() { + s, data, i := &d.scan, d.data, d.off + depth := len(s.parseState) + for { + op := s.step(s, data[i]) + i++ + if len(s.parseState) < depth { + d.off = i + d.opcode = op + return + } + } +} + +// scanNext processes the byte at d.data[d.off]. +func (d *decodeState) scanNext() { + if d.off < len(d.data) { + d.opcode = d.scan.step(&d.scan, d.data[d.off]) + d.off++ + } else { + d.opcode = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +func (d *decodeState) scanWhile(op int) { + s, data, i := &d.scan, d.data, d.off + for i < len(data) { + newOp := s.step(s, data[i]) + i++ + if newOp != op { + d.opcode = newOp + d.off = i + return + } + } + + d.off = len(data) + 1 // mark processed EOF with len+1 + d.opcode = d.scan.eof() +} + +// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the +// common case where we're decoding a literal. The decoder scans the input +// twice, once for syntax errors and to check the length of the value, and the +// second to perform the decoding. +// +// Only in the second step do we use decodeState to tokenize literals, so we +// know there aren't any syntax errors. We can take advantage of that knowledge, +// and scan a literal's bytes much more quickly. +func (d *decodeState) rescanLiteral() { + data, i := d.data, d.off +Switch: + switch data[i-1] { + case '"': // string + for ; i < len(data); i++ { + switch data[i] { + case '\\': + i++ // escaped char + case '"': + i++ // tokenize the closing quote too + break Switch + } + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number + for ; i < len(data); i++ { + switch data[i] { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + '.', 'e', 'E', '+', '-': + default: + break Switch + } + } + case 't': // true + i += len("rue") + case 'f': // false + i += len("alse") + case 'n': // null + i += len("ull") + } + if i < len(data) { + d.opcode = stateEndValue(&d.scan, data[i]) + } else { + d.opcode = scanEnd + } + d.off = i + 1 +} + +// value consumes a JSON value from d.data[d.off-1:], decoding into v, and +// reads the following byte ahead. If v is invalid, the value is discarded. +// The first byte of the value has been read already. +func (d *decodeState) value(v reflect.Value) error { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray: + if v.IsValid() { + if err := d.array(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginObject: + if v.IsValid() { + if err := d.object(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginLiteral: + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + if v.IsValid() { + if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil { + return err + } + } + } + return nil +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() any { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray, scanBeginObject: + d.skip() + d.scanNext() + + case scanBeginLiteral: + v := d.literalInterface() + switch v.(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// If it encounters an Unmarshaler, indirect stops and returns that. +// If decodingNull is true, indirect stops at the first settable pointer so it +// can be set to nil. +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Pointer { + break + } + + if decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if !decodingNull { + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into v. +// The first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + ai := d.arrayInterface() + v.Set(reflect.ValueOf(ai)) + return nil + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + case reflect.Array, reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + // Expand slice length, growing the slice if necessary. + if v.Kind() == reflect.Slice { + if i >= v.Cap() { + v.Grow(1) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + if err := d.value(v.Index(i)); err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + if err := d.value(reflect.Value{}); err != nil { + return err + } + } + i++ + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + for ; i < v.Len(); i++ { + v.Index(i).SetZero() // zero remainder of array + } + } else { + v.SetLen(i) // truncate the slice + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } + return nil +} + +var ( + nullLiteral = []byte("null") + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +// object consumes an object from d.data[d.off-1:], decoding into v. +// The first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + t := v.Type() + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + oi := d.objectInterface() + v.Set(reflect.ValueOf(oi)) + return nil + } + + var fields structFields + + // Check type of target: + // struct or + // map[T1]T2 where T1 is string, an integer type, + // or an encoding.TextUnmarshaler + switch v.Kind() { + case reflect.Map: + // Map key must either have string kind, have an integer kind, + // or be an encoding.TextUnmarshaler. + switch t.Key().Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) { + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + fields = cachedTypeFields(t) + // ok + default: + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + + var mapElem reflect.Value + var origErrorContext errorContext + if d.errorContext != nil { + origErrorContext = *d.errorContext + } + + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquoteBytes(item) + if !ok { + panic(phasePanicMsg) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := t.Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.SetZero() + } + subv = mapElem + } else { + f := fields.byExactName[string(key)] + if f == nil { + f = fields.byFoldedName[string(foldName(key))] + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Pointer { + if subv.IsNil() { + // If a struct embeds a pointer to an unexported type, + // it is not possible to set a newly allocated value + // since the field is unexported. + // + // See https://golang.org/issue/21357 + if !subv.CanSet() { + d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem())) + // Invalidate subv to ensure d.value(subv) skips over + // the JSON value without assigning it to subv. + subv = reflect.Value{} + destring = false + break + } + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name) + d.errorContext.Struct = t + } else if d.disallowUnknownFields { + d.saveError(fmt.Errorf("json: unknown field %q", key)) + } + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + if err := d.literalStore(nullLiteral, subv, false); err != nil { + return err + } + case string: + if err := d.literalStore([]byte(qv), subv, true); err != nil { + return err + } + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + if err := d.value(subv); err != nil { + return err + } + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kt := t.Key() + var kv reflect.Value + switch { + case reflect.PointerTo(kt).Implements(textUnmarshalerType): + kv = reflect.New(kt) + if err := d.literalStore(item, kv, true); err != nil { + return err + } + kv = kv.Elem() + case kt.Kind() == reflect.String: + kv = reflect.ValueOf(key).Convert(kt) + default: + switch kt.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := string(key) + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || reflect.Zero(kt).OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.ValueOf(n).Convert(kt) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := string(key) + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || reflect.Zero(kt).OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.ValueOf(n).Convert(kt) + default: + panic("json: Unexpected key type") // should never occur + } + } + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.errorContext != nil { + // Reset errorContext to its original state. + // Keep the same underlying array for FieldStack, to reuse the + // space and avoid unnecessary allocs. + d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)] + d.errorContext.Struct = origErrorContext.Struct + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return nil +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (any, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, &UnmarshalTypeError{Value: "number " + s, Type: reflect.TypeOf(0.0), Offset: int64(d.off)} + } + return f, nil +} + +var numberType = reflect.TypeOf(Number("")) + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error { + // Check for unmarshaler. + if len(item) == 0 { + // Empty string given + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + isNull := item[0] == 'n' // null + u, ut, pv := indirect(v, isNull) + if u != nil { + return u.UnmarshalJSON(item) + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + val := "number" + switch item[0] { + case 'n': + val = "null" + case 't', 'f': + val = "bool" + } + d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())}) + return nil + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + return ut.UnmarshalText(s) + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "null" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: + v.SetZero() + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := item[0] == 't' + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "true" && string(item) != "false" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + if v.Type() == numberType && !isValidNumber(string(s)) { + return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item) + } + v.SetString(string(s)) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + s := string(item) + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + // s must be a valid number, because it's + // already been tokenized. + v.SetString(s) + break + } + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Interface: + n, err := d.convertNumber(s) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(s, v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetFloat(n) + } + } + return nil +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns interface{} +func (d *decodeState) valueInterface() (val any) { + switch d.opcode { + default: + panic(phasePanicMsg) + case scanBeginArray: + val = d.arrayInterface() + d.scanNext() + case scanBeginObject: + val = d.objectInterface() + d.scanNext() + case scanBeginLiteral: + val = d.literalInterface() + } + return +} + +// arrayInterface is like array but returns []interface{}. +func (d *decodeState) arrayInterface() []any { + v := make([]any, 0) + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + return v +} + +// objectInterface is like object but returns map[string]interface{}. +func (d *decodeState) objectInterface() map[string]any { + m := make(map[string]any) + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read string key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return m +} + +// literalInterface consumes and returns a literal from d.data[d.off-1:] and +// it reads the following byte ahead. The first byte of the literal has been +// read already (that's how the caller knows it's a literal). +func (d *decodeState) literalInterface() any { + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + item := d.data[start:d.readIndex()] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + panic(phasePanicMsg) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + var r rune + for _, c := range s[2:6] { + switch { + case '0' <= c && c <= '9': + c = c - '0' + case 'a' <= c && c <= 'f': + c = c - 'a' + 10 + case 'A' <= c && c <= 'F': + c = c - 'A' + 10 + default: + return -1 + } + r = r*16 + rune(c) + } + return r +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/common/json/internal/contextjson/encode.go b/common/json/internal/contextjson/encode.go new file mode 100644 index 0000000..6da0bd9 --- /dev/null +++ b/common/json/internal/contextjson/encode.go @@ -0,0 +1,1283 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON as defined in +// RFC 7159. The mapping between JSON and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "encoding" + "encoding/base64" + "fmt" + "math" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements the Marshaler interface +// and is not a nil pointer, Marshal calls its MarshalJSON method +// to produce JSON. If no MarshalJSON method is present but the +// value implements encoding.TextMarshaler instead, Marshal calls +// its MarshalText method and encodes the result as a JSON string. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// UnmarshalJSON. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and Number values encode as JSON numbers. +// NaN and +/-Inf values will return an [UnsupportedValueError]. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// So that the JSON will be safe to embed inside HTML