From cb26be3e2aa19921c0b8b9b895a3e5bf462e86bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 23:55:54 +0800 Subject: [PATCH] Implement read waiter for UDP --- go.mod | 3 +-- go.sum | 5 ++-- hysteria/packet.go | 48 +++++++++------------------------ hysteria/packet_wait.go | 37 ++++++++++++++++++++++++++ hysteria2/packet.go | 48 +++++++++------------------------ hysteria2/packet_wait.go | 37 ++++++++++++++++++++++++++ tuic/packet.go | 57 ++++++++++++++-------------------------- tuic/packet_wait.go | 37 ++++++++++++++++++++++++++ 8 files changed, 161 insertions(+), 111 deletions(-) create mode 100644 hysteria/packet_wait.go create mode 100644 hysteria2/packet_wait.go create mode 100644 tuic/packet_wait.go diff --git a/go.mod b/go.mod index 9578e08..785c141 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( github.com/gofrs/uuid/v5 v5.0.0 github.com/sagernet/quic-go v0.40.0 - github.com/sagernet/sing v0.2.20 + github.com/sagernet/sing v0.3.0-rc.2 golang.org/x/crypto v0.17.0 golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 ) @@ -16,7 +16,6 @@ require ( github.com/onsi/ginkgo/v2 v2.9.7 // indirect github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qtls-go1-20 v0.4.1 // indirect - github.com/stretchr/testify v1.8.4 // indirect golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index de3dedc..9e355ca 100644 --- a/go.sum +++ b/go.sum @@ -25,10 +25,11 @@ github.com/sagernet/quic-go v0.40.0 h1:DvQNPb72lzvNQDe9tcUyHTw8eRv6PLtM2mNYmdlzU github.com/sagernet/quic-go v0.40.0/go.mod h1:VqtdhlbkeeG5Okhb3eDMb/9o0EoglReHunNT9ukrJAI= github.com/sagernet/sing v0.2.20 h1:ckcCB/5xu8G8wElNeH74IF6Soac5xWN+eQUXRuonjPQ= github.com/sagernet/sing v0.2.20/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= +github.com/sagernet/sing v0.3.0-rc.2 h1:l5rq+bTrNhpAPd2Vjzi/sEhil4O6Bb1CKv6LdPLJKug= +github.com/sagernet/sing v0.3.0-rc.2/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= diff --git a/hysteria/packet.go b/hysteria/packet.go index 9d148d8..daa6208 100644 --- a/hysteria/packet.go +++ b/hysteria/packet.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" ) @@ -118,17 +119,18 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { } type udpPacketConn struct { - ctx context.Context - cancel common.ContextCancelCauseFunc - sessionID uint32 - quicConn quic.Connection - data chan *udpMessage - udpMTU int - udpMTUTime time.Time - packetId atomic.Uint32 - closeOnce sync.Once - defragger *udpDefragger - onDestroy func() + ctx context.Context + cancel common.ContextCancelCauseFunc + sessionID uint32 + quicConn quic.Connection + data chan *udpMessage + udpMTU int + udpMTUTime time.Time + packetId atomic.Uint32 + closeOnce sync.Once + defragger *udpDefragger + onDestroy func() + readWaitOptions N.ReadWaitOptions } func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { @@ -143,18 +145,6 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f } } -func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case p := <-c.data: - buffer = p.data - destination = M.ParseSocksaddrHostPort(p.host, p.port) - p.release() - return - case <-c.ctx.Done(): - return nil, M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { select { case p := <-c.data: @@ -167,18 +157,6 @@ func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, } } -func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = newBuffer().ReadOnceFrom(p.data) - destination = M.ParseSocksaddrHostPort(p.host, p.port) - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { select { case pkt := <-c.data: diff --git a/hysteria/packet_wait.go b/hysteria/packet_wait.go new file mode 100644 index 0000000..f1c54ff --- /dev/null +++ b/hysteria/packet_wait.go @@ -0,0 +1,37 @@ +package hysteria + +import ( + "io" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return options.NeedHeadroom() +} + +func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + destination = M.ParseSocksaddrHostPort(p.host, p.port) + if c.readWaitOptions.NeedHeadroom() { + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.Write(p.data.Bytes()) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + p.releaseMessage() + c.readWaitOptions.PostReturn(buffer) + } else { + buffer = p.data + p.release() + } + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +} diff --git a/hysteria2/packet.go b/hysteria2/packet.go index 286bf3b..f77ab33 100644 --- a/hysteria2/packet.go +++ b/hysteria2/packet.go @@ -20,6 +20,7 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/cache" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) var udpMessagePool = sync.Pool{ @@ -114,17 +115,18 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { } type udpPacketConn struct { - ctx context.Context - cancel common.ContextCancelCauseFunc - sessionID uint32 - quicConn quic.Connection - data chan *udpMessage - udpMTU int - udpMTUTime time.Time - packetId atomic.Uint32 - closeOnce sync.Once - defragger *udpDefragger - onDestroy func() + ctx context.Context + cancel common.ContextCancelCauseFunc + sessionID uint32 + quicConn quic.Connection + data chan *udpMessage + udpMTU int + udpMTUTime time.Time + packetId atomic.Uint32 + closeOnce sync.Once + defragger *udpDefragger + onDestroy func() + readWaitOptions N.ReadWaitOptions } func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { @@ -139,18 +141,6 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy f } } -func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case p := <-c.data: - buffer = p.data - destination = M.ParseSocksaddr(p.destination) - p.release() - return - case <-c.ctx.Done(): - return nil, M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { select { case p := <-c.data: @@ -163,18 +153,6 @@ func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, } } -func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = newBuffer().ReadOnceFrom(p.data) - destination = M.ParseSocksaddr(p.destination) - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { select { case pkt := <-c.data: diff --git a/hysteria2/packet_wait.go b/hysteria2/packet_wait.go new file mode 100644 index 0000000..bbaa296 --- /dev/null +++ b/hysteria2/packet_wait.go @@ -0,0 +1,37 @@ +package hysteria2 + +import ( + "io" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return options.NeedHeadroom() +} + +func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + destination = M.ParseSocksaddr(p.destination) + if c.readWaitOptions.NeedHeadroom() { + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.Write(p.data.Bytes()) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + p.releaseMessage() + c.readWaitOptions.PostReturn(buffer) + } else { + buffer = p.data + p.release() + } + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +} diff --git a/tuic/packet.go b/tuic/packet.go index 8b52de2..ecc15fe 100644 --- a/tuic/packet.go +++ b/tuic/packet.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) var udpMessagePool = sync.Pool{ @@ -114,20 +115,26 @@ func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { return fragments } +var ( + _ N.NetPacketConn = (*udpPacketConn)(nil) + _ N.PacketReadWaiter = (*udpPacketConn)(nil) +) + type udpPacketConn struct { - ctx context.Context - cancel common.ContextCancelCauseFunc - sessionID uint16 - quicConn quic.Connection - data chan *udpMessage - udpStream bool - udpMTU int - udpMTUTime time.Time - packetId atomic.Uint32 - closeOnce sync.Once - isServer bool - defragger *udpDefragger - onDestroy func() + ctx context.Context + cancel common.ContextCancelCauseFunc + sessionID uint16 + quicConn quic.Connection + data chan *udpMessage + udpStream bool + udpMTU int + udpMTUTime time.Time + packetId atomic.Uint32 + closeOnce sync.Once + isServer bool + defragger *udpDefragger + onDestroy func() + readWaitOptions N.ReadWaitOptions } func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream bool, isServer bool, onDestroy func()) *udpPacketConn { @@ -144,18 +151,6 @@ func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, udpStream b } } -func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case p := <-c.data: - buffer = p.data - destination = p.destination - p.release() - return - case <-c.ctx.Done(): - return nil, M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { select { case p := <-c.data: @@ -168,18 +163,6 @@ func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, } } -func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case p := <-c.data: - _, err = newBuffer().ReadOnceFrom(p.data) - destination = p.destination - p.releaseMessage() - return - case <-c.ctx.Done(): - return M.Socksaddr{}, io.ErrClosedPipe - } -} - func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { select { case pkt := <-c.data: diff --git a/tuic/packet_wait.go b/tuic/packet_wait.go new file mode 100644 index 0000000..88f3066 --- /dev/null +++ b/tuic/packet_wait.go @@ -0,0 +1,37 @@ +package tuic + +import ( + "io" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return options.NeedHeadroom() +} + +func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + destination = p.destination + if c.readWaitOptions.NeedHeadroom() { + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.Write(p.data.Bytes()) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + p.releaseMessage() + c.readWaitOptions.PostReturn(buffer) + } else { + buffer = p.data + p.release() + } + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +}