diff --git a/common/batch/batch.go b/common/batch/batch.go index 4a1da7a..ae974e8 100644 --- a/common/batch/batch.go +++ b/common/batch/batch.go @@ -3,6 +3,9 @@ package batch import ( "context" "sync" + + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" ) type Option[T any] func(b *Batch[T]) @@ -17,6 +20,10 @@ type Error struct { Err error } +func (e *Error) Error() string { + return E.Cause(e.Err, e.Key).Error() +} + func WithConcurrencyNum[T any](n int) Option[T] { return func(b *Batch[T]) { q := make(chan struct{}, n) @@ -35,7 +42,7 @@ type Batch[T any] struct { mux sync.Mutex err *Error once sync.Once - cancel func() + cancel common.ContextCancelCauseFunc } func (b *Batch[T]) Go(key string, fn func() (T, error)) { @@ -54,7 +61,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) { b.once.Do(func() { b.err = &Error{key, err} if b.cancel != nil { - b.cancel() + b.cancel(b.err) } }) } @@ -69,7 +76,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) { func (b *Batch[T]) Wait() *Error { b.wg.Wait() if b.cancel != nil { - b.cancel() + b.cancel(nil) } return b.err } @@ -90,7 +97,7 @@ func (b *Batch[T]) Result() map[string]Result[T] { } func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := common.ContextWithCancelCause(ctx) b := &Batch[T]{ result: map[string]Result[T]{}, diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 715a8cd..f5b3854 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -185,6 +185,10 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, } func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { + return CopyConnContextList([]context.Context{ctx}, conn, dest) +} + +func CopyConnContextList(contextList []context.Context, conn net.Conn, dest net.Conn) error { var group task.Group if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex { group.Append("upload", func(ctx context.Context) error { @@ -221,7 +225,7 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { group.Cleanup(func() { common.Close(conn, dest) }) - return group.Run(ctx) + return group.RunContextList(contextList) } func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { @@ -335,6 +339,10 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er } func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { + return CopyPacketConnContextList([]context.Context{ctx}, conn, dest) +} + +func CopyPacketConnContextList(contextList []context.Context, conn N.PacketConn, dest N.PacketConn) error { var group task.Group group.Append("upload", func(ctx context.Context) error { return common.Error(CopyPacket(dest, conn)) @@ -346,5 +354,5 @@ func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) e common.Close(conn, dest) }) group.FastFail() - return group.Run(ctx) + return group.RunContextList(contextList) } diff --git a/common/canceler/instance.go b/common/canceler/instance.go index 2f80be4..05faa91 100644 --- a/common/canceler/instance.go +++ b/common/canceler/instance.go @@ -2,17 +2,21 @@ package canceler import ( "context" + "net" + "os" "time" + + "github.com/sagernet/sing/common" ) type Instance struct { ctx context.Context - cancelFunc context.CancelFunc + cancelFunc common.ContextCancelCauseFunc timer *time.Timer timeout time.Duration } -func New(ctx context.Context, cancelFunc context.CancelFunc, timeout time.Duration) *Instance { +func New(ctx context.Context, cancelFunc common.ContextCancelCauseFunc, timeout time.Duration) *Instance { instance := &Instance{ ctx, cancelFunc, @@ -47,11 +51,15 @@ func (i *Instance) wait() { case <-i.timer.C: case <-i.ctx.Done(): } - i.Close() + i.CloseWithError(os.ErrDeadlineExceeded) } func (i *Instance) Close() error { - i.timer.Stop() - i.cancelFunc() + i.CloseWithError(net.ErrClosed) return nil } + +func (i *Instance) CloseWithError(err error) { + i.timer.Stop() + i.cancelFunc(err) +} diff --git a/common/canceler/packet.go b/common/canceler/packet.go index 7833fdc..ecc2006 100644 --- a/common/canceler/packet.go +++ b/common/canceler/packet.go @@ -33,7 +33,7 @@ func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration if err == nil { return NewTimeoutPacketConn(ctx, conn, timeout) } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := common.ContextWithCancelCause(ctx) instance := New(ctx, cancel, timeout) return ctx, &TimerPacketConn{conn, instance} } diff --git a/common/canceler/packet_timeout.go b/common/canceler/packet_timeout.go index 83a12bb..561f212 100644 --- a/common/canceler/packet_timeout.go +++ b/common/canceler/packet_timeout.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -13,12 +14,12 @@ import ( type TimeoutPacketConn struct { N.PacketConn timeout time.Duration - cancel context.CancelFunc + cancel common.ContextCancelCauseFunc active time.Time } func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := common.ContextWithCancelCause(ctx) return ctx, &TimeoutPacketConn{ PacketConn: conn, timeout: timeout, @@ -38,7 +39,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa return } else if E.IsTimeout(err) { if time.Since(c.active) > c.timeout { - c.cancel() + c.cancel(err) return } } else { diff --git a/common/context.go b/common/context.go new file mode 100644 index 0000000..e93fa19 --- /dev/null +++ b/common/context.go @@ -0,0 +1,18 @@ +package common + +import ( + "context" + "reflect" +) + +func SelectContext(contextList []context.Context) (int, error) { + chosen, _, _ := reflect.Select(Map(Filter(contextList, func(it context.Context) bool { + return it.Done() != nil + }), func(it context.Context) reflect.SelectCase { + return reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(it.Done()), + } + })) + return chosen, contextList[chosen].Err() +} diff --git a/common/context_compat.go b/common/context_compat.go new file mode 100644 index 0000000..4374ef6 --- /dev/null +++ b/common/context_compat.go @@ -0,0 +1,14 @@ +//go:build go1.20 + +package common + +import "context" + +type ( + ContextCancelCauseFunc = context.CancelCauseFunc +) + +var ( + ContextWithCancelCause = context.WithCancelCause + ContextCause = context.Cause +) diff --git a/common/context_lagacy.go b/common/context_lagacy.go new file mode 100644 index 0000000..14f2f5d --- /dev/null +++ b/common/context_lagacy.go @@ -0,0 +1,16 @@ +//go:build !go1.20 + +package common + +import "context" + +type ContextCancelCauseFunc func(cause error) + +func ContextWithCancelCause(parentContext context.Context) (context.Context, ContextCancelCauseFunc) { + ctx, cancel := context.WithCancel(parentContext) + return ctx, func(_ error) { cancel() } +} + +func ContextCause(context context.Context) error { + return context.Err() +} diff --git a/common/task/task.go b/common/task/task.go index cf96a79..cd37d60 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" ) @@ -12,6 +13,12 @@ type taskItem struct { Run func(ctx context.Context) error } +type errTaskSucceed struct{} + +func (e errTaskSucceed) Error() string { + return "task succeed" +} + type Group struct { tasks []taskItem cleanup func() @@ -39,67 +46,57 @@ func (g *Group) FastFail() { g.fastFail = true } -func (g *Group) Run(ctx context.Context) error { - var retAccess sync.Mutex - var retErr error +func (g *Group) Run(contextList ...context.Context) error { + return g.RunContextList(contextList) +} - taskCount := int8(len(g.tasks)) - taskCtx, taskFinish := context.WithCancel(context.Background()) - var mixedCtx context.Context - var mixedFinish context.CancelFunc - if ctx.Done() != nil || g.fastFail { - mixedCtx, mixedFinish = context.WithCancel(ctx) - } else { - mixedCtx, mixedFinish = taskCtx, taskFinish +func (g *Group) RunContextList(contextList []context.Context) error { + if len(contextList) == 0 { + contextList = append(contextList, context.Background()) } + taskContext, taskFinish := common.ContextWithCancelCause(context.Background()) + taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background()) + + var errorAccess sync.Mutex + var returnError error + taskCount := int8(len(g.tasks)) + for _, task := range g.tasks { currentTask := task go func() { - err := currentTask.Run(mixedCtx) - retAccess.Lock() + err := currentTask.Run(taskCancelContext) + errorAccess.Lock() if err != nil { - retErr = E.Append(retErr, err, func(err error) error { - if currentTask.Name == "" { - return err - } - return E.Cause(err, currentTask.Name) - }) + if currentTask.Name != "" { + err = E.Cause(err, currentTask.Name) + } + returnError = E.Errors(returnError, err) if g.fastFail { - mixedFinish() + taskCancel(err) } } taskCount-- currentCount := taskCount - retAccess.Unlock() + errorAccess.Unlock() if currentCount == 0 { - taskFinish() + taskCancel(errTaskSucceed{}) + taskFinish(errTaskSucceed{}) } }() } - var upstreamErr error - - select { - case <-ctx.Done(): - upstreamErr = ctx.Err() - case <-taskCtx.Done(): - mixedFinish() - case <-mixedCtx.Done(): + selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskContext}, contextList...)) + if selectedContext != 0 { + returnError = E.Append(returnError, upstreamErr, func(err error) error { + return E.Cause(err, "upstream") + }) } if g.cleanup != nil { g.cleanup() } - <-taskCtx.Done() - - taskFinish() - mixedFinish() - - retErr = E.Append(retErr, upstreamErr, func(err error) error { - return E.Cause(err, "upstream") - }) - - return retErr + <-taskContext.Done() + return returnError } diff --git a/common/udpnat/service.go b/common/udpnat/service.go index 6c44ad3..d8a7614 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -74,7 +74,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu localAddr: metadata.Source, remoteAddr: metadata.Destination, } - c.ctx, c.cancel = context.WithCancel(ctx) + c.ctx, c.cancel = common.ContextWithCancelCause(ctx) return c }) if !loaded { @@ -110,7 +110,7 @@ type packet struct { type conn struct { ctx context.Context - cancel context.CancelFunc + cancel common.ContextCancelCauseFunc data chan packet localAddr M.Socksaddr remoteAddr M.Socksaddr @@ -180,7 +180,7 @@ func (c *conn) Close() error { return os.ErrClosed default: } - c.cancel() + c.cancel(net.ErrClosed) if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser { return sourceCloser.Close() }