Use multi context in task/conn copy

This commit is contained in:
世界 2023-04-10 12:20:23 +08:00
parent a82d82e559
commit b7cd741872
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 127 additions and 58 deletions

View file

@ -3,6 +3,9 @@ package batch
import (
"context"
"sync"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type Option[T any] func(b *Batch[T])
@ -17,6 +20,10 @@ type Error struct {
Err error
}
func (e *Error) Error() string {
return E.Cause(e.Err, e.Key).Error()
}
func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch[T]) {
q := make(chan struct{}, n)
@ -35,7 +42,7 @@ type Batch[T any] struct {
mux sync.Mutex
err *Error
once sync.Once
cancel func()
cancel common.ContextCancelCauseFunc
}
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
@ -54,7 +61,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.once.Do(func() {
b.err = &Error{key, err}
if b.cancel != nil {
b.cancel()
b.cancel(b.err)
}
})
}
@ -69,7 +76,7 @@ func (b *Batch[T]) Go(key string, fn func() (T, error)) {
func (b *Batch[T]) Wait() *Error {
b.wg.Wait()
if b.cancel != nil {
b.cancel()
b.cancel(nil)
}
return b.err
}
@ -90,7 +97,7 @@ func (b *Batch[T]) Result() map[string]Result[T] {
}
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := common.ContextWithCancelCause(ctx)
b := &Batch[T]{
result: map[string]Result[T]{},

View file

@ -185,6 +185,10 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
}
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
return CopyConnContextList([]context.Context{ctx}, conn, dest)
}
func CopyConnContextList(contextList []context.Context, conn net.Conn, dest net.Conn) error {
var group task.Group
if _, dstDuplex := common.Cast[rw.WriteCloser](dest); dstDuplex {
group.Append("upload", func(ctx context.Context) error {
@ -221,7 +225,7 @@ func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
group.Cleanup(func() {
common.Close(conn, dest)
})
return group.Run(ctx)
return group.RunContextList(contextList)
}
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
@ -335,6 +339,10 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
}
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
return CopyPacketConnContextList([]context.Context{ctx}, conn, dest)
}
func CopyPacketConnContextList(contextList []context.Context, conn N.PacketConn, dest N.PacketConn) error {
var group task.Group
group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(dest, conn))
@ -346,5 +354,5 @@ func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) e
common.Close(conn, dest)
})
group.FastFail()
return group.Run(ctx)
return group.RunContextList(contextList)
}

View file

@ -2,17 +2,21 @@ package canceler
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
)
type Instance struct {
ctx context.Context
cancelFunc context.CancelFunc
cancelFunc common.ContextCancelCauseFunc
timer *time.Timer
timeout time.Duration
}
func New(ctx context.Context, cancelFunc context.CancelFunc, timeout time.Duration) *Instance {
func New(ctx context.Context, cancelFunc common.ContextCancelCauseFunc, timeout time.Duration) *Instance {
instance := &Instance{
ctx,
cancelFunc,
@ -47,11 +51,15 @@ func (i *Instance) wait() {
case <-i.timer.C:
case <-i.ctx.Done():
}
i.Close()
i.CloseWithError(os.ErrDeadlineExceeded)
}
func (i *Instance) Close() error {
i.timer.Stop()
i.cancelFunc()
i.CloseWithError(net.ErrClosed)
return nil
}
func (i *Instance) CloseWithError(err error) {
i.timer.Stop()
i.cancelFunc(err)
}

View file

@ -33,7 +33,7 @@ func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration
if err == nil {
return NewTimeoutPacketConn(ctx, conn, timeout)
}
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := common.ContextWithCancelCause(ctx)
instance := New(ctx, cancel, timeout)
return ctx, &TimerPacketConn{conn, instance}
}

View file

@ -4,6 +4,7 @@ import (
"context"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@ -13,12 +14,12 @@ import (
type TimeoutPacketConn struct {
N.PacketConn
timeout time.Duration
cancel context.CancelFunc
cancel common.ContextCancelCauseFunc
active time.Time
}
func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := common.ContextWithCancelCause(ctx)
return ctx, &TimeoutPacketConn{
PacketConn: conn,
timeout: timeout,
@ -38,7 +39,7 @@ func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksa
return
} else if E.IsTimeout(err) {
if time.Since(c.active) > c.timeout {
c.cancel()
c.cancel(err)
return
}
} else {

18
common/context.go Normal file
View file

@ -0,0 +1,18 @@
package common
import (
"context"
"reflect"
)
func SelectContext(contextList []context.Context) (int, error) {
chosen, _, _ := reflect.Select(Map(Filter(contextList, func(it context.Context) bool {
return it.Done() != nil
}), func(it context.Context) reflect.SelectCase {
return reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(it.Done()),
}
}))
return chosen, contextList[chosen].Err()
}

14
common/context_compat.go Normal file
View file

@ -0,0 +1,14 @@
//go:build go1.20
package common
import "context"
type (
ContextCancelCauseFunc = context.CancelCauseFunc
)
var (
ContextWithCancelCause = context.WithCancelCause
ContextCause = context.Cause
)

16
common/context_lagacy.go Normal file
View file

@ -0,0 +1,16 @@
//go:build !go1.20
package common
import "context"
type ContextCancelCauseFunc func(cause error)
func ContextWithCancelCause(parentContext context.Context) (context.Context, ContextCancelCauseFunc) {
ctx, cancel := context.WithCancel(parentContext)
return ctx, func(_ error) { cancel() }
}
func ContextCause(context context.Context) error {
return context.Err()
}

View file

@ -4,6 +4,7 @@ import (
"context"
"sync"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
@ -12,6 +13,12 @@ type taskItem struct {
Run func(ctx context.Context) error
}
type errTaskSucceed struct{}
func (e errTaskSucceed) Error() string {
return "task succeed"
}
type Group struct {
tasks []taskItem
cleanup func()
@ -39,67 +46,57 @@ func (g *Group) FastFail() {
g.fastFail = true
}
func (g *Group) Run(ctx context.Context) error {
var retAccess sync.Mutex
var retErr error
func (g *Group) Run(contextList ...context.Context) error {
return g.RunContextList(contextList)
}
taskCount := int8(len(g.tasks))
taskCtx, taskFinish := context.WithCancel(context.Background())
var mixedCtx context.Context
var mixedFinish context.CancelFunc
if ctx.Done() != nil || g.fastFail {
mixedCtx, mixedFinish = context.WithCancel(ctx)
} else {
mixedCtx, mixedFinish = taskCtx, taskFinish
func (g *Group) RunContextList(contextList []context.Context) error {
if len(contextList) == 0 {
contextList = append(contextList, context.Background())
}
taskContext, taskFinish := common.ContextWithCancelCause(context.Background())
taskCancelContext, taskCancel := common.ContextWithCancelCause(context.Background())
var errorAccess sync.Mutex
var returnError error
taskCount := int8(len(g.tasks))
for _, task := range g.tasks {
currentTask := task
go func() {
err := currentTask.Run(mixedCtx)
retAccess.Lock()
err := currentTask.Run(taskCancelContext)
errorAccess.Lock()
if err != nil {
retErr = E.Append(retErr, err, func(err error) error {
if currentTask.Name == "" {
return err
}
return E.Cause(err, currentTask.Name)
})
if currentTask.Name != "" {
err = E.Cause(err, currentTask.Name)
}
returnError = E.Errors(returnError, err)
if g.fastFail {
mixedFinish()
taskCancel(err)
}
}
taskCount--
currentCount := taskCount
retAccess.Unlock()
errorAccess.Unlock()
if currentCount == 0 {
taskFinish()
taskCancel(errTaskSucceed{})
taskFinish(errTaskSucceed{})
}
}()
}
var upstreamErr error
select {
case <-ctx.Done():
upstreamErr = ctx.Err()
case <-taskCtx.Done():
mixedFinish()
case <-mixedCtx.Done():
selectedContext, upstreamErr := common.SelectContext(append([]context.Context{taskContext}, contextList...))
if selectedContext != 0 {
returnError = E.Append(returnError, upstreamErr, func(err error) error {
return E.Cause(err, "upstream")
})
}
if g.cleanup != nil {
g.cleanup()
}
<-taskCtx.Done()
taskFinish()
mixedFinish()
retErr = E.Append(retErr, upstreamErr, func(err error) error {
return E.Cause(err, "upstream")
})
return retErr
<-taskContext.Done()
return returnError
}

View file

@ -74,7 +74,7 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
localAddr: metadata.Source,
remoteAddr: metadata.Destination,
}
c.ctx, c.cancel = context.WithCancel(ctx)
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
return c
})
if !loaded {
@ -110,7 +110,7 @@ type packet struct {
type conn struct {
ctx context.Context
cancel context.CancelFunc
cancel common.ContextCancelCauseFunc
data chan packet
localAddr M.Socksaddr
remoteAddr M.Socksaddr
@ -180,7 +180,7 @@ func (c *conn) Close() error {
return os.ErrClosed
default:
}
c.cancel()
c.cancel(net.ErrClosed)
if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser {
return sourceCloser.Close()
}