Fix UDP async write

This commit is contained in:
世界 2023-08-03 14:57:58 +08:00
parent b5da22fad2
commit 5c6fb54965
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 77 additions and 15 deletions

View file

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
"sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -87,6 +88,7 @@ func (c *clientConn) Upstream() any {
type clientPacketConn struct { type clientPacketConn struct {
N.ExtendedConn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr destination M.Socksaddr
requestWritten bool requestWritten bool
responseRead bool responseRead bool
@ -150,8 +152,14 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
func (c *clientPacketConn) Write(b []byte) (n int, err error) { func (c *clientPacketConn) Write(b []byte) (n int, err error) {
if !c.requestWritten { if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(b) return c.writeRequest(b)
} }
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b))) err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
if err != nil { if err != nil {
return return
@ -178,9 +186,15 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error { func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
if !c.requestWritten { if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
defer buffer.Release() defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes())) return common.Error(c.writeRequest(buffer.Bytes()))
} }
}
bLen := buffer.Len() bLen := buffer.Len()
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
return c.ExtendedConn.WriteBuffer(buffer) return c.ExtendedConn.WriteBuffer(buffer)
@ -212,8 +226,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten { if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(p) return c.writeRequest(p)
} }
}
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
if err != nil { if err != nil {
return return
@ -250,6 +270,7 @@ var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
type clientPacketAddrConn struct { type clientPacketAddrConn struct {
N.ExtendedConn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr destination M.Socksaddr
requestWritten bool requestWritten bool
responseRead bool responseRead bool
@ -325,8 +346,14 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa
func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.requestWritten { if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
return c.writeRequest(p, M.SocksaddrFromNet(addr)) return c.writeRequest(p, M.SocksaddrFromNet(addr))
} }
}
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
if err != nil { if err != nil {
return return
@ -361,9 +388,15 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc
func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if !c.requestWritten { if !c.requestWritten {
c.access.Lock()
if c.requestWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
defer buffer.Release() defer buffer.Release()
return common.Error(c.writeRequest(buffer.Bytes(), destination)) return common.Error(c.writeRequest(buffer.Bytes(), destination))
} }
}
bLen := buffer.Len() bLen := buffer.Len()
header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
common.Must( common.Must(

View file

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
"sync"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -79,6 +80,7 @@ var (
type serverPacketConn struct { type serverPacketConn struct {
N.ExtendedConn N.ExtendedConn
access sync.Mutex
destination M.Socksaddr destination M.Socksaddr
responseWritten bool responseWritten bool
} }
@ -112,6 +114,12 @@ func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad
pLen := buffer.Len() pLen := buffer.Len()
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
if !c.responseWritten { if !c.responseWritten {
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
}
buffer.ExtendHeader(1)[0] = statusSuccess buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true c.responseWritten = true
} }
@ -133,10 +141,17 @@ func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error)
func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.responseWritten { if !c.responseWritten {
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
_, err = c.ExtendedConn.Write([]byte{statusSuccess}) _, err = c.ExtendedConn.Write([]byte{statusSuccess})
if err != nil { if err != nil {
return return
} }
c.responseWritten = true
}
} }
err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
if err != nil { if err != nil {
@ -167,6 +182,7 @@ var (
type serverPacketAddrConn struct { type serverPacketAddrConn struct {
N.ExtendedConn N.ExtendedConn
access sync.Mutex
responseWritten bool responseWritten bool
} }
@ -205,10 +221,17 @@ func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
func (c *serverPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *serverPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if !c.responseWritten { if !c.responseWritten {
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
_, err = c.ExtendedConn.Write([]byte{statusSuccess}) _, err = c.ExtendedConn.Write([]byte{statusSuccess})
if err != nil { if err != nil {
return return
} }
c.responseWritten = true
}
} }
err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
if err != nil { if err != nil {
@ -243,9 +266,15 @@ func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc
common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination))
if !c.responseWritten { if !c.responseWritten {
c.access.Lock()
if c.responseWritten {
c.access.Unlock()
} else {
defer c.access.Unlock()
buffer.ExtendHeader(1)[0] = statusSuccess buffer.ExtendHeader(1)[0] = statusSuccess
c.responseWritten = true c.responseWritten = true
} }
}
return c.ExtendedConn.WriteBuffer(buffer) return c.ExtendedConn.WriteBuffer(buffer)
} }