mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Add deadline reader/conn
This commit is contained in:
parent
cee74ef1f4
commit
df54c89b04
10 changed files with 707 additions and 3 deletions
15
common/atomic/types.go
Normal file
15
common/atomic/types.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
//go:build go1.19
|
||||
|
||||
package atomic
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type (
|
||||
Bool = atomic.Bool
|
||||
Int32 = atomic.Int32
|
||||
Int64 = atomic.Int64
|
||||
Uint32 = atomic.Uint32
|
||||
Uint64 = atomic.Uint64
|
||||
Uintptr = atomic.Uintptr
|
||||
Value = atomic.Value
|
||||
)
|
198
common/atomic/types_compat.go
Normal file
198
common/atomic/types_compat.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !go1.19
|
||||
|
||||
package atomic
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// A Bool is an atomic boolean value.
|
||||
// The zero value is false.
|
||||
type Bool struct {
|
||||
_ noCopy
|
||||
v uint32
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Bool) Load() bool { return atomic.LoadUint32(&x.v) != 0 }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Bool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Bool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for the boolean value x.
|
||||
func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) {
|
||||
return atomic.CompareAndSwapUint32(&x.v, b32(old), b32(new))
|
||||
}
|
||||
|
||||
// b32 returns a uint32 0 or 1 representing b.
|
||||
func b32(b bool) uint32 {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// A Pointer is an atomic pointer of type *T. The zero value is a nil *T.
|
||||
type Pointer[T any] struct {
|
||||
// Mention *T in a field to disallow conversion between Pointer types.
|
||||
// See go.dev/issue/56603 for more details.
|
||||
// Use *T, not T, to avoid spurious recursive type definition errors.
|
||||
_ [0]*T
|
||||
|
||||
_ noCopy
|
||||
v unsafe.Pointer
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Pointer[T]) Load() *T { return (*T)(atomic.LoadPointer(&x.v)) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Pointer[T]) Store(val *T) { atomic.StorePointer(&x.v, unsafe.Pointer(val)) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Pointer[T]) Swap(new *T) (old *T) {
|
||||
return (*T)(atomic.SwapPointer(&x.v, unsafe.Pointer(new)))
|
||||
}
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) {
|
||||
return atomic.CompareAndSwapPointer(&x.v, unsafe.Pointer(old), unsafe.Pointer(new))
|
||||
}
|
||||
|
||||
// An Int32 is an atomic int32. The zero value is zero.
|
||||
type Int32 struct {
|
||||
_ noCopy
|
||||
v int32
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Int32) Load() int32 { return atomic.LoadInt32(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Int32) Store(val int32) { atomic.StoreInt32(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Int32) CompareAndSwap(old, new int32) (swapped bool) {
|
||||
return atomic.CompareAndSwapInt32(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Int32) Add(delta int32) (new int32) { return atomic.AddInt32(&x.v, delta) }
|
||||
|
||||
// An Int64 is an atomic int64. The zero value is zero.
|
||||
type Int64 struct {
|
||||
_ noCopy
|
||||
v int64
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Int64) Load() int64 { return atomic.LoadInt64(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Int64) Store(val int64) { atomic.StoreInt64(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Int64) CompareAndSwap(old, new int64) (swapped bool) {
|
||||
return atomic.CompareAndSwapInt64(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Int64) Add(delta int64) (new int64) { return atomic.AddInt64(&x.v, delta) }
|
||||
|
||||
// An Uint32 is an atomic uint32. The zero value is zero.
|
||||
type Uint32 struct {
|
||||
_ noCopy
|
||||
v uint32
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Uint32) Load() uint32 { return atomic.LoadUint32(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Uint32) Store(val uint32) { atomic.StoreUint32(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Uint32) CompareAndSwap(old, new uint32) (swapped bool) {
|
||||
return atomic.CompareAndSwapUint32(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(&x.v, delta) }
|
||||
|
||||
// An Uint64 is an atomic uint64. The zero value is zero.
|
||||
type Uint64 struct {
|
||||
_ noCopy
|
||||
v uint64
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Uint64) Load() uint64 { return atomic.LoadUint64(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Uint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
|
||||
return atomic.CompareAndSwapUint64(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) }
|
||||
|
||||
// An Uintptr is an atomic uintptr. The zero value is zero.
|
||||
type Uintptr struct {
|
||||
_ noCopy
|
||||
v uintptr
|
||||
}
|
||||
|
||||
// Load atomically loads and returns the value stored in x.
|
||||
func (x *Uintptr) Load() uintptr { return atomic.LoadUintptr(&x.v) }
|
||||
|
||||
// Store atomically stores val into x.
|
||||
func (x *Uintptr) Store(val uintptr) { atomic.StoreUintptr(&x.v, val) }
|
||||
|
||||
// Swap atomically stores new into x and returns the previous value.
|
||||
func (x *Uintptr) Swap(new uintptr) (old uintptr) { return atomic.SwapUintptr(&x.v, new) }
|
||||
|
||||
// CompareAndSwap executes the compare-and-swap operation for x.
|
||||
func (x *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) {
|
||||
return atomic.CompareAndSwapUintptr(&x.v, old, new)
|
||||
}
|
||||
|
||||
// Add atomically adds delta to x and returns the new value.
|
||||
func (x *Uintptr) Add(delta uintptr) (new uintptr) { return atomic.AddUintptr(&x.v, delta) }
|
||||
|
||||
// noCopy may be added to structs which must not be copied
|
||||
// after the first use.
|
||||
//
|
||||
// See https://golang.org/issues/8005#issuecomment-190753527
|
||||
// for details.
|
||||
//
|
||||
// Note that it must not be embedded, due to the Lock and Unlock methods.
|
||||
type noCopy struct{}
|
||||
|
||||
// Lock is a no-op used by -copylocks checker from `go vet`.
|
||||
func (*noCopy) Lock() {}
|
||||
func (*noCopy) Unlock() {}
|
||||
|
||||
type Value = atomic.Value
|
|
@ -164,13 +164,13 @@ func (b *Buffer) WriteByte(d byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadOnceFrom(r io.Reader) (int64, error) {
|
||||
func (b *Buffer) ReadOnceFrom(r io.Reader) (int, error) {
|
||||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
n, err := r.Read(b.FreeBytes())
|
||||
b.end += n
|
||||
return int64(n), err
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
|
||||
|
@ -184,7 +184,8 @@ func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
|
|||
|
||||
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
|
||||
if min <= 0 {
|
||||
return b.ReadOnceFrom(r)
|
||||
n, err := b.ReadOnceFrom(r)
|
||||
return int64(n), err
|
||||
}
|
||||
if b.IsFull() {
|
||||
return 0, io.ErrShortBuffer
|
||||
|
|
47
common/bufio/deadline/conn.go
Normal file
47
common/bufio/deadline/conn.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
N.ExtendedConn
|
||||
reader *Reader
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
return c.reader.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.reader.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return c.reader.ReaderReplaceable()
|
||||
}
|
||||
|
||||
func (c *Conn) UpstreamReader() any {
|
||||
return c.reader.UpstreamReader()
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.ExtendedConn
|
||||
}
|
47
common/bufio/deadline/packet_conn.go
Normal file
47
common/bufio/deadline/packet_conn.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type PacketConn struct {
|
||||
N.NetPacketConn
|
||||
reader *PacketReader
|
||||
}
|
||||
|
||||
func NewPacketConn(conn N.NetPacketConn) *PacketConn {
|
||||
return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
return c.reader.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
return c.reader.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetReadDeadline(t time.Time) error {
|
||||
return c.NetPacketConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReaderReplaceable() bool {
|
||||
return c.reader.ReaderReplaceable()
|
||||
}
|
||||
|
||||
func (c *PacketConn) UpstreamReader() any {
|
||||
return c.reader.UpstreamReader()
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *PacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
152
common/bufio/deadline/packet_reader.go
Normal file
152
common/bufio/deadline/packet_reader.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TimeoutPacketReader interface {
|
||||
N.NetPacketConn
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type PacketReader struct {
|
||||
TimeoutPacketReader
|
||||
deadline time.Time
|
||||
disablePipe atomic.Bool
|
||||
pipeDeadline pipeDeadline
|
||||
cacheAccess sync.RWMutex
|
||||
inRead atomic.Bool
|
||||
cached bool
|
||||
cachedBuffer *buf.Buffer
|
||||
cachedAddr M.Socksaddr
|
||||
cachedErr error
|
||||
}
|
||||
|
||||
func NewPacketReader(reader TimeoutPacketReader) *PacketReader {
|
||||
return &PacketReader{TimeoutPacketReader: reader}
|
||||
}
|
||||
|
||||
func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
if r.disablePipe.Load() || r.deadline.IsZero() {
|
||||
return r.TimeoutPacketReader.ReadFrom(p)
|
||||
}
|
||||
r.cacheAccess.Lock()
|
||||
if r.cached {
|
||||
n = copy(p, r.cachedBuffer.Bytes())
|
||||
addr = r.cachedAddr.UDPAddr()
|
||||
err = r.cachedErr
|
||||
r.cachedBuffer.Release()
|
||||
r.cached = false
|
||||
r.cacheAccess.Unlock()
|
||||
return
|
||||
}
|
||||
r.cacheAccess.Unlock()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
n, addr, err = r.pipeReadFrom(p, r.pipeDeadline.wait())
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr net.Addr, err error) {
|
||||
r.cacheAccess.Lock()
|
||||
r.inRead.Store(true)
|
||||
defer func() {
|
||||
r.inRead.Store(false)
|
||||
r.cacheAccess.Unlock()
|
||||
}()
|
||||
cacheBuffer := buf.NewSize(len(p))
|
||||
n, addr, err = r.TimeoutPacketReader.ReadFrom(cacheBuffer.Bytes())
|
||||
if isClosedChan(cancel) {
|
||||
r.cached = true
|
||||
r.cachedBuffer = cacheBuffer
|
||||
r.cachedAddr = M.SocksaddrFromNet(addr)
|
||||
r.cachedErr = err
|
||||
} else {
|
||||
copy(p, cacheBuffer.Bytes())
|
||||
cacheBuffer.Release()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if r.disablePipe.Load() || r.deadline.IsZero() {
|
||||
return r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
}
|
||||
r.cacheAccess.Lock()
|
||||
if r.cached {
|
||||
destination = r.cachedAddr
|
||||
err = r.cachedErr
|
||||
buffer.Write(r.cachedBuffer.Bytes())
|
||||
r.cachedBuffer.Release()
|
||||
r.cached = false
|
||||
r.cacheAccess.Unlock()
|
||||
return
|
||||
}
|
||||
r.cacheAccess.Unlock()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
destination, err = r.pipeReadPacket(buffer, r.pipeDeadline.wait())
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, cancel chan struct{}) (destination M.Socksaddr, err error) {
|
||||
r.cacheAccess.Lock()
|
||||
r.inRead.Store(true)
|
||||
defer func() {
|
||||
r.inRead.Store(false)
|
||||
r.cacheAccess.Unlock()
|
||||
}()
|
||||
cacheBuffer := buf.NewSize(buffer.FreeLen())
|
||||
destination, err = r.TimeoutPacketReader.ReadPacket(cacheBuffer)
|
||||
if isClosedChan(cancel) {
|
||||
r.cached = true
|
||||
r.cachedBuffer = cacheBuffer
|
||||
r.cachedAddr = destination
|
||||
r.cachedErr = err
|
||||
} else {
|
||||
common.Must1(buffer.ReadOnceFrom(cacheBuffer))
|
||||
cacheBuffer.Release()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *PacketReader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline = t
|
||||
r.pipeDeadline.set(t)
|
||||
if r.disablePipe.Load() || !r.inRead.Load() {
|
||||
r.disablePipe.Store(true)
|
||||
return r.TimeoutPacketReader.SetReadDeadline(t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PacketReader) ReaderReplaceable() bool {
|
||||
return r.deadline.IsZero()
|
||||
}
|
||||
|
||||
func (r *PacketReader) UpstreamReader() any {
|
||||
return r.TimeoutPacketReader
|
||||
}
|
78
common/bufio/deadline/pipe.go
Normal file
78
common/bufio/deadline/pipe.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package deadline
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pipeDeadline is an abstraction for handling timeouts.
|
||||
type pipeDeadline struct {
|
||||
mu sync.Mutex // Guards timer and cancel
|
||||
timer *time.Timer
|
||||
cancel chan struct{} // Must be non-nil
|
||||
}
|
||||
|
||||
func makePipeDeadline() pipeDeadline {
|
||||
return pipeDeadline{cancel: make(chan struct{})}
|
||||
}
|
||||
|
||||
// set sets the point in time when the deadline will time out.
|
||||
// A timeout event is signaled by closing the channel returned by waiter.
|
||||
// Once a timeout has occurred, the deadline can be refreshed by specifying a
|
||||
// t value in the future.
|
||||
//
|
||||
// A zero value for t prevents timeout.
|
||||
func (d *pipeDeadline) set(t time.Time) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.timer != nil && !d.timer.Stop() {
|
||||
<-d.cancel // Wait for the timer callback to finish and close cancel
|
||||
}
|
||||
d.timer = nil
|
||||
|
||||
// Time is zero, then there is no deadline.
|
||||
closed := isClosedChan(d.cancel)
|
||||
if t.IsZero() {
|
||||
if closed {
|
||||
d.cancel = make(chan struct{})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Time in the future, setup a timer to cancel in the future.
|
||||
if dur := time.Until(t); dur > 0 {
|
||||
if closed {
|
||||
d.cancel = make(chan struct{})
|
||||
}
|
||||
d.timer = time.AfterFunc(dur, func() {
|
||||
close(d.cancel)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Time in the past, so close immediately.
|
||||
if !closed {
|
||||
close(d.cancel)
|
||||
}
|
||||
}
|
||||
|
||||
// wait returns a channel that is closed when the deadline is exceeded.
|
||||
func (d *pipeDeadline) wait() chan struct{} {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.cancel
|
||||
}
|
||||
|
||||
func isClosedChan(c <-chan struct{}) bool {
|
||||
select {
|
||||
case <-c:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
157
common/bufio/deadline/reader.go
Normal file
157
common/bufio/deadline/reader.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TimeoutReader interface {
|
||||
io.Reader
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
type Reader struct {
|
||||
N.ExtendedReader
|
||||
timeoutReader TimeoutReader
|
||||
deadline time.Time
|
||||
disablePipe atomic.Bool
|
||||
pipeDeadline pipeDeadline
|
||||
cacheAccess sync.RWMutex
|
||||
inRead atomic.Bool
|
||||
cached bool
|
||||
cachedBuffer *buf.Buffer
|
||||
cachedErr error
|
||||
}
|
||||
|
||||
func NewReader(reader TimeoutReader) *Reader {
|
||||
return &Reader{ExtendedReader: bufio.NewExtendedReader(reader), timeoutReader: reader}
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||
if r.disablePipe.Load() || r.deadline.IsZero() {
|
||||
return r.ExtendedReader.Read(p)
|
||||
}
|
||||
r.cacheAccess.Lock()
|
||||
if r.cached {
|
||||
n = copy(p, r.cachedBuffer.Bytes())
|
||||
err = r.cachedErr
|
||||
r.cachedBuffer.Advance(n)
|
||||
if r.cachedBuffer.IsEmpty() {
|
||||
r.cachedBuffer.Release()
|
||||
r.cached = false
|
||||
}
|
||||
r.cacheAccess.Unlock()
|
||||
return
|
||||
}
|
||||
r.cacheAccess.Unlock()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
n, err = r.pipeRead(p, r.pipeDeadline.wait())
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) {
|
||||
r.cacheAccess.Lock()
|
||||
r.inRead.Store(true)
|
||||
defer func() {
|
||||
r.inRead.Store(false)
|
||||
r.cacheAccess.Unlock()
|
||||
}()
|
||||
|
||||
buffer := buf.NewSize(len(p))
|
||||
n, err = buffer.ReadOnceFrom(r.ExtendedReader)
|
||||
if isClosedChan(cancel) {
|
||||
r.cached = true
|
||||
r.cachedBuffer = buffer
|
||||
r.cachedErr = err
|
||||
} else {
|
||||
n = copy(p, buffer.Bytes())
|
||||
buffer.Release()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if r.disablePipe.Load() || r.deadline.IsZero() {
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
}
|
||||
r.cacheAccess.Lock()
|
||||
if r.cached {
|
||||
n := copy(buffer.FreeBytes(), r.cachedBuffer.Bytes())
|
||||
err := r.cachedErr
|
||||
buffer.Resize(buffer.Start(), n)
|
||||
r.cachedBuffer.Advance(n)
|
||||
if r.cachedBuffer.IsEmpty() {
|
||||
r.cachedBuffer.Release()
|
||||
r.cached = false
|
||||
}
|
||||
r.cacheAccess.Unlock()
|
||||
return err
|
||||
}
|
||||
r.cacheAccess.Unlock()
|
||||
done := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
err = r.pipeReadBuffer(buffer, r.pipeDeadline.wait())
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return err
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, cancel chan struct{}) error {
|
||||
r.cacheAccess.Lock()
|
||||
r.inRead.Store(true)
|
||||
defer func() {
|
||||
r.inRead.Store(false)
|
||||
r.cacheAccess.Unlock()
|
||||
}()
|
||||
cacheBuffer := buf.NewSize(buffer.FreeLen())
|
||||
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
|
||||
if isClosedChan(cancel) {
|
||||
r.cached = true
|
||||
r.cachedBuffer = cacheBuffer
|
||||
r.cachedErr = err
|
||||
} else {
|
||||
common.Must1(buffer.ReadOnceFrom(cacheBuffer))
|
||||
cacheBuffer.Release()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Reader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline = t
|
||||
r.pipeDeadline.set(t)
|
||||
if r.disablePipe.Load() || !r.inRead.Load() {
|
||||
r.disablePipe.Store(true)
|
||||
return r.timeoutReader.SetReadDeadline(t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) ReaderReplaceable() bool {
|
||||
return r.disablePipe.Load() || r.deadline.IsZero()
|
||||
}
|
||||
|
||||
func (r *Reader) UpstreamReader() any {
|
||||
return r.ExtendedReader
|
||||
}
|
|
@ -14,6 +14,13 @@ type FallbackPacketConn struct {
|
|||
N.PacketConn
|
||||
}
|
||||
|
||||
func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn {
|
||||
if packetConn, loaded := conn.(N.NetPacketConn); loaded {
|
||||
return packetConn
|
||||
}
|
||||
return &FallbackPacketConn{PacketConn: conn}
|
||||
}
|
||||
|
||||
func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer := buf.With(p)
|
||||
destination, err := c.ReadPacket(buffer)
|
||||
|
|
|
@ -21,10 +21,12 @@ type TimeoutPacketReader interface {
|
|||
}
|
||||
|
||||
type NetPacketReader interface {
|
||||
PacketReader
|
||||
ReadFrom(p []byte) (n int, addr net.Addr, err error)
|
||||
}
|
||||
|
||||
type NetPacketWriter interface {
|
||||
PacketWriter
|
||||
WriteTo(p []byte, addr net.Addr) (n int, err error)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue