mirror of
https://github.com/SagerNet/sing.git
synced 2025-04-05 04:47:40 +03:00
Improve task
This commit is contained in:
parent
afbe231237
commit
d9ca259bec
4 changed files with 128 additions and 152 deletions
|
@ -5,7 +5,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
@ -146,17 +145,21 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
func CopyConn(ctx context.Context, conn net.Conn, dest net.Conn) error {
|
||||||
defer common.Close(conn, dest)
|
var group task.Group
|
||||||
err := task.Run(ctx, func() error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
defer rw.CloseRead(conn)
|
defer rw.CloseRead(conn)
|
||||||
defer rw.CloseWrite(dest)
|
defer rw.CloseWrite(dest)
|
||||||
return common.Error(Copy(dest, conn))
|
return common.Error(Copy(dest, conn))
|
||||||
}, func() error {
|
})
|
||||||
|
group.Append("download", func(ctx context.Context) error {
|
||||||
defer rw.CloseRead(dest)
|
defer rw.CloseRead(dest)
|
||||||
defer rw.CloseWrite(conn)
|
defer rw.CloseWrite(conn)
|
||||||
return common.Error(Copy(conn, dest))
|
return common.Error(Copy(conn, dest))
|
||||||
})
|
})
|
||||||
return err
|
group.Cleanup(func() {
|
||||||
|
common.Close(conn, dest)
|
||||||
|
})
|
||||||
|
return group.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||||
|
@ -202,48 +205,6 @@ func CopyPacket(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
|
|
||||||
unsafeSrc, srcUnsafe := common.Cast[N.ThreadSafePacketReader](src)
|
|
||||||
_, dstUnsafe := common.Cast[N.ThreadUnsafeWriter](dst)
|
|
||||||
if srcUnsafe {
|
|
||||||
dstHeadroom := N.CalculateHeadroom(dst)
|
|
||||||
if dstHeadroom == 0 {
|
|
||||||
return CopyPacketWithSrcBufferTimeout(dst, unsafeSrc, src, timeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if dstUnsafe {
|
|
||||||
return CopyPacketWithPoolTimeout(dst, src, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
_buffer := buf.StackNewPacket()
|
|
||||||
defer common.KeepAlive(_buffer)
|
|
||||||
buffer := common.Dup(_buffer)
|
|
||||||
defer buffer.Release()
|
|
||||||
buffer.IncRef()
|
|
||||||
defer buffer.DecRef()
|
|
||||||
var destination M.Socksaddr
|
|
||||||
for {
|
|
||||||
buffer.Reset()
|
|
||||||
err = src.SetReadDeadline(time.Now().Add(timeout))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
destination, err = src.ReadPacket(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if buffer.IsFull() {
|
|
||||||
return 0, io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
dataLen := buffer.Len()
|
|
||||||
err = dst.WritePacket(buffer, destination)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
|
func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (n int64, err error) {
|
||||||
var buffer *buf.Buffer
|
var buffer *buf.Buffer
|
||||||
var destination M.Socksaddr
|
var destination M.Socksaddr
|
||||||
|
@ -267,33 +228,6 @@ func CopyPacketWithSrcBuffer(dst N.PacketWriter, src N.ThreadSafePacketReader) (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketWithSrcBufferTimeout(dst N.PacketWriter, src N.ThreadSafePacketReader, tSrc N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
|
|
||||||
var buffer *buf.Buffer
|
|
||||||
var destination M.Socksaddr
|
|
||||||
var notFirstTime bool
|
|
||||||
for {
|
|
||||||
err = tSrc.SetReadDeadline(time.Now().Add(timeout))
|
|
||||||
if err != nil {
|
|
||||||
if !notFirstTime {
|
|
||||||
err = N.HandshakeFailure(dst, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer, destination, err = src.ReadPacketThreadSafe()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dataLen := buffer.Len()
|
|
||||||
err = dst.WritePacket(buffer, destination)
|
|
||||||
if err != nil {
|
|
||||||
buffer.Release()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
|
||||||
notFirstTime = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err error) {
|
||||||
var destination M.Socksaddr
|
var destination M.Socksaddr
|
||||||
var notFirstTime bool
|
var notFirstTime bool
|
||||||
|
@ -322,54 +256,19 @@ func CopyPacketWithPool(dst N.PacketWriter, src N.PacketReader) (n int64, err er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyPacketWithPoolTimeout(dst N.PacketWriter, src N.TimeoutPacketReader, timeout time.Duration) (n int64, err error) {
|
|
||||||
var destination M.Socksaddr
|
|
||||||
var notFirstTime bool
|
|
||||||
for {
|
|
||||||
buffer := buf.NewPacket()
|
|
||||||
err = src.SetReadDeadline(time.Now().Add(timeout))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
destination, err = src.ReadPacket(buffer)
|
|
||||||
if err != nil {
|
|
||||||
buffer.Release()
|
|
||||||
if !notFirstTime {
|
|
||||||
err = N.HandshakeFailure(dst, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if buffer.IsFull() {
|
|
||||||
buffer.Release()
|
|
||||||
return 0, io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
dataLen := buffer.Len()
|
|
||||||
err = dst.WritePacket(buffer, destination)
|
|
||||||
if err != nil {
|
|
||||||
buffer.Release()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n += int64(dataLen)
|
|
||||||
notFirstTime = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
func CopyPacketConn(ctx context.Context, conn N.PacketConn, dest N.PacketConn) error {
|
||||||
defer common.Close(conn, dest)
|
var group task.Group
|
||||||
return task.Any(ctx, func(ctx context.Context) error {
|
group.Append("upload", func(ctx context.Context) error {
|
||||||
return common.Error(CopyPacket(dest, conn))
|
return common.Error(CopyPacket(dest, conn))
|
||||||
}, func(ctx context.Context) error {
|
})
|
||||||
|
group.Append("download", func(ctx context.Context) error {
|
||||||
return common.Error(CopyPacket(conn, dest))
|
return common.Error(CopyPacket(conn, dest))
|
||||||
})
|
})
|
||||||
}
|
group.Cleanup(func() {
|
||||||
|
common.Close(conn, dest)
|
||||||
func CopyPacketConnTimeout(ctx context.Context, conn N.PacketConn, dest N.PacketConn, timeout time.Duration) error {
|
|
||||||
defer common.Close(conn, dest)
|
|
||||||
return task.Any(ctx, func(ctx context.Context) error {
|
|
||||||
return common.Error(CopyPacketTimeout(dest, conn, timeout))
|
|
||||||
}, func(ctx context.Context) error {
|
|
||||||
return common.Error(CopyPacketTimeout(conn, dest, timeout))
|
|
||||||
})
|
})
|
||||||
|
group.FastFail()
|
||||||
|
return group.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPacketConn(conn net.PacketConn) N.NetPacketConn {
|
func NewPacketConn(conn net.PacketConn) N.NetPacketConn {
|
||||||
|
|
|
@ -13,7 +13,7 @@ type multiError struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *multiError) Error() string {
|
func (e *multiError) Error() string {
|
||||||
return "multi error: (" + strings.Join(F.MapToString(e.errors), " | ") + ")"
|
return strings.Join(F.MapToString(e.errors), " | ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *multiError) UnwrapMulti() []error {
|
func (e *multiError) UnwrapMulti() []error {
|
||||||
|
@ -50,7 +50,7 @@ func Append(err error, other error, block func(error) error) error {
|
||||||
if other == nil {
|
if other == nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return Errors(err, block(err))
|
return Errors(err, block(other))
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsMulti(err error, targetList ...error) bool {
|
func IsMulti(err error, targetList ...error) bool {
|
||||||
|
|
|
@ -7,48 +7,99 @@ import (
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Run(ctx context.Context, tasks ...func() error) error {
|
type taskItem struct {
|
||||||
runtimeCtx, cancel := context.WithCancel(ctx)
|
Name string
|
||||||
wg := &sync.WaitGroup{}
|
Run func(ctx context.Context) error
|
||||||
wg.Add(len(tasks))
|
|
||||||
var retErr []error
|
|
||||||
for _, task := range tasks {
|
|
||||||
currentTask := task
|
|
||||||
go func() {
|
|
||||||
if err := currentTask(); err != nil {
|
|
||||||
retErr = append(retErr, err)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case <-runtimeCtx.Done():
|
|
||||||
}
|
|
||||||
retErr = append(retErr, ctx.Err())
|
|
||||||
return E.Errors(retErr...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
|
type Group struct {
|
||||||
runtimeCtx, cancel := context.WithCancel(ctx)
|
tasks []taskItem
|
||||||
defer cancel()
|
cleanup func()
|
||||||
|
fastFail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) Append(name string, f func(ctx context.Context) error) {
|
||||||
|
g.tasks = append(g.tasks, taskItem{
|
||||||
|
Name: name,
|
||||||
|
Run: f,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) Append0(f func(ctx context.Context) error) {
|
||||||
|
g.tasks = append(g.tasks, taskItem{
|
||||||
|
Run: f,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) Cleanup(f func()) {
|
||||||
|
g.cleanup = f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) FastFail() {
|
||||||
|
g.fastFail = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) Run(ctx context.Context) error {
|
||||||
|
var retAccess sync.Mutex
|
||||||
var retErr error
|
var retErr error
|
||||||
for _, task := range tasks {
|
|
||||||
|
taskCount := int8(len(g.tasks))
|
||||||
|
taskCtx, taskFinish := context.WithCancel(context.Background())
|
||||||
|
var mixedCtx context.Context
|
||||||
|
var mixedFinish context.CancelFunc
|
||||||
|
if ctx.Done() != nil || g.fastFail {
|
||||||
|
mixedCtx, mixedFinish = context.WithCancel(ctx)
|
||||||
|
} else {
|
||||||
|
mixedCtx, mixedFinish = taskCtx, taskFinish
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, task := range g.tasks {
|
||||||
currentTask := task
|
currentTask := task
|
||||||
go func() {
|
go func() {
|
||||||
if err := currentTask(runtimeCtx); err != nil {
|
err := currentTask.Run(mixedCtx)
|
||||||
retErr = err
|
retAccess.Lock()
|
||||||
|
if err != nil {
|
||||||
|
retErr = E.Append(retErr, err, func(err error) error {
|
||||||
|
if currentTask.Name == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return E.Cause(err, currentTask.Name)
|
||||||
|
})
|
||||||
|
if g.fastFail {
|
||||||
|
mixedFinish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
taskCount--
|
||||||
|
currentCount := taskCount
|
||||||
|
retAccess.Unlock()
|
||||||
|
if currentCount == 0 {
|
||||||
|
taskFinish()
|
||||||
}
|
}
|
||||||
cancel()
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var upstreamErr error
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case <-runtimeCtx.Done():
|
upstreamErr = ctx.Err()
|
||||||
|
case <-taskCtx.Done():
|
||||||
|
mixedFinish()
|
||||||
|
case <-mixedCtx.Done():
|
||||||
}
|
}
|
||||||
return E.Errors(retErr, ctx.Err())
|
|
||||||
|
if g.cleanup != nil {
|
||||||
|
g.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
<-taskCtx.Done()
|
||||||
|
|
||||||
|
taskFinish()
|
||||||
|
mixedFinish()
|
||||||
|
|
||||||
|
retErr = E.Append(retErr, upstreamErr, func(err error) error {
|
||||||
|
return E.Cause(err, "upstream")
|
||||||
|
})
|
||||||
|
|
||||||
|
return retErr
|
||||||
}
|
}
|
||||||
|
|
26
common/task/task_deprecated.go
Normal file
26
common/task/task_deprecated.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package task
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
func Run(ctx context.Context, tasks ...func() error) error {
|
||||||
|
var group Group
|
||||||
|
for _, task := range tasks {
|
||||||
|
currentTask := task
|
||||||
|
group.Append0(func(ctx context.Context) error {
|
||||||
|
return currentTask()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return group.Run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Any(ctx context.Context, tasks ...func(ctx context.Context) error) error {
|
||||||
|
var group Group
|
||||||
|
for _, task := range tasks {
|
||||||
|
currentTask := task
|
||||||
|
group.Append0(func(ctx context.Context) error {
|
||||||
|
return currentTask(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
group.FastFail()
|
||||||
|
return group.Run(ctx)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue