mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-05 13:37:45 +03:00
feat: add padding to requests & responses
This commit is contained in:
parent
9f54aade8f
commit
ebb9b3217e
5 changed files with 374 additions and 305 deletions
|
@ -11,6 +11,7 @@ const (
|
||||||
|
|
||||||
HeaderAuth = "Hysteria-Auth"
|
HeaderAuth = "Hysteria-Auth"
|
||||||
HeaderCCRX = "Hysteria-CC-RX"
|
HeaderCCRX = "Hysteria-CC-RX"
|
||||||
|
HeaderPadding = "Hysteria-Padding"
|
||||||
|
|
||||||
StatusAuthOK = 233
|
StatusAuthOK = 233
|
||||||
)
|
)
|
||||||
|
@ -24,6 +25,7 @@ func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) {
|
||||||
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
|
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
|
||||||
h.Set(HeaderAuth, auth)
|
h.Set(HeaderAuth, auth)
|
||||||
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
||||||
|
h.Set(HeaderPadding, authRequestPadding.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
|
func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
|
||||||
|
@ -33,4 +35,5 @@ func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
|
||||||
|
|
||||||
func AuthResponseDataToHeader(h http.Header, rx uint64) {
|
func AuthResponseDataToHeader(h http.Header, rx uint64) {
|
||||||
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
||||||
|
h.Set(HeaderPadding, authResponsePadding.String())
|
||||||
}
|
}
|
||||||
|
|
26
core/internal/protocol/padding.go
Normal file
26
core/internal/protocol/padding.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// padding specifies a half-open range [Min, Max).
|
||||||
|
type padding struct {
|
||||||
|
Min int
|
||||||
|
Max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p padding) String() string {
|
||||||
|
n := p.Min + rand.Intn(p.Max-p.Min)
|
||||||
|
return strings.Repeat("a", n) // No need to randomize since everything is encrypted anyway
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
authRequestPadding = padding{Min: 256, Max: 2048}
|
||||||
|
authResponsePadding = padding{Min: 256, Max: 2048}
|
||||||
|
tcpRequestPadding = padding{Min: 64, Max: 512}
|
||||||
|
tcpResponsePadding = padding{Min: 128, Max: 1024}
|
||||||
|
udpRequestPadding = padding{Min: 64, Max: 512}
|
||||||
|
udpResponsePadding = padding{Min: 128, Max: 1024}
|
||||||
|
)
|
|
@ -15,8 +15,11 @@ const (
|
||||||
FrameTypeTCPRequest = 0x401
|
FrameTypeTCPRequest = 0x401
|
||||||
FrameTypeUDPRequest = 0x402
|
FrameTypeUDPRequest = 0x402
|
||||||
|
|
||||||
MaxAddressLength = 2048 // for preventing DoS attack by sending a very large address length
|
// Max length values are for preventing DoS attacks
|
||||||
MaxMessageLength = 2048 // for preventing DoS attack by sending a very large message length
|
|
||||||
|
MaxAddressLength = 2048
|
||||||
|
MaxMessageLength = 2048
|
||||||
|
MaxPaddingLength = 4096
|
||||||
|
|
||||||
MaxUDPSize = 4096
|
MaxUDPSize = 4096
|
||||||
|
|
||||||
|
@ -30,28 +33,52 @@ const (
|
||||||
// 0x401 (QUIC varint)
|
// 0x401 (QUIC varint)
|
||||||
// Address length (QUIC varint)
|
// Address length (QUIC varint)
|
||||||
// Address (bytes)
|
// Address (bytes)
|
||||||
|
// Padding length (QUIC varint)
|
||||||
|
// Padding (bytes)
|
||||||
|
|
||||||
func ReadTCPRequest(r io.Reader) (string, error) {
|
func ReadTCPRequest(r io.Reader) (string, error) {
|
||||||
bReader := quicvarint.NewReader(r)
|
bReader := quicvarint.NewReader(r)
|
||||||
l, err := quicvarint.Read(bReader)
|
addrLen, err := quicvarint.Read(bReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if l == 0 || l > MaxAddressLength {
|
if addrLen == 0 || addrLen > MaxAddressLength {
|
||||||
return "", errors.ProtocolError{Message: "invalid address length"}
|
return "", errors.ProtocolError{Message: "invalid address length"}
|
||||||
}
|
}
|
||||||
buf := make([]byte, l)
|
addrBuf := make([]byte, addrLen)
|
||||||
_, err = io.ReadFull(r, buf)
|
_, err = io.ReadFull(r, addrBuf)
|
||||||
return string(buf), err
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
paddingLen, err := quicvarint.Read(bReader)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if paddingLen > MaxPaddingLength {
|
||||||
|
return "", errors.ProtocolError{Message: "invalid padding length"}
|
||||||
|
}
|
||||||
|
if paddingLen > 0 {
|
||||||
|
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(addrBuf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteTCPRequest(w io.Writer, addr string) error {
|
func WriteTCPRequest(w io.Writer, addr string) error {
|
||||||
l := len(addr)
|
padding := tcpRequestPadding.String()
|
||||||
sz := int(quicvarint.Len(FrameTypeTCPRequest)) + int(quicvarint.Len(uint64(l))) + l
|
paddingLen := len(padding)
|
||||||
|
addrLen := len(addr)
|
||||||
|
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
|
||||||
|
int(quicvarint.Len(uint64(addrLen))) + addrLen +
|
||||||
|
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||||
buf := make([]byte, sz)
|
buf := make([]byte, sz)
|
||||||
i := varintPut(buf, FrameTypeTCPRequest)
|
i := varintPut(buf, FrameTypeTCPRequest)
|
||||||
i += varintPut(buf[i:], uint64(l))
|
i += varintPut(buf[i:], uint64(addrLen))
|
||||||
copy(buf[i:], addr)
|
i += copy(buf[i:], addr)
|
||||||
|
i += varintPut(buf[i:], uint64(paddingLen))
|
||||||
|
copy(buf[i:], padding)
|
||||||
_, err := w.Write(buf)
|
_, err := w.Write(buf)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -60,6 +87,8 @@ func WriteTCPRequest(w io.Writer, addr string) error {
|
||||||
// Status (byte, 0=ok, 1=error)
|
// Status (byte, 0=ok, 1=error)
|
||||||
// Message length (QUIC varint)
|
// Message length (QUIC varint)
|
||||||
// Message (bytes)
|
// Message (bytes)
|
||||||
|
// Padding length (QUIC varint)
|
||||||
|
// Padding (bytes)
|
||||||
|
|
||||||
func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
||||||
var status [1]byte
|
var status [1]byte
|
||||||
|
@ -67,45 +96,90 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
||||||
return false, "", err
|
return false, "", err
|
||||||
}
|
}
|
||||||
bReader := quicvarint.NewReader(r)
|
bReader := quicvarint.NewReader(r)
|
||||||
l, err := quicvarint.Read(bReader)
|
msgLen, err := quicvarint.Read(bReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", err
|
return false, "", err
|
||||||
}
|
}
|
||||||
if l == 0 {
|
if msgLen > MaxMessageLength {
|
||||||
// No message is ok
|
|
||||||
return status[0] == 0, "", nil
|
|
||||||
}
|
|
||||||
if l > MaxMessageLength {
|
|
||||||
return false, "", errors.ProtocolError{Message: "invalid message length"}
|
return false, "", errors.ProtocolError{Message: "invalid message length"}
|
||||||
}
|
}
|
||||||
buf := make([]byte, l)
|
var msgBuf []byte
|
||||||
_, err = io.ReadFull(r, buf)
|
// No message is fine
|
||||||
return status[0] == 0, string(buf), err
|
if msgLen > 0 {
|
||||||
|
msgBuf = make([]byte, msgLen)
|
||||||
|
_, err = io.ReadFull(r, msgBuf)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
paddingLen, err := quicvarint.Read(bReader)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
if paddingLen > MaxPaddingLength {
|
||||||
|
return false, "", errors.ProtocolError{Message: "invalid padding length"}
|
||||||
|
}
|
||||||
|
if paddingLen > 0 {
|
||||||
|
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return status[0] == 0, string(msgBuf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
|
func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
|
||||||
l := len(msg)
|
padding := tcpResponsePadding.String()
|
||||||
sz := 1 + int(quicvarint.Len(uint64(l))) + l
|
paddingLen := len(padding)
|
||||||
|
msgLen := len(msg)
|
||||||
|
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
||||||
|
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||||
buf := make([]byte, sz)
|
buf := make([]byte, sz)
|
||||||
if ok {
|
if ok {
|
||||||
buf[0] = 0
|
buf[0] = 0
|
||||||
} else {
|
} else {
|
||||||
buf[0] = 1
|
buf[0] = 1
|
||||||
}
|
}
|
||||||
i := varintPut(buf[1:], uint64(l))
|
i := varintPut(buf[1:], uint64(msgLen))
|
||||||
copy(buf[1+i:], msg)
|
i += copy(buf[1+i:], msg)
|
||||||
|
i += varintPut(buf[1+i:], uint64(paddingLen))
|
||||||
|
copy(buf[1+i:], padding)
|
||||||
_, err := w.Write(buf)
|
_, err := w.Write(buf)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPRequest format:
|
// UDPRequest format:
|
||||||
// 0x402 (QUIC varint)
|
// 0x402 (QUIC varint)
|
||||||
|
// Padding length (QUIC varint)
|
||||||
|
// Padding (bytes)
|
||||||
|
|
||||||
// Nothing to read
|
func ReadUDPRequest(r io.Reader) error {
|
||||||
|
bReader := quicvarint.NewReader(r)
|
||||||
|
paddingLen, err := quicvarint.Read(bReader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if paddingLen > MaxPaddingLength {
|
||||||
|
return errors.ProtocolError{Message: "invalid padding length"}
|
||||||
|
}
|
||||||
|
if paddingLen > 0 {
|
||||||
|
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func WriteUDPRequest(w io.Writer) error {
|
func WriteUDPRequest(w io.Writer) error {
|
||||||
buf := make([]byte, quicvarint.Len(FrameTypeUDPRequest))
|
padding := udpRequestPadding.String()
|
||||||
varintPut(buf, FrameTypeUDPRequest)
|
paddingLen := len(padding)
|
||||||
|
sz := int(quicvarint.Len(FrameTypeUDPRequest)) +
|
||||||
|
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||||
|
buf := make([]byte, sz)
|
||||||
|
i := varintPut(buf, FrameTypeUDPRequest)
|
||||||
|
i += varintPut(buf[i:], uint64(paddingLen))
|
||||||
|
copy(buf[i:], padding)
|
||||||
_, err := w.Write(buf)
|
_, err := w.Write(buf)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -115,6 +189,8 @@ func WriteUDPRequest(w io.Writer) error {
|
||||||
// Session ID (uint32 BE)
|
// Session ID (uint32 BE)
|
||||||
// Message length (QUIC varint)
|
// Message length (QUIC varint)
|
||||||
// Message (bytes)
|
// Message (bytes)
|
||||||
|
// Padding length (QUIC varint)
|
||||||
|
// Padding (bytes)
|
||||||
|
|
||||||
func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
|
func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
|
||||||
var status [1]byte
|
var status [1]byte
|
||||||
|
@ -126,33 +202,55 @@ func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
|
||||||
return false, 0, "", err
|
return false, 0, "", err
|
||||||
}
|
}
|
||||||
bReader := quicvarint.NewReader(r)
|
bReader := quicvarint.NewReader(r)
|
||||||
l, err := quicvarint.Read(bReader)
|
msgLen, err := quicvarint.Read(bReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, 0, "", err
|
return false, 0, "", err
|
||||||
}
|
}
|
||||||
if l == 0 {
|
if msgLen > MaxMessageLength {
|
||||||
// No message is ok
|
|
||||||
return status[0] == 0, sessionID, "", nil
|
|
||||||
}
|
|
||||||
if l > MaxMessageLength {
|
|
||||||
return false, 0, "", errors.ProtocolError{Message: "invalid message length"}
|
return false, 0, "", errors.ProtocolError{Message: "invalid message length"}
|
||||||
}
|
}
|
||||||
buf := make([]byte, l)
|
var msgBuf []byte
|
||||||
_, err = io.ReadFull(r, buf)
|
// No message is fine
|
||||||
return status[0] == 0, sessionID, string(buf), err
|
if msgLen > 0 {
|
||||||
|
msgBuf = make([]byte, msgLen)
|
||||||
|
_, err = io.ReadFull(r, msgBuf)
|
||||||
|
if err != nil {
|
||||||
|
return false, 0, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
paddingLen, err := quicvarint.Read(bReader)
|
||||||
|
if err != nil {
|
||||||
|
return false, 0, "", err
|
||||||
|
}
|
||||||
|
if paddingLen > MaxPaddingLength {
|
||||||
|
return false, 0, "", errors.ProtocolError{Message: "invalid padding length"}
|
||||||
|
}
|
||||||
|
if paddingLen > 0 {
|
||||||
|
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||||
|
if err != nil {
|
||||||
|
return false, 0, "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return status[0] == 0, sessionID, string(msgBuf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error {
|
func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error {
|
||||||
l := len(msg)
|
padding := udpResponsePadding.String()
|
||||||
buf := make([]byte, 5+int(quicvarint.Len(uint64(l)))+l)
|
paddingLen := len(padding)
|
||||||
|
msgLen := len(msg)
|
||||||
|
sz := 1 + 4 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
||||||
|
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||||
|
buf := make([]byte, sz)
|
||||||
if ok {
|
if ok {
|
||||||
buf[0] = 0
|
buf[0] = 0
|
||||||
} else {
|
} else {
|
||||||
buf[0] = 1
|
buf[0] = 1
|
||||||
}
|
}
|
||||||
binary.BigEndian.PutUint32(buf[1:], sessionID)
|
binary.BigEndian.PutUint32(buf[1:], sessionID)
|
||||||
i := varintPut(buf[5:], uint64(l))
|
i := varintPut(buf[5:], uint64(msgLen))
|
||||||
copy(buf[5+i:], msg)
|
i += copy(buf[5+i:], msg)
|
||||||
|
i += varintPut(buf[5+i:], uint64(paddingLen))
|
||||||
|
copy(buf[5+i:], padding)
|
||||||
_, err := w.Write(buf)
|
_, err := w.Write(buf)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,126 +2,11 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReadTCPRequest(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
r io.Reader
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal 1",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x05hello")),
|
|
||||||
},
|
|
||||||
want: "hello",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal 2",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x41\x25We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People")),
|
|
||||||
},
|
|
||||||
want: "We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x00")),
|
|
||||||
},
|
|
||||||
want: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "incomplete",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x06oh no")),
|
|
||||||
},
|
|
||||||
want: "oh no\x00",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "too long",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x66\x77\x88Whatever")),
|
|
||||||
},
|
|
||||||
want: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := ReadTCPRequest(tt.args.r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("ReadTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("ReadTCPRequest() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteTCPRequest(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
addr string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantW string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal 1",
|
|
||||||
args: args{
|
|
||||||
addr: "hello",
|
|
||||||
},
|
|
||||||
wantW: "\x44\x01\x05hello",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal 2",
|
|
||||||
args: args{
|
|
||||||
addr: "We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People",
|
|
||||||
},
|
|
||||||
wantW: "\x44\x01\x41\x25We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
args: args{
|
|
||||||
addr: "",
|
|
||||||
},
|
|
||||||
wantW: "\x44\x01\x00",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
err := WriteTCPRequest(w, tt.args.addr)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("WriteTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gotW := w.String(); gotW != tt.wantW {
|
|
||||||
t.Errorf("WriteTCPRequest() gotW = %v, want %v", []byte(gotW), tt.wantW)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUDPMessage(t *testing.T) {
|
func TestUDPMessage(t *testing.T) {
|
||||||
t.Run("buffer too small", func(t *testing.T) {
|
t.Run("buffer too small", func(t *testing.T) {
|
||||||
// Make sure Serialize returns -1 when the buffer is too small.
|
// Make sure Serialize returns -1 when the buffer is too small.
|
||||||
|
@ -240,75 +125,138 @@ func TestUDPMessageMalformed(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadTCPResponse(t *testing.T) {
|
func TestReadTCPRequest(t *testing.T) {
|
||||||
type args struct {
|
|
||||||
r io.Reader
|
|
||||||
}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
data []byte
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal no padding",
|
||||||
|
data: []byte("\x0egoogle.com:443\x00"),
|
||||||
|
want: "google.com:443",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal with padding",
|
||||||
|
data: []byte("\x0bholy.cc:443\x02gg"),
|
||||||
|
want: "holy.cc:443",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete 1",
|
||||||
|
data: []byte("\x0bhoho"),
|
||||||
|
want: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete 2",
|
||||||
|
data: []byte("\x0bholy.cc:443\x05x"),
|
||||||
|
want: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := bytes.NewReader(tt.data)
|
||||||
|
got, err := ReadTCPRequest(r)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ReadTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ReadTCPRequest() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteTCPRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
wantW string // Just a prefix, we don't care about the padding
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal 1",
|
||||||
|
addr: "google.com:443",
|
||||||
|
wantW: "\x44\x01\x0egoogle.com:443",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal 2",
|
||||||
|
addr: "client-api.arkoselabs.com:8080",
|
||||||
|
wantW: "\x44\x01\x1eclient-api.arkoselabs.com:8080",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
addr: "",
|
||||||
|
wantW: "\x44\x01\x00",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
err := WriteTCPRequest(w, tt.addr)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("WriteTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
||||||
|
t.Errorf("WriteTCPRequest() gotW = %v, want %v", gotW, tt.wantW)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadTCPResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
want bool
|
want bool
|
||||||
want1 string
|
want1 string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "success 1",
|
name: "normal ok no padding",
|
||||||
args: args{
|
data: []byte("\x00\x0bhello world\x00"),
|
||||||
r: bytes.NewReader([]byte("\x00\x00")),
|
|
||||||
},
|
|
||||||
want: true,
|
want: true,
|
||||||
|
want1: "hello world",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal error with padding",
|
||||||
|
data: []byte("\x01\x06stop!!\x05xxxxx"),
|
||||||
|
want1: "stop!!",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal ok no message with padding",
|
||||||
|
data: []byte("\x01\x00\x05xxxxx"),
|
||||||
want1: "",
|
want1: "",
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success 2",
|
name: "incomplete 1",
|
||||||
args: args{
|
data: []byte("\x00\x0bhoho"),
|
||||||
r: bytes.NewReader([]byte("\x00\x12are ya winning son")),
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
want1: "are ya winning son",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure 1",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x01\x00")),
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
want1: "",
|
want1: "",
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure 2",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x01\x15you ain't winning son")),
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
want1: "you ain't winning son",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "incomplete",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x01\x25princess peach is in another castle")),
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
want1: "princess peach is in another castle\x00\x00",
|
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "too long",
|
name: "incomplete 2",
|
||||||
args: args{
|
data: []byte("\x01\x05jesus\x05x"),
|
||||||
r: bytes.NewReader([]byte("\xAA\xBB\xCCrandom stuff")),
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
want1: "",
|
want1: "",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, got1, err := ReadTCPResponse(tt.args.r)
|
r := bytes.NewReader(tt.data)
|
||||||
|
got, got1, err := ReadTCPResponse(r)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("ReadTCPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("ReadTCPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -331,40 +279,26 @@ func TestWriteTCPResponse(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantW string
|
wantW string // Just a prefix, we don't care about the padding
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "success 1",
|
name: "normal ok",
|
||||||
args: args{
|
args: args{ok: true, msg: "hello world"},
|
||||||
ok: true,
|
wantW: "\x00\x0bhello world",
|
||||||
msg: "",
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "normal error",
|
||||||
|
args: args{ok: false, msg: "stop!!"},
|
||||||
|
wantW: "\x01\x06stop!!",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
args: args{ok: true, msg: ""},
|
||||||
wantW: "\x00\x00",
|
wantW: "\x00\x00",
|
||||||
},
|
wantErr: false,
|
||||||
{
|
|
||||||
name: "success 2",
|
|
||||||
args: args{
|
|
||||||
ok: true,
|
|
||||||
msg: "Welcome XDXDXD",
|
|
||||||
},
|
|
||||||
wantW: "\x00\x0EWelcome XDXDXD",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure 1",
|
|
||||||
args: args{
|
|
||||||
ok: false,
|
|
||||||
msg: "",
|
|
||||||
},
|
|
||||||
wantW: "\x01\x00",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure 2",
|
|
||||||
args: args{
|
|
||||||
ok: false,
|
|
||||||
msg: "me trying to find who u are: ...",
|
|
||||||
},
|
|
||||||
wantW: "\x01\x20me trying to find who u are: ...",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -375,17 +309,49 @@ func TestWriteTCPResponse(t *testing.T) {
|
||||||
t.Errorf("WriteTCPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("WriteTCPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if gotW := w.String(); gotW != tt.wantW {
|
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
||||||
t.Errorf("WriteTCPResponse() gotW = %v, want %v", gotW, tt.wantW)
|
t.Errorf("WriteTCPResponse() gotW = %v, want %v", gotW, tt.wantW)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadUDPRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal no padding",
|
||||||
|
data: []byte("\x00\x00"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal with padding",
|
||||||
|
data: []byte("\x02gg"),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete 1",
|
||||||
|
data: []byte("\x0bhoho"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := bytes.NewReader(tt.data)
|
||||||
|
if err := ReadUDPRequest(r); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ReadUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteUDPRequest(t *testing.T) {
|
func TestWriteUDPRequest(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
wantW string
|
wantW string // Just a prefix, we don't care about the padding
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -402,7 +368,7 @@ func TestWriteUDPRequest(t *testing.T) {
|
||||||
t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if gotW := w.String(); gotW != tt.wantW {
|
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
||||||
t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW)
|
t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -410,62 +376,49 @@ func TestWriteUDPRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadUDPResponse(t *testing.T) {
|
func TestReadUDPResponse(t *testing.T) {
|
||||||
type args struct {
|
|
||||||
r io.Reader
|
|
||||||
}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
data []byte
|
||||||
want bool
|
want bool
|
||||||
want1 uint32
|
want1 uint32
|
||||||
want2 string
|
want2 string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "success 1",
|
name: "normal ok no padding",
|
||||||
args: args{
|
data: []byte("\x00\x00\x00\x00\x33\x0bhello world\x00"),
|
||||||
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x02\x00")),
|
|
||||||
},
|
|
||||||
want: true,
|
want: true,
|
||||||
want1: 2,
|
want1: 51,
|
||||||
|
want2: "hello world",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal error with padding",
|
||||||
|
data: []byte("\x01\x00\x00\x33\x33\x06stop!!\x05xxxxx"),
|
||||||
|
want: false,
|
||||||
|
want1: 13107,
|
||||||
|
want2: "stop!!",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "normal ok no message with padding",
|
||||||
|
data: []byte("\x00\x00\x00\x00\x33\x00\x05xxxxx"),
|
||||||
|
want: true,
|
||||||
|
want1: 51,
|
||||||
want2: "",
|
want2: "",
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success 2",
|
name: "incomplete 1",
|
||||||
args: args{
|
data: []byte("\x00\x00\x06"),
|
||||||
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x03\x0EWelcome XDXDXD")),
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
want1: 3,
|
|
||||||
want2: "Welcome XDXDXD",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x01\x00\x00\x00\x01\x20me trying to find who u are: ...")),
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
want1: 1,
|
|
||||||
want2: "me trying to find who u are: ...",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "incomplete",
|
|
||||||
args: args{
|
|
||||||
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x02")),
|
|
||||||
},
|
|
||||||
want: false,
|
want: false,
|
||||||
want1: 0,
|
want1: 0,
|
||||||
want2: "",
|
want2: "",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "too long",
|
name: "incomplete 2",
|
||||||
args: args{
|
data: []byte("\x01\x00\x01\x02\x03\x05jesus\x05x"),
|
||||||
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x02\xCC\xFF\x66no cap")),
|
|
||||||
},
|
|
||||||
want: false,
|
want: false,
|
||||||
want1: 0,
|
want1: 0,
|
||||||
want2: "",
|
want2: "",
|
||||||
|
@ -474,7 +427,8 @@ func TestReadUDPResponse(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, got1, got2, err := ReadUDPResponse(tt.args.r)
|
r := bytes.NewReader(tt.data)
|
||||||
|
got, got1, got2, err := ReadUDPResponse(r)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -501,44 +455,26 @@ func TestWriteUDPResponse(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantW string
|
wantW string // Just a prefix, we don't care about the padding
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "success 1",
|
name: "normal ok",
|
||||||
args: args{
|
args: args{ok: true, sessionID: 6, msg: "hello world"},
|
||||||
ok: true,
|
wantW: "\x00\x00\x00\x00\x06\x0bhello world",
|
||||||
sessionID: 88,
|
wantErr: false,
|
||||||
msg: "",
|
|
||||||
},
|
|
||||||
wantW: "\x00\x00\x00\x00\x58\x00",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success 2",
|
name: "normal error",
|
||||||
args: args{
|
args: args{ok: false, sessionID: 7, msg: "stop!!"},
|
||||||
ok: true,
|
wantW: "\x01\x00\x00\x00\x07\x06stop!!",
|
||||||
sessionID: 233,
|
wantErr: false,
|
||||||
msg: "together forever",
|
|
||||||
},
|
|
||||||
wantW: "\x00\x00\x00\x00\xE9\x10together forever",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "failure 1",
|
name: "empty",
|
||||||
args: args{
|
args: args{ok: true, sessionID: 0, msg: ""},
|
||||||
ok: false,
|
wantW: "\x00\x00\x00\x00\x00\x00",
|
||||||
sessionID: 1,
|
wantErr: false,
|
||||||
msg: "",
|
|
||||||
},
|
|
||||||
wantW: "\x01\x00\x00\x00\x01\x00",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure 2",
|
|
||||||
args: args{
|
|
||||||
ok: false,
|
|
||||||
sessionID: 696969,
|
|
||||||
msg: "run away run away",
|
|
||||||
},
|
|
||||||
wantW: "\x01\x00\x0A\xA2\x89\x11run away run away",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -549,7 +485,7 @@ func TestWriteUDPResponse(t *testing.T) {
|
||||||
t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if gotW := w.String(); gotW != tt.wantW {
|
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
||||||
t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW)
|
t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -289,6 +289,12 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Read request
|
||||||
|
err := protocol.ReadUDPRequest(stream)
|
||||||
|
if err != nil {
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
// Add to session manager
|
// Add to session manager
|
||||||
sessionID, conn, connCloseFunc, err := h.udpSM.Add()
|
sessionID, conn, connCloseFunc, err := h.udpSM.Add()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue