hysteria2: Fix ReadTCPResponse crash

This commit is contained in:
世界 2024-05-12 16:14:28 +08:00
parent a7af781687
commit e18a59987e
No known key found for this signature in database
GPG key ID: CD109927C34A63C4

View file

@ -92,19 +92,28 @@ func WriteTCPRequest(addr string, payload []byte) *buf.Buffer {
// Padding length (QUIC varint) // Padding length (QUIC varint)
// Padding (bytes) // Padding (bytes)
func ReadTCPResponse(r io.Reader) (bool, string, error) { func ReadTCPResponse(r io.Reader) (ok bool, message string, err error) {
var status [1]byte var status [1]byte
if _, err := io.ReadFull(r, status[:]); err != nil { _, err = io.ReadFull(r, status[:])
return false, "", err
}
bReader := quicvarint.NewReader(r)
msg, err := ReadVString(bReader)
if err != nil { if err != nil {
return false, "", err return
}
ok = status[0] == 0
bReader := quicvarint.NewReader(r)
messageLen, err := quicvarint.Read(bReader)
if err != nil {
return
}
if messageLen > MaxMessageLength {
return false, "", E.New("invalid message length")
}
message, err = rw.ReadString(r, int(messageLen))
if err != nil {
return
} }
paddingLen, err := quicvarint.Read(bReader) paddingLen, err := quicvarint.Read(bReader)
if err != nil { if err != nil {
return false, "", err return
} }
if paddingLen > MaxPaddingLength { if paddingLen > MaxPaddingLength {
return false, "", E.New("invalid padding length") return false, "", E.New("invalid padding length")
@ -112,16 +121,19 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) {
if paddingLen > 0 { if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen)) _, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil { if err != nil {
return false, "", err return
} }
} }
return status[0] == 0, msg, nil return
} }
func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer { func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer {
padding := tcpResponsePadding.String() padding := tcpResponsePadding.String()
paddingLen := len(padding) paddingLen := len(padding)
msgLen := len(msg) msgLen := len(msg)
if msgLen > MaxMessageLength {
msgLen = MaxMessageLength
}
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buffer := buf.NewSize(sz + len(payload)) buffer := buf.NewSize(sz + len(payload))
@ -198,7 +210,7 @@ func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if lAddr == 0 || lAddr > MaxMessageLength { if lAddr == 0 || lAddr > MaxAddressLength {
return nil, E.New("invalid address length") return nil, E.New("invalid address length")
} }
bs := buf.Bytes() bs := buf.Bytes()
@ -212,6 +224,9 @@ func ReadVString(reader io.Reader) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
if length > MaxAddressLength {
return "", E.New("invalid address length")
}
value, err := rw.ReadBytes(reader, int(length)) value, err := rw.ReadBytes(reader, int(length))
if err != nil { if err != nil {
return "", err return "", err