diff --git a/hysteria2/internal/protocol/proxy.go b/hysteria2/internal/protocol/proxy.go index 795b3cb..f76265b 100644 --- a/hysteria2/internal/protocol/proxy.go +++ b/hysteria2/internal/protocol/proxy.go @@ -92,19 +92,28 @@ func WriteTCPRequest(addr string, payload []byte) *buf.Buffer { // Padding length (QUIC varint) // Padding (bytes) -func ReadTCPResponse(r io.Reader) (bool, string, error) { +func ReadTCPResponse(r io.Reader) (ok bool, message string, err error) { var status [1]byte - if _, err := io.ReadFull(r, status[:]); err != nil { - return false, "", err - } - bReader := quicvarint.NewReader(r) - msg, err := ReadVString(bReader) + _, err = io.ReadFull(r, status[:]) if err != nil { - return false, "", err + return + } + ok = status[0] == 0 + bReader := quicvarint.NewReader(r) + messageLen, err := quicvarint.Read(bReader) + if err != nil { + return + } + if messageLen > MaxMessageLength { + return false, "", E.New("invalid message length") + } + message, err = rw.ReadString(r, int(messageLen)) + if err != nil { + return } paddingLen, err := quicvarint.Read(bReader) if err != nil { - return false, "", err + return } if paddingLen > MaxPaddingLength { return false, "", E.New("invalid padding length") @@ -112,16 +121,19 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) { if paddingLen > 0 { _, err = io.CopyN(io.Discard, r, int64(paddingLen)) if err != nil { - return false, "", err + return } } - return status[0] == 0, msg, nil + return } func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer { padding := tcpResponsePadding.String() paddingLen := len(padding) msgLen := len(msg) + if msgLen > MaxMessageLength { + msgLen = MaxMessageLength + } sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + int(quicvarint.Len(uint64(paddingLen))) + paddingLen buffer := buf.NewSize(sz + len(payload)) @@ -198,7 +210,7 @@ func ParseUDPMessage(msg []byte) (*UDPMessage, error) { if err != nil { return nil, err } - if lAddr == 0 || lAddr > MaxMessageLength { + if lAddr == 0 || lAddr > MaxAddressLength { return nil, E.New("invalid address length") } bs := buf.Bytes() @@ -212,6 +224,9 @@ func ReadVString(reader io.Reader) (string, error) { if err != nil { return "", err } + if length > MaxAddressLength { + return "", E.New("invalid address length") + } value, err := rw.ReadBytes(reader, int(length)) if err != nil { return "", err