Fix HandshakeConn write back

This commit is contained in:
世界 2023-04-17 16:20:20 +08:00
parent ad6dd0aa7b
commit 18a3739974
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 17 additions and 15 deletions

View file

@ -67,7 +67,8 @@ func (r *BufferedReader) ReadBuffer(buffer *buf.Buffer) error {
return err
}
func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
//TODO: add CopyBuffer
/*func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
if r.buffer == nil {
return Copy(w, r.upstream)
}
@ -76,7 +77,7 @@ func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
}
defer r.buffer.Release()
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer)
}
}*/
func (r *BufferedReader) ReadCached() *buf.Buffer {
buffer := r.buffer
@ -150,12 +151,12 @@ func (w *BufferedWriter) Fallthrough() error {
return nil
}
func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) {
/*func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) {
if w.buffer == nil {
return Copy(w.upstream, r)
}
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r), w.buffer)
}
}*/
func (w *BufferedWriter) WriterReplaceable() bool {
return w.buffer == nil

View file

@ -58,6 +58,7 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
} else if dst == nil {
return 0, E.New("nil writer")
}
origDst := dst
src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok {
@ -72,19 +73,19 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
}
return rt.ReadFrom(src)
}
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src))
return CopyExtended(origDst, NewExtendedWriter(dst), NewExtendedReader(src))
}
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
safeSrc := N.IsSafeReader(src)
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
if safeSrc != nil {
if headroom == 0 {
return CopyExtendedWithSrcBuffer(dst, safeSrc)
return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc)
}
}
if N.IsUnsafeWriter(dst) {
return CopyExtendedWithPool(dst, src)
return CopyExtendedWithPool(origDst, dst, src)
}
bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 {
@ -96,10 +97,10 @@ func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err erro
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
return CopyExtendedBuffer(dst, src, buffer)
return CopyExtendedBuffer(origDst, dst, src, buffer)
}
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
func CopyExtendedBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(dst)
@ -112,7 +113,7 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.
err = src.ReadBuffer(readBuffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
err = N.HandshakeFailure(origDst, err)
}
return
}
@ -127,14 +128,14 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.
}
}
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
func CopyExtendedWithSrcBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
var notFirstTime bool
for {
var buffer *buf.Buffer
buffer, err = src.ReadBufferThreadSafe()
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
err = N.HandshakeFailure(origDst, err)
}
return
}
@ -149,7 +150,7 @@ func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n
}
}
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst)
@ -168,7 +169,7 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(dst, err)
err = N.HandshakeFailure(origDst, err)
}
return
}