Fix bad group usages

This commit is contained in:
世界 2024-08-18 11:15:20 +08:00
parent ec1df651e8
commit 96bef0733f
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 40 additions and 26 deletions

View file

@ -157,10 +157,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
} }
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error { func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, source, destination)
}
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
var group task.Group var group task.Group
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex { if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
@ -197,7 +193,19 @@ func CopyConnContextList(contextList []context.Context, source net.Conn, destina
group.Cleanup(func() { group.Cleanup(func() {
common.Close(source, destination) common.Close(source, destination)
}) })
return group.RunContextList(contextList) return group.Run(ctx)
}
// Deprecated: not used
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
switch len(contextList) {
case 0:
return CopyConn(context.Background(), source, destination)
case 1:
return CopyConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
} }
func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
@ -318,10 +326,6 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
} }
func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error { func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
}
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
var group task.Group var group task.Group
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(destination, source)) return common.Error(CopyPacket(destination, source))
@ -333,5 +337,17 @@ func CopyPacketConnContextList(contextList []context.Context, source N.PacketCon
common.Close(source, destination) common.Close(source, destination)
}) })
group.FastFail() group.FastFail()
return group.RunContextList(contextList) return group.Run(ctx)
}
// Deprecated: not used
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
switch len(contextList) {
case 0:
return CopyPacketConn(context.Background(), source, destination)
case 1:
return CopyPacketConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
} }

View file

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
) )
// Deprecated: not used
func SelectContext(contextList []context.Context) (int, error) { func SelectContext(contextList []context.Context) (int, error) {
if len(contextList) == 1 { if len(contextList) == 1 {
<-contextList[0].Done() <-contextList[0].Done()

View file

@ -54,17 +54,9 @@ func (g *Group) Concurrency(n int) {
} }
} }
func (g *Group) Run(contextList ...context.Context) error { func (g *Group) Run(ctx context.Context) error {
return g.RunContextList(contextList)
}
func (g *Group) RunContextList(contextList []context.Context) error {
if len(contextList) == 0 {
contextList = append(contextList, context.Background())
}
taskContext, taskFinish := common.ContextWithCancelCause(context.Background()) taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
taskCancelContext, taskCancel := common.ContextWithCancelCause(contextList[0]) taskCancelContext, taskCancel := common.ContextWithCancelCause(ctx)
var errorAccess sync.Mutex var errorAccess sync.Mutex
var returnError error var returnError error
@ -112,8 +104,13 @@ func (g *Group) RunContextList(contextList []context.Context) error {
}() }()
} }
selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskCancelContext}, contextList[1:]...)) var upstreamErr bool
taskCancel(upstreamErr) select {
case <-taskCancelContext.Done():
case <-ctx.Done():
upstreamErr = true
taskCancel(ctx.Err())
}
if g.cleanup != nil { if g.cleanup != nil {
g.cleanup() g.cleanup()
@ -121,10 +118,8 @@ func (g *Group) RunContextList(contextList []context.Context) error {
<-taskContext.Done() <-taskContext.Done()
if selectedContext != 0 { if upstreamErr {
returnError = E.Append(returnError, upstreamErr, func(err error) error { return ctx.Err()
return E.Cause(err, "upstream")
})
} }
return returnError return returnError

View file

@ -2,6 +2,7 @@ package task
import "context" import "context"
// Deprecated: Use Group instead
func Run(ctx context.Context, tasks ...func() error) error { func Run(ctx context.Context, tasks ...func() error) error {
var group Group var group Group
for _, task := range tasks { for _, task := range tasks {
@ -13,6 +14,7 @@ func Run(ctx context.Context, tasks ...func() error) error {
return group.Run(ctx) return group.Run(ctx)
} }
// Deprecated: Use Group instead
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error { func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
var group Group var group Group
for _, task := range tasks { for _, task := range tasks {