Fix deadline

This commit is contained in:
世界 2023-04-13 07:40:08 +08:00
parent d88db59703
commit f8049ca89b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 115 additions and 72 deletions

View file

@ -34,10 +34,6 @@ func (c *PacketConn) ReaderReplaceable() bool {
return c.reader.ReaderReplaceable()
}
func (c *PacketConn) UpstreamReader() any {
return c.reader.UpstreamReader()
}
func (c *PacketConn) WriterReplaceable() bool {
return true
}

View file

@ -6,7 +6,6 @@ import (
"sync"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
@ -14,17 +13,17 @@ import (
)
type TimeoutPacketReader interface {
N.NetPacketConn
N.NetPacketReader
SetReadDeadline(t time.Time) error
}
type PacketReader struct {
TimeoutPacketReader
deadline time.Time
disablePipe atomic.Bool
pipeDeadline pipeDeadline
cacheAccess sync.RWMutex
disablePipe atomic.Bool
inRead atomic.Bool
cacheAccess sync.RWMutex
cached bool
cachedBuffer *buf.Buffer
cachedAddr M.Socksaddr
@ -36,8 +35,13 @@ func NewPacketReader(reader TimeoutPacketReader) *PacketReader {
}
func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() {
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadFrom(p)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
n, addr, err = r.TimeoutPacketReader.ReadFrom(p)
return
}
r.cacheAccess.Lock()
if r.cached {
@ -51,28 +55,35 @@ func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}
r.cacheAccess.Unlock()
done := make(chan struct{})
var access sync.Mutex
var cancel bool
go func() {
n, addr, err = r.pipeReadFrom(p, r.pipeDeadline.wait())
close(done)
n, addr, err = r.pipeReadFrom(p, &access, &cancel, done)
}()
select {
case <-done:
return
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
access.Lock()
defer access.Unlock()
select {
case <-done:
return
default:
}
cancel = true
return 0, nil, os.ErrDeadlineExceeded
}
func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr net.Addr, err error) {
func (r *PacketReader) pipeReadFrom(p []byte, access *sync.Mutex, cancel *bool, done chan struct{}) (n int, addr net.Addr, err error) {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
defer r.cacheAccess.Unlock()
cacheBuffer := buf.NewSize(len(p))
n, addr, err = r.TimeoutPacketReader.ReadFrom(cacheBuffer.Bytes())
if isClosedChan(cancel) {
access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true
r.cachedBuffer = cacheBuffer
r.cachedAddr = M.SocksaddrFromNet(addr)
@ -81,12 +92,18 @@ func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr
copy(p, cacheBuffer.Bytes())
cacheBuffer.Release()
}
close(done)
return
}
func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() {
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadPacket(buffer)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
return
}
r.cacheAccess.Lock()
if r.cached {
@ -100,51 +117,61 @@ func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
}
r.cacheAccess.Unlock()
done := make(chan struct{})
var access sync.Mutex
var cancel bool
go func() {
destination, err = r.pipeReadPacket(buffer, r.pipeDeadline.wait())
close(done)
destination, err = r.pipeReadPacket(buffer, &access, &cancel, done)
}()
select {
case <-done:
return
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
access.Lock()
defer access.Unlock()
select {
case <-done:
return
default:
}
cancel = true
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, cancel chan struct{}) (destination M.Socksaddr, err error) {
func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, access *sync.Mutex, cancel *bool, done chan struct{}) (destination M.Socksaddr, err error) {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
defer r.cacheAccess.Unlock()
cacheBuffer := buf.NewSize(buffer.FreeLen())
destination, err = r.TimeoutPacketReader.ReadPacket(cacheBuffer)
if isClosedChan(cancel) {
access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true
r.cachedBuffer = cacheBuffer
r.cachedAddr = destination
r.cachedErr = err
} else {
common.Must1(buffer.ReadOnceFrom(cacheBuffer))
buffer.ReadOnceFrom(cacheBuffer)
cacheBuffer.Release()
}
close(done)
return
}
func (r *PacketReader) SetReadDeadline(t time.Time) error {
r.deadline = t
r.pipeDeadline.set(t)
if r.disablePipe.Load() || !r.inRead.Load() {
if r.disablePipe.Load() {
return r.TimeoutPacketReader.SetReadDeadline(t)
} else if r.inRead.Load() {
r.disablePipe.Store(true)
return r.TimeoutPacketReader.SetReadDeadline(t)
}
r.deadline = t
r.pipeDeadline.set(t)
return nil
}
func (r *PacketReader) ReaderReplaceable() bool {
return r.deadline.IsZero()
return r.disablePipe.Load() || r.deadline.IsZero()
}
func (r *PacketReader) UpstreamReader() any {

View file

@ -6,7 +6,6 @@ import (
"sync"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@ -22,10 +21,10 @@ type Reader struct {
N.ExtendedReader
timeoutReader TimeoutReader
deadline time.Time
disablePipe atomic.Bool
pipeDeadline pipeDeadline
cacheAccess sync.RWMutex
disablePipe atomic.Bool
inRead atomic.Bool
cacheAccess sync.RWMutex
cached bool
cachedBuffer *buf.Buffer
cachedErr error
@ -36,8 +35,13 @@ func NewReader(reader TimeoutReader) *Reader {
}
func (r *Reader) Read(p []byte) (n int, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() {
if r.disablePipe.Load() {
return r.ExtendedReader.Read(p)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
n, err = r.ExtendedReader.Read(p)
return
}
r.cacheAccess.Lock()
if r.cached {
@ -53,29 +57,35 @@ func (r *Reader) Read(p []byte) (n int, err error) {
}
r.cacheAccess.Unlock()
done := make(chan struct{})
var access sync.Mutex
var cancel bool
go func() {
n, err = r.pipeRead(p, r.pipeDeadline.wait())
close(done)
n, err = r.pipeRead(p, &access, &cancel, done)
}()
select {
case <-done:
return
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
access.Lock()
defer access.Unlock()
select {
case <-done:
return
default:
}
cancel = true
return 0, os.ErrDeadlineExceeded
}
func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) {
func (r *Reader) pipeRead(p []byte, access *sync.Mutex, cancel *bool, done chan struct{}) (n int, err error) {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
defer r.cacheAccess.Unlock()
buffer := buf.NewSize(len(p))
n, err = buffer.ReadOnceFrom(r.ExtendedReader)
if isClosedChan(cancel) {
access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true
r.cachedBuffer = buffer
r.cachedErr = err
@ -83,11 +93,16 @@ func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) {
n = copy(p, buffer.Bytes())
buffer.Release()
}
close(done)
return
}
func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
if r.disablePipe.Load() || r.deadline.IsZero() {
if r.disablePipe.Load() {
return r.ExtendedReader.ReadBuffer(buffer)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
return r.ExtendedReader.ReadBuffer(buffer)
}
r.cacheAccess.Lock()
@ -105,51 +120,56 @@ func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
}
r.cacheAccess.Unlock()
done := make(chan struct{})
var access sync.Mutex
var cancel bool
var err error
go func() {
err = r.pipeReadBuffer(buffer, r.pipeDeadline.wait())
close(done)
err = r.pipeReadBuffer(buffer, &access, &cancel, done)
}()
select {
case <-done:
return err
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
}
access.Lock()
defer access.Unlock()
select {
case <-done:
return err
default:
}
cancel = true
return os.ErrDeadlineExceeded
}
func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, cancel chan struct{}) error {
func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, access *sync.Mutex, cancel *bool, done chan struct{}) error {
r.cacheAccess.Lock()
r.inRead.Store(true)
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
defer r.cacheAccess.Unlock()
cacheBuffer := buf.NewSize(buffer.FreeLen())
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
if isClosedChan(cancel) {
access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true
r.cachedBuffer = cacheBuffer
r.cachedErr = err
} else {
common.Must1(buffer.ReadOnceFrom(cacheBuffer))
buffer.ReadOnceFrom(cacheBuffer)
cacheBuffer.Release()
}
close(done)
return err
}
func (r *Reader) SetReadDeadline(t time.Time) error {
if r.disablePipe.Load() {
return r.timeoutReader.SetReadDeadline(t)
} else if r.inRead.Load() {
r.disablePipe.Store(true)
return r.timeoutReader.SetReadDeadline(t)
}
r.deadline = t
r.pipeDeadline.set(t)
if r.disablePipe.Load() || !r.inRead.Load() {
err := r.timeoutReader.SetReadDeadline(t)
if err == os.ErrInvalid {
return nil
} else {
r.disablePipe.Store(true)
}
return err
}
return nil
}

View file

@ -11,4 +11,4 @@ import (
func ReadRequest(b *bufio.Reader) (req *http.Request, err error)
//go:linkname URLSetPath net/url.(*URL).setPath
func URLSetPath(u *url.URL, p string) error
func URLSetPath(u *url.URL, p string) error