Fix deadline reader

This commit is contained in:
世界 2023-12-01 12:21:23 +08:00
parent 2dcabf4bfc
commit bca74039ea
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
7 changed files with 153 additions and 77 deletions

View file

@ -14,18 +14,18 @@ type Conn struct {
reader Reader
}
func NewConn(conn net.Conn) *Conn {
func NewConn(conn net.Conn) N.ExtendedConn {
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
return deadlineConn
}
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)}
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewReader(conn)})
}
func NewFallbackConn(conn net.Conn) *Conn {
func NewFallbackConn(conn net.Conn) N.ExtendedConn {
if deadlineConn, isDeadline := conn.(*Conn); isDeadline {
return deadlineConn
}
return &Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)}
return NewSerialConn(&Conn{ExtendedConn: bufio.NewExtendedConn(conn), reader: NewFallbackReader(conn)})
}
func (c *Conn) Read(p []byte) (n int, err error) {

View file

@ -14,18 +14,18 @@ type PacketConn struct {
reader PacketReader
}
func NewPacketConn(conn N.NetPacketConn) *PacketConn {
func NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
return deadlineConn
}
return &PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)}
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewPacketReader(conn)})
}
func NewFallbackPacketConn(conn N.NetPacketConn) *PacketConn {
func NewFallbackPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if deadlineConn, isDeadline := conn.(*PacketConn); isDeadline {
return deadlineConn
}
return &PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)}
return NewSerialPacketConn(&PacketConn{NetPacketConn: conn, reader: NewFallbackPacketReader(conn)})
}
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {

View file

@ -52,14 +52,13 @@ func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
default:
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadFrom(len(p))
default:
}
return r.readFrom(p)
}
func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
@ -106,14 +105,13 @@ func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr,
default:
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadFromBuffer(buffer.FreeLen())
default:
go r.pipeReadFrom(buffer.FreeLen())
}
return r.readPacket(buffer)
}
func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
@ -134,17 +132,6 @@ func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *bu
}
}
func (r *packetReader) pipeReadFromBuffer(pLen int) {
buffer := buf.NewSize(pLen)
destination, err := r.TimeoutPacketReader.ReadPacket(buffer)
r.result <- &packetReadResult{
buffer: buffer,
destination: destination,
err: err,
}
r.done <- struct{}{}
}
func (r *packetReader) SetReadDeadline(t time.Time) error {
r.deadline.Store(t)
r.pipeDeadline.set(t)

View file

@ -2,6 +2,7 @@ package deadline
import (
"net"
"os"
"time"
"github.com/sagernet/sing/common/atomic"
@ -25,12 +26,15 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
return r.pipeReturnFrom(result, p)
default:
}
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadFrom(p)
}
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
case <-r.done:
if r.deadline.Load().IsZero() {
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadFrom(p)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
@ -38,9 +42,13 @@ func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err err
return
}
go r.pipeReadFrom(len(p))
default:
}
return r.readFrom(p)
select {
case result := <-r.result:
return r.pipeReturnFrom(result, p)
case <-r.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
}
func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
@ -49,22 +57,29 @@ func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Soc
return r.pipeReturnFromBuffer(result, buffer)
default:
}
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadPacket(buffer)
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
case <-r.done:
if r.deadline.Load().IsZero() {
if r.disablePipe.Load() {
return r.TimeoutPacketReader.ReadPacket(buffer)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
return
}
go r.pipeReadFromBuffer(buffer.FreeLen())
default:
go r.pipeReadFrom(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnFromBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
return r.readPacket(buffer)
}
func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {

View file

@ -54,14 +54,13 @@ func (r *reader) Read(p []byte) (n int, err error) {
default:
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
case <-r.done:
go r.pipeRead(len(p))
default:
}
return r.read(p)
}
func (r *reader) read(p []byte) (n int, err error) {
select {
case result := <-r.result:
return r.pipeReturn(result, p)
@ -99,14 +98,13 @@ func (r *reader) ReadBuffer(buffer *buf.Buffer) error {
default:
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
case <-r.done:
go r.pipeReadBuffer(buffer.FreeLen())
default:
go r.pipeRead(buffer.FreeLen())
}
return r.readBuffer(buffer)
}
func (r *reader) readBuffer(buffer *buf.Buffer) error {
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
@ -127,16 +125,6 @@ func (r *reader) pipeReturnBuffer(result *readResult, buffer *buf.Buffer) error
}
}
func (r *reader) pipeReadBuffer(pLen int) {
cacheBuffer := buf.NewSize(pLen)
err := r.ExtendedReader.ReadBuffer(cacheBuffer)
r.result <- &readResult{
buffer: cacheBuffer,
err: err,
}
r.done <- struct{}{}
}
func (r *reader) SetReadDeadline(t time.Time) error {
r.deadline.Store(t)
r.pipeDeadline.set(t)

View file

@ -1,6 +1,7 @@
package deadline
import (
"os"
"time"
"github.com/sagernet/sing/common/atomic"
@ -23,12 +24,15 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
return r.pipeReturn(result, p)
default:
}
if r.disablePipe.Load() {
return r.ExtendedReader.Read(p)
}
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
case <-r.done:
if r.deadline.Load().IsZero() {
if r.disablePipe.Load() {
return r.ExtendedReader.Read(p)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
@ -36,9 +40,13 @@ func (r *fallbackReader) Read(p []byte) (n int, err error) {
return
}
go r.pipeRead(len(p))
default:
}
return r.reader.read(p)
select {
case result := <-r.result:
return r.pipeReturn(result, p)
case <-r.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
@ -47,21 +55,28 @@ func (r *fallbackReader) ReadBuffer(buffer *buf.Buffer) error {
return r.pipeReturnBuffer(result, buffer)
default:
}
if r.disablePipe.Load() {
return r.ExtendedReader.ReadBuffer(buffer)
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
case <-r.done:
if r.deadline.Load().IsZero() {
if r.disablePipe.Load() {
return r.ExtendedReader.ReadBuffer(buffer)
} else if r.deadline.Load().IsZero() {
r.done <- struct{}{}
r.inRead.Store(true)
defer r.inRead.Store(false)
return r.ExtendedReader.ReadBuffer(buffer)
}
go r.pipeReadBuffer(buffer.FreeLen())
default:
go r.pipeRead(buffer.FreeLen())
}
select {
case result := <-r.result:
return r.pipeReturnBuffer(result, buffer)
case <-r.pipeDeadline.wait():
return os.ErrDeadlineExceeded
}
return r.readBuffer(buffer)
}
func (r *fallbackReader) SetReadDeadline(t time.Time) error {

View file

@ -0,0 +1,71 @@
package deadline
import (
"net"
"sync"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/debug"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type SerialConn struct {
N.ExtendedConn
access sync.Mutex
}
func NewSerialConn(conn N.ExtendedConn) N.ExtendedConn {
if !debug.Enabled {
return conn
}
return &SerialConn{ExtendedConn: conn}
}
func (c *SerialConn) Read(p []byte) (n int, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.ExtendedConn.Read(p)
}
func (c *SerialConn) ReadBuffer(buffer *buf.Buffer) error {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.ExtendedConn.ReadBuffer(buffer)
}
type SerialPacketConn struct {
N.NetPacketConn
access sync.Mutex
}
func NewSerialPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if !debug.Enabled {
return conn
}
return &SerialPacketConn{NetPacketConn: conn}
}
func (c *SerialPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.NetPacketConn.ReadFrom(p)
}
func (c *SerialPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if !c.access.TryLock() {
panic("concurrent read on deadline conn")
}
defer c.access.Unlock()
return c.NetPacketConn.ReadPacket(buffer)
}
func (c *SerialPacketConn) Upstream() any {
return c.NetPacketConn
}