diff --git a/common/buf/buffer.go b/common/buf/buffer.go index 0f35263..9c57b71 100644 --- a/common/buf/buffer.go +++ b/common/buf/buffer.go @@ -168,16 +168,13 @@ func (b *Buffer) WriteByte(d byte) error { return nil } -func (b *Buffer) ReadFrom(r io.Reader) (int64, error) { +func (b *Buffer) ReadOnceFrom(r io.Reader) (int64, error) { if b.IsFull() { return 0, io.ErrShortBuffer } n, err := r.Read(b.FreeBytes()) - if err != nil { - return 0, err - } b.end += n - return int64(n), nil + return int64(n), err } func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) { @@ -185,55 +182,46 @@ func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) { return 0, nil, io.ErrShortBuffer } n, addr, err := r.ReadFrom(b.FreeBytes()) - if err != nil { - return 0, nil, err - } b.end += n - return int64(n), addr, nil + return int64(n), addr, err } func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) { if min <= 0 { - return b.ReadFrom(r) + return b.ReadOnceFrom(r) } if b.IsFull() { return 0, io.ErrShortBuffer } n, err := io.ReadAtLeast(r, b.FreeBytes(), min) - if err != nil { - return 0, err - } b.end += n - return int64(n), nil + return int64(n), err } func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) { if b.IsFull() { return 0, io.ErrShortBuffer } - end := b.end + size - n, err = io.ReadFull(r, b.data[b.end:end]) - if err != nil { - return - } - b.end = end + n, err = io.ReadFull(r, b.data[b.end:b.end+size]) + b.end += n return } -func (b *Buffer) ReadAllFrom(reader io.Reader) (int, error) { +func (b *Buffer) ReadFrom(reader io.Reader) (n int64, err error) { for { if b.IsFull() { return 0, io.ErrShortBuffer } - readN, err := reader.Read(b.FreeBytes()) + var readN int + readN, err = reader.Read(b.FreeBytes()) + b.end += readN + n += int64(readN) if err != nil { if errors.Is(err, io.EOF) { - return b.Len(), nil - } else { - return 0, err + err = nil } + return } - b.end += readN } } diff --git a/common/bufio/cache.go b/common/bufio/cache.go index 6cb057b..99f0f2d 100644 --- a/common/bufio/cache.go +++ b/common/bufio/cache.go @@ -154,7 +154,7 @@ func NewCachedPacketConn(conn N.PacketConn, buffer *buf.Buffer, destination M.So func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { if c.buffer != nil { - _, err = buffer.ReadFrom(c.buffer) + _, err = buffer.ReadOnceFrom(c.buffer) if err != nil { return M.Socksaddr{}, err } diff --git a/common/udpnat/service.go b/common/udpnat/service.go index b7ff39b..d3b9248 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -144,7 +144,7 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { } select { case p := <-c.data: - _, err = buffer.ReadFrom(p.data) + _, err = buffer.ReadOnceFrom(p.data) p.data.Release() return p.destination, err case <-c.ctx.Done(): diff --git a/protocol/http/client.go b/protocol/http/client.go index 1336215..cbca5d0 100644 --- a/protocol/http/client.go +++ b/protocol/http/client.go @@ -72,7 +72,7 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. case http.StatusOK: if reader.Buffered() > 0 { buffer := buf.NewSize(reader.Buffered()) - _, err = buffer.ReadFrom(reader) + _, err = buffer.ReadFullFrom(reader, buffer.FreeLen()) if err != nil { return nil, err }