Fix Fix socks5 packet conn

This commit is contained in:
世界 2024-05-21 15:08:19 +08:00
parent f67a0988a6
commit de1b0bd772
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 92 additions and 17 deletions

View file

@ -39,7 +39,7 @@ func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
if !isReadWaiter { if !isReadWaiter {
return nil, false return nil, false
} }
return &BindPacketReadWaiter{readWaiter}, true return &bindPacketReadWaiter{readWaiter}, true
} }
func (c *bindPacketConn) RemoteAddr() net.Addr { func (c *bindPacketConn) RemoteAddr() net.Addr {
@ -104,9 +104,62 @@ func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
if !isReadWaiter { if !isReadWaiter {
return nil, false return nil, false
} }
return &UnbindPacketReadWaiter{readWaiter, c.addr}, true return &unbindPacketReadWaiter{readWaiter, c.addr}, true
} }
func (c *UnbindPacketConn) Upstream() any { func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn return c.ExtendedConn
} }
func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn {
return &serverPacketConn{
NetPacketConn: NewPacketConn(conn),
}
}
type serverPacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
}
func (c *serverPacketConn) Read(p []byte) (n int, err error) {
n, addr, err := c.NetPacketConn.ReadFrom(p)
if err != nil {
return
}
c.remoteAddr = M.SocksaddrFromNet(addr)
return
}
func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error {
destination, err := c.NetPacketConn.ReadPacket(buffer)
if err != nil {
return err
}
c.remoteAddr = destination
return nil
}
func (c *serverPacketConn) Write(p []byte) (n int, err error) {
return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr())
}
func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error {
return c.NetPacketConn.WritePacket(buffer, c.remoteAddr)
}
func (c *serverPacketConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverPacketConn) Upstream() any {
return c.NetPacketConn
}
func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

View file

@ -6,33 +6,33 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil) var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil)
type BindPacketReadWaiter struct { type bindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter readWaiter N.PacketReadWaiter
} }
func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options) return w.readWaiter.InitializeReadWaiter(options)
} }
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, _, err = w.readWaiter.WaitReadPacket() buffer, _, err = w.readWaiter.WaitReadPacket()
return return
} }
var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil) var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil)
type UnbindPacketReadWaiter struct { type unbindPacketReadWaiter struct {
readWaiter N.ReadWaiter readWaiter N.ReadWaiter
addr M.Socksaddr addr M.Socksaddr
} }
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options) return w.readWaiter.InitializeReadWaiter(options)
} }
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, err = w.readWaiter.WaitReadBuffer() buffer, err = w.readWaiter.WaitReadBuffer()
if err != nil { if err != nil {
return return
@ -40,3 +40,23 @@ func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destinati
destination = w.addr destination = w.addr
return return
} }
var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil)
type serverPacketReadWaiter struct {
*serverPacketConn
readWaiter N.PacketReadWaiter
}
func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}
func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, destination, err := w.readWaiter.WaitReadPacket()
if err != nil {
return
}
w.remoteAddr = destination
return
}

View file

@ -154,16 +154,17 @@ type WriterWithMTU interface {
func CalculateMTU(reader any, writer any) int { func CalculateMTU(reader any, writer any) int {
readerMTU := calculateReaderMTU(reader) readerMTU := calculateReaderMTU(reader)
readerHeadroom := calculateReaderFrontHeadroom(reader)
writerMTU := calculateWriterMTU(writer) writerMTU := calculateWriterMTU(writer)
if readerMTU > writerMTU { if readerMTU == 0 && writerMTU == 0 || readerMTU > buf.BufferSize || writerMTU > buf.BufferSize {
return readerMTU + readerHeadroom
}
if writerMTU > buf.BufferSize {
return 0 return 0
} }
readerHeadroom := calculateReaderFrontHeadroom(reader)
if readerMTU > writerMTU {
return readerMTU + readerHeadroom
} else {
return writerMTU + readerHeadroom return writerMTU + readerHeadroom
} }
}
func calculateReaderMTU(reader any) int { func calculateReaderMTU(reader any) int {
var mtu int var mtu int

View file

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -219,7 +220,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
metadata.Destination = request.Destination metadata.Destination = request.Destination
var innerError error var innerError error
done := make(chan struct{}) done := make(chan struct{})
associatePacketConn := NewAssociatePacketConn(udpConn, request.Destination, conn) associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn)
go func() { go func() {
innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata) innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata)
close(done) close(done)