hysteria/core/utils/pipe.go
2022-11-24 00:22:44 -08:00

94 lines
1.6 KiB
Go

package utils
import (
"io"
"net"
"time"
)
const PipeBufferSize = 32 * 1024
func Pipe(src, dst io.ReadWriter, count func(int)) error {
buf := make([]byte, PipeBufferSize)
for {
rn, err := src.Read(buf)
if rn > 0 {
if count != nil {
count(rn)
}
_, err := dst.Write(buf[:rn])
if err != nil {
return err
}
}
if err != nil {
return err
}
}
}
// count: positive numbers for rw1 to rw2, negative numbers for rw2 to re1
func Pipe2Way(rw1, rw2 io.ReadWriter, count func(int)) error {
errChan := make(chan error, 2)
go func() {
var revCount func(int)
if count != nil {
revCount = func(i int) {
count(-i)
}
}
errChan <- Pipe(rw2, rw1, revCount)
}()
go func() {
errChan <- Pipe(rw1, rw2, count)
}()
// We only need the first error
return <-errChan
}
func PipePairWithTimeout(conn net.Conn, stream io.ReadWriteCloser, timeout time.Duration) error {
errChan := make(chan error, 2)
// TCP to stream
go func() {
buf := make([]byte, PipeBufferSize)
for {
if timeout != 0 {
_ = conn.SetDeadline(time.Now().Add(timeout))
}
rn, err := conn.Read(buf)
if rn > 0 {
_, err := stream.Write(buf[:rn])
if err != nil {
errChan <- err
return
}
}
if err != nil {
errChan <- err
return
}
}
}()
// Stream to TCP
go func() {
buf := make([]byte, PipeBufferSize)
for {
rn, err := stream.Read(buf)
if rn > 0 {
_, err := conn.Write(buf[:rn])
if err != nil {
errChan <- err
return
}
if timeout != 0 {
_ = conn.SetDeadline(time.Now().Add(timeout))
}
}
if err != nil {
errChan <- err
return
}
}
}()
return <-errChan
}