diff --git a/common/bufio/bind.go b/common/bufio/bind.go index caa5bbd..9788b4d 100644 --- a/common/bufio/bind.go +++ b/common/bufio/bind.go @@ -39,7 +39,7 @@ func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { if !isReadWaiter { return nil, false } - return &BindPacketReadWaiter{readWaiter}, true + return &bindPacketReadWaiter{readWaiter}, true } func (c *bindPacketConn) RemoteAddr() net.Addr { @@ -104,9 +104,62 @@ func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { if !isReadWaiter { return nil, false } - return &UnbindPacketReadWaiter{readWaiter, c.addr}, true + return &unbindPacketReadWaiter{readWaiter, c.addr}, true } func (c *UnbindPacketConn) Upstream() any { return c.ExtendedConn } + +func NewServerPacketConn(conn net.PacketConn) N.ExtendedConn { + return &serverPacketConn{ + NetPacketConn: NewPacketConn(conn), + } +} + +type serverPacketConn struct { + N.NetPacketConn + remoteAddr M.Socksaddr +} + +func (c *serverPacketConn) Read(p []byte) (n int, err error) { + n, addr, err := c.NetPacketConn.ReadFrom(p) + if err != nil { + return + } + c.remoteAddr = M.SocksaddrFromNet(addr) + return +} + +func (c *serverPacketConn) ReadBuffer(buffer *buf.Buffer) error { + destination, err := c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return err + } + c.remoteAddr = destination + return nil +} + +func (c *serverPacketConn) Write(p []byte) (n int, err error) { + return c.NetPacketConn.WriteTo(p, c.remoteAddr.UDPAddr()) +} + +func (c *serverPacketConn) WriteBuffer(buffer *buf.Buffer) error { + return c.NetPacketConn.WritePacket(buffer, c.remoteAddr) +} + +func (c *serverPacketConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *serverPacketConn) Upstream() any { + return c.NetPacketConn +} + +func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { + readWaiter, isReadWaiter := CreatePacketReadWaiter(c.NetPacketConn) + if !isReadWaiter { + return nil, false + } + return &serverPacketReadWaiter{c, readWaiter}, true +} diff --git a/common/bufio/bind_wait.go b/common/bufio/bind_wait.go index 1396552..779474c 100644 --- a/common/bufio/bind_wait.go +++ b/common/bufio/bind_wait.go @@ -6,33 +6,33 @@ import ( N "github.com/sagernet/sing/common/network" ) -var _ N.ReadWaiter = (*BindPacketReadWaiter)(nil) +var _ N.ReadWaiter = (*bindPacketReadWaiter)(nil) -type BindPacketReadWaiter struct { +type bindPacketReadWaiter struct { readWaiter N.PacketReadWaiter } -func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { +func (w *bindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { return w.readWaiter.InitializeReadWaiter(options) } -func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { +func (w *bindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { buffer, _, err = w.readWaiter.WaitReadPacket() return } -var _ N.PacketReadWaiter = (*UnbindPacketReadWaiter)(nil) +var _ N.PacketReadWaiter = (*unbindPacketReadWaiter)(nil) -type UnbindPacketReadWaiter struct { +type unbindPacketReadWaiter struct { readWaiter N.ReadWaiter addr M.Socksaddr } -func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { +func (w *unbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { return w.readWaiter.InitializeReadWaiter(options) } -func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { +func (w *unbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { buffer, err = w.readWaiter.WaitReadBuffer() if err != nil { return @@ -40,3 +40,23 @@ func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destinati destination = w.addr return } + +var _ N.ReadWaiter = (*serverPacketReadWaiter)(nil) + +type serverPacketReadWaiter struct { + *serverPacketConn + readWaiter N.PacketReadWaiter +} + +func (w *serverPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *serverPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) { + buffer, destination, err := w.readWaiter.WaitReadPacket() + if err != nil { + return + } + w.remoteAddr = destination + return +} diff --git a/common/network/thread.go b/common/network/thread.go index ac3ca69..22063af 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -154,15 +154,16 @@ type WriterWithMTU interface { func CalculateMTU(reader any, writer any) int { readerMTU := calculateReaderMTU(reader) - readerHeadroom := calculateReaderFrontHeadroom(reader) writerMTU := calculateWriterMTU(writer) - if readerMTU > writerMTU { - return readerMTU + readerHeadroom - } - if writerMTU > buf.BufferSize { + if readerMTU == 0 && writerMTU == 0 || readerMTU > buf.BufferSize || writerMTU > buf.BufferSize { return 0 } - return writerMTU + readerHeadroom + readerHeadroom := calculateReaderFrontHeadroom(reader) + if readerMTU > writerMTU { + return readerMTU + readerHeadroom + } else { + return writerMTU + readerHeadroom + } } func calculateReaderMTU(reader any) int { diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 2a2cecf..9e742a2 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -219,7 +220,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent metadata.Destination = request.Destination var innerError error done := make(chan struct{}) - associatePacketConn := NewAssociatePacketConn(udpConn, request.Destination, conn) + associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn) go func() { innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata) close(done)