From d9ca259bec6a0cb28d74b4481d3eb277e152e5b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 3 Aug 2022 17:07:57 +0800 Subject: [PATCH] Improve task --- common/bufio/conn.go | 133 ++++----------------------------- common/exceptions/multi.go | 4 +- common/task/task.go | 117 +++++++++++++++++++++-------- common/task/task_deprecated.go | 26 +++++++ 4 files changed, 128 insertions(+), 152 deletions(-) create mode 100644 common/task/task_deprecated.go diff --git a/common/bufio/conn.go b/common/bufio/conn.go index b8a5c03..39b22a8 100644 --- a/common/bufio/conn.go +++ b/common/bufio/conn.go @@ -5,7 +5,6 @@ import ( "io" "net" "os" - "time" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -146,17 +145,21 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, } func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { - defer common.Close(conn, dest) - err := task.Run(ctx, func() error { + var group task.Group + group.Append("upload", func(ctx context.Context) error { defer rw.CloseRead(conn) defer rw.CloseWrite(dest) return common.Error(Copy(dest, conn)) - }, func() error { + }) + group.Append("download", func(ctx context.Context) error { defer rw.CloseRead(dest) defer rw.CloseWrite(conn) return common.Error(Copy(conn, dest)) }) - return err + group.Cleanup(func() { + common.Close(conn, dest) + }) + return group.Run(ctx) } func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { @@ -202,48 +205,6 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { } } -func CopyPacketTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) { - unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src) - _, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst) - if srcUnsafe { - dstHeadroom := N.CalculateHeadroom(dst) - if dstHeadroom == 0 { - return CopyPacketWithSrcBufferTimeout(dst, unsafeSrc, src, timeout) - } - } - if dstUnsafe { - return CopyPacketWithPoolTimeout(dst, src, timeout) - } - - _buffer := buf.StackNewPacket() - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - buffer.IncRef() - defer buffer.DecRef() - var destination M.Socksaddr - for { - buffer.Reset() - err = src.SetReadDeadline(time.Now().Add(timeout)) - if err != nil { - return - } - destination, err = src.ReadPacket(buffer) - if err != nil { - return - } - if buffer.IsFull() { - return 0, io.ErrShortBuffer - } - dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) - if err != nil { - return - } - n += int64(dataLen) - } -} - func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) { var buffer *buf.Buffer var destination M.Socksaddr @@ -267,33 +228,6 @@ func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) ( } } -func CopyPacketWithSrcBufferTimeout(dst N.PacketWriter, src N.ThreadSafePacketReader, tSrc N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) { - var buffer *buf.Buffer - var destination M.Socksaddr - var notFirstTime bool - for { - err = tSrc.SetReadDeadline(time.Now().Add(timeout)) - if err != nil { - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - buffer, destination, err = src.ReadPacketThreadSafe() - if err != nil { - return - } - dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - notFirstTime = true - } -} - func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { var destination M.Socksaddr var notFirstTime bool @@ -322,54 +256,19 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er } } -func CopyPacketWithPoolTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) { - var destination M.Socksaddr - var notFirstTime bool - for { - buffer := buf.NewPacket() - err = src.SetReadDeadline(time.Now().Add(timeout)) - if err != nil { - return - } - destination, err = src.ReadPacket(buffer) - if err != nil { - buffer.Release() - if !notFirstTime { - err = N.HandshakeFailure(dst, err) - } - return - } - if buffer.IsFull() { - buffer.Release() - return 0, io.ErrShortBuffer - } - dataLen := buffer.Len() - err = dst.WritePacket(buffer, destination) - if err != nil { - buffer.Release() - return - } - n += int64(dataLen) - notFirstTime = true - } -} - func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { - defer common.Close(conn, dest) - return task.Any(ctx, func(ctx context.Context) error { + var group task.Group + group.Append("upload", func(ctx context.Context) error { return common.Error(CopyPacket(dest, conn)) - }, func(ctx context.Context) error { + }) + group.Append("download", func(ctx context.Context) error { return common.Error(CopyPacket(conn, dest)) }) -} - -func CopyPacketConnTimeout(ctx context.Context, conn N.PacketConn, dest N.PacketConn, timeout time.Duration) error { - defer common.Close(conn, dest) - return task.Any(ctx, func(ctx context.Context) error { - return common.Error(CopyPacketTimeout(dest, conn, timeout)) - }, func(ctx context.Context) error { - return common.Error(CopyPacketTimeout(conn, dest, timeout)) + group.Cleanup(func() { + common.Close(conn, dest) }) + group.FastFail() + return group.Run(ctx) } func NewPacketConn(conn net.PacketConn) N.NetPacketConn { diff --git a/common/exceptions/multi.go b/common/exceptions/multi.go index a491fe1..16e3a00 100644 --- a/common/exceptions/multi.go +++ b/common/exceptions/multi.go @@ -13,7 +13,7 @@ type multiError struct { } func (e *multiError) Error() string { - return "multi error: (" + strings.Join(F.MapToString(e.errors), " | ") + ")" + return strings.Join(F.MapToString(e.errors), " | ") } func (e *multiError) UnwrapMulti() []error { @@ -50,7 +50,7 @@ func Append(err error, other error, block func(error) error) error { if other == nil { return err } - return Errors(err, block(err)) + return Errors(err, block(other)) } func IsMulti(err error, targetList ...error) bool { diff --git a/common/task/task.go b/common/task/task.go index 698ef3f..cf96a79 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -7,48 +7,99 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func Run(ctx context.Context, tasks ...func() error) error { - runtimeCtx, cancel := context.WithCancel(ctx) - wg := &sync.WaitGroup{} - wg.Add(len(tasks)) - var retErr []error - for _, task := range tasks { - currentTask := task - go func() { - if err := currentTask(); err != nil { - retErr = append(retErr, err) - } - wg.Done() - }() - } - go func() { - wg.Wait() - cancel() - }() - select { - case <-ctx.Done(): - case <-runtimeCtx.Done(): - } - retErr = append(retErr, ctx.Err()) - return E.Errors(retErr...) +type taskItem struct { + Name string + Run func(ctx context.Context) error } -func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error { - runtimeCtx, cancel := context.WithCancel(ctx) - defer cancel() +type Group struct { + tasks []taskItem + cleanup func() + fastFail bool +} + +func (g *Group) Append(name string, f func(ctx context.Context) error) { + g.tasks = append(g.tasks, taskItem{ + Name: name, + Run: f, + }) +} + +func (g *Group) Append0(f func(ctx context.Context) error) { + g.tasks = append(g.tasks, taskItem{ + Run: f, + }) +} + +func (g *Group) Cleanup(f func()) { + g.cleanup = f +} + +func (g *Group) FastFail() { + g.fastFail = true +} + +func (g *Group) Run(ctx context.Context) error { + var retAccess sync.Mutex var retErr error - for _, task := range tasks { + + taskCount := int8(len(g.tasks)) + taskCtx, taskFinish := context.WithCancel(context.Background()) + var mixedCtx context.Context + var mixedFinish context.CancelFunc + if ctx.Done() != nil || g.fastFail { + mixedCtx, mixedFinish = context.WithCancel(ctx) + } else { + mixedCtx, mixedFinish = taskCtx, taskFinish + } + + for _, task := range g.tasks { currentTask := task go func() { - if err := currentTask(runtimeCtx); err != nil { - retErr = err + err := currentTask.Run(mixedCtx) + retAccess.Lock() + if err != nil { + retErr = E.Append(retErr, err, func(err error) error { + if currentTask.Name == "" { + return err + } + return E.Cause(err, currentTask.Name) + }) + if g.fastFail { + mixedFinish() + } + } + taskCount-- + currentCount := taskCount + retAccess.Unlock() + if currentCount == 0 { + taskFinish() } - cancel() }() } + + var upstreamErr error + select { case <-ctx.Done(): - case <-runtimeCtx.Done(): + upstreamErr = ctx.Err() + case <-taskCtx.Done(): + mixedFinish() + case <-mixedCtx.Done(): } - return E.Errors(retErr, ctx.Err()) + + if g.cleanup != nil { + g.cleanup() + } + + <-taskCtx.Done() + + taskFinish() + mixedFinish() + + retErr = E.Append(retErr, upstreamErr, func(err error) error { + return E.Cause(err, "upstream") + }) + + return retErr } diff --git a/common/task/task_deprecated.go b/common/task/task_deprecated.go new file mode 100644 index 0000000..50c0ece --- /dev/null +++ b/common/task/task_deprecated.go @@ -0,0 +1,26 @@ +package task + +import "context" + +func Run(ctx context.Context, tasks ...func() error) error { + var group Group + for _, task := range tasks { + currentTask := task + group.Append0(func(ctx context.Context) error { + return currentTask() + }) + } + return group.Run(ctx) +} + +func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error { + var group Group + for _, task := range tasks { + currentTask := task + group.Append0(func(ctx context.Context) error { + return currentTask(ctx) + }) + } + group.FastFail() + return group.Run(ctx) +}