diff --git a/common/bufio/buffer.go b/common/bufio/buffer.go index 4f9ee94..30f4c1e 100644 --- a/common/bufio/buffer.go +++ b/common/bufio/buffer.go @@ -3,6 +3,7 @@ package bufio import ( "io" "os" + "sync" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -75,6 +76,7 @@ func (r *BufferedReader) Upstream() any { type BufferedWriter struct { upstream io.Writer buffer *buf.Buffer + access sync.Mutex } func NewBufferedWriter(upstream io.Writer, buffer *buf.Buffer) *BufferedWriter { @@ -85,6 +87,11 @@ func NewBufferedWriter(upstream io.Writer, buffer *buf.Buffer) *BufferedWriter { } func (w *BufferedWriter) Write(p []byte) (n int, err error) { + w.access.Lock() + defer w.access.Unlock() + if w.buffer == nil { + return w.upstream.Write(p) + } for { var writeN int writeN, err = w.buffer.Write(p[n:]) @@ -100,6 +107,30 @@ func (w *BufferedWriter) Write(p []byte) (n int, err error) { } } +func (w *BufferedWriter) Fallthrough() error { + w.access.Lock() + defer w.access.Unlock() + if w.buffer == nil { + return nil + } + if !w.buffer.IsEmpty() { + _, err := w.upstream.Write(w.buffer.Bytes()) + if err != nil { + return err + } + } + w.buffer.Release() + w.buffer = nil + return nil +} + func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) { + if w.buffer == nil { + return Copy(w.upstream, r) + } return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r), w.buffer) } + +func (w *BufferedWriter) WriterReplaceable() bool { + return w.buffer == nil +}