From f8049ca89ba9a87149c8cfc185a51164715abf98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 13 Apr 2023 07:40:08 +0800 Subject: [PATCH] Fix deadline --- common/bufio/deadline/packet_conn.go | 4 -- common/bufio/deadline/packet_reader.go | 89 ++++++++++++++++--------- common/bufio/deadline/reader.go | 92 ++++++++++++++++---------- protocol/http/link.go | 2 +- 4 files changed, 115 insertions(+), 72 deletions(-) diff --git a/common/bufio/deadline/packet_conn.go b/common/bufio/deadline/packet_conn.go index 4c8faa2..2f7649f 100644 --- a/common/bufio/deadline/packet_conn.go +++ b/common/bufio/deadline/packet_conn.go @@ -34,10 +34,6 @@ func (c *PacketConn) ReaderReplaceable() bool { return c.reader.ReaderReplaceable() } -func (c *PacketConn) UpstreamReader() any { - return c.reader.UpstreamReader() -} - func (c *PacketConn) WriterReplaceable() bool { return true } diff --git a/common/bufio/deadline/packet_reader.go b/common/bufio/deadline/packet_reader.go index 27ede1f..5fd3f75 100644 --- a/common/bufio/deadline/packet_reader.go +++ b/common/bufio/deadline/packet_reader.go @@ -6,7 +6,6 @@ import ( "sync" "time" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" @@ -14,17 +13,17 @@ import ( ) type TimeoutPacketReader interface { - N.NetPacketConn + N.NetPacketReader SetReadDeadline(t time.Time) error } type PacketReader struct { TimeoutPacketReader deadline time.Time - disablePipe atomic.Bool pipeDeadline pipeDeadline - cacheAccess sync.RWMutex + disablePipe atomic.Bool inRead atomic.Bool + cacheAccess sync.RWMutex cached bool cachedBuffer *buf.Buffer cachedAddr M.Socksaddr @@ -36,8 +35,13 @@ func NewPacketReader(reader TimeoutPacketReader) *PacketReader { } func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if r.disablePipe.Load() || r.deadline.IsZero() { + if r.disablePipe.Load() { return r.TimeoutPacketReader.ReadFrom(p) + } else if r.deadline.IsZero() { + r.inRead.Store(true) + defer r.inRead.Store(false) + n, addr, err = r.TimeoutPacketReader.ReadFrom(p) + return } r.cacheAccess.Lock() if r.cached { @@ -51,28 +55,35 @@ func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } r.cacheAccess.Unlock() done := make(chan struct{}) + var access sync.Mutex + var cancel bool go func() { - n, addr, err = r.pipeReadFrom(p, r.pipeDeadline.wait()) - close(done) + n, addr, err = r.pipeReadFrom(p, &access, &cancel, done) }() select { case <-done: return case <-r.pipeDeadline.wait(): - return 0, nil, os.ErrDeadlineExceeded } + access.Lock() + defer access.Unlock() + select { + case <-done: + return + default: + } + cancel = true + return 0, nil, os.ErrDeadlineExceeded } -func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr net.Addr, err error) { +func (r *PacketReader) pipeReadFrom(p []byte, access *sync.Mutex, cancel *bool, done chan struct{}) (n int, addr net.Addr, err error) { r.cacheAccess.Lock() - r.inRead.Store(true) - defer func() { - r.inRead.Store(false) - r.cacheAccess.Unlock() - }() + defer r.cacheAccess.Unlock() cacheBuffer := buf.NewSize(len(p)) n, addr, err = r.TimeoutPacketReader.ReadFrom(cacheBuffer.Bytes()) - if isClosedChan(cancel) { + access.Lock() + defer access.Unlock() + if *cancel { r.cached = true r.cachedBuffer = cacheBuffer r.cachedAddr = M.SocksaddrFromNet(addr) @@ -81,12 +92,18 @@ func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr copy(p, cacheBuffer.Bytes()) cacheBuffer.Release() } + close(done) return } func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - if r.disablePipe.Load() || r.deadline.IsZero() { + if r.disablePipe.Load() { return r.TimeoutPacketReader.ReadPacket(buffer) + } else if r.deadline.IsZero() { + r.inRead.Store(true) + defer r.inRead.Store(false) + destination, err = r.TimeoutPacketReader.ReadPacket(buffer) + return } r.cacheAccess.Lock() if r.cached { @@ -100,51 +117,61 @@ func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, } r.cacheAccess.Unlock() done := make(chan struct{}) + var access sync.Mutex + var cancel bool go func() { - destination, err = r.pipeReadPacket(buffer, r.pipeDeadline.wait()) - close(done) + destination, err = r.pipeReadPacket(buffer, &access, &cancel, done) }() select { case <-done: return case <-r.pipeDeadline.wait(): - return M.Socksaddr{}, os.ErrDeadlineExceeded } + access.Lock() + defer access.Unlock() + select { + case <-done: + return + default: + } + cancel = true + return M.Socksaddr{}, os.ErrDeadlineExceeded } -func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, cancel chan struct{}) (destination M.Socksaddr, err error) { +func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, access *sync.Mutex, cancel *bool, done chan struct{}) (destination M.Socksaddr, err error) { r.cacheAccess.Lock() - r.inRead.Store(true) - defer func() { - r.inRead.Store(false) - r.cacheAccess.Unlock() - }() + defer r.cacheAccess.Unlock() cacheBuffer := buf.NewSize(buffer.FreeLen()) destination, err = r.TimeoutPacketReader.ReadPacket(cacheBuffer) - if isClosedChan(cancel) { + access.Lock() + defer access.Unlock() + if *cancel { r.cached = true r.cachedBuffer = cacheBuffer r.cachedAddr = destination r.cachedErr = err } else { - common.Must1(buffer.ReadOnceFrom(cacheBuffer)) + buffer.ReadOnceFrom(cacheBuffer) cacheBuffer.Release() } + close(done) return } func (r *PacketReader) SetReadDeadline(t time.Time) error { - r.deadline = t - r.pipeDeadline.set(t) - if r.disablePipe.Load() || !r.inRead.Load() { + if r.disablePipe.Load() { + return r.TimeoutPacketReader.SetReadDeadline(t) + } else if r.inRead.Load() { r.disablePipe.Store(true) return r.TimeoutPacketReader.SetReadDeadline(t) } + r.deadline = t + r.pipeDeadline.set(t) return nil } func (r *PacketReader) ReaderReplaceable() bool { - return r.deadline.IsZero() + return r.disablePipe.Load() || r.deadline.IsZero() } func (r *PacketReader) UpstreamReader() any { diff --git a/common/bufio/deadline/reader.go b/common/bufio/deadline/reader.go index fed7232..68f58a3 100644 --- a/common/bufio/deadline/reader.go +++ b/common/bufio/deadline/reader.go @@ -6,7 +6,6 @@ import ( "sync" "time" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -22,10 +21,10 @@ type Reader struct { N.ExtendedReader timeoutReader TimeoutReader deadline time.Time - disablePipe atomic.Bool pipeDeadline pipeDeadline - cacheAccess sync.RWMutex + disablePipe atomic.Bool inRead atomic.Bool + cacheAccess sync.RWMutex cached bool cachedBuffer *buf.Buffer cachedErr error @@ -36,8 +35,13 @@ func NewReader(reader TimeoutReader) *Reader { } func (r *Reader) Read(p []byte) (n int, err error) { - if r.disablePipe.Load() || r.deadline.IsZero() { + if r.disablePipe.Load() { return r.ExtendedReader.Read(p) + } else if r.deadline.IsZero() { + r.inRead.Store(true) + defer r.inRead.Store(false) + n, err = r.ExtendedReader.Read(p) + return } r.cacheAccess.Lock() if r.cached { @@ -53,29 +57,35 @@ func (r *Reader) Read(p []byte) (n int, err error) { } r.cacheAccess.Unlock() done := make(chan struct{}) + var access sync.Mutex + var cancel bool go func() { - n, err = r.pipeRead(p, r.pipeDeadline.wait()) - close(done) + n, err = r.pipeRead(p, &access, &cancel, done) }() select { case <-done: return case <-r.pipeDeadline.wait(): - return 0, os.ErrDeadlineExceeded } + access.Lock() + defer access.Unlock() + select { + case <-done: + return + default: + } + cancel = true + return 0, os.ErrDeadlineExceeded } -func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) { +func (r *Reader) pipeRead(p []byte, access *sync.Mutex, cancel *bool, done chan struct{}) (n int, err error) { r.cacheAccess.Lock() - r.inRead.Store(true) - defer func() { - r.inRead.Store(false) - r.cacheAccess.Unlock() - }() - + defer r.cacheAccess.Unlock() buffer := buf.NewSize(len(p)) n, err = buffer.ReadOnceFrom(r.ExtendedReader) - if isClosedChan(cancel) { + access.Lock() + defer access.Unlock() + if *cancel { r.cached = true r.cachedBuffer = buffer r.cachedErr = err @@ -83,11 +93,16 @@ func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) { n = copy(p, buffer.Bytes()) buffer.Release() } + close(done) return } func (r *Reader) ReadBuffer(buffer *buf.Buffer) error { - if r.disablePipe.Load() || r.deadline.IsZero() { + if r.disablePipe.Load() { + return r.ExtendedReader.ReadBuffer(buffer) + } else if r.deadline.IsZero() { + r.inRead.Store(true) + defer r.inRead.Store(false) return r.ExtendedReader.ReadBuffer(buffer) } r.cacheAccess.Lock() @@ -105,51 +120,56 @@ func (r *Reader) ReadBuffer(buffer *buf.Buffer) error { } r.cacheAccess.Unlock() done := make(chan struct{}) + var access sync.Mutex + var cancel bool var err error go func() { - err = r.pipeReadBuffer(buffer, r.pipeDeadline.wait()) - close(done) + err = r.pipeReadBuffer(buffer, &access, &cancel, done) }() select { case <-done: return err case <-r.pipeDeadline.wait(): - return os.ErrDeadlineExceeded } + access.Lock() + defer access.Unlock() + select { + case <-done: + return err + default: + } + cancel = true + return os.ErrDeadlineExceeded } -func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, cancel chan struct{}) error { +func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, access *sync.Mutex, cancel *bool, done chan struct{}) error { r.cacheAccess.Lock() - r.inRead.Store(true) - defer func() { - r.inRead.Store(false) - r.cacheAccess.Unlock() - }() + defer r.cacheAccess.Unlock() cacheBuffer := buf.NewSize(buffer.FreeLen()) err := r.ExtendedReader.ReadBuffer(cacheBuffer) - if isClosedChan(cancel) { + access.Lock() + defer access.Unlock() + if *cancel { r.cached = true r.cachedBuffer = cacheBuffer r.cachedErr = err } else { - common.Must1(buffer.ReadOnceFrom(cacheBuffer)) + buffer.ReadOnceFrom(cacheBuffer) cacheBuffer.Release() } + close(done) return err } func (r *Reader) SetReadDeadline(t time.Time) error { + if r.disablePipe.Load() { + return r.timeoutReader.SetReadDeadline(t) + } else if r.inRead.Load() { + r.disablePipe.Store(true) + return r.timeoutReader.SetReadDeadline(t) + } r.deadline = t r.pipeDeadline.set(t) - if r.disablePipe.Load() || !r.inRead.Load() { - err := r.timeoutReader.SetReadDeadline(t) - if err == os.ErrInvalid { - return nil - } else { - r.disablePipe.Store(true) - } - return err - } return nil } diff --git a/protocol/http/link.go b/protocol/http/link.go index ed8804c..19cb6cf 100644 --- a/protocol/http/link.go +++ b/protocol/http/link.go @@ -11,4 +11,4 @@ import ( func ReadRequest(b *bufio.Reader) (req *http.Request, err error) //go:linkname URLSetPath net/url.(*URL).setPath -func URLSetPath(u *url.URL, p string) error \ No newline at end of file +func URLSetPath(u *url.URL, p string) error