Use multi context in task/conn copy

This commit is contained in:
世界 2023-04-10 12:20:23 +08:00
parent a82d82e559
commit b7cd741872
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 127 additions and 58 deletions

View file

@ -3,6 +3,9 @@ package batch
import ( import (
"context" "context"
"sync" "sync"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
) )
type Option[T any] func(b *Batch[T]) type Option[T any] func(b *Batch[T])
@ -17,6 +20,10 @@ type Error struct {
Err error Err error
} }
func (e *Error) Error() string {
return E.Cause(e.Err, e.Key).Error()
}
func WithConcurrencyNum[T any](n int) Option[T] { func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch[T]) { return func(b *Batch[T]) {
q := make(chan struct{}, n) q := make(chan struct{}, n)
@ -35,7 +42,7 @@ type Batch[T any] struct {
mux sync.Mutex mux sync.Mutex
err *Error err *Error
once sync.Once once sync.Once
cancel func() cancel common.ContextCancelCauseFunc
} }
func (b *Batch[T]) Go(key string, fn func() (T, error)) { func (b *Batch[T]) Go(key string, fn func() (T, error)) {
@ -54,7 +61,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.once.Do(func() { b.once.Do(func() {
b.err = &Error{key, err} b.err = &Error{key, err}
if b.cancel != nil { if b.cancel != nil {
b.cancel() b.cancel(b.err)
} }
}) })
} }
@ -69,7 +76,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
func (b *Batch[T]) Wait() *Error { func (b *Batch[T]) Wait() *Error {
b.wg.Wait() b.wg.Wait()
if b.cancel != nil { if b.cancel != nil {
b.cancel() b.cancel(nil)
} }
return b.err return b.err
} }
@ -90,7 +97,7 @@ func (b *Batch[T]) Result() map[string]Result[T] {
} }
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) { func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := common.ContextWithCancelCause(ctx)
b := &Batch[T]{ b := &Batch[T]{
result: map[string]Result[T]{}, result: map[string]Result[T]{},

View file

@ -185,6 +185,10 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
} }
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, conn, dest)
}
func CopyConnContextList(contextList []context.Context, conn net.Conn, dest net.Conn) error {
var group task.Group var group task.Group
if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex { if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex {
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
@ -221,7 +225,7 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
group.Cleanup(func() { group.Cleanup(func() {
common.Close(conn, dest) common.Close(conn, dest)
}) })
return group.Run(ctx) return group.RunContextList(contextList)
} }
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
@ -335,6 +339,10 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
} }
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, conn, dest)
}
func CopyPacketConnContextList(contextList []context.Context, conn N.PacketConn, dest N.PacketConn) error {
var group task.Group var group task.Group
group.Append("upload", func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(dest, conn)) return common.Error(CopyPacket(dest, conn))
@ -346,5 +354,5 @@ func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) e
common.Close(conn, dest) common.Close(conn, dest)
}) })
group.FastFail() group.FastFail()
return group.Run(ctx) return group.RunContextList(contextList)
} }

View file

@ -2,17 +2,21 @@ package canceler
import ( import (
"context" "context"
"net"
"os"
"time" "time"
"github.com/sagernet/sing/common"
) )
type Instance struct { type Instance struct {
ctx context.Context ctx context.Context
cancelFunc context.CancelFunc cancelFunc common.ContextCancelCauseFunc
timer *time.Timer timer *time.Timer
timeout time.Duration timeout time.Duration
} }
func New(ctx context.Context, cancelFunc context.CancelFunc, timeout time.Duration) *Instance { func New(ctx context.Context, cancelFunc common.ContextCancelCauseFunc, timeout time.Duration) *Instance {
instance := &Instance{ instance := &Instance{
ctx, ctx,
cancelFunc, cancelFunc,
@ -47,11 +51,15 @@ func (i *Instance) wait() {
case <-i.timer.C: case <-i.timer.C:
case <-i.ctx.Done(): case <-i.ctx.Done():
} }
i.Close() i.CloseWithError(os.ErrDeadlineExceeded)
} }
func (i *Instance) Close() error { func (i *Instance) Close() error {
i.timer.Stop() i.CloseWithError(net.ErrClosed)
i.cancelFunc()
return nil return nil
} }
func (i *Instance) CloseWithError(err error) {
i.timer.Stop()
i.cancelFunc(err)
}

View file

@ -33,7 +33,7 @@ func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration
if err == nil { if err == nil {
return NewTimeoutPacketConn(ctx, conn, timeout) return NewTimeoutPacketConn(ctx, conn, timeout)
} }
ctx, cancel := context.WithCancel(ctx) ctx, cancel := common.ContextWithCancelCause(ctx)
instance := New(ctx, cancel, timeout) instance := New(ctx, cancel, timeout)
return ctx, &TimerPacketConn{conn, instance} return ctx, &TimerPacketConn{conn, instance}
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -13,12 +14,12 @@ import (
type TimeoutPacketConn struct { type TimeoutPacketConn struct {
N.PacketConn N.PacketConn
timeout time.Duration timeout time.Duration
cancel context.CancelFunc cancel common.ContextCancelCauseFunc
active time.Time active time.Time
} }
func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) { func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := common.ContextWithCancelCause(ctx)
return ctx, &TimeoutPacketConn{ return ctx, &TimeoutPacketConn{
PacketConn: conn, PacketConn: conn,
timeout: timeout, timeout: timeout,
@ -38,7 +39,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
return return
} else if E.IsTimeout(err) { } else if E.IsTimeout(err) {
if time.Since(c.active) > c.timeout { if time.Since(c.active) > c.timeout {
c.cancel() c.cancel(err)
return return
} }
} else { } else {

18
common/context.go Normal file
View file

@ -0,0 +1,18 @@
package common
import (
"context"
"reflect"
)
func SelectContext(contextList []context.Context) (int, error) {
chosen, _, _ := reflect.Select(Map(Filter(contextList, func(it context.Context) bool {
return it.Done() != nil
}), func(it context.Context) reflect.SelectCase {
return reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(it.Done()),
}
}))
return chosen, contextList[chosen].Err()
}

14
common/context_compat.go Normal file
View file

@ -0,0 +1,14 @@
//go:build go1.20
package common
import "context"
type (
ContextCancelCauseFunc = context.CancelCauseFunc
)
var (
ContextWithCancelCause = context.WithCancelCause
ContextCause = context.Cause
)

16
common/context_lagacy.go Normal file
View file

@ -0,0 +1,16 @@
//go:build !go1.20
package common
import "context"
type ContextCancelCauseFunc func(cause error)
func ContextWithCancelCause(parentContext context.Context) (context.Context, ContextCancelCauseFunc) {
ctx, cancel := context.WithCancel(parentContext)
return ctx, func(_ error) { cancel() }
}
func ContextCause(context context.Context) error {
return context.Err()
}

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"sync" "sync"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
@ -12,6 +13,12 @@ type taskItem struct {
Run func(ctx context.Context) error Run func(ctx context.Context) error
} }
type errTaskSucceed struct{}
func (e errTaskSucceed) Error() string {
return "task succeed"
}
type Group struct { type Group struct {
tasks []taskItem tasks []taskItem
cleanup func() cleanup func()
@ -39,67 +46,57 @@ func (g *Group) FastFail() {
g.fastFail = true g.fastFail = true
} }
func (g *Group) Run(ctx context.Context) error { func (g *Group) Run(contextList ...context.Context) error {
var retAccess sync.Mutex return g.RunContextList(contextList)
var retErr error }
taskCount := int8(len(g.tasks)) func (g *Group) RunContextList(contextList []context.Context) error {
taskCtx, taskFinish := context.WithCancel(context.Background()) if len(contextList) == 0 {
var mixedCtx context.Context contextList = append(contextList, context.Background())
var mixedFinish context.CancelFunc
if ctx.Done() != nil || g.fastFail {
mixedCtx, mixedFinish = context.WithCancel(ctx)
} else {
mixedCtx, mixedFinish = taskCtx, taskFinish
} }
taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background())
var errorAccess sync.Mutex
var returnError error
taskCount := int8(len(g.tasks))
for _, task := range g.tasks { for _, task := range g.tasks {
currentTask := task currentTask := task
go func() { go func() {
err := currentTask.Run(mixedCtx) err := currentTask.Run(taskCancelContext)
retAccess.Lock() errorAccess.Lock()
if err != nil { if err != nil {
retErr = E.Append(retErr, err, func(err error) error { if currentTask.Name != "" {
if currentTask.Name == "" { err = E.Cause(err, currentTask.Name)
return err }
} returnError = E.Errors(returnError, err)
return E.Cause(err, currentTask.Name)
})
if g.fastFail { if g.fastFail {
mixedFinish() taskCancel(err)
} }
} }
taskCount-- taskCount--
currentCount := taskCount currentCount := taskCount
retAccess.Unlock() errorAccess.Unlock()
if currentCount == 0 { if currentCount == 0 {
taskFinish() taskCancel(errTaskSucceed{})
taskFinish(errTaskSucceed{})
} }
}() }()
} }
var upstreamErr error selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskContext}, contextList...))
if selectedContext != 0 {
select { returnError = E.Append(returnError, upstreamErr, func(err error) error {
case <-ctx.Done(): return E.Cause(err, "upstream")
upstreamErr = ctx.Err() })
case <-taskCtx.Done():
mixedFinish()
case <-mixedCtx.Done():
} }
if g.cleanup != nil { if g.cleanup != nil {
g.cleanup() g.cleanup()
} }
<-taskCtx.Done() <-taskContext.Done()
return returnError
taskFinish()
mixedFinish()
retErr = E.Append(retErr, upstreamErr, func(err error) error {
return E.Cause(err, "upstream")
})
return retErr
} }

View file

@ -74,7 +74,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
localAddr: metadata.Source, localAddr: metadata.Source,
remoteAddr: metadata.Destination, remoteAddr: metadata.Destination,
} }
c.ctx, c.cancel = context.WithCancel(ctx) c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
return c return c
}) })
if !loaded { if !loaded {
@ -110,7 +110,7 @@ type packet struct {
type conn struct { type conn struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel common.ContextCancelCauseFunc
data chan packet data chan packet
localAddr M.Socksaddr localAddr M.Socksaddr
remoteAddr M.Socksaddr remoteAddr M.Socksaddr
@ -180,7 +180,7 @@ func (c *conn) Close() error {
return os.ErrClosed return os.ErrClosed
default: default:
} }
c.cancel() c.cancel(net.ErrClosed)
if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser { if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser {
return sourceCloser.Close() return sourceCloser.Close()
} }