Improve task

This commit is contained in:
世界 2022-08-03 17:07:57 +08:00
parent afbe231237
commit d9ca259bec
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 128 additions and 152 deletions

View file

@ -5,7 +5,6 @@ import (
"io" "io"
"net" "net"
"os" "os"
"time"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
@ -146,17 +145,21 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
} }
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error { func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
defer common.Close(conn, dest) var group task.Group
err := task.Run(ctx, func() error { group.Append("upload", func(ctx context.Context) error {
defer rw.CloseRead(conn) defer rw.CloseRead(conn)
defer rw.CloseWrite(dest) defer rw.CloseWrite(dest)
return common.Error(Copy(dest, conn)) return common.Error(Copy(dest, conn))
}, func() error { })
group.Append("download", func(ctx context.Context) error {
defer rw.CloseRead(dest) defer rw.CloseRead(dest)
defer rw.CloseWrite(conn) defer rw.CloseWrite(conn)
return common.Error(Copy(conn, dest)) return common.Error(Copy(conn, dest))
}) })
return err group.Cleanup(func() {
common.Close(conn, dest)
})
return group.Run(ctx)
} }
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
@ -202,48 +205,6 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
} }
} }
func CopyPacketTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
if srcUnsafe {
dstHeadroom := N.CalculateHeadroom(dst)
if dstHeadroom == 0 {
return CopyPacketWithSrcBufferTimeout(dst, unsafeSrc, src, timeout)
}
}
if dstUnsafe {
return CopyPacketWithPoolTimeout(dst, src, timeout)
}
_buffer := buf.StackNewPacket()
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
buffer.IncRef()
defer buffer.DecRef()
var destination M.Socksaddr
for {
buffer.Reset()
err = src.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return
}
destination, err = src.ReadPacket(buffer)
if err != nil {
return
}
if buffer.IsFull() {
return 0, io.ErrShortBuffer
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
return
}
n += int64(dataLen)
}
}
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) { func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
var buffer *buf.Buffer var buffer *buf.Buffer
var destination M.Socksaddr var destination M.Socksaddr
@ -267,33 +228,6 @@ func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (
} }
} }
func CopyPacketWithSrcBufferTimeout(dst N.PacketWriter, src N.ThreadSafePacketReader, tSrc N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
var notFirstTime bool
for {
err = tSrc.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
buffer, destination, err = src.ReadPacketThreadSafe()
if err != nil {
return
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) { func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
var destination M.Socksaddr var destination M.Socksaddr
var notFirstTime bool var notFirstTime bool
@ -322,54 +256,19 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
} }
} }
func CopyPacketWithPoolTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
var destination M.Socksaddr
var notFirstTime bool
for {
buffer := buf.NewPacket()
err = src.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return
}
destination, err = src.ReadPacket(buffer)
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
}
return
}
if buffer.IsFull() {
buffer.Release()
return 0, io.ErrShortBuffer
}
dataLen := buffer.Len()
err = dst.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
return
}
n += int64(dataLen)
notFirstTime = true
}
}
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error { func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
defer common.Close(conn, dest) var group task.Group
return task.Any(ctx, func(ctx context.Context) error { group.Append("upload", func(ctx context.Context) error {
return common.Error(CopyPacket(dest, conn)) return common.Error(CopyPacket(dest, conn))
}, func(ctx context.Context) error { })
group.Append("download", func(ctx context.Context) error {
return common.Error(CopyPacket(conn, dest)) return common.Error(CopyPacket(conn, dest))
}) })
} group.Cleanup(func() {
common.Close(conn, dest)
func CopyPacketConnTimeout(ctx context.Context, conn N.PacketConn, dest N.PacketConn, timeout time.Duration) error {
defer common.Close(conn, dest)
return task.Any(ctx, func(ctx context.Context) error {
return common.Error(CopyPacketTimeout(dest, conn, timeout))
}, func(ctx context.Context) error {
return common.Error(CopyPacketTimeout(conn, dest, timeout))
}) })
group.FastFail()
return group.Run(ctx)
} }
func NewPacketConn(conn net.PacketConn) N.NetPacketConn { func NewPacketConn(conn net.PacketConn) N.NetPacketConn {

View file

@ -13,7 +13,7 @@ type multiError struct {
} }
func (e *multiError) Error() string { func (e *multiError) Error() string {
return "multi error: (" + strings.Join(F.MapToString(e.errors), " | ") + ")" return strings.Join(F.MapToString(e.errors), " | ")
} }
func (e *multiError) UnwrapMulti() []error { func (e *multiError) UnwrapMulti() []error {
@ -50,7 +50,7 @@ func Append(err error, other error, block func(error) error) error {
if other == nil { if other == nil {
return err return err
} }
return Errors(err, block(err)) return Errors(err, block(other))
} }
func IsMulti(err error, targetList ...error) bool { func IsMulti(err error, targetList ...error) bool {

View file

@ -7,48 +7,99 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
func Run(ctx context.Context, tasks ...func() error) error { type taskItem struct {
runtimeCtx, cancel := context.WithCancel(ctx) Name string
wg := &sync.WaitGroup{} Run func(ctx context.Context) error
wg.Add(len(tasks))
var retErr []error
for _, task := range tasks {
currentTask := task
go func() {
if err := currentTask(); err != nil {
retErr = append(retErr, err)
}
wg.Done()
}()
}
go func() {
wg.Wait()
cancel()
}()
select {
case <-ctx.Done():
case <-runtimeCtx.Done():
}
retErr = append(retErr, ctx.Err())
return E.Errors(retErr...)
} }
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error { type Group struct {
runtimeCtx, cancel := context.WithCancel(ctx) tasks []taskItem
defer cancel() cleanup func()
fastFail bool
}
func (g *Group) Append(name string, f func(ctx context.Context) error) {
g.tasks = append(g.tasks, taskItem{
Name: name,
Run: f,
})
}
func (g *Group) Append0(f func(ctx context.Context) error) {
g.tasks = append(g.tasks, taskItem{
Run: f,
})
}
func (g *Group) Cleanup(f func()) {
g.cleanup = f
}
func (g *Group) FastFail() {
g.fastFail = true
}
func (g *Group) Run(ctx context.Context) error {
var retAccess sync.Mutex
var retErr error var retErr error
for _, task := range tasks {
taskCount := int8(len(g.tasks))
taskCtx, taskFinish := context.WithCancel(context.Background())
var mixedCtx context.Context
var mixedFinish context.CancelFunc
if ctx.Done() != nil || g.fastFail {
mixedCtx, mixedFinish = context.WithCancel(ctx)
} else {
mixedCtx, mixedFinish = taskCtx, taskFinish
}
for _, task := range g.tasks {
currentTask := task currentTask := task
go func() { go func() {
if err := currentTask(runtimeCtx); err != nil { err := currentTask.Run(mixedCtx)
retErr = err retAccess.Lock()
if err != nil {
retErr = E.Append(retErr, err, func(err error) error {
if currentTask.Name == "" {
return err
}
return E.Cause(err, currentTask.Name)
})
if g.fastFail {
mixedFinish()
}
}
taskCount--
currentCount := taskCount
retAccess.Unlock()
if currentCount == 0 {
taskFinish()
} }
cancel()
}() }()
} }
var upstreamErr error
select { select {
case <-ctx.Done(): case <-ctx.Done():
case <-runtimeCtx.Done(): upstreamErr = ctx.Err()
case <-taskCtx.Done():
mixedFinish()
case <-mixedCtx.Done():
} }
return E.Errors(retErr, ctx.Err())
if g.cleanup != nil {
g.cleanup()
}
<-taskCtx.Done()
taskFinish()
mixedFinish()
retErr = E.Append(retErr, upstreamErr, func(err error) error {
return E.Cause(err, "upstream")
})
return retErr
} }

View file

@ -0,0 +1,26 @@
package task
import "context"
func Run(ctx context.Context, tasks ...func() error) error {
var group Group
for _, task := range tasks {
currentTask := task
group.Append0(func(ctx context.Context) error {
return currentTask()
})
}
return group.Run(ctx)
}
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
var group Group
for _, task := range tasks {
currentTask := task
group.Append0(func(ctx context.Context) error {
return currentTask(ctx)
})
}
group.FastFail()
return group.Run(ctx)
}