Fix HandshakeConn write back

This commit is contained in:
世界 2023-04-17 16:20:20 +08:00
parent ad6dd0aa7b
commit 18a3739974
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
2 changed files with 17 additions and 15 deletions

View file

@ -67,7 +67,8 @@ func (r *BufferedReader) ReadBuffer(buffer *buf.Buffer) error {
return err return err
} }
func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) { //TODO: add CopyBuffer
/*func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
if r.buffer == nil { if r.buffer == nil {
return Copy(w, r.upstream) return Copy(w, r.upstream)
} }
@ -76,7 +77,7 @@ func (r *BufferedReader) WriteTo(w io.Writer) (n int64, err error) {
} }
defer r.buffer.Release() defer r.buffer.Release()
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer) return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r.upstream), r.buffer)
} }*/
func (r *BufferedReader) ReadCached() *buf.Buffer { func (r *BufferedReader) ReadCached() *buf.Buffer {
buffer := r.buffer buffer := r.buffer
@ -150,12 +151,12 @@ func (w *BufferedWriter) Fallthrough() error {
return nil return nil
} }
func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) { /*func (w *BufferedWriter) ReadFrom(r io.Reader) (n int64, err error) {
if w.buffer == nil { if w.buffer == nil {
return Copy(w.upstream, r) return Copy(w.upstream, r)
} }
return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r), w.buffer) return CopyExtendedBuffer(NewExtendedWriter(w), NewExtendedReader(r), w.buffer)
} }*/
func (w *BufferedWriter) WriterReplaceable() bool { func (w *BufferedWriter) WriterReplaceable() bool {
return w.buffer == nil return w.buffer == nil

View file

@ -58,6 +58,7 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
} else if dst == nil { } else if dst == nil {
return 0, E.New("nil writer") return 0, E.New("nil writer")
} }
origDst := dst
src = N.UnwrapReader(src) src = N.UnwrapReader(src)
dst = N.UnwrapWriter(dst) dst = N.UnwrapWriter(dst)
if wt, ok := src.(io.WriterTo); ok { if wt, ok := src.(io.WriterTo); ok {
@ -72,19 +73,19 @@ func Copy(dst io.Writer, src io.Reader) (n int64, err error) {
} }
return rt.ReadFrom(src) return rt.ReadFrom(src)
} }
return CopyExtended(NewExtendedWriter(dst), NewExtendedReader(src)) return CopyExtended(origDst, NewExtendedWriter(dst), NewExtendedReader(src))
} }
func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { func CopyExtended(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
safeSrc := N.IsSafeReader(src) safeSrc := N.IsSafeReader(src)
headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst) headroom := N.CalculateFrontHeadroom(dst) + N.CalculateRearHeadroom(dst)
if safeSrc != nil { if safeSrc != nil {
if headroom == 0 { if headroom == 0 {
return CopyExtendedWithSrcBuffer(dst, safeSrc) return CopyExtendedWithSrcBuffer(origDst, dst, safeSrc)
} }
} }
if N.IsUnsafeWriter(dst) { if N.IsUnsafeWriter(dst) {
return CopyExtendedWithPool(dst, src) return CopyExtendedWithPool(origDst, dst, src)
} }
bufferSize := N.CalculateMTU(src, dst) bufferSize := N.CalculateMTU(src, dst)
if bufferSize > 0 { if bufferSize > 0 {
@ -96,10 +97,10 @@ func CopyExtended(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err erro
defer common.KeepAlive(_buffer) defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer) buffer := common.Dup(_buffer)
defer buffer.Release() defer buffer.Release()
return CopyExtendedBuffer(dst, src, buffer) return CopyExtendedBuffer(origDst, dst, src, buffer)
} }
func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) { func CopyExtendedBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.Buffer) (n int64, err error) {
buffer.IncRef() buffer.IncRef()
defer buffer.DecRef() defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(dst) frontHeadroom := N.CalculateFrontHeadroom(dst)
@ -112,7 +113,7 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.
err = src.ReadBuffer(readBuffer) err = src.ReadBuffer(readBuffer)
if err != nil { if err != nil {
if !notFirstTime { if !notFirstTime {
err = N.HandshakeFailure(dst, err) err = N.HandshakeFailure(origDst, err)
} }
return return
} }
@ -127,14 +128,14 @@ func CopyExtendedBuffer(dst N.ExtendedWriter, src N.ExtendedReader, buffer *buf.
} }
} }
func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) { func CopyExtendedWithSrcBuffer(origDst io.Writer, dst N.ExtendedWriter, src N.ThreadSafeReader) (n int64, err error) {
var notFirstTime bool var notFirstTime bool
for { for {
var buffer *buf.Buffer var buffer *buf.Buffer
buffer, err = src.ReadBufferThreadSafe() buffer, err = src.ReadBufferThreadSafe()
if err != nil { if err != nil {
if !notFirstTime { if !notFirstTime {
err = N.HandshakeFailure(dst, err) err = N.HandshakeFailure(origDst, err)
} }
return return
} }
@ -149,7 +150,7 @@ func CopyExtendedWithSrcBuffer(dst N.ExtendedWriter, src N.ThreadSafeReader) (n
} }
} }
func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) { func CopyExtendedWithPool(origDst io.Writer, dst N.ExtendedWriter, src N.ExtendedReader) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(dst) frontHeadroom := N.CalculateFrontHeadroom(dst)
rearHeadroom := N.CalculateRearHeadroom(dst) rearHeadroom := N.CalculateRearHeadroom(dst)
bufferSize := N.CalculateMTU(src, dst) bufferSize := N.CalculateMTU(src, dst)
@ -168,7 +169,7 @@ func CopyExtendedWithPool(dst N.ExtendedWriter, src N.ExtendedReader) (n int64,
if err != nil { if err != nil {
buffer.Release() buffer.Release()
if !notFirstTime { if !notFirstTime {
err = N.HandshakeFailure(dst, err) err = N.HandshakeFailure(origDst, err)
} }
return return
} }