Add deadline reader/conn

This commit is contained in:
世界 2023-04-08 13:24:15 +08:00
parent cee74ef1f4
commit df54c89b04
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 707 additions and 3 deletions

15
common/atomic/types.go Normal file
View file

@ -0,0 +1,15 @@
//go:build go1.19
package atomic
import "sync/atomic"
type (
Bool = atomic.Bool
Int32 = atomic.Int32
Int64 = atomic.Int64
Uint32 = atomic.Uint32
Uint64 = atomic.Uint64
Uintptr = atomic.Uintptr
Value = atomic.Value
)

View file

@ -0,0 +1,198 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.19
package atomic
import (
"sync/atomic"
"unsafe"
)
// A Bool is an atomic boolean value.
// The zero value is false.
type Bool struct {
_ noCopy
v uint32
}
// Load atomically loads and returns the value stored in x.
func (x *Bool) Load() bool { return atomic.LoadUint32(&x.v) != 0 }
// Store atomically stores val into x.
func (x *Bool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) }
// Swap atomically stores new into x and returns the previous value.
func (x *Bool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 }
// CompareAndSwap executes the compare-and-swap operation for the boolean value x.
func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) {
return atomic.CompareAndSwapUint32(&x.v, b32(old), b32(new))
}
// b32 returns a uint32 0 or 1 representing b.
func b32(b bool) uint32 {
if b {
return 1
}
return 0
}
// A Pointer is an atomic pointer of type *T. The zero value is a nil *T.
type Pointer[T any] struct {
// Mention *T in a field to disallow conversion between Pointer types.
// See go.dev/issue/56603 for more details.
// Use *T, not T, to avoid spurious recursive type definition errors.
_ [0]*T
_ noCopy
v unsafe.Pointer
}
// Load atomically loads and returns the value stored in x.
func (x *Pointer[T]) Load() *T { return (*T)(atomic.LoadPointer(&x.v)) }
// Store atomically stores val into x.
func (x *Pointer[T]) Store(val *T) { atomic.StorePointer(&x.v, unsafe.Pointer(val)) }
// Swap atomically stores new into x and returns the previous value.
func (x *Pointer[T]) Swap(new *T) (old *T) {
return (*T)(atomic.SwapPointer(&x.v, unsafe.Pointer(new)))
}
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) {
return atomic.CompareAndSwapPointer(&x.v, unsafe.Pointer(old), unsafe.Pointer(new))
}
// An Int32 is an atomic int32. The zero value is zero.
type Int32 struct {
_ noCopy
v int32
}
// Load atomically loads and returns the value stored in x.
func (x *Int32) Load() int32 { return atomic.LoadInt32(&x.v) }
// Store atomically stores val into x.
func (x *Int32) Store(val int32) { atomic.StoreInt32(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Int32) Swap(new int32) (old int32) { return atomic.SwapInt32(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Int32) CompareAndSwap(old, new int32) (swapped bool) {
return atomic.CompareAndSwapInt32(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Int32) Add(delta int32) (new int32) { return atomic.AddInt32(&x.v, delta) }
// An Int64 is an atomic int64. The zero value is zero.
type Int64 struct {
_ noCopy
v int64
}
// Load atomically loads and returns the value stored in x.
func (x *Int64) Load() int64 { return atomic.LoadInt64(&x.v) }
// Store atomically stores val into x.
func (x *Int64) Store(val int64) { atomic.StoreInt64(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Int64) Swap(new int64) (old int64) { return atomic.SwapInt64(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Int64) CompareAndSwap(old, new int64) (swapped bool) {
return atomic.CompareAndSwapInt64(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Int64) Add(delta int64) (new int64) { return atomic.AddInt64(&x.v, delta) }
// An Uint32 is an atomic uint32. The zero value is zero.
type Uint32 struct {
_ noCopy
v uint32
}
// Load atomically loads and returns the value stored in x.
func (x *Uint32) Load() uint32 { return atomic.LoadUint32(&x.v) }
// Store atomically stores val into x.
func (x *Uint32) Store(val uint32) { atomic.StoreUint32(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uint32) Swap(new uint32) (old uint32) { return atomic.SwapUint32(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint32) CompareAndSwap(old, new uint32) (swapped bool) {
return atomic.CompareAndSwapUint32(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uint32) Add(delta uint32) (new uint32) { return atomic.AddUint32(&x.v, delta) }
// An Uint64 is an atomic uint64. The zero value is zero.
type Uint64 struct {
_ noCopy
v uint64
}
// Load atomically loads and returns the value stored in x.
func (x *Uint64) Load() uint64 { return atomic.LoadUint64(&x.v) }
// Store atomically stores val into x.
func (x *Uint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
return atomic.CompareAndSwapUint64(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) }
// An Uintptr is an atomic uintptr. The zero value is zero.
type Uintptr struct {
_ noCopy
v uintptr
}
// Load atomically loads and returns the value stored in x.
func (x *Uintptr) Load() uintptr { return atomic.LoadUintptr(&x.v) }
// Store atomically stores val into x.
func (x *Uintptr) Store(val uintptr) { atomic.StoreUintptr(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uintptr) Swap(new uintptr) (old uintptr) { return atomic.SwapUintptr(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) {
return atomic.CompareAndSwapUintptr(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uintptr) Add(delta uintptr) (new uintptr) { return atomic.AddUintptr(&x.v, delta) }
// noCopy may be added to structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
//
// Note that it must not be embedded, due to the Lock and Unlock methods.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
type Value = atomic.Value

View file

@ -164,13 +164,13 @@ func (b *Buffer) WriteByte(d byte) error {
return nil
}
func (b *Buffer) ReadOnceFrom(r io.Reader) (int64, error) {
func (b *Buffer) ReadOnceFrom(r io.Reader) (int, error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n, err := r.Read(b.FreeBytes())
b.end += n
return int64(n), err
return n, err
}
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
@ -184,7 +184,8 @@ func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
if min <= 0 {
return b.ReadOnceFrom(r)
n, err := b.ReadOnceFrom(r)
return int64(n), err
}
if b.IsFull() {
return 0, io.ErrShortBuffer

View file

@ -0,0 +1,47 @@
package deadline
import (
"net"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
type Conn struct {
N.ExtendedConn
reader *Reader
}
func NewConn(conn net.Conn) *Conn {
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}
}
func (c *Conn) Read(p []byte) (n int, err error) {
return c.reader.Read(p)
}
func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
return c.reader.ReadBuffer(buffer)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.reader.SetReadDeadline(t)
}
func (c *Conn) ReaderReplaceable() bool {
return c.reader.ReaderReplaceable()
}
func (c *Conn) UpstreamReader() any {
return c.reader.UpstreamReader()
}
func (c *Conn) WriterReplaceable() bool {
return true
}
func (c *Conn) Upstream() any {
return c.ExtendedConn
}

View file

@ -0,0 +1,47 @@
package deadline
import (
"net"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type PacketConn struct {
N.NetPacketConn
reader *PacketReader
}
func NewPacketConn(conn N.NetPacketConn) *PacketConn {
return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}
}
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return c.reader.ReadFrom(p)
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
return c.reader.ReadPacket(buffer)
}
func (c *PacketConn) SetReadDeadline(t time.Time) error {
return c.NetPacketConn.SetReadDeadline(t)
}
func (c *PacketConn) ReaderReplaceable() bool {
return c.reader.ReaderReplaceable()
}
func (c *PacketConn) UpstreamReader() any {
return c.reader.UpstreamReader()
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}
func (c *PacketConn) Upstream() any {
return c.NetPacketConn
}

View file

@ -0,0 +1,152 @@
package deadline
import (
"net"
"os"
"sync"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TimeoutPacketReader interface {
N.NetPacketConn
SetReadDeadline(t time.Time) error
}
type PacketReader struct {
TimeoutPacketReader
deadline time.Time
disablePipe atomic.Bool
pipeDeadline pipeDeadline
cacheAccess sync.RWMutex
inRead atomic.Bool
cached bool
cachedBuffer *buf.Buffer
cachedAddr M.Socksaddr
cachedErr error
}
func NewPacketReader(reader TimeoutPacketReader) *PacketReader {
return &PacketReader{TimeoutPacketReader: reader}
}
func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() {
return r.TimeoutPacketReader.ReadFrom(p)
}
r.cacheAccess.Lock()
if r.cached {
n = copy(p, r.cachedBuffer.Bytes())
addr = r.cachedAddr.UDPAddr()
err = r.cachedErr
r.cachedBuffer.Release()
r.cached = false
r.cacheAccess.Unlock()
return
}
r.cacheAccess.Unlock()
done := make(chan struct{})
go func() {
n, addr, err = r.pipeReadFrom(p, r.pipeDeadline.wait())
close(done)
}()
select {
case <-done:
return
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
}
func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr net.Addr, err error) {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
cacheBuffer := buf.NewSize(len(p))
n, addr, err = r.TimeoutPacketReader.ReadFrom(cacheBuffer.Bytes())
if isClosedChan(cancel) {
r.cached = true
r.cachedBuffer = cacheBuffer
r.cachedAddr = M.SocksaddrFromNet(addr)
r.cachedErr = err
} else {
copy(p, cacheBuffer.Bytes())
cacheBuffer.Release()
}
return
}
func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() {
return r.TimeoutPacketReader.ReadPacket(buffer)
}
r.cacheAccess.Lock()
if r.cached {
destination = r.cachedAddr
err = r.cachedErr
buffer.Write(r.cachedBuffer.Bytes())
r.cachedBuffer.Release()
r.cached = false
r.cacheAccess.Unlock()
return
}
r.cacheAccess.Unlock()
done := make(chan struct{})
go func() {
destination, err = r.pipeReadPacket(buffer, r.pipeDeadline.wait())
close(done)
}()
select {
case <-done:
return
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, cancel chan struct{}) (destination M.Socksaddr, err error) {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
cacheBuffer := buf.NewSize(buffer.FreeLen())
destination, err = r.TimeoutPacketReader.ReadPacket(cacheBuffer)
if isClosedChan(cancel) {
r.cached = true
r.cachedBuffer = cacheBuffer
r.cachedAddr = destination
r.cachedErr = err
} else {
common.Must1(buffer.ReadOnceFrom(cacheBuffer))
cacheBuffer.Release()
}
return
}
func (r *PacketReader) SetReadDeadline(t time.Time) error {
r.deadline = t
r.pipeDeadline.set(t)
if r.disablePipe.Load() || !r.inRead.Load() {
r.disablePipe.Store(true)
return r.TimeoutPacketReader.SetReadDeadline(t)
}
return nil
}
func (r *PacketReader) ReaderReplaceable() bool {
return r.deadline.IsZero()
}
func (r *PacketReader) UpstreamReader() any {
return r.TimeoutPacketReader
}

View file

@ -0,0 +1,78 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deadline
import (
"sync"
"time"
)
// pipeDeadline is an abstraction for handling timeouts.
type pipeDeadline struct {
mu sync.Mutex // Guards timer and cancel
timer *time.Timer
cancel chan struct{} // Must be non-nil
}
func makePipeDeadline() pipeDeadline {
return pipeDeadline{cancel: make(chan struct{})}
}
// set sets the point in time when the deadline will time out.
// A timeout event is signaled by closing the channel returned by waiter.
// Once a timeout has occurred, the deadline can be refreshed by specifying a
// t value in the future.
//
// A zero value for t prevents timeout.
func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
d.timer = nil
// Time is zero, then there is no deadline.
closed := isClosedChan(d.cancel)
if t.IsZero() {
if closed {
d.cancel = make(chan struct{})
}
return
}
// Time in the future, setup a timer to cancel in the future.
if dur := time.Until(t); dur > 0 {
if closed {
d.cancel = make(chan struct{})
}
d.timer = time.AfterFunc(dur, func() {
close(d.cancel)
})
return
}
// Time in the past, so close immediately.
if !closed {
close(d.cancel)
}
}
// wait returns a channel that is closed when the deadline is exceeded.
func (d *pipeDeadline) wait() chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
return d.cancel
}
func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
return true
default:
return false
}
}

View file

@ -0,0 +1,157 @@
package deadline
import (
"io"
"os"
"sync"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
type TimeoutReader interface {
io.Reader
SetReadDeadline(t time.Time) error
}
type Reader struct {
N.ExtendedReader
timeoutReader TimeoutReader
deadline time.Time
disablePipe atomic.Bool
pipeDeadline pipeDeadline
cacheAccess sync.RWMutex
inRead atomic.Bool
cached bool
cachedBuffer *buf.Buffer
cachedErr error
}
func NewReader(reader TimeoutReader) *Reader {
return &Reader{ExtendedReader: bufio.NewExtendedReader(reader), timeoutReader: reader}
}
func (r *Reader) Read(p []byte) (n int, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() {
return r.ExtendedReader.Read(p)
}
r.cacheAccess.Lock()
if r.cached {
n = copy(p, r.cachedBuffer.Bytes())
err = r.cachedErr
r.cachedBuffer.Advance(n)
if r.cachedBuffer.IsEmpty() {
r.cachedBuffer.Release()
r.cached = false
}
r.cacheAccess.Unlock()
return
}
r.cacheAccess.Unlock()
done := make(chan struct{})
go func() {
n, err = r.pipeRead(p, r.pipeDeadline.wait())
close(done)
}()
select {
case <-done:
return
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
buffer := buf.NewSize(len(p))
n, err = buffer.ReadOnceFrom(r.ExtendedReader)
if isClosedChan(cancel) {
r.cached = true
r.cachedBuffer = buffer
r.cachedErr = err
} else {
n = copy(p, buffer.Bytes())
buffer.Release()
}
return
}
func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
if r.disablePipe.Load() || r.deadline.IsZero() {
return r.ExtendedReader.ReadBuffer(buffer)
}
r.cacheAccess.Lock()
if r.cached {
n := copy(buffer.FreeBytes(), r.cachedBuffer.Bytes())
err := r.cachedErr
buffer.Resize(buffer.Start(), n)
r.cachedBuffer.Advance(n)
if r.cachedBuffer.IsEmpty() {
r.cachedBuffer.Release()
r.cached = false
}
r.cacheAccess.Unlock()
return err
}
r.cacheAccess.Unlock()
done := make(chan struct{})
var err error
go func() {
err = r.pipeReadBuffer(buffer, r.pipeDeadline.wait())
close(done)
}()
select {
case <-done:
return err
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
}
}
func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, cancel chan struct{}) error {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
cacheBuffer := buf.NewSize(buffer.FreeLen())
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
if isClosedChan(cancel) {
r.cached = true
r.cachedBuffer = cacheBuffer
r.cachedErr = err
} else {
common.Must1(buffer.ReadOnceFrom(cacheBuffer))
cacheBuffer.Release()
}
return err
}
func (r *Reader) SetReadDeadline(t time.Time) error {
r.deadline = t
r.pipeDeadline.set(t)
if r.disablePipe.Load() || !r.inRead.Load() {
r.disablePipe.Store(true)
return r.timeoutReader.SetReadDeadline(t)
}
return nil
}
func (r *Reader) ReaderReplaceable() bool {
return r.disablePipe.Load() || r.deadline.IsZero()
}
func (r *Reader) UpstreamReader() any {
return r.ExtendedReader
}

View file

@ -14,6 +14,13 @@ type FallbackPacketConn struct {
N.PacketConn
}
func NewNetPacketConn(conn N.PacketConn) N.NetPacketConn {
if packetConn, loaded := conn.(N.NetPacketConn); loaded {
return packetConn
}
return &FallbackPacketConn{PacketConn: conn}
}
func (c *FallbackPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer := buf.With(p)
destination, err := c.ReadPacket(buffer)

View file

@ -21,10 +21,12 @@ type TimeoutPacketReader interface {
}
type NetPacketReader interface {
PacketReader
ReadFrom(p []byte) (n int, addr net.Addr, err error)
}
type NetPacketWriter interface {
PacketWriter
WriteTo(p []byte, addr net.Addr) (n int, err error)
}