From 96bef0733f3defe86e0b6ca714cebb487d0b9fc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 18 Aug 2024 11:15:20 +0800 Subject: [PATCH] Fix bad group usages --- common/bufio/copy.go | 36 ++++++++++++++++++++++++---------- common/context.go | 1 + common/task/task.go | 27 +++++++++++-------------- common/task/task_deprecated.go | 2 ++ 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 309f56d..ebb03fe 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -157,10 +157,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, } func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error { - return CopyConnContextList([]context.Context{ctx}, source, destination) -} - -func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { var group task.Group if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex { group.Append("upload", func(ctx context.Context) error { @@ -197,7 +193,19 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina group.Cleanup(func() { common.Close(source, destination) }) - return group.RunContextList(contextList) + return group.Run(ctx) +} + +// Deprecated: not used +func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { + switch len(contextList) { + case 0: + return CopyConn(context.Background(), source, destination) + case 1: + return CopyConn(contextList[0], source, destination) + default: + panic("invalid context list") + } } func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { @@ -318,10 +326,6 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr } func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error { - return CopyPacketConnContextList([]context.Context{ctx}, source, destination) -} - -func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { var group task.Group group.Append("upload", func(ctx context.Context) error { return common.Error(CopyPacket(destination, source)) @@ -333,5 +337,17 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon common.Close(source, destination) }) group.FastFail() - return group.RunContextList(contextList) + return group.Run(ctx) +} + +// Deprecated: not used +func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { + switch len(contextList) { + case 0: + return CopyPacketConn(context.Background(), source, destination) + case 1: + return CopyPacketConn(contextList[0], source, destination) + default: + panic("invalid context list") + } } diff --git a/common/context.go b/common/context.go index c5200e4..fa7cbd0 100644 --- a/common/context.go +++ b/common/context.go @@ -5,6 +5,7 @@ import ( "reflect" ) +// Deprecated: not used func SelectContext(contextList []context.Context) (int, error) { if len(contextList) == 1 { <-contextList[0].Done() diff --git a/common/task/task.go b/common/task/task.go index cbddb7a..32040d4 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -54,17 +54,9 @@ func (g *Group) Concurrency(n int) { } } -func (g *Group) Run(contextList ...context.Context) error { - return g.RunContextList(contextList) -} - -func (g *Group) RunContextList(contextList []context.Context) error { - if len(contextList) == 0 { - contextList = append(contextList, context.Background()) - } - +func (g *Group) Run(ctx context.Context) error { taskContext, taskFinish := common.ContextWithCancelCause(context.Background()) - taskCancelContext, taskCancel := common.ContextWithCancelCause(contextList[0]) + taskCancelContext, taskCancel := common.ContextWithCancelCause(ctx) var errorAccess sync.Mutex var returnError error @@ -112,8 +104,13 @@ func (g *Group) RunContextList(contextList []context.Context) error { }() } - selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList[1:]...)) - taskCancel(upstreamErr) + var upstreamErr bool + select { + case <-taskCancelContext.Done(): + case <-ctx.Done(): + upstreamErr = true + taskCancel(ctx.Err()) + } if g.cleanup != nil { g.cleanup() @@ -121,10 +118,8 @@ func (g *Group) RunContextList(contextList []context.Context) error { <-taskContext.Done() - if selectedContext != 0 { - returnError = E.Append(returnError, upstreamErr, func(err error) error { - return E.Cause(err, "upstream") - }) + if upstreamErr { + return ctx.Err() } return returnError diff --git a/common/task/task_deprecated.go b/common/task/task_deprecated.go index 50c0ece..a712eab 100644 --- a/common/task/task_deprecated.go +++ b/common/task/task_deprecated.go @@ -2,6 +2,7 @@ package task import "context" +// Deprecated: Use Group instead func Run(ctx context.Context, tasks ...func() error) error { var group Group for _, task := range tasks { @@ -13,6 +14,7 @@ func Run(ctx context.Context, tasks ...func() error) error { return group.Run(ctx) } +// Deprecated: Use Group instead func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error { var group Group for _, task := range tasks {