sing/common/batch/batch.go
2023-04-10 12:23:24 +08:00

112 lines
1.9 KiB
Go

package batch
import (
"context"
"sync"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
type Option[T any] func(b *Batch[T])
type Result[T any] struct {
Value T
Err error
}
type Error struct {
Key string
Err error
}
func (e *Error) Error() string {
return E.Cause(e.Err, e.Key).Error()
}
func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch[T]) {
q := make(chan struct{}, n)
for i := 0; i < n; i++ {
q <- struct{}{}
}
b.queue = q
}
}
// Batch similar to errgroup, but can control the maximum number of concurrent
type Batch[T any] struct {
result map[string]Result[T]
queue chan struct{}
wg sync.WaitGroup
mux sync.Mutex
err *Error
once sync.Once
cancel common.ContextCancelCauseFunc
}
func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if b.queue != nil {
<-b.queue
defer func() {
b.queue <- struct{}{}
}()
}
value, err := fn()
if err != nil {
b.once.Do(func() {
b.err = &Error{key, err}
if b.cancel != nil {
b.cancel(b.err)
}
})
}
ret := Result[T]{value, err}
b.mux.Lock()
defer b.mux.Unlock()
b.result[key] = ret
}()
}
func (b *Batch[T]) Wait() *Error {
b.wg.Wait()
if b.cancel != nil {
b.cancel(nil)
}
return b.err
}
func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
err := b.Wait()
return b.Result(), err
}
func (b *Batch[T]) Result() map[string]Result[T] {
b.mux.Lock()
defer b.mux.Unlock()
copy := map[string]Result[T]{}
for k, v := range b.result {
copy[k] = v
}
return copy
}
func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
ctx, cancel := common.ContextWithCancelCause(ctx)
b := &Batch[T]{
result: map[string]Result[T]{},
}
for _, o := range opts {
o(b)
}
b.cancel = cancel
return b, ctx
}