Fix UDP async write

This commit is contained in:
世界 2023-08-03 14:57:58 +08:00
parent b5da22fad2
commit 5c6fb54965
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 77 additions and 15 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/binary"
"io"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@ -87,6 +88,7 @@ func (c *clientConn) Upstream() any {
type clientPacketConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
@ -150,7 +152,13 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
func (c *clientPacketConn) Write(b []byte) (n int, err error) {
if !c.requestWritten {
return c.writeRequest(b)
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(b)
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
if err != nil {
@ -178,8 +186,14 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
if !c.requestWritten {
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes()))
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes()))
}
}
bLen := buffer.Len()
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
@ -212,7 +226,13 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten {
return c.writeRequest(p)
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(p)
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
if err != nil {
@ -250,6 +270,7 @@ var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
type clientPacketAddrConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
requestWritten bool
responseRead bool
@ -325,7 +346,13 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten {
return c.writeRequest(p, M.SocksaddrFromNet(addr))
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(p, M.SocksaddrFromNet(addr))
}
}
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
if err != nil {
@ -361,8 +388,14 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.requestWritten {
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes(), destination))
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes(), destination))
}
}
bLen := buffer.Len()
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))

View file

@ -4,6 +4,7 @@ import (
"encoding/binary"
"io"
"net"
"sync"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
@ -79,6 +80,7 @@ var (
type serverPacketConn struct {
N.ExtendedConn
access sync.Mutex
destination M.Socksaddr
responseWritten bool
}
@ -112,6 +114,12 @@ func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
pLen := buffer.Len()
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
if !c.responseWritten {
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
}
buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true
}
@ -133,9 +141,16 @@ func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.responseWritten {
_, err = c.ExtendedConn.Write([]byte{statusSuccess})
if err != nil {
return
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
_, err = c.ExtendedConn.Write([]byte{statusSuccess})
if err != nil {
return
}
c.responseWritten = true
}
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
@ -167,6 +182,7 @@ var (
type serverPacketAddrConn struct {
N.ExtendedConn
access sync.Mutex
responseWritten bool
}
@ -205,9 +221,16 @@ func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
func (c *serverPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.responseWritten {
_, err = c.ExtendedConn.Write([]byte{statusSuccess})
if err != nil {
return
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
_, err = c.ExtendedConn.Write([]byte{statusSuccess})
if err != nil {
return
}
c.responseWritten = true
}
}
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
@ -243,8 +266,14 @@ func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination))
if !c.responseWritten {
buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true
}
}
return c.ExtendedConn.WriteBuffer(buffer)
}