Fix deadline

This commit is contained in:
世界 2023-04-13 07:40:08 +08:00
parent d88db59703
commit f8049ca89b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 115 additions and 72 deletions

View file

@ -34,10 +34,6 @@ func (c *PacketConn) ReaderReplaceable() bool {
return c.reader.ReaderReplaceable() return c.reader.ReaderReplaceable()
} }
func (c *PacketConn) UpstreamReader() any {
return c.reader.UpstreamReader()
}
func (c *PacketConn) WriterReplaceable() bool { func (c *PacketConn) WriterReplaceable() bool {
return true return true
} }

View file

@ -6,7 +6,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -14,17 +13,17 @@ import (
) )
type TimeoutPacketReader interface { type TimeoutPacketReader interface {
N.NetPacketConn N.NetPacketReader
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
} }
type PacketReader struct { type PacketReader struct {
TimeoutPacketReader TimeoutPacketReader
deadline time.Time deadline time.Time
disablePipe atomic.Bool
pipeDeadline pipeDeadline pipeDeadline pipeDeadline
cacheAccess sync.RWMutex disablePipe atomic.Bool
inRead atomic.Bool inRead atomic.Bool
cacheAccess sync.RWMutex
cached bool cached bool
cachedBuffer *buf.Buffer cachedBuffer *buf.Buffer
cachedAddr M.Socksaddr cachedAddr M.Socksaddr
@ -36,8 +35,13 @@ func NewPacketReader(reader TimeoutPacketReader) *PacketReader {
} }
func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() { if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadFrom(p) return r.TimeoutPacketReader.ReadFrom(p)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
n, addr, err = r.TimeoutPacketReader.ReadFrom(p)
return
} }
r.cacheAccess.Lock() r.cacheAccess.Lock()
if r.cached { if r.cached {
@ -51,28 +55,35 @@ func (r *PacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
r.cacheAccess.Unlock() r.cacheAccess.Unlock()
done := make(chan struct{}) done := make(chan struct{})
var access sync.Mutex
var cancel bool
go func() { go func() {
n, addr, err = r.pipeReadFrom(p, r.pipeDeadline.wait()) n, addr, err = r.pipeReadFrom(p, &access, &cancel, done)
close(done)
}() }()
select { select {
case <-done: case <-done:
return return
case <-r.pipeDeadline.wait(): case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
} }
access.Lock()
defer access.Unlock()
select {
case <-done:
return
default:
}
cancel = true
return 0, nil, os.ErrDeadlineExceeded
} }
func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr net.Addr, err error) { func (r *PacketReader) pipeReadFrom(p []byte, access *sync.Mutex, cancel *bool, done chan struct{}) (n int, addr net.Addr, err error) {
r.cacheAccess.Lock() r.cacheAccess.Lock()
r.inRead.Store(true) defer r.cacheAccess.Unlock()
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
cacheBuffer := buf.NewSize(len(p)) cacheBuffer := buf.NewSize(len(p))
n, addr, err = r.TimeoutPacketReader.ReadFrom(cacheBuffer.Bytes()) n, addr, err = r.TimeoutPacketReader.ReadFrom(cacheBuffer.Bytes())
if isClosedChan(cancel) { access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true r.cached = true
r.cachedBuffer = cacheBuffer r.cachedBuffer = cacheBuffer
r.cachedAddr = M.SocksaddrFromNet(addr) r.cachedAddr = M.SocksaddrFromNet(addr)
@ -81,12 +92,18 @@ func (r *PacketReader) pipeReadFrom(p []byte, cancel chan struct{}) (n int, addr
copy(p, cacheBuffer.Bytes()) copy(p, cacheBuffer.Bytes())
cacheBuffer.Release() cacheBuffer.Release()
} }
close(done)
return return
} }
func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() { if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadPacket(buffer) return r.TimeoutPacketReader.ReadPacket(buffer)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
return
} }
r.cacheAccess.Lock() r.cacheAccess.Lock()
if r.cached { if r.cached {
@ -100,51 +117,61 @@ func (r *PacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
} }
r.cacheAccess.Unlock() r.cacheAccess.Unlock()
done := make(chan struct{}) done := make(chan struct{})
var access sync.Mutex
var cancel bool
go func() { go func() {
destination, err = r.pipeReadPacket(buffer, r.pipeDeadline.wait()) destination, err = r.pipeReadPacket(buffer, &access, &cancel, done)
close(done)
}() }()
select { select {
case <-done: case <-done:
return return
case <-r.pipeDeadline.wait(): case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
} }
access.Lock()
defer access.Unlock()
select {
case <-done:
return
default:
}
cancel = true
return M.Socksaddr{}, os.ErrDeadlineExceeded
} }
func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, cancel chan struct{}) (destination M.Socksaddr, err error) { func (r *PacketReader) pipeReadPacket(buffer *buf.Buffer, access *sync.Mutex, cancel *bool, done chan struct{}) (destination M.Socksaddr, err error) {
r.cacheAccess.Lock() r.cacheAccess.Lock()
r.inRead.Store(true) defer r.cacheAccess.Unlock()
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
cacheBuffer := buf.NewSize(buffer.FreeLen()) cacheBuffer := buf.NewSize(buffer.FreeLen())
destination, err = r.TimeoutPacketReader.ReadPacket(cacheBuffer) destination, err = r.TimeoutPacketReader.ReadPacket(cacheBuffer)
if isClosedChan(cancel) { access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true r.cached = true
r.cachedBuffer = cacheBuffer r.cachedBuffer = cacheBuffer
r.cachedAddr = destination r.cachedAddr = destination
r.cachedErr = err r.cachedErr = err
} else { } else {
common.Must1(buffer.ReadOnceFrom(cacheBuffer)) buffer.ReadOnceFrom(cacheBuffer)
cacheBuffer.Release() cacheBuffer.Release()
} }
close(done)
return return
} }
func (r *PacketReader) SetReadDeadline(t time.Time) error { func (r *PacketReader) SetReadDeadline(t time.Time) error {
r.deadline = t if r.disablePipe.Load() {
r.pipeDeadline.set(t) return r.TimeoutPacketReader.SetReadDeadline(t)
if r.disablePipe.Load() || !r.inRead.Load() { } else if r.inRead.Load() {
r.disablePipe.Store(true) r.disablePipe.Store(true)
return r.TimeoutPacketReader.SetReadDeadline(t) return r.TimeoutPacketReader.SetReadDeadline(t)
} }
r.deadline = t
r.pipeDeadline.set(t)
return nil return nil
} }
func (r *PacketReader) ReaderReplaceable() bool { func (r *PacketReader) ReaderReplaceable() bool {
return r.deadline.IsZero() return r.disablePipe.Load() || r.deadline.IsZero()
} }
func (r *PacketReader) UpstreamReader() any { func (r *PacketReader) UpstreamReader() any {

View file

@ -6,7 +6,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
@ -22,10 +21,10 @@ type Reader struct {
N.ExtendedReader N.ExtendedReader
timeoutReader TimeoutReader timeoutReader TimeoutReader
deadline time.Time deadline time.Time
disablePipe atomic.Bool
pipeDeadline pipeDeadline pipeDeadline pipeDeadline
cacheAccess sync.RWMutex disablePipe atomic.Bool
inRead atomic.Bool inRead atomic.Bool
cacheAccess sync.RWMutex
cached bool cached bool
cachedBuffer *buf.Buffer cachedBuffer *buf.Buffer
cachedErr error cachedErr error
@ -36,8 +35,13 @@ func NewReader(reader TimeoutReader) *Reader {
} }
func (r *Reader) Read(p []byte) (n int, err error) { func (r *Reader) Read(p []byte) (n int, err error) {
if r.disablePipe.Load() || r.deadline.IsZero() { if r.disablePipe.Load() {
return r.ExtendedReader.Read(p) return r.ExtendedReader.Read(p)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
n, err = r.ExtendedReader.Read(p)
return
} }
r.cacheAccess.Lock() r.cacheAccess.Lock()
if r.cached { if r.cached {
@ -53,29 +57,35 @@ func (r *Reader) Read(p []byte) (n int, err error) {
} }
r.cacheAccess.Unlock() r.cacheAccess.Unlock()
done := make(chan struct{}) done := make(chan struct{})
var access sync.Mutex
var cancel bool
go func() { go func() {
n, err = r.pipeRead(p, r.pipeDeadline.wait()) n, err = r.pipeRead(p, &access, &cancel, done)
close(done)
}() }()
select { select {
case <-done: case <-done:
return return
case <-r.pipeDeadline.wait(): case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
} }
access.Lock()
defer access.Unlock()
select {
case <-done:
return
default:
}
cancel = true
return 0, os.ErrDeadlineExceeded
} }
func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) { func (r *Reader) pipeRead(p []byte, access *sync.Mutex, cancel *bool, done chan struct{}) (n int, err error) {
r.cacheAccess.Lock() r.cacheAccess.Lock()
r.inRead.Store(true) defer r.cacheAccess.Unlock()
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
buffer := buf.NewSize(len(p)) buffer := buf.NewSize(len(p))
n, err = buffer.ReadOnceFrom(r.ExtendedReader) n, err = buffer.ReadOnceFrom(r.ExtendedReader)
if isClosedChan(cancel) { access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true r.cached = true
r.cachedBuffer = buffer r.cachedBuffer = buffer
r.cachedErr = err r.cachedErr = err
@ -83,11 +93,16 @@ func (r *Reader) pipeRead(p []byte, cancel chan struct{}) (n int, err error) {
n = copy(p, buffer.Bytes()) n = copy(p, buffer.Bytes())
buffer.Release() buffer.Release()
} }
close(done)
return return
} }
func (r *Reader) ReadBuffer(buffer *buf.Buffer) error { func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
if r.disablePipe.Load() || r.deadline.IsZero() { if r.disablePipe.Load() {
return r.ExtendedReader.ReadBuffer(buffer)
} else if r.deadline.IsZero() {
r.inRead.Store(true)
defer r.inRead.Store(false)
return r.ExtendedReader.ReadBuffer(buffer) return r.ExtendedReader.ReadBuffer(buffer)
} }
r.cacheAccess.Lock() r.cacheAccess.Lock()
@ -105,51 +120,56 @@ func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
} }
r.cacheAccess.Unlock() r.cacheAccess.Unlock()
done := make(chan struct{}) done := make(chan struct{})
var access sync.Mutex
var cancel bool
var err error var err error
go func() { go func() {
err = r.pipeReadBuffer(buffer, r.pipeDeadline.wait()) err = r.pipeReadBuffer(buffer, &access, &cancel, done)
close(done)
}() }()
select { select {
case <-done: case <-done:
return err return err
case <-r.pipeDeadline.wait(): case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
} }
access.Lock()
defer access.Unlock()
select {
case <-done:
return err
default:
}
cancel = true
return os.ErrDeadlineExceeded
} }
func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, cancel chan struct{}) error { func (r *Reader) pipeReadBuffer(buffer *buf.Buffer, access *sync.Mutex, cancel *bool, done chan struct{}) error {
r.cacheAccess.Lock() r.cacheAccess.Lock()
r.inRead.Store(true) defer r.cacheAccess.Unlock()
defer func() {
r.inRead.Store(false)
r.cacheAccess.Unlock()
}()
cacheBuffer := buf.NewSize(buffer.FreeLen()) cacheBuffer := buf.NewSize(buffer.FreeLen())
err := r.ExtendedReader.ReadBuffer(cacheBuffer) err := r.ExtendedReader.ReadBuffer(cacheBuffer)
if isClosedChan(cancel) { access.Lock()
defer access.Unlock()
if *cancel {
r.cached = true r.cached = true
r.cachedBuffer = cacheBuffer r.cachedBuffer = cacheBuffer
r.cachedErr = err r.cachedErr = err
} else { } else {
common.Must1(buffer.ReadOnceFrom(cacheBuffer)) buffer.ReadOnceFrom(cacheBuffer)
cacheBuffer.Release() cacheBuffer.Release()
} }
close(done)
return err return err
} }
func (r *Reader) SetReadDeadline(t time.Time) error { func (r *Reader) SetReadDeadline(t time.Time) error {
if r.disablePipe.Load() {
return r.timeoutReader.SetReadDeadline(t)
} else if r.inRead.Load() {
r.disablePipe.Store(true)
return r.timeoutReader.SetReadDeadline(t)
}
r.deadline = t r.deadline = t
r.pipeDeadline.set(t) r.pipeDeadline.set(t)
if r.disablePipe.Load() || !r.inRead.Load() {
err := r.timeoutReader.SetReadDeadline(t)
if err == os.ErrInvalid {
return nil
} else {
r.disablePipe.Store(true)
}
return err
}
return nil return nil
} }

View file

@ -11,4 +11,4 @@ import (
func ReadRequest(b *bufio.Reader) (req *http.Request, err error) func ReadRequest(b *bufio.Reader) (req *http.Request, err error)
//go:linkname URLSetPath net/url.(*URL).setPath //go:linkname URLSetPath net/url.(*URL).setPath
func URLSetPath(u *url.URL, p string) error func URLSetPath(u *url.URL, p string) error