feat: add padding to requests & responses

This commit is contained in:
tobyxdd 2023-05-31 21:53:15 -07:00
parent 9f54aade8f
commit ebb9b3217e
5 changed files with 374 additions and 305 deletions

View file

@ -9,8 +9,9 @@ const (
URLHost = "hysteria"
URLPath = "/auth"
HeaderAuth = "Hysteria-Auth"
HeaderCCRX = "Hysteria-CC-RX"
HeaderAuth = "Hysteria-Auth"
HeaderCCRX = "Hysteria-CC-RX"
HeaderPadding = "Hysteria-Padding"
StatusAuthOK = 233
)
@ -24,6 +25,7 @@ func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) {
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
h.Set(HeaderAuth, auth)
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
h.Set(HeaderPadding, authRequestPadding.String())
}
func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
@ -33,4 +35,5 @@ func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
func AuthResponseDataToHeader(h http.Header, rx uint64) {
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
h.Set(HeaderPadding, authResponsePadding.String())
}

View file

@ -0,0 +1,26 @@
package protocol
import (
"math/rand"
"strings"
)
// padding specifies a half-open range [Min, Max).
type padding struct {
Min int
Max int
}
func (p padding) String() string {
n := p.Min + rand.Intn(p.Max-p.Min)
return strings.Repeat("a", n) // No need to randomize since everything is encrypted anyway
}
var (
authRequestPadding = padding{Min: 256, Max: 2048}
authResponsePadding = padding{Min: 256, Max: 2048}
tcpRequestPadding = padding{Min: 64, Max: 512}
tcpResponsePadding = padding{Min: 128, Max: 1024}
udpRequestPadding = padding{Min: 64, Max: 512}
udpResponsePadding = padding{Min: 128, Max: 1024}
)

View file

@ -15,8 +15,11 @@ const (
FrameTypeTCPRequest = 0x401
FrameTypeUDPRequest = 0x402
MaxAddressLength = 2048 // for preventing DoS attack by sending a very large address length
MaxMessageLength = 2048 // for preventing DoS attack by sending a very large message length
// Max length values are for preventing DoS attacks
MaxAddressLength = 2048
MaxMessageLength = 2048
MaxPaddingLength = 4096
MaxUDPSize = 4096
@ -30,28 +33,52 @@ const (
// 0x401 (QUIC varint)
// Address length (QUIC varint)
// Address (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadTCPRequest(r io.Reader) (string, error) {
bReader := quicvarint.NewReader(r)
l, err := quicvarint.Read(bReader)
addrLen, err := quicvarint.Read(bReader)
if err != nil {
return "", err
}
if l == 0 || l > MaxAddressLength {
if addrLen == 0 || addrLen > MaxAddressLength {
return "", errors.ProtocolError{Message: "invalid address length"}
}
buf := make([]byte, l)
_, err = io.ReadFull(r, buf)
return string(buf), err
addrBuf := make([]byte, addrLen)
_, err = io.ReadFull(r, addrBuf)
if err != nil {
return "", err
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return "", err
}
if paddingLen > MaxPaddingLength {
return "", errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return "", err
}
}
return string(addrBuf), nil
}
func WriteTCPRequest(w io.Writer, addr string) error {
l := len(addr)
sz := int(quicvarint.Len(FrameTypeTCPRequest)) + int(quicvarint.Len(uint64(l))) + l
padding := tcpRequestPadding.String()
paddingLen := len(padding)
addrLen := len(addr)
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
int(quicvarint.Len(uint64(addrLen))) + addrLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
i := varintPut(buf, FrameTypeTCPRequest)
i += varintPut(buf[i:], uint64(l))
copy(buf[i:], addr)
i += varintPut(buf[i:], uint64(addrLen))
i += copy(buf[i:], addr)
i += varintPut(buf[i:], uint64(paddingLen))
copy(buf[i:], padding)
_, err := w.Write(buf)
return err
}
@ -60,6 +87,8 @@ func WriteTCPRequest(w io.Writer, addr string) error {
// Status (byte, 0=ok, 1=error)
// Message length (QUIC varint)
// Message (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadTCPResponse(r io.Reader) (bool, string, error) {
var status [1]byte
@ -67,45 +96,90 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) {
return false, "", err
}
bReader := quicvarint.NewReader(r)
l, err := quicvarint.Read(bReader)
msgLen, err := quicvarint.Read(bReader)
if err != nil {
return false, "", err
}
if l == 0 {
// No message is ok
return status[0] == 0, "", nil
}
if l > MaxMessageLength {
if msgLen > MaxMessageLength {
return false, "", errors.ProtocolError{Message: "invalid message length"}
}
buf := make([]byte, l)
_, err = io.ReadFull(r, buf)
return status[0] == 0, string(buf), err
var msgBuf []byte
// No message is fine
if msgLen > 0 {
msgBuf = make([]byte, msgLen)
_, err = io.ReadFull(r, msgBuf)
if err != nil {
return false, "", err
}
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return false, "", err
}
if paddingLen > MaxPaddingLength {
return false, "", errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return false, "", err
}
}
return status[0] == 0, string(msgBuf), nil
}
func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
l := len(msg)
sz := 1 + int(quicvarint.Len(uint64(l))) + l
padding := tcpResponsePadding.String()
paddingLen := len(padding)
msgLen := len(msg)
sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
if ok {
buf[0] = 0
} else {
buf[0] = 1
}
i := varintPut(buf[1:], uint64(l))
copy(buf[1+i:], msg)
i := varintPut(buf[1:], uint64(msgLen))
i += copy(buf[1+i:], msg)
i += varintPut(buf[1+i:], uint64(paddingLen))
copy(buf[1+i:], padding)
_, err := w.Write(buf)
return err
}
// UDPRequest format:
// 0x402 (QUIC varint)
// Padding length (QUIC varint)
// Padding (bytes)
// Nothing to read
func ReadUDPRequest(r io.Reader) error {
bReader := quicvarint.NewReader(r)
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return err
}
if paddingLen > MaxPaddingLength {
return errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return err
}
}
return nil
}
func WriteUDPRequest(w io.Writer) error {
buf := make([]byte, quicvarint.Len(FrameTypeUDPRequest))
varintPut(buf, FrameTypeUDPRequest)
padding := udpRequestPadding.String()
paddingLen := len(padding)
sz := int(quicvarint.Len(FrameTypeUDPRequest)) +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
i := varintPut(buf, FrameTypeUDPRequest)
i += varintPut(buf[i:], uint64(paddingLen))
copy(buf[i:], padding)
_, err := w.Write(buf)
return err
}
@ -115,6 +189,8 @@ func WriteUDPRequest(w io.Writer) error {
// Session ID (uint32 BE)
// Message length (QUIC varint)
// Message (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
var status [1]byte
@ -126,33 +202,55 @@ func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
return false, 0, "", err
}
bReader := quicvarint.NewReader(r)
l, err := quicvarint.Read(bReader)
msgLen, err := quicvarint.Read(bReader)
if err != nil {
return false, 0, "", err
}
if l == 0 {
// No message is ok
return status[0] == 0, sessionID, "", nil
}
if l > MaxMessageLength {
if msgLen > MaxMessageLength {
return false, 0, "", errors.ProtocolError{Message: "invalid message length"}
}
buf := make([]byte, l)
_, err = io.ReadFull(r, buf)
return status[0] == 0, sessionID, string(buf), err
var msgBuf []byte
// No message is fine
if msgLen > 0 {
msgBuf = make([]byte, msgLen)
_, err = io.ReadFull(r, msgBuf)
if err != nil {
return false, 0, "", err
}
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return false, 0, "", err
}
if paddingLen > MaxPaddingLength {
return false, 0, "", errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return false, 0, "", err
}
}
return status[0] == 0, sessionID, string(msgBuf), nil
}
func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error {
l := len(msg)
buf := make([]byte, 5+int(quicvarint.Len(uint64(l)))+l)
padding := udpResponsePadding.String()
paddingLen := len(padding)
msgLen := len(msg)
sz := 1 + 4 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
if ok {
buf[0] = 0
} else {
buf[0] = 1
}
binary.BigEndian.PutUint32(buf[1:], sessionID)
i := varintPut(buf[5:], uint64(l))
copy(buf[5+i:], msg)
i := varintPut(buf[5:], uint64(msgLen))
i += copy(buf[5+i:], msg)
i += varintPut(buf[5+i:], uint64(paddingLen))
copy(buf[5+i:], padding)
_, err := w.Write(buf)
return err
}

View file

@ -2,126 +2,11 @@ package protocol
import (
"bytes"
"io"
"reflect"
"strings"
"testing"
)
func TestReadTCPRequest(t *testing.T) {
type args struct {
r io.Reader
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normal 1",
args: args{
r: bytes.NewReader([]byte("\x05hello")),
},
want: "hello",
wantErr: false,
},
{
name: "normal 2",
args: args{
r: bytes.NewReader([]byte("\x41\x25We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People")),
},
want: "We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People",
wantErr: false,
},
{
name: "empty",
args: args{
r: bytes.NewReader([]byte("\x00")),
},
want: "",
wantErr: true,
},
{
name: "incomplete",
args: args{
r: bytes.NewReader([]byte("\x06oh no")),
},
want: "oh no\x00",
wantErr: true,
},
{
name: "too long",
args: args{
r: bytes.NewReader([]byte("\x66\x77\x88Whatever")),
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ReadTCPRequest(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ReadTCPRequest() got = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteTCPRequest(t *testing.T) {
type args struct {
addr string
}
tests := []struct {
name string
args args
wantW string
wantErr bool
}{
{
name: "normal 1",
args: args{
addr: "hello",
},
wantW: "\x44\x01\x05hello",
wantErr: false,
},
{
name: "normal 2",
args: args{
addr: "We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People",
},
wantW: "\x44\x01\x41\x25We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People We the People",
wantErr: false,
},
{
name: "empty",
args: args{
addr: "",
},
wantW: "\x44\x01\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := WriteTCPRequest(w, tt.args.addr)
if (err != nil) != tt.wantErr {
t.Errorf("WriteTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
t.Errorf("WriteTCPRequest() gotW = %v, want %v", []byte(gotW), tt.wantW)
}
})
}
}
func TestUDPMessage(t *testing.T) {
t.Run("buffer too small", func(t *testing.T) {
// Make sure Serialize returns -1 when the buffer is too small.
@ -240,75 +125,138 @@ func TestUDPMessageMalformed(t *testing.T) {
}
}
func TestReadTCPResponse(t *testing.T) {
type args struct {
r io.Reader
}
func TestReadTCPRequest(t *testing.T) {
tests := []struct {
name string
args args
data []byte
want string
wantErr bool
}{
{
name: "normal no padding",
data: []byte("\x0egoogle.com:443\x00"),
want: "google.com:443",
wantErr: false,
},
{
name: "normal with padding",
data: []byte("\x0bholy.cc:443\x02gg"),
want: "holy.cc:443",
wantErr: false,
},
{
name: "incomplete 1",
data: []byte("\x0bhoho"),
want: "",
wantErr: true,
},
{
name: "incomplete 2",
data: []byte("\x0bholy.cc:443\x05x"),
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, err := ReadTCPRequest(r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ReadTCPRequest() got = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteTCPRequest(t *testing.T) {
tests := []struct {
name string
addr string
wantW string // Just a prefix, we don't care about the padding
wantErr bool
}{
{
name: "normal 1",
addr: "google.com:443",
wantW: "\x44\x01\x0egoogle.com:443",
wantErr: false,
},
{
name: "normal 2",
addr: "client-api.arkoselabs.com:8080",
wantW: "\x44\x01\x1eclient-api.arkoselabs.com:8080",
wantErr: false,
},
{
name: "empty",
addr: "",
wantW: "\x44\x01\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := WriteTCPRequest(w, tt.addr)
if (err != nil) != tt.wantErr {
t.Errorf("WriteTCPRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
t.Errorf("WriteTCPRequest() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadTCPResponse(t *testing.T) {
tests := []struct {
name string
data []byte
want bool
want1 string
wantErr bool
}{
{
name: "success 1",
args: args{
r: bytes.NewReader([]byte("\x00\x00")),
},
name: "normal ok no padding",
data: []byte("\x00\x0bhello world\x00"),
want: true,
want1: "hello world",
wantErr: false,
},
{
name: "normal error with padding",
data: []byte("\x01\x06stop!!\x05xxxxx"),
want1: "stop!!",
wantErr: false,
},
{
name: "normal ok no message with padding",
data: []byte("\x01\x00\x05xxxxx"),
want1: "",
wantErr: false,
},
{
name: "success 2",
args: args{
r: bytes.NewReader([]byte("\x00\x12are ya winning son")),
},
want: true,
want1: "are ya winning son",
wantErr: false,
},
{
name: "failure 1",
args: args{
r: bytes.NewReader([]byte("\x01\x00")),
},
want: false,
name: "incomplete 1",
data: []byte("\x00\x0bhoho"),
want1: "",
wantErr: false,
},
{
name: "failure 2",
args: args{
r: bytes.NewReader([]byte("\x01\x15you ain't winning son")),
},
want: false,
want1: "you ain't winning son",
wantErr: false,
},
{
name: "incomplete",
args: args{
r: bytes.NewReader([]byte("\x01\x25princess peach is in another castle")),
},
want: false,
want1: "princess peach is in another castle\x00\x00",
wantErr: true,
},
{
name: "too long",
args: args{
r: bytes.NewReader([]byte("\xAA\xBB\xCCrandom stuff")),
},
want: false,
name: "incomplete 2",
data: []byte("\x01\x05jesus\x05x"),
want1: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := ReadTCPResponse(tt.args.r)
r := bytes.NewReader(tt.data)
got, got1, err := ReadTCPResponse(r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadTCPResponse() error = %v, wantErr %v", err, tt.wantErr)
return
@ -331,40 +279,26 @@ func TestWriteTCPResponse(t *testing.T) {
tests := []struct {
name string
args args
wantW string
wantW string // Just a prefix, we don't care about the padding
wantErr bool
}{
{
name: "success 1",
args: args{
ok: true,
msg: "",
},
wantW: "\x00\x00",
name: "normal ok",
args: args{ok: true, msg: "hello world"},
wantW: "\x00\x0bhello world",
wantErr: false,
},
{
name: "success 2",
args: args{
ok: true,
msg: "Welcome XDXDXD",
},
wantW: "\x00\x0EWelcome XDXDXD",
name: "normal error",
args: args{ok: false, msg: "stop!!"},
wantW: "\x01\x06stop!!",
wantErr: false,
},
{
name: "failure 1",
args: args{
ok: false,
msg: "",
},
wantW: "\x01\x00",
},
{
name: "failure 2",
args: args{
ok: false,
msg: "me trying to find who u are: ...",
},
wantW: "\x01\x20me trying to find who u are: ...",
name: "empty",
args: args{ok: true, msg: ""},
wantW: "\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
@ -375,17 +309,49 @@ func TestWriteTCPResponse(t *testing.T) {
t.Errorf("WriteTCPResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
t.Errorf("WriteTCPResponse() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadUDPRequest(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{
name: "normal no padding",
data: []byte("\x00\x00"),
wantErr: false,
},
{
name: "normal with padding",
data: []byte("\x02gg"),
wantErr: false,
},
{
name: "incomplete 1",
data: []byte("\x0bhoho"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
if err := ReadUDPRequest(r); (err != nil) != tt.wantErr {
t.Errorf("ReadUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestWriteUDPRequest(t *testing.T) {
tests := []struct {
name string
wantW string
wantW string // Just a prefix, we don't care about the padding
wantErr bool
}{
{
@ -402,7 +368,7 @@ func TestWriteUDPRequest(t *testing.T) {
t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW)
}
})
@ -410,62 +376,49 @@ func TestWriteUDPRequest(t *testing.T) {
}
func TestReadUDPResponse(t *testing.T) {
type args struct {
r io.Reader
}
tests := []struct {
name string
args args
data []byte
want bool
want1 uint32
want2 string
wantErr bool
}{
{
name: "success 1",
args: args{
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x02\x00")),
},
name: "normal ok no padding",
data: []byte("\x00\x00\x00\x00\x33\x0bhello world\x00"),
want: true,
want1: 2,
want1: 51,
want2: "hello world",
wantErr: false,
},
{
name: "normal error with padding",
data: []byte("\x01\x00\x00\x33\x33\x06stop!!\x05xxxxx"),
want: false,
want1: 13107,
want2: "stop!!",
wantErr: false,
},
{
name: "normal ok no message with padding",
data: []byte("\x00\x00\x00\x00\x33\x00\x05xxxxx"),
want: true,
want1: 51,
want2: "",
wantErr: false,
},
{
name: "success 2",
args: args{
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x03\x0EWelcome XDXDXD")),
},
want: true,
want1: 3,
want2: "Welcome XDXDXD",
wantErr: false,
},
{
name: "failure",
args: args{
r: bytes.NewReader([]byte("\x01\x00\x00\x00\x01\x20me trying to find who u are: ...")),
},
want: false,
want1: 1,
want2: "me trying to find who u are: ...",
wantErr: false,
},
{
name: "incomplete",
args: args{
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x02")),
},
name: "incomplete 1",
data: []byte("\x00\x00\x06"),
want: false,
want1: 0,
want2: "",
wantErr: true,
},
{
name: "too long",
args: args{
r: bytes.NewReader([]byte("\x00\x00\x00\x00\x02\xCC\xFF\x66no cap")),
},
name: "incomplete 2",
data: []byte("\x01\x00\x01\x02\x03\x05jesus\x05x"),
want: false,
want1: 0,
want2: "",
@ -474,7 +427,8 @@ func TestReadUDPResponse(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, got2, err := ReadUDPResponse(tt.args.r)
r := bytes.NewReader(tt.data)
got, got1, got2, err := ReadUDPResponse(r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
return
@ -501,44 +455,26 @@ func TestWriteUDPResponse(t *testing.T) {
tests := []struct {
name string
args args
wantW string
wantW string // Just a prefix, we don't care about the padding
wantErr bool
}{
{
name: "success 1",
args: args{
ok: true,
sessionID: 88,
msg: "",
},
wantW: "\x00\x00\x00\x00\x58\x00",
name: "normal ok",
args: args{ok: true, sessionID: 6, msg: "hello world"},
wantW: "\x00\x00\x00\x00\x06\x0bhello world",
wantErr: false,
},
{
name: "success 2",
args: args{
ok: true,
sessionID: 233,
msg: "together forever",
},
wantW: "\x00\x00\x00\x00\xE9\x10together forever",
name: "normal error",
args: args{ok: false, sessionID: 7, msg: "stop!!"},
wantW: "\x01\x00\x00\x00\x07\x06stop!!",
wantErr: false,
},
{
name: "failure 1",
args: args{
ok: false,
sessionID: 1,
msg: "",
},
wantW: "\x01\x00\x00\x00\x01\x00",
},
{
name: "failure 2",
args: args{
ok: false,
sessionID: 696969,
msg: "run away run away",
},
wantW: "\x01\x00\x0A\xA2\x89\x11run away run away",
name: "empty",
args: args{ok: true, sessionID: 0, msg: ""},
wantW: "\x00\x00\x00\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
@ -549,7 +485,7 @@ func TestWriteUDPResponse(t *testing.T) {
t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); gotW != tt.wantW {
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW)
}
})

View file

@ -289,6 +289,12 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
_ = stream.Close()
return
}
// Read request
err := protocol.ReadUDPRequest(stream)
if err != nil {
_ = stream.Close()
return
}
// Add to session manager
sessionID, conn, connCloseFunc, err := h.udpSM.Add()
if err != nil {