mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-03 20:07:38 +03:00
Use multi context in task/conn copy
This commit is contained in:
parent
a82d82e559
commit
b7cd741872
10 changed files with 127 additions and 58 deletions
|
@ -3,6 +3,9 @@ package batch
|
|||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type Option[T any] func(b *Batch[T])
|
||||
|
@ -17,6 +20,10 @@ type Error struct {
|
|||
Err error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return E.Cause(e.Err, e.Key).Error()
|
||||
}
|
||||
|
||||
func WithConcurrencyNum[T any](n int) Option[T] {
|
||||
return func(b *Batch[T]) {
|
||||
q := make(chan struct{}, n)
|
||||
|
@ -35,7 +42,7 @@ type Batch[T any] struct {
|
|||
mux sync.Mutex
|
||||
err *Error
|
||||
once sync.Once
|
||||
cancel func()
|
||||
cancel common.ContextCancelCauseFunc
|
||||
}
|
||||
|
||||
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
|
||||
|
@ -54,7 +61,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
|
|||
b.once.Do(func() {
|
||||
b.err = &Error{key, err}
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel(b.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -69,7 +76,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
|
|||
func (b *Batch[T]) Wait() *Error {
|
||||
b.wg.Wait()
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel(nil)
|
||||
}
|
||||
return b.err
|
||||
}
|
||||
|
@ -90,7 +97,7 @@ func (b *Batch[T]) Result() map[string]Result[T] {
|
|||
}
|
||||
|
||||
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
|
||||
b := &Batch[T]{
|
||||
result: map[string]Result[T]{},
|
||||
|
|
|
@ -185,6 +185,10 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
|
|||
}
|
||||
|
||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||
return CopyConnContextList([]context.Context{ctx}, conn, dest)
|
||||
}
|
||||
|
||||
func CopyConnContextList(contextList []context.Context, conn net.Conn, dest net.Conn) error {
|
||||
var group task.Group
|
||||
if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex {
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
|
@ -221,7 +225,7 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
|||
group.Cleanup(func() {
|
||||
common.Close(conn, dest)
|
||||
})
|
||||
return group.Run(ctx)
|
||||
return group.RunContextList(contextList)
|
||||
}
|
||||
|
||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||
|
@ -335,6 +339,10 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
|
|||
}
|
||||
|
||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||
return CopyPacketConnContextList([]context.Context{ctx}, conn, dest)
|
||||
}
|
||||
|
||||
func CopyPacketConnContextList(contextList []context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||
var group task.Group
|
||||
group.Append("upload", func(ctx context.Context) error {
|
||||
return common.Error(CopyPacket(dest, conn))
|
||||
|
@ -346,5 +354,5 @@ func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) e
|
|||
common.Close(conn, dest)
|
||||
})
|
||||
group.FastFail()
|
||||
return group.Run(ctx)
|
||||
return group.RunContextList(contextList)
|
||||
}
|
||||
|
|
|
@ -2,17 +2,21 @@ package canceler
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
cancelFunc common.ContextCancelCauseFunc
|
||||
timer *time.Timer
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cancelFunc context.CancelFunc, timeout time.Duration) *Instance {
|
||||
func New(ctx context.Context, cancelFunc common.ContextCancelCauseFunc, timeout time.Duration) *Instance {
|
||||
instance := &Instance{
|
||||
ctx,
|
||||
cancelFunc,
|
||||
|
@ -47,11 +51,15 @@ func (i *Instance) wait() {
|
|||
case <-i.timer.C:
|
||||
case <-i.ctx.Done():
|
||||
}
|
||||
i.Close()
|
||||
i.CloseWithError(os.ErrDeadlineExceeded)
|
||||
}
|
||||
|
||||
func (i *Instance) Close() error {
|
||||
i.timer.Stop()
|
||||
i.cancelFunc()
|
||||
i.CloseWithError(net.ErrClosed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Instance) CloseWithError(err error) {
|
||||
i.timer.Stop()
|
||||
i.cancelFunc(err)
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration
|
|||
if err == nil {
|
||||
return NewTimeoutPacketConn(ctx, conn, timeout)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
instance := New(ctx, cancel, timeout)
|
||||
return ctx, &TimerPacketConn{conn, instance}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
@ -13,12 +14,12 @@ import (
|
|||
type TimeoutPacketConn struct {
|
||||
N.PacketConn
|
||||
timeout time.Duration
|
||||
cancel context.CancelFunc
|
||||
cancel common.ContextCancelCauseFunc
|
||||
active time.Time
|
||||
}
|
||||
|
||||
func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := common.ContextWithCancelCause(ctx)
|
||||
return ctx, &TimeoutPacketConn{
|
||||
PacketConn: conn,
|
||||
timeout: timeout,
|
||||
|
@ -38,7 +39,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
|
|||
return
|
||||
} else if E.IsTimeout(err) {
|
||||
if time.Since(c.active) > c.timeout {
|
||||
c.cancel()
|
||||
c.cancel(err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
|
|
18
common/context.go
Normal file
18
common/context.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func SelectContext(contextList []context.Context) (int, error) {
|
||||
chosen, _, _ := reflect.Select(Map(Filter(contextList, func(it context.Context) bool {
|
||||
return it.Done() != nil
|
||||
}), func(it context.Context) reflect.SelectCase {
|
||||
return reflect.SelectCase{
|
||||
Dir: reflect.SelectRecv,
|
||||
Chan: reflect.ValueOf(it.Done()),
|
||||
}
|
||||
}))
|
||||
return chosen, contextList[chosen].Err()
|
||||
}
|
14
common/context_compat.go
Normal file
14
common/context_compat.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
//go:build go1.20
|
||||
|
||||
package common
|
||||
|
||||
import "context"
|
||||
|
||||
type (
|
||||
ContextCancelCauseFunc = context.CancelCauseFunc
|
||||
)
|
||||
|
||||
var (
|
||||
ContextWithCancelCause = context.WithCancelCause
|
||||
ContextCause = context.Cause
|
||||
)
|
16
common/context_lagacy.go
Normal file
16
common/context_lagacy.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
//go:build !go1.20
|
||||
|
||||
package common
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextCancelCauseFunc func(cause error)
|
||||
|
||||
func ContextWithCancelCause(parentContext context.Context) (context.Context, ContextCancelCauseFunc) {
|
||||
ctx, cancel := context.WithCancel(parentContext)
|
||||
return ctx, func(_ error) { cancel() }
|
||||
}
|
||||
|
||||
func ContextCause(context context.Context) error {
|
||||
return context.Err()
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
|
@ -12,6 +13,12 @@ type taskItem struct {
|
|||
Run func(ctx context.Context) error
|
||||
}
|
||||
|
||||
type errTaskSucceed struct{}
|
||||
|
||||
func (e errTaskSucceed) Error() string {
|
||||
return "task succeed"
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
tasks []taskItem
|
||||
cleanup func()
|
||||
|
@ -39,67 +46,57 @@ func (g *Group) FastFail() {
|
|||
g.fastFail = true
|
||||
}
|
||||
|
||||
func (g *Group) Run(ctx context.Context) error {
|
||||
var retAccess sync.Mutex
|
||||
var retErr error
|
||||
func (g *Group) Run(contextList ...context.Context) error {
|
||||
return g.RunContextList(contextList)
|
||||
}
|
||||
|
||||
taskCount := int8(len(g.tasks))
|
||||
taskCtx, taskFinish := context.WithCancel(context.Background())
|
||||
var mixedCtx context.Context
|
||||
var mixedFinish context.CancelFunc
|
||||
if ctx.Done() != nil || g.fastFail {
|
||||
mixedCtx, mixedFinish = context.WithCancel(ctx)
|
||||
} else {
|
||||
mixedCtx, mixedFinish = taskCtx, taskFinish
|
||||
func (g *Group) RunContextList(contextList []context.Context) error {
|
||||
if len(contextList) == 0 {
|
||||
contextList = append(contextList, context.Background())
|
||||
}
|
||||
|
||||
taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
|
||||
taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background())
|
||||
|
||||
var errorAccess sync.Mutex
|
||||
var returnError error
|
||||
taskCount := int8(len(g.tasks))
|
||||
|
||||
for _, task := range g.tasks {
|
||||
currentTask := task
|
||||
go func() {
|
||||
err := currentTask.Run(mixedCtx)
|
||||
retAccess.Lock()
|
||||
err := currentTask.Run(taskCancelContext)
|
||||
errorAccess.Lock()
|
||||
if err != nil {
|
||||
retErr = E.Append(retErr, err, func(err error) error {
|
||||
if currentTask.Name == "" {
|
||||
return err
|
||||
}
|
||||
return E.Cause(err, currentTask.Name)
|
||||
})
|
||||
if currentTask.Name != "" {
|
||||
err = E.Cause(err, currentTask.Name)
|
||||
}
|
||||
returnError = E.Errors(returnError, err)
|
||||
if g.fastFail {
|
||||
mixedFinish()
|
||||
taskCancel(err)
|
||||
}
|
||||
}
|
||||
taskCount--
|
||||
currentCount := taskCount
|
||||
retAccess.Unlock()
|
||||
errorAccess.Unlock()
|
||||
if currentCount == 0 {
|
||||
taskFinish()
|
||||
taskCancel(errTaskSucceed{})
|
||||
taskFinish(errTaskSucceed{})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var upstreamErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
upstreamErr = ctx.Err()
|
||||
case <-taskCtx.Done():
|
||||
mixedFinish()
|
||||
case <-mixedCtx.Done():
|
||||
selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskContext}, contextList...))
|
||||
if selectedContext != 0 {
|
||||
returnError = E.Append(returnError, upstreamErr, func(err error) error {
|
||||
return E.Cause(err, "upstream")
|
||||
})
|
||||
}
|
||||
|
||||
if g.cleanup != nil {
|
||||
g.cleanup()
|
||||
}
|
||||
|
||||
<-taskCtx.Done()
|
||||
|
||||
taskFinish()
|
||||
mixedFinish()
|
||||
|
||||
retErr = E.Append(retErr, upstreamErr, func(err error) error {
|
||||
return E.Cause(err, "upstream")
|
||||
})
|
||||
|
||||
return retErr
|
||||
<-taskContext.Done()
|
||||
return returnError
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
|
|||
localAddr: metadata.Source,
|
||||
remoteAddr: metadata.Destination,
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
|
||||
return c
|
||||
})
|
||||
if !loaded {
|
||||
|
@ -110,7 +110,7 @@ type packet struct {
|
|||
|
||||
type conn struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
cancel common.ContextCancelCauseFunc
|
||||
data chan packet
|
||||
localAddr M.Socksaddr
|
||||
remoteAddr M.Socksaddr
|
||||
|
@ -180,7 +180,7 @@ func (c *conn) Close() error {
|
|||
return os.ErrClosed
|
||||
default:
|
||||
}
|
||||
c.cancel()
|
||||
c.cancel(net.ErrClosed)
|
||||
if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser {
|
||||
return sourceCloser.Close()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue