mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-06 05:17:38 +03:00
Fix deadline reader
This commit is contained in:
parent
2dcabf4bfc
commit
bca74039ea
7 changed files with 153 additions and 77 deletions
|
@ -14,18 +14,18 @@ type Conn struct {
|
|||
reader Reader
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn) *Conn {
|
||||
func NewConn(conn net.Conn) N.ExtendedConn {
|
||||
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}
|
||||
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)})
|
||||
}
|
||||
|
||||
func NewFallbackConn(conn net.Conn) *Conn {
|
||||
func NewFallbackConn(conn net.Conn) N.ExtendedConn {
|
||||
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)}
|
||||
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)})
|
||||
}
|
||||
|
||||
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
|
|
|
@ -14,18 +14,18 @@ type PacketConn struct {
|
|||
reader PacketReader
|
||||
}
|
||||
|
||||
func NewPacketConn(conn N.NetPacketConn) *PacketConn {
|
||||
func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}
|
||||
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)})
|
||||
}
|
||||
|
||||
func NewFallbackPacketConn(conn N.NetPacketConn) *PacketConn {
|
||||
func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
|
||||
return deadlineConn
|
||||
}
|
||||
return &PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)}
|
||||
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)})
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
|
|
|
@ -52,14 +52,13 @@ func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadFrom(len(p))
|
||||
default:
|
||||
}
|
||||
return r.readFrom(p)
|
||||
}
|
||||
|
||||
func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
|
@ -106,14 +105,13 @@ func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadFromBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeReadFrom(buffer.FreeLen())
|
||||
}
|
||||
return r.readPacket(buffer)
|
||||
}
|
||||
|
||||
func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
|
@ -134,17 +132,6 @@ func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *bu
|
|||
}
|
||||
}
|
||||
|
||||
func (r *packetReader) pipeReadFromBuffer(pLen int) {
|
||||
buffer := buf.NewSize(pLen)
|
||||
destination, err := r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
r.result <- &packetReadResult{
|
||||
buffer: buffer,
|
||||
destination: destination,
|
||||
err: err,
|
||||
}
|
||||
r.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (r *packetReader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline.Store(t)
|
||||
r.pipeDeadline.set(t)
|
||||
|
|
|
@ -2,6 +2,7 @@ package deadline
|
|||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
|
@ -25,12 +26,15 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|||
return r.pipeReturnFrom(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.ReadFrom(p)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
|
@ -38,9 +42,13 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
|
|||
return
|
||||
}
|
||||
go r.pipeReadFrom(len(p))
|
||||
default:
|
||||
}
|
||||
return r.readFrom(p)
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFrom(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
|
@ -49,22 +57,29 @@ func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Soc
|
|||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
|
||||
return
|
||||
}
|
||||
go r.pipeReadFromBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeReadFrom(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnFromBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return M.Socksaddr{}, os.ErrDeadlineExceeded
|
||||
}
|
||||
return r.readPacket(buffer)
|
||||
}
|
||||
|
||||
func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {
|
||||
|
|
|
@ -54,14 +54,13 @@ func (r *reader) Read(p []byte) (n int, err error) {
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeRead(len(p))
|
||||
default:
|
||||
}
|
||||
return r.read(p)
|
||||
}
|
||||
|
||||
func (r *reader) read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
|
@ -99,14 +98,13 @@ func (r *reader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
go r.pipeReadBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeRead(buffer.FreeLen())
|
||||
}
|
||||
return r.readBuffer(buffer)
|
||||
}
|
||||
|
||||
func (r *reader) readBuffer(buffer *buf.Buffer) error {
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
|
@ -127,16 +125,6 @@ func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error
|
|||
}
|
||||
}
|
||||
|
||||
func (r *reader) pipeReadBuffer(pLen int) {
|
||||
cacheBuffer := buf.NewSize(pLen)
|
||||
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
|
||||
r.result <- &readResult{
|
||||
buffer: cacheBuffer,
|
||||
err: err,
|
||||
}
|
||||
r.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (r *reader) SetReadDeadline(t time.Time) error {
|
||||
r.deadline.Store(t)
|
||||
r.pipeDeadline.set(t)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
|
@ -23,12 +24,15 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
|
|||
return r.pipeReturn(result, p)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.ExtendedReader.Read(p)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
|
@ -36,9 +40,13 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
go r.pipeRead(len(p))
|
||||
default:
|
||||
}
|
||||
return r.reader.read(p)
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturn(result, p)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
|
||||
|
@ -47,21 +55,28 @@ func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
|
|||
return r.pipeReturnBuffer(result, buffer)
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
case <-r.done:
|
||||
if r.disablePipe.Load() {
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
}
|
||||
select {
|
||||
case <-r.done:
|
||||
if r.deadline.Load().IsZero() {
|
||||
} else if r.deadline.Load().IsZero() {
|
||||
r.done <- struct{}{}
|
||||
r.inRead.Store(true)
|
||||
defer r.inRead.Store(false)
|
||||
return r.ExtendedReader.ReadBuffer(buffer)
|
||||
}
|
||||
go r.pipeReadBuffer(buffer.FreeLen())
|
||||
default:
|
||||
go r.pipeRead(buffer.FreeLen())
|
||||
}
|
||||
select {
|
||||
case result := <-r.result:
|
||||
return r.pipeReturnBuffer(result, buffer)
|
||||
case <-r.pipeDeadline.wait():
|
||||
return os.ErrDeadlineExceeded
|
||||
}
|
||||
return r.readBuffer(buffer)
|
||||
}
|
||||
|
||||
func (r *fallbackReader) SetReadDeadline(t time.Time) error {
|
||||
|
|
71
common/bufio/deadline/serial.go
Normal file
71
common/bufio/deadline/serial.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package deadline
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/debug"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type SerialConn struct {
|
||||
N.ExtendedConn
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn {
|
||||
if !debug.Enabled {
|
||||
return conn
|
||||
}
|
||||
return &SerialConn{ExtendedConn: conn}
|
||||
}
|
||||
|
||||
func (c *SerialConn) Read(p []byte) (n int, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.ExtendedConn.Read(p)
|
||||
}
|
||||
|
||||
func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.ExtendedConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
type SerialPacketConn struct {
|
||||
N.NetPacketConn
|
||||
access sync.Mutex
|
||||
}
|
||||
|
||||
func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn {
|
||||
if !debug.Enabled {
|
||||
return conn
|
||||
}
|
||||
return &SerialPacketConn{NetPacketConn: conn}
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.NetPacketConn.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if !c.access.TryLock() {
|
||||
panic("concurrent read on deadline conn")
|
||||
}
|
||||
defer c.access.Unlock()
|
||||
return c.NetPacketConn.ReadPacket(buffer)
|
||||
}
|
||||
|
||||
func (c *SerialPacketConn) Upstream() any {
|
||||
return c.NetPacketConn
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue