From 284cb5ce98bfbdec223fdbe42a84871780fdd4ce Mon Sep 17 00:00:00 2001 From: dyhkwong <50692134+dyhkwong@users.noreply.github.com> Date: Fri, 17 May 2024 20:28:11 +0800 Subject: [PATCH] Fix socks5 packet conn --- common/bufio/nat.go | 8 +++ protocol/socks/client.go | 3 +- protocol/socks/packet.go | 91 +++++++++++++---------------- protocol/socks/packet_vectorised.go | 18 ++---- protocol/socks/packet_wait.go | 10 ++-- 5 files changed, 59 insertions(+), 71 deletions(-) diff --git a/common/bufio/nat.go b/common/bufio/nat.go index 6069344..cafeb06 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -63,6 +63,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } +func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + func (c *unidirectionalNATPacketConn) Upstream() any { return c.NetPacketConn } @@ -136,6 +140,10 @@ func (c *bidirectionalNATPacketConn) Upstream() any { return c.NetPacketConn } +func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr { destination.Port = 0 return destination diff --git a/protocol/socks/client.go b/protocol/socks/client.go index fd0a34d..6004f83 100644 --- a/protocol/socks/client.go +++ b/protocol/socks/client.go @@ -7,7 +7,6 @@ import ( "os" "strings" - "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -148,7 +147,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock tcpConn.Close() return nil, err } - return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil + return NewAssociatePacketConn(udpConn, address, tcpConn), nil } return nil, os.ErrInvalid } diff --git a/protocol/socks/packet.go b/protocol/socks/packet.go index 21860ea..130f0ac 100644 --- a/protocol/socks/packet.go +++ b/protocol/socks/packet.go @@ -21,54 +21,41 @@ import ( var ErrInvalidPacket = E.New("socks5: invalid packet") type AssociatePacketConn struct { - N.NetPacketConn + N.AbstractConn + conn N.ExtendedConn remoteAddr M.Socksaddr underlying net.Conn } -func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn { +func NewAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn { return &AssociatePacketConn{ - NetPacketConn: bufio.NewPacketConn(conn), - remoteAddr: remoteAddr, - underlying: underlying, + AbstractConn: conn, + conn: bufio.NewExtendedConn(conn), + remoteAddr: remoteAddr, + underlying: underlying, } } -// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead. -func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn { - return &AssociatePacketConn{ - NetPacketConn: bufio.NewUnbindPacketConn(conn), - remoteAddr: remoteAddr, - underlying: underlying, - } -} - -func (c *AssociatePacketConn) RemoteAddr() net.Addr { - return c.remoteAddr.UDPAddr() -} - -//warn:unsafe func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, addr, err = c.NetPacketConn.ReadFrom(p) + n, err = c.conn.Read(p) if err != nil { return } if n < 3 { return 0, nil, ErrInvalidPacket } - c.remoteAddr = M.SocksaddrFromNet(addr) reader := bytes.NewReader(p[3:n]) destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) if err != nil { return } + c.remoteAddr = destination addr = destination.UDPAddr() index := 3 + int(reader.Size()) - reader.Len() n = copy(p, p[index:n]) return } -//warn:unsafe func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { destination := M.SocksaddrFromNet(addr) buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) @@ -82,7 +69,34 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error if err != nil { return } - return bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr) + return c.conn.Write(buffer.Bytes()) +} + +func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + err = c.conn.ReadBuffer(buffer) + if err != nil { + return + } + if buffer.Len() < 3 { + return M.Socksaddr{}, ErrInvalidPacket + } + buffer.Advance(3) + destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) + if err != nil { + return + } + c.remoteAddr = destination + return destination.Unwrap(), nil +} + +func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination))) + common.Must(header.WriteZeroN(3)) + err := M.SocksaddrSerializer.WriteAddrPort(header, destination) + if err != nil { + return err + } + return c.conn.WriteBuffer(buffer) } func (c *AssociatePacketConn) Read(b []byte) (n int, err error) { @@ -94,35 +108,12 @@ func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { return c.WriteTo(b, c.remoteAddr) } -func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - destination, err = c.NetPacketConn.ReadPacket(buffer) - if err != nil { - return M.Socksaddr{}, err - } - if buffer.Len() < 3 { - return M.Socksaddr{}, ErrInvalidPacket - } - c.remoteAddr = destination - buffer.Advance(3) - destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) - if err != nil { - return - } - return destination.Unwrap(), nil -} - -func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination))) - common.Must(header.WriteZeroN(3)) - err := M.SocksaddrSerializer.WriteAddrPort(header, destination) - if err != nil { - return err - } - return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)) +func (c *AssociatePacketConn) RemoteAddr() net.Addr { + return c.remoteAddr.UDPAddr() } func (c *AssociatePacketConn) Upstream() any { - return c.NetPacketConn + return c.conn } func (c *AssociatePacketConn) FrontHeadroom() int { @@ -131,7 +122,7 @@ func (c *AssociatePacketConn) FrontHeadroom() int { func (c *AssociatePacketConn) Close() error { return common.Close( - c.NetPacketConn, + c.conn, c.underlying, ) } diff --git a/protocol/socks/packet_vectorised.go b/protocol/socks/packet_vectorised.go index c73286b..fb71728 100644 --- a/protocol/socks/packet_vectorised.go +++ b/protocol/socks/packet_vectorised.go @@ -17,23 +17,13 @@ type VectorisedAssociatePacketConn struct { N.VectorisedPacketWriter } -func NewVectorisedAssociatePacketConn(conn net.PacketConn, writer N.VectorisedPacketWriter, remoteAddr M.Socksaddr, underlying net.Conn) *VectorisedAssociatePacketConn { - return &VectorisedAssociatePacketConn{ - AssociatePacketConn{ - NetPacketConn: bufio.NewPacketConn(conn), - remoteAddr: remoteAddr, - underlying: underlying, - }, - writer, - } -} - func NewVectorisedAssociateConn(conn net.Conn, writer N.VectorisedWriter, remoteAddr M.Socksaddr, underlying net.Conn) *VectorisedAssociatePacketConn { return &VectorisedAssociatePacketConn{ AssociatePacketConn{ - NetPacketConn: bufio.NewUnbindPacketConn(conn), - remoteAddr: remoteAddr, - underlying: underlying, + AbstractConn: conn, + conn: bufio.NewExtendedConn(conn), + remoteAddr: remoteAddr, + underlying: underlying, }, &bufio.UnbindVectorisedPacketWriter{VectorisedWriter: writer}, } diff --git a/protocol/socks/packet_wait.go b/protocol/socks/packet_wait.go index 32b1341..ecf673a 100644 --- a/protocol/socks/packet_wait.go +++ b/protocol/socks/packet_wait.go @@ -10,7 +10,7 @@ import ( var _ N.PacketReadWaitCreator = (*AssociatePacketConn)(nil) func (c *AssociatePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { - readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(c.NetPacketConn) + readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.conn) if !isReadWaiter { return nil, false } @@ -21,7 +21,7 @@ var _ N.PacketReadWaiter = (*AssociatePacketReadWaiter)(nil) type AssociatePacketReadWaiter struct { conn *AssociatePacketConn - readWaiter N.PacketReadWaiter + readWaiter N.ReadWaiter } func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { @@ -29,7 +29,7 @@ func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptio } func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - buffer, destination, err = w.readWaiter.WaitReadPacket() + buffer, err = w.readWaiter.WaitReadBuffer() if err != nil { return } @@ -37,12 +37,12 @@ func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destin buffer.Release() return nil, M.Socksaddr{}, ErrInvalidPacket } - w.conn.remoteAddr = destination buffer.Advance(3) destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) if err != nil { buffer.Release() - return nil, M.Socksaddr{}, err + return } + w.conn.remoteAddr = destination return }