Fix buffer read (break change)

This commit is contained in:
世界 2022-07-26 10:52:01 +08:00
parent c15de13f4f
commit bc109486f1
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 17 additions and 29 deletions

View file

@ -168,16 +168,13 @@ func (b *Buffer) WriteByte(d byte) error {
return nil return nil
} }
func (b *Buffer) ReadFrom(r io.Reader) (int64, error) { func (b *Buffer) ReadOnceFrom(r io.Reader) (int64, error) {
if b.IsFull() { if b.IsFull() {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
n, err := r.Read(b.FreeBytes()) n, err := r.Read(b.FreeBytes())
if err != nil {
return 0, err
}
b.end += n b.end += n
return int64(n), nil return int64(n), err
} }
func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) { func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
@ -185,55 +182,46 @@ func (b *Buffer) ReadPacketFrom(r net.PacketConn) (int64, net.Addr, error) {
return 0, nil, io.ErrShortBuffer return 0, nil, io.ErrShortBuffer
} }
n, addr, err := r.ReadFrom(b.FreeBytes()) n, addr, err := r.ReadFrom(b.FreeBytes())
if err != nil {
return 0, nil, err
}
b.end += n b.end += n
return int64(n), addr, nil return int64(n), addr, err
} }
func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) { func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
if min <= 0 { if min <= 0 {
return b.ReadFrom(r) return b.ReadOnceFrom(r)
} }
if b.IsFull() { if b.IsFull() {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
n, err := io.ReadAtLeast(r, b.FreeBytes(), min) n, err := io.ReadAtLeast(r, b.FreeBytes(), min)
if err != nil {
return 0, err
}
b.end += n b.end += n
return int64(n), nil return int64(n), err
} }
func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) { func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.IsFull() { if b.IsFull() {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
end := b.end + size n, err = io.ReadFull(r, b.data[b.end:b.end+size])
n, err = io.ReadFull(r, b.data[b.end:end]) b.end += n
if err != nil {
return
}
b.end = end
return return
} }
func (b *Buffer) ReadAllFrom(reader io.Reader) (int, error) { func (b *Buffer) ReadFrom(reader io.Reader) (n int64, err error) {
for { for {
if b.IsFull() { if b.IsFull() {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
readN, err := reader.Read(b.FreeBytes()) var readN int
readN, err = reader.Read(b.FreeBytes())
b.end += readN
n += int64(readN)
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
return b.Len(), nil err = nil
} else {
return 0, err
} }
return
} }
b.end += readN
} }
} }

View file

@ -154,7 +154,7 @@ func NewCachedPacketConn(conn N.PacketConn, buffer *buf.Buffer, destination M.So
func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { func (c *CachedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if c.buffer != nil { if c.buffer != nil {
_, err = buffer.ReadFrom(c.buffer) _, err = buffer.ReadOnceFrom(c.buffer)
if err != nil { if err != nil {
return M.Socksaddr{}, err return M.Socksaddr{}, err
} }

View file

@ -144,7 +144,7 @@ func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
} }
select { select {
case p := <-c.data: case p := <-c.data:
_, err = buffer.ReadFrom(p.data) _, err = buffer.ReadOnceFrom(p.data)
p.data.Release() p.data.Release()
return p.destination, err return p.destination, err
case <-c.ctx.Done(): case <-c.ctx.Done():

View file

@ -72,7 +72,7 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M.
case http.StatusOK: case http.StatusOK:
if reader.Buffered() > 0 { if reader.Buffered() > 0 {
buffer := buf.NewSize(reader.Buffered()) buffer := buf.NewSize(reader.Buffered())
_, err = buffer.ReadFrom(reader) _, err = buffer.ReadFullFrom(reader, buffer.FreeLen())
if err != nil { if err != nil {
return nil, err return nil, err
} }