From b2828dac5f865d7a66f0bb6933be3dfe413baee7 Mon Sep 17 00:00:00 2001
From: wwqgtxx <wwqgtxx@gmail.com>
Date: Wed, 21 Sep 2022 16:57:35 +0800
Subject: [PATCH] Fix buffer overflow

---
 common/buf/buffer.go        | 2 +-
 common/uot/client.go        | 7 ++++---
 common/uot/server.go        | 3 ---
 protocol/trojan/protocol.go | 4 ----
 4 files changed, 5 insertions(+), 11 deletions(-)

diff --git a/common/buf/buffer.go b/common/buf/buffer.go
index 8f2ce1a..4b3a34f 100644
--- a/common/buf/buffer.go
+++ b/common/buf/buffer.go
@@ -199,7 +199,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
 }
 
 func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
-	if b.IsFull() {
+	if b.end+size > b.Cap() {
 		return 0, io.ErrShortBuffer
 	}
 	n, err = io.ReadFull(r, b.data[b.end:b.end+size])
diff --git a/common/uot/client.go b/common/uot/client.go
index 05d1c49..cf13999 100644
--- a/common/uot/client.go
+++ b/common/uot/client.go
@@ -36,10 +36,11 @@ func (c *ClientConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
 	if err != nil {
 		return M.Socksaddr{}, err
 	}
-	if buffer.FreeLen() < int(length) {
-		return M.Socksaddr{}, io.ErrShortBuffer
+	_, err = buffer.ReadFullFrom(c, int(length))
+	if err != nil {
+		return M.Socksaddr{}, err
 	}
-	return destination, common.Error(buffer.ReadFullFrom(c, int(length)))
+	return destination, nil
 }
 
 func (c *ClientConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
diff --git a/common/uot/server.go b/common/uot/server.go
index 91721ed..b8605bb 100644
--- a/common/uot/server.go
+++ b/common/uot/server.go
@@ -68,9 +68,6 @@ func (c *ServerConn) loopInput() {
 			break
 		}
 		buffer.FullReset()
-		if int(length) > buffer.FreeLen() {
-			break
-		}
 		_, err = buffer.ReadFullFrom(c.inputReader, int(length))
 		if err != nil {
 			break
diff --git a/protocol/trojan/protocol.go b/protocol/trojan/protocol.go
index f9dc730..b56061f 100644
--- a/protocol/trojan/protocol.go
+++ b/protocol/trojan/protocol.go
@@ -288,10 +288,6 @@ func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
 		return M.Socksaddr{}, E.Cause(err, "read chunk length")
 	}
 
-	if buffer.FreeLen() < int(length) {
-		return M.Socksaddr{}, io.ErrShortBuffer
-	}
-
 	err = rw.SkipN(conn, 2)
 	if err != nil {
 		return M.Socksaddr{}, E.Cause(err, "skip crlf")