Fix Fix socks5 packet conn

This commit is contained in:
世界 2024-05-21 15:08:19 +08:00
parent f67a0988a6
commit de1b0bd772
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 92 additions and 17 deletions

View file

@ -39,7 +39,7 @@ func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
if !isReadWaiter {
return nil, false
}
return &BindPacketReadWaiter{readWaiter}, true
return &bindPacketReadWaiter{readWaiter}, true
}
func (c *bindPacketConn) RemoteAddr() net.Addr {
@ -104,9 +104,62 @@ func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
if !isReadWaiter {
return nil, false
}
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
}
func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
return &serverPacketConn{
NetPacketConn: NewPacketConn(conn),
}
}
type serverPacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
}
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
n, addr, err := c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
c.remoteAddr = M.SocksaddrFromNet(addr)
return
}
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
destination, err := c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return err
}
c.remoteAddr = destination
return nil
}
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
}
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
}
func (c *serverPacketConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

View file

@ -6,33 +6,33 @@ import (
N "github.com/sagernet/sing/common/network"
)
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil)
var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
type BindPacketReadWaiter struct {
type bindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}
func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket()
return
}
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil)
var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
type UnbindPacketReadWaiter struct {
type unbindPacketReadWaiter struct {
readWaiter N.ReadWaiter
addr M.Socksaddr
}
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil {
return
@ -40,3 +40,23 @@ func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destinati
destination = w.addr
return
}
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
type serverPacketReadWaiter struct {
*serverPacketConn
readWaiter N.PacketReadWaiter
}
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, destination, err := w.readWaiter.WaitReadPacket()
if err != nil {
return
}
w.remoteAddr = destination
return
}

View file

@ -154,15 +154,16 @@ type WriterWithMTU interface {
func CalculateMTU(reader any, writer any) int {
readerMTU := calculateReaderMTU(reader)
readerHeadroom := calculateReaderFrontHeadroom(reader)
writerMTU := calculateWriterMTU(writer)
if readerMTU > writerMTU {
return readerMTU + readerHeadroom
}
if writerMTU > buf.BufferSize {
if readerMTU == 0 && writerMTU == 0 || readerMTU > buf.BufferSize || writerMTU > buf.BufferSize {
return 0
}
return writerMTU + readerHeadroom
readerHeadroom := calculateReaderFrontHeadroom(reader)
if readerMTU > writerMTU {
return readerMTU + readerHeadroom
} else {
return writerMTU + readerHeadroom
}
}
func calculateReaderMTU(reader any) int {

View file

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@ -219,7 +220,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
metadata.Destination = request.Destination
var innerError error
done := make(chan struct{})
associatePacketConn := NewAssociatePacketConn(udpConn, request.Destination, conn)
associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn)
go func() {
innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata)
close(done)