From 5c6fb54965a1470a60c021b8069045d3fdec3030 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 3 Aug 2023 14:57:58 +0800 Subject: [PATCH] Fix UDP async write --- client_conn.go | 47 ++++++++++++++++++++++++++++++++++++++++------- server_conn.go | 45 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/client_conn.go b/client_conn.go index 9a82621..2ccfc0b 100644 --- a/client_conn.go +++ b/client_conn.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "net" + "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -87,6 +88,7 @@ func (c *clientConn) Upstream() any { type clientPacketConn struct { N.ExtendedConn + access sync.Mutex destination M.Socksaddr requestWritten bool responseRead bool @@ -150,7 +152,13 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { func (c *clientPacketConn) Write(b []byte) (n int, err error) { if !c.requestWritten { - return c.writeRequest(b) + c.access.Lock() + if c.requestWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + return c.writeRequest(b) + } } err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b))) if err != nil { @@ -178,8 +186,14 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error { if !c.requestWritten { - defer buffer.Release() - return common.Error(c.writeRequest(buffer.Bytes())) + c.access.Lock() + if c.requestWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + defer buffer.Release() + return common.Error(c.writeRequest(buffer.Bytes())) + } } bLen := buffer.Len() binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) @@ -212,7 +226,13 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !c.requestWritten { - return c.writeRequest(p) + c.access.Lock() + if c.requestWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + return c.writeRequest(p) + } } err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) if err != nil { @@ -250,6 +270,7 @@ var _ N.NetPacketConn = (*clientPacketAddrConn)(nil) type clientPacketAddrConn struct { N.ExtendedConn + access sync.Mutex destination M.Socksaddr requestWritten bool responseRead bool @@ -325,7 +346,13 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !c.requestWritten { - return c.writeRequest(p, M.SocksaddrFromNet(addr)) + c.access.Lock() + if c.requestWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + return c.writeRequest(p, M.SocksaddrFromNet(addr)) + } } err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) if err != nil { @@ -361,8 +388,14 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { if !c.requestWritten { - defer buffer.Release() - return common.Error(c.writeRequest(buffer.Bytes(), destination)) + c.access.Lock() + if c.requestWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + defer buffer.Release() + return common.Error(c.writeRequest(buffer.Bytes(), destination)) + } } bLen := buffer.Len() header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) diff --git a/server_conn.go b/server_conn.go index fa7b4e6..70a9689 100644 --- a/server_conn.go +++ b/server_conn.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "net" + "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -79,6 +80,7 @@ var ( type serverPacketConn struct { N.ExtendedConn + access sync.Mutex destination M.Socksaddr responseWritten bool } @@ -112,6 +114,12 @@ func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad pLen := buffer.Len() common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) if !c.responseWritten { + c.access.Lock() + if c.responseWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + } buffer.ExtendHeader(1)[0] = statusSuccess c.responseWritten = true } @@ -133,9 +141,16 @@ func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !c.responseWritten { - _, err = c.ExtendedConn.Write([]byte{statusSuccess}) - if err != nil { - return + c.access.Lock() + if c.responseWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + _, err = c.ExtendedConn.Write([]byte{statusSuccess}) + if err != nil { + return + } + c.responseWritten = true } } err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) @@ -167,6 +182,7 @@ var ( type serverPacketAddrConn struct { N.ExtendedConn + access sync.Mutex responseWritten bool } @@ -205,9 +221,16 @@ func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err func (c *serverPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if !c.responseWritten { - _, err = c.ExtendedConn.Write([]byte{statusSuccess}) - if err != nil { - return + c.access.Lock() + if c.responseWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + _, err = c.ExtendedConn.Write([]byte{statusSuccess}) + if err != nil { + return + } + c.responseWritten = true } } err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) @@ -243,8 +266,14 @@ func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) if !c.responseWritten { - buffer.ExtendHeader(1)[0] = statusSuccess - c.responseWritten = true + c.access.Lock() + if c.responseWritten { + c.access.Unlock() + } else { + defer c.access.Unlock() + buffer.ExtendHeader(1)[0] = statusSuccess + c.responseWritten = true + } } return c.ExtendedConn.WriteBuffer(buffer) }