Fix socks packet conn

This commit is contained in:
世界 2022-04-27 11:59:41 +08:00
parent 63ef20617a
commit 82e1dc7058
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
3 changed files with 37 additions and 5 deletions

View file

@ -356,7 +356,7 @@ func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
serverConn := c.method.DialPacketConn(udpConn)
return task.Run(ctx, func() error {
var init bool
return socks.CopyPacketConn(serverConn, conn, func(destination *M.AddrPort, n int) {
return socks.CopyPacketConn0(serverConn, conn, func(destination *M.AddrPort, n int) {
if !init {
init = true
logrus.Info("UDP ", conn.LocalAddr(), " ==> ", destination)
@ -365,7 +365,7 @@ func (c *LocalClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
}
})
}, func() error {
return socks.CopyPacketConn(conn, serverConn, func(destination *M.AddrPort, n int) {
return socks.CopyPacketConn0(conn, serverConn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})

View file

@ -123,11 +123,11 @@ func (c *localClient) NewPacketConnection(conn socks.PacketConn, _ M.Metadata) e
client := uot.NewClientConn(upstream)
return task.Run(context.Background(), func() error {
return socks.CopyPacketConn(client, conn, func(destination *M.AddrPort, n int) {
return socks.CopyPacketConn0(client, conn, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " ==> ", destination)
})
}, func() error {
return socks.CopyPacketConn(conn, client, func(destination *M.AddrPort, n int) {
return socks.CopyPacketConn0(conn, client, func(destination *M.AddrPort, n int) {
logrus.Trace("UDP ", conn.LocalAddr(), " <== ", destination)
})
})

View file

@ -1,6 +1,8 @@
package socks
import (
"context"
"github.com/sagernet/sing/common/task"
"net"
"time"
@ -43,7 +45,37 @@ func (s *PacketConnStub) SetWriteDeadline(t time.Time) error {
return nil
}
func CopyPacketConn(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
func CopyPacketConn(ctx context.Context, dest PacketConn, conn PacketConn) error {
return task.Run(ctx, func() error {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for {
destination, err := conn.ReadPacket(buffer)
if err != nil {
return err
}
err = dest.WritePacket(buffer, destination)
if err != nil {
return err
}
}
}, func() error {
_buffer := buf.StackNew()
buffer := common.Dup(_buffer)
for {
destination, err := dest.ReadPacket(buffer)
if err != nil {
return err
}
err = conn.WritePacket(buffer, destination)
if err != nil {
return err
}
}
})
}
func CopyPacketConn0(dest PacketConn, conn PacketConn, onAction func(destination *M.AddrPort, n int)) error {
for {
buffer := buf.New()
destination, err := conn.ReadPacket(buffer)