mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-04 13:07:39 +03:00
feat(wip): udp rework server side
This commit is contained in:
parent
6245f83262
commit
a2fbcc6507
17 changed files with 554 additions and 513 deletions
|
@ -400,7 +400,7 @@ func (c *udpConn) Receive() ([]byte, string, error) {
|
||||||
// Send is not thread-safe as it uses a shared send buffer for now.
|
// Send is not thread-safe as it uses a shared send buffer for now.
|
||||||
func (c *udpConn) Send(data []byte, addr string) error {
|
func (c *udpConn) Send(data []byte, addr string) error {
|
||||||
// Try no frag first
|
// Try no frag first
|
||||||
msg := protocol.UDPMessage{
|
msg := &protocol.UDPMessage{
|
||||||
SessionID: c.SessionID,
|
SessionID: c.SessionID,
|
||||||
PacketID: 0,
|
PacketID: 0,
|
||||||
FragID: 0,
|
FragID: 0,
|
||||||
|
|
|
@ -4,6 +4,7 @@ go 1.20
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000
|
github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000
|
||||||
|
go.uber.org/goleak v1.2.1
|
||||||
golang.org/x/time v0.3.0
|
golang.org/x/time v0.3.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
|
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
||||||
|
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
|
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
|
||||||
|
|
|
@ -4,9 +4,9 @@ import (
|
||||||
"github.com/apernet/hysteria/core/internal/protocol"
|
"github.com/apernet/hysteria/core/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
||||||
if m.Size() <= maxSize {
|
if m.Size() <= maxSize {
|
||||||
return []protocol.UDPMessage{m}
|
return []protocol.UDPMessage{*m}
|
||||||
}
|
}
|
||||||
fullPayload := m.Data
|
fullPayload := m.Data
|
||||||
maxPayloadSize := maxSize - m.HeaderSize()
|
maxPayloadSize := maxSize - m.HeaderSize()
|
||||||
|
@ -19,7 +19,7 @@ func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
||||||
if payloadSize > maxPayloadSize {
|
if payloadSize > maxPayloadSize {
|
||||||
payloadSize = maxPayloadSize
|
payloadSize = maxPayloadSize
|
||||||
}
|
}
|
||||||
frag := m
|
frag := *m
|
||||||
frag.FragID = fragID
|
frag.FragID = fragID
|
||||||
frag.FragCount = fragCount
|
frag.FragCount = fragCount
|
||||||
frag.Data = fullPayload[off : off+payloadSize]
|
frag.Data = fullPayload[off : off+payloadSize]
|
||||||
|
|
|
@ -124,7 +124,7 @@ func TestFragUDPMessage(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := FragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
|
if got := FragUDPMessage(&tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("FragUDPMessage() = %v, want %v", got, tt.want)
|
t.Errorf("FragUDPMessage() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -36,7 +36,7 @@ func TestClientNoServer(t *testing.T) {
|
||||||
// Try UDP
|
// Try UDP
|
||||||
_, err = c.ListenUDP()
|
_, err = c.ListenUDP()
|
||||||
if !errors.As(err, &cErr) {
|
if !errors.As(err, &cErr) {
|
||||||
t.Fatal("expected connect error from ListenUDP")
|
t.Fatal("expected connect error from DialUDP")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ func TestClientServerBadAuth(t *testing.T) {
|
||||||
// Try UDP
|
// Try UDP
|
||||||
_, err = c.ListenUDP()
|
_, err = c.ListenUDP()
|
||||||
if !errors.As(err, &aErr) {
|
if !errors.As(err, &aErr) {
|
||||||
t.Fatal("expected auth error from ListenUDP")
|
t.Fatal("expected auth error from DialUDP")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,31 +9,34 @@ const (
|
||||||
URLHost = "hysteria"
|
URLHost = "hysteria"
|
||||||
URLPath = "/auth"
|
URLPath = "/auth"
|
||||||
|
|
||||||
HeaderAuth = "Hysteria-Auth"
|
RequestHeaderAuth = "Hysteria-Auth"
|
||||||
HeaderCCRX = "Hysteria-CC-RX"
|
ResponseHeaderUDPEnabled = "Hysteria-UDP"
|
||||||
HeaderPadding = "Hysteria-Padding"
|
CommonHeaderCCRX = "Hysteria-CC-RX"
|
||||||
|
CommonHeaderPadding = "Hysteria-Padding"
|
||||||
|
|
||||||
StatusAuthOK = 233
|
StatusAuthOK = 233
|
||||||
)
|
)
|
||||||
|
|
||||||
func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) {
|
func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) {
|
||||||
auth = h.Get(HeaderAuth)
|
auth = h.Get(RequestHeaderAuth)
|
||||||
rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64)
|
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
|
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
|
||||||
h.Set(HeaderAuth, auth)
|
h.Set(RequestHeaderAuth, auth)
|
||||||
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10))
|
||||||
h.Set(HeaderPadding, authRequestPadding.String())
|
h.Set(CommonHeaderPadding, authRequestPadding.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
|
func AuthResponseDataFromHeader(h http.Header) (udp bool, rx uint64) {
|
||||||
rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64)
|
udp, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
|
||||||
|
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthResponseDataToHeader(h http.Header, rx uint64) {
|
func AuthResponseDataToHeader(h http.Header, udp bool, rx uint64) {
|
||||||
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(udp))
|
||||||
h.Set(HeaderPadding, authResponsePadding.String())
|
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10))
|
||||||
|
h.Set(CommonHeaderPadding, authResponsePadding.String())
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,4 @@ var (
|
||||||
authResponsePadding = padding{Min: 256, Max: 2048}
|
authResponsePadding = padding{Min: 256, Max: 2048}
|
||||||
tcpRequestPadding = padding{Min: 64, Max: 512}
|
tcpRequestPadding = padding{Min: 64, Max: 512}
|
||||||
tcpResponsePadding = padding{Min: 128, Max: 1024}
|
tcpResponsePadding = padding{Min: 128, Max: 1024}
|
||||||
udpRequestPadding = padding{Min: 64, Max: 512}
|
|
||||||
udpResponsePadding = padding{Min: 128, Max: 1024}
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FrameTypeTCPRequest = 0x401
|
FrameTypeTCPRequest = 0x401
|
||||||
FrameTypeUDPRequest = 0x402
|
|
||||||
|
|
||||||
// Max length values are for preventing DoS attacks
|
// Max length values are for preventing DoS attacks
|
||||||
|
|
||||||
|
@ -148,113 +147,6 @@ func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPRequest format:
|
|
||||||
// 0x402 (QUIC varint)
|
|
||||||
// Padding length (QUIC varint)
|
|
||||||
// Padding (bytes)
|
|
||||||
|
|
||||||
func ReadUDPRequest(r io.Reader) error {
|
|
||||||
bReader := quicvarint.NewReader(r)
|
|
||||||
paddingLen, err := quicvarint.Read(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if paddingLen > MaxPaddingLength {
|
|
||||||
return errors.ProtocolError{Message: "invalid padding length"}
|
|
||||||
}
|
|
||||||
if paddingLen > 0 {
|
|
||||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteUDPRequest(w io.Writer) error {
|
|
||||||
padding := udpRequestPadding.String()
|
|
||||||
paddingLen := len(padding)
|
|
||||||
sz := int(quicvarint.Len(FrameTypeUDPRequest)) +
|
|
||||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
|
||||||
buf := make([]byte, sz)
|
|
||||||
i := varintPut(buf, FrameTypeUDPRequest)
|
|
||||||
i += varintPut(buf[i:], uint64(paddingLen))
|
|
||||||
copy(buf[i:], padding)
|
|
||||||
_, err := w.Write(buf)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UDPResponse format:
|
|
||||||
// Status (byte, 0=ok, 1=error)
|
|
||||||
// Session ID (uint32 BE)
|
|
||||||
// Message length (QUIC varint)
|
|
||||||
// Message (bytes)
|
|
||||||
// Padding length (QUIC varint)
|
|
||||||
// Padding (bytes)
|
|
||||||
|
|
||||||
func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
|
|
||||||
var status [1]byte
|
|
||||||
if _, err := io.ReadFull(r, status[:]); err != nil {
|
|
||||||
return false, 0, "", err
|
|
||||||
}
|
|
||||||
var sessionID uint32
|
|
||||||
if err := binary.Read(r, binary.BigEndian, &sessionID); err != nil {
|
|
||||||
return false, 0, "", err
|
|
||||||
}
|
|
||||||
bReader := quicvarint.NewReader(r)
|
|
||||||
msgLen, err := quicvarint.Read(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return false, 0, "", err
|
|
||||||
}
|
|
||||||
if msgLen > MaxMessageLength {
|
|
||||||
return false, 0, "", errors.ProtocolError{Message: "invalid message length"}
|
|
||||||
}
|
|
||||||
var msgBuf []byte
|
|
||||||
// No message is fine
|
|
||||||
if msgLen > 0 {
|
|
||||||
msgBuf = make([]byte, msgLen)
|
|
||||||
_, err = io.ReadFull(r, msgBuf)
|
|
||||||
if err != nil {
|
|
||||||
return false, 0, "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
paddingLen, err := quicvarint.Read(bReader)
|
|
||||||
if err != nil {
|
|
||||||
return false, 0, "", err
|
|
||||||
}
|
|
||||||
if paddingLen > MaxPaddingLength {
|
|
||||||
return false, 0, "", errors.ProtocolError{Message: "invalid padding length"}
|
|
||||||
}
|
|
||||||
if paddingLen > 0 {
|
|
||||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
|
||||||
if err != nil {
|
|
||||||
return false, 0, "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status[0] == 0, sessionID, string(msgBuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error {
|
|
||||||
padding := udpResponsePadding.String()
|
|
||||||
paddingLen := len(padding)
|
|
||||||
msgLen := len(msg)
|
|
||||||
sz := 1 + 4 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
|
||||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
|
||||||
buf := make([]byte, sz)
|
|
||||||
if ok {
|
|
||||||
buf[0] = 0
|
|
||||||
} else {
|
|
||||||
buf[0] = 1
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint32(buf[1:], sessionID)
|
|
||||||
i := varintPut(buf[5:], uint64(msgLen))
|
|
||||||
i += copy(buf[5+i:], msg)
|
|
||||||
i += varintPut(buf[5+i:], uint64(paddingLen))
|
|
||||||
copy(buf[5+i:], padding)
|
|
||||||
_, err := w.Write(buf)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UDPMessage format:
|
// UDPMessage format:
|
||||||
// Session ID (uint32 BE)
|
// Session ID (uint32 BE)
|
||||||
// Packet ID (uint16 BE)
|
// Packet ID (uint16 BE)
|
||||||
|
|
|
@ -315,179 +315,3 @@ func TestWriteTCPResponse(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadUDPRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data []byte
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal no padding",
|
|
||||||
data: []byte("\x00\x00"),
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal with padding",
|
|
||||||
data: []byte("\x02gg"),
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "incomplete 1",
|
|
||||||
data: []byte("\x0bhoho"),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := bytes.NewReader(tt.data)
|
|
||||||
if err := ReadUDPRequest(r); (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("ReadUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteUDPRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
wantW string // Just a prefix, we don't care about the padding
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal",
|
|
||||||
wantW: "\x44\x02",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
err := WriteUDPRequest(w)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
|
||||||
t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadUDPResponse(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data []byte
|
|
||||||
want bool
|
|
||||||
want1 uint32
|
|
||||||
want2 string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal ok no padding",
|
|
||||||
data: []byte("\x00\x00\x00\x00\x33\x0bhello world\x00"),
|
|
||||||
want: true,
|
|
||||||
want1: 51,
|
|
||||||
want2: "hello world",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal error with padding",
|
|
||||||
data: []byte("\x01\x00\x00\x33\x33\x06stop!!\x05xxxxx"),
|
|
||||||
want: false,
|
|
||||||
want1: 13107,
|
|
||||||
want2: "stop!!",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal ok no message with padding",
|
|
||||||
data: []byte("\x00\x00\x00\x00\x33\x00\x05xxxxx"),
|
|
||||||
want: true,
|
|
||||||
want1: 51,
|
|
||||||
want2: "",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "incomplete 1",
|
|
||||||
data: []byte("\x00\x00\x06"),
|
|
||||||
want: false,
|
|
||||||
want1: 0,
|
|
||||||
want2: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "incomplete 2",
|
|
||||||
data: []byte("\x01\x00\x01\x02\x03\x05jesus\x05x"),
|
|
||||||
want: false,
|
|
||||||
want1: 0,
|
|
||||||
want2: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := bytes.NewReader(tt.data)
|
|
||||||
got, got1, got2, err := ReadUDPResponse(r)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("ReadUDPResponse() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
if got1 != tt.want1 {
|
|
||||||
t.Errorf("ReadUDPResponse() got1 = %v, want %v", got1, tt.want1)
|
|
||||||
}
|
|
||||||
if got2 != tt.want2 {
|
|
||||||
t.Errorf("ReadUDPResponse() got2 = %v, want %v", got2, tt.want2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteUDPResponse(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ok bool
|
|
||||||
sessionID uint32
|
|
||||||
msg string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantW string // Just a prefix, we don't care about the padding
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "normal ok",
|
|
||||||
args: args{ok: true, sessionID: 6, msg: "hello world"},
|
|
||||||
wantW: "\x00\x00\x00\x00\x06\x0bhello world",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal error",
|
|
||||||
args: args{ok: false, sessionID: 7, msg: "stop!!"},
|
|
||||||
wantW: "\x01\x00\x00\x00\x07\x06stop!!",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
args: args{ok: true, sessionID: 0, msg: ""},
|
|
||||||
wantW: "\x00\x00\x00\x00\x00\x00",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
err := WriteUDPResponse(w, tt.args.ok, tt.args.sessionID, tt.args.msg)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
|
||||||
t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
24
core/internal/utils/atomic.go
Normal file
24
core/internal/utils/atomic.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AtomicTime struct {
|
||||||
|
v atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAtomicTime(t time.Time) *AtomicTime {
|
||||||
|
a := &AtomicTime{}
|
||||||
|
a.Set(t)
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *AtomicTime) Set(new time.Time) {
|
||||||
|
t.v.Store(new)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *AtomicTime) Get() time.Time {
|
||||||
|
return t.v.Load().(time.Time)
|
||||||
|
}
|
|
@ -103,9 +103,13 @@ type QUICConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Outbound provides the implementation of how the server should connect to remote servers.
|
// Outbound provides the implementation of how the server should connect to remote servers.
|
||||||
|
// Even though it's called DialUDP, outbound implementations do not necessarily have to
|
||||||
|
// return a "connected" UDP socket that can only send and receive from reqAddr. It's the
|
||||||
|
// address of the first packet to be sent.
|
||||||
|
// It's perfectly fine to have a "full-cone" implementation for UDP.
|
||||||
type Outbound interface {
|
type Outbound interface {
|
||||||
DialTCP(reqAddr string) (net.Conn, error)
|
DialTCP(reqAddr string) (net.Conn, error)
|
||||||
ListenUDP() (UDPConn, error)
|
DialUDP(reqAddr string) (UDPConn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPConn is like net.PacketConn, but uses string for addresses.
|
// UDPConn is like net.PacketConn, but uses string for addresses.
|
||||||
|
@ -125,7 +129,7 @@ func (o *defaultOutbound) DialTCP(reqAddr string) (net.Conn, error) {
|
||||||
return defaultOutboundDialer.Dial("tcp", reqAddr)
|
return defaultOutboundDialer.Dial("tcp", reqAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *defaultOutbound) ListenUDP() (UDPConn, error) {
|
func (o *defaultOutbound) DialUDP(reqAddr string) (UDPConn, error) {
|
||||||
conn, err := net.ListenUDP("udp", nil)
|
conn, err := net.ListenUDP("udp", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -171,7 +175,7 @@ type EventLogger interface {
|
||||||
Disconnect(addr net.Addr, id string, err error)
|
Disconnect(addr net.Addr, id string, err error)
|
||||||
TCPRequest(addr net.Addr, id, reqAddr string)
|
TCPRequest(addr net.Addr, id, reqAddr string)
|
||||||
TCPError(addr net.Addr, id, reqAddr string, err error)
|
TCPError(addr net.Addr, id, reqAddr string, err error)
|
||||||
UDPRequest(addr net.Addr, id string, sessionID uint32)
|
UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string)
|
||||||
UDPError(addr net.Addr, id string, sessionID uint32, err error)
|
UDPError(addr net.Addr, id string, sessionID uint32, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,14 +3,11 @@ package server
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/hysteria/core/internal/congestion"
|
"github.com/apernet/hysteria/core/internal/congestion"
|
||||||
"github.com/apernet/hysteria/core/internal/frag"
|
|
||||||
"github.com/apernet/hysteria/core/internal/protocol"
|
"github.com/apernet/hysteria/core/internal/protocol"
|
||||||
"github.com/apernet/hysteria/core/internal/utils"
|
"github.com/apernet/hysteria/core/internal/utils"
|
||||||
|
|
||||||
|
@ -21,6 +18,8 @@ import (
|
||||||
const (
|
const (
|
||||||
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
||||||
closeErrCodeTrafficLimitReached = 0x107 // HTTP3 ErrCodeExcessiveLoad
|
closeErrCodeTrafficLimitReached = 0x107 // HTTP3 ErrCodeExcessiveLoad
|
||||||
|
|
||||||
|
udpSessionIdleTimeout = 60 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
|
@ -101,90 +100,21 @@ type h3sHandler struct {
|
||||||
authID string
|
authID string
|
||||||
|
|
||||||
udpOnce sync.Once
|
udpOnce sync.Once
|
||||||
udpSM udpSessionManager
|
udpSM *udpSessionManager // Only set after authentication
|
||||||
}
|
}
|
||||||
|
|
||||||
func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler {
|
func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler {
|
||||||
return &h3sHandler{
|
return &h3sHandler{
|
||||||
config: config,
|
config: config,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
udpSM: udpSessionManager{
|
|
||||||
listenFunc: config.Outbound.ListenUDP,
|
|
||||||
m: make(map[uint32]*udpSessionEntry),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpSessionEntry struct {
|
|
||||||
Conn UDPConn
|
|
||||||
D *frag.Defragger
|
|
||||||
Closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpSessionManager struct {
|
|
||||||
listenFunc func() (UDPConn, error)
|
|
||||||
mutex sync.RWMutex
|
|
||||||
m map[uint32]*udpSessionEntry
|
|
||||||
nextID uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add returns the session ID, the UDP connection and a function to close the UDP connection & delete the session.
|
|
||||||
func (m *udpSessionManager) Add() (uint32, UDPConn, func(), error) {
|
|
||||||
conn, err := m.listenFunc()
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
id := m.nextID
|
|
||||||
m.nextID++
|
|
||||||
entry := &udpSessionEntry{
|
|
||||||
Conn: conn,
|
|
||||||
D: &frag.Defragger{},
|
|
||||||
Closed: false,
|
|
||||||
}
|
|
||||||
m.m[id] = entry
|
|
||||||
|
|
||||||
return id, conn, func() {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
if entry.Closed {
|
|
||||||
// Already closed
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entry.Closed = true
|
|
||||||
_ = conn.Close()
|
|
||||||
delete(m.m, id)
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Feed feeds a UDP message to the session manager.
|
|
||||||
// If the message itself is a complete message, or it's the last fragment of a message,
|
|
||||||
// it will be sent to the UDP connection.
|
|
||||||
// The function will then return the number of bytes sent and any error occurred.
|
|
||||||
func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) (int, error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
entry, ok := m.m[msg.SessionID]
|
|
||||||
if !ok {
|
|
||||||
// No such session, drop the message
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
dfMsg := entry.D.Feed(msg)
|
|
||||||
if dfMsg == nil {
|
|
||||||
// Not a complete message yet
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
|
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
|
||||||
if h.authenticated {
|
if h.authenticated {
|
||||||
// Already authenticated
|
// Already authenticated
|
||||||
protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
|
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx)
|
||||||
w.WriteHeader(protocol.StatusAuthOK)
|
w.WriteHeader(protocol.StatusAuthOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -204,18 +134,23 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
h.conn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
|
h.conn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
|
||||||
}
|
}
|
||||||
// Auth OK, send response
|
// Auth OK, send response
|
||||||
protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
|
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx)
|
||||||
w.WriteHeader(protocol.StatusAuthOK)
|
w.WriteHeader(protocol.StatusAuthOK)
|
||||||
// Call event logger
|
// Call event logger
|
||||||
if h.config.EventLogger != nil {
|
if h.config.EventLogger != nil {
|
||||||
h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx)
|
h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx)
|
||||||
}
|
}
|
||||||
// Start UDP loop if UDP is not disabled
|
// Initialize UDP session manager (if UDP is enabled)
|
||||||
// We use sync.Once to make sure that only one goroutine is started,
|
// We use sync.Once to make sure that only one goroutine is started,
|
||||||
// as ServeHTTP may be called by multiple goroutines simultaneously
|
// as ServeHTTP may be called by multiple goroutines simultaneously
|
||||||
if !h.config.DisableUDP {
|
if !h.config.DisableUDP {
|
||||||
h.udpOnce.Do(func() {
|
h.udpOnce.Do(func() {
|
||||||
go h.udpLoop()
|
sm := newUDPSessionManager(
|
||||||
|
&udpsmIO{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
|
||||||
|
&udpsmEventLogger{h.conn, id, h.config.EventLogger},
|
||||||
|
udpSessionIdleTimeout)
|
||||||
|
h.udpSM = sm
|
||||||
|
go sm.Run()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -240,9 +175,6 @@ func (h *h3sHandler) ProxyStreamHijacker(ft http3.FrameType, conn quic.Connectio
|
||||||
case protocol.FrameTypeTCPRequest:
|
case protocol.FrameTypeTCPRequest:
|
||||||
go h.handleTCPRequest(stream)
|
go h.handleTCPRequest(stream)
|
||||||
return true, nil
|
return true, nil
|
||||||
case protocol.FrameTypeUDPRequest:
|
|
||||||
go h.handleUDPRequest(stream)
|
|
||||||
return true, nil
|
|
||||||
default:
|
default:
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -290,125 +222,6 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
|
|
||||||
if h.config.DisableUDP {
|
|
||||||
// UDP is disabled, send error message and close the stream
|
|
||||||
_ = protocol.WriteUDPResponse(stream, false, 0, "UDP is disabled on this server")
|
|
||||||
_ = stream.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Read request
|
|
||||||
err := protocol.ReadUDPRequest(stream)
|
|
||||||
if err != nil {
|
|
||||||
_ = stream.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Add to session manager
|
|
||||||
sessionID, conn, connCloseFunc, err := h.udpSM.Add()
|
|
||||||
if err != nil {
|
|
||||||
_ = protocol.WriteUDPResponse(stream, false, 0, err.Error())
|
|
||||||
_ = stream.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send response
|
|
||||||
_ = protocol.WriteUDPResponse(stream, true, sessionID, "")
|
|
||||||
// Call event logger
|
|
||||||
if h.config.EventLogger != nil {
|
|
||||||
h.config.EventLogger.UDPRequest(h.conn.RemoteAddr(), h.authID, sessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// client <- remote direction
|
|
||||||
go func() {
|
|
||||||
udpBuf := make([]byte, protocol.MaxUDPSize)
|
|
||||||
msgBuf := make([]byte, protocol.MaxUDPSize)
|
|
||||||
for {
|
|
||||||
udpN, rAddr, err := conn.ReadFrom(udpBuf)
|
|
||||||
if err != nil {
|
|
||||||
connCloseFunc()
|
|
||||||
_ = stream.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if h.config.TrafficLogger != nil {
|
|
||||||
ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
|
|
||||||
if !ok {
|
|
||||||
// TrafficLogger requested to disconnect the client
|
|
||||||
_ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Try no frag first
|
|
||||||
msg := protocol.UDPMessage{
|
|
||||||
SessionID: sessionID,
|
|
||||||
PacketID: 0,
|
|
||||||
FragID: 0,
|
|
||||||
FragCount: 1,
|
|
||||||
Addr: rAddr,
|
|
||||||
Data: udpBuf[:udpN],
|
|
||||||
}
|
|
||||||
msgN := msg.Serialize(msgBuf)
|
|
||||||
if msgN < 0 {
|
|
||||||
// Message even larger than MaxUDPSize, drop it
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sendErr := h.conn.SendMessage(msgBuf[:msgN])
|
|
||||||
var errTooLarge quic.ErrMessageTooLarge
|
|
||||||
if errors.As(sendErr, &errTooLarge) {
|
|
||||||
// Message too large, try fragmentation
|
|
||||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
|
||||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
|
||||||
for _, fMsg := range fMsgs {
|
|
||||||
msgN = fMsg.Serialize(msgBuf)
|
|
||||||
_ = h.conn.SendMessage(msgBuf[:msgN])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Hold (drain) the stream until the client closes it.
|
|
||||||
// Closing the stream is the signal to stop the UDP session.
|
|
||||||
_, err = io.Copy(io.Discard, stream)
|
|
||||||
// Call event logger
|
|
||||||
if h.config.EventLogger != nil {
|
|
||||||
h.config.EventLogger.UDPError(h.conn.RemoteAddr(), h.authID, sessionID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
connCloseFunc()
|
|
||||||
_ = stream.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *h3sHandler) udpLoop() {
|
|
||||||
for {
|
|
||||||
msg, err := h.conn.ReceiveMessage()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := h.handleUDPMessage(msg)
|
|
||||||
if !ok {
|
|
||||||
// TrafficLogger requested to disconnect the client
|
|
||||||
_ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// client -> remote direction
|
|
||||||
// Returns a bool indicating whether the receiving loop should continue
|
|
||||||
func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) {
|
|
||||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if h.config.TrafficLogger != nil {
|
|
||||||
ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, _ = h.udpSM.Feed(udpMsg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if h.config.MasqHandler != nil {
|
if h.config.MasqHandler != nil {
|
||||||
h.config.MasqHandler.ServeHTTP(w, r)
|
h.config.MasqHandler.ServeHTTP(w, r)
|
||||||
|
@ -417,3 +230,74 @@ func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// udpsmIO is the IO implementation for udpSessionManager with TrafficLogger support
|
||||||
|
type udpsmIO struct {
|
||||||
|
Conn quic.Connection
|
||||||
|
AuthID string
|
||||||
|
TrafficLogger TrafficLogger
|
||||||
|
Outbound Outbound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (io *udpsmIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||||
|
for {
|
||||||
|
msg, err := io.Conn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
// Connection error, this will stop the session manager
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
// Invalid message, this is fine - just wait for the next
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if io.TrafficLogger != nil {
|
||||||
|
ok := io.TrafficLogger.Log(io.AuthID, uint64(len(udpMsg.Data)), 0)
|
||||||
|
if !ok {
|
||||||
|
// TrafficLogger requested to disconnect the client
|
||||||
|
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
||||||
|
return nil, errDisconnect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return udpMsg, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (io *udpsmIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||||
|
if io.TrafficLogger != nil {
|
||||||
|
ok := io.TrafficLogger.Log(io.AuthID, 0, uint64(len(msg.Data)))
|
||||||
|
if !ok {
|
||||||
|
// TrafficLogger requested to disconnect the client
|
||||||
|
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
||||||
|
return errDisconnect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msgN := msg.Serialize(buf)
|
||||||
|
if msgN < 0 {
|
||||||
|
// Message larger than buffer, silent drop
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return io.Conn.SendMessage(buf[:msgN])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (io *udpsmIO) DialUDP(reqAddr string) (UDPConn, error) {
|
||||||
|
return io.Outbound.DialUDP(reqAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpsmEventLogger struct {
|
||||||
|
Conn quic.Connection
|
||||||
|
AuthID string
|
||||||
|
EventLogger EventLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *udpsmEventLogger) New(sessionID uint32, reqAddr string) {
|
||||||
|
if l.EventLogger != nil {
|
||||||
|
l.EventLogger.UDPRequest(l.Conn.RemoteAddr(), l.AuthID, sessionID, reqAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *udpsmEventLogger) Closed(sessionID uint32, err error) {
|
||||||
|
if l.EventLogger != nil {
|
||||||
|
l.EventLogger.UDPError(l.Conn.RemoteAddr(), l.AuthID, sessionID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
218
core/server/udp.go
Normal file
218
core/server/udp.go
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/quic-go/quic-go"
|
||||||
|
|
||||||
|
"github.com/apernet/hysteria/core/internal/frag"
|
||||||
|
"github.com/apernet/hysteria/core/internal/protocol"
|
||||||
|
"github.com/apernet/hysteria/core/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
idleCleanupInterval = 1 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpSessionManagerIO interface {
|
||||||
|
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||||
|
SendMessage([]byte, *protocol.UDPMessage) error
|
||||||
|
DialUDP(reqAddr string) (UDPConn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpSessionManagerEventLogger interface {
|
||||||
|
New(sessionID uint32, reqAddr string)
|
||||||
|
Closed(sessionID uint32, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpSessionEntry struct {
|
||||||
|
ID uint32
|
||||||
|
Conn UDPConn
|
||||||
|
D *frag.Defragger
|
||||||
|
Last *utils.AtomicTime
|
||||||
|
Closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feed feeds a UDP message to the session.
|
||||||
|
// If the message itself is a complete message, or it completes a fragmented message,
|
||||||
|
// the message is written to the session's UDP connection, and the number of bytes
|
||||||
|
// written is returned.
|
||||||
|
// Otherwise, 0 and nil are returned.
|
||||||
|
func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
|
||||||
|
e.Last.Set(time.Now())
|
||||||
|
dfMsg := e.D.Feed(msg)
|
||||||
|
if dfMsg == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
|
||||||
|
// and sends using the provided io.
|
||||||
|
// Exit and returns error when either the underlying UDP connection returns
|
||||||
|
// error (e.g. closed), or the provided io returns error when sending.
|
||||||
|
func (e *udpSessionEntry) ReceiveLoop(io udpSessionManagerIO) error {
|
||||||
|
udpBuf := make([]byte, protocol.MaxUDPSize)
|
||||||
|
msgBuf := make([]byte, protocol.MaxUDPSize)
|
||||||
|
for {
|
||||||
|
udpN, rAddr, err := e.Conn.ReadFrom(udpBuf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
e.Last.Set(time.Now())
|
||||||
|
|
||||||
|
msg := &protocol.UDPMessage{
|
||||||
|
SessionID: e.ID,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: rAddr,
|
||||||
|
Data: udpBuf[:udpN],
|
||||||
|
}
|
||||||
|
err = sendMessageAutoFrag(io, msgBuf, msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMessageAutoFrag tries to send a UDP message as a whole first,
|
||||||
|
// but if it fails due to quic.ErrMessageTooLarge, it tries again by
|
||||||
|
// fragmenting the message.
|
||||||
|
func sendMessageAutoFrag(io udpSessionManagerIO, buf []byte, msg *protocol.UDPMessage) error {
|
||||||
|
err := io.SendMessage(buf, msg)
|
||||||
|
var errTooLarge quic.ErrMessageTooLarge
|
||||||
|
if errors.As(err, &errTooLarge) {
|
||||||
|
// Message too large, try fragmentation
|
||||||
|
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||||
|
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||||
|
for _, fMsg := range fMsgs {
|
||||||
|
err := io.SendMessage(buf, &fMsg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpSessionManager manages the lifecycle of UDP sessions.
|
||||||
|
// Each UDP session is identified by a SessionID, and corresponds to a UDP connection.
|
||||||
|
// A UDP session is created when a UDP message with a new SessionID is received.
|
||||||
|
// Similar to standard NAT, a UDP session is destroyed when no UDP message is received
|
||||||
|
// for a certain period of time (specified by idleTimeout).
|
||||||
|
type udpSessionManager struct {
|
||||||
|
io udpSessionManagerIO
|
||||||
|
eventLogger udpSessionManagerEventLogger
|
||||||
|
idleTimeout time.Duration
|
||||||
|
|
||||||
|
mutex sync.Mutex
|
||||||
|
m map[uint32]*udpSessionEntry
|
||||||
|
nextID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPSessionManager(
|
||||||
|
io udpSessionManagerIO,
|
||||||
|
eventLogger udpSessionManagerEventLogger,
|
||||||
|
idleTimeout time.Duration,
|
||||||
|
) *udpSessionManager {
|
||||||
|
return &udpSessionManager{
|
||||||
|
io: io,
|
||||||
|
eventLogger: eventLogger,
|
||||||
|
idleTimeout: idleTimeout,
|
||||||
|
m: make(map[uint32]*udpSessionEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run runs the session manager main loop.
|
||||||
|
// Exit and returns error when the underlying io returns error (e.g. closed).
|
||||||
|
func (m *udpSessionManager) Run() error {
|
||||||
|
stopCh := make(chan struct{})
|
||||||
|
go m.idleCleanupLoop(stopCh)
|
||||||
|
defer close(stopCh)
|
||||||
|
defer m.cleanup(false)
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := m.io.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.feed(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) {
|
||||||
|
ticker := time.NewTicker(idleCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
m.cleanup(true)
|
||||||
|
case <-stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) cleanup(idleOnly bool) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for sessionID, entry := range m.m {
|
||||||
|
if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout {
|
||||||
|
entry.Closed = true
|
||||||
|
_ = entry.Conn.Close()
|
||||||
|
m.eventLogger.Closed(sessionID, nil)
|
||||||
|
delete(m.m, sessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
|
||||||
|
entry := m.m[msg.SessionID]
|
||||||
|
if entry == nil {
|
||||||
|
// New session
|
||||||
|
m.eventLogger.New(msg.SessionID, msg.Addr)
|
||||||
|
conn, err := m.io.DialUDP(msg.Addr)
|
||||||
|
if err != nil {
|
||||||
|
m.mutex.Unlock()
|
||||||
|
m.eventLogger.Closed(msg.SessionID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry = &udpSessionEntry{
|
||||||
|
ID: msg.SessionID,
|
||||||
|
Conn: conn,
|
||||||
|
D: &frag.Defragger{},
|
||||||
|
Last: utils.NewAtomicTime(time.Now()),
|
||||||
|
}
|
||||||
|
// Start the receive loop for this session
|
||||||
|
go func() {
|
||||||
|
err := entry.ReceiveLoop(m.io)
|
||||||
|
// Receive loop stopped, remove the session
|
||||||
|
m.mutex.Lock()
|
||||||
|
if !entry.Closed {
|
||||||
|
entry.Closed = true
|
||||||
|
_ = entry.Conn.Close()
|
||||||
|
m.eventLogger.Closed(entry.ID, err)
|
||||||
|
delete(m.m, entry.ID)
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
m.m[msg.SessionID] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
// Feed the message to the session
|
||||||
|
// Feed (send) errors are ignored for now,
|
||||||
|
// as some are temporary (e.g. invalid address)
|
||||||
|
_, _ = entry.Feed(msg)
|
||||||
|
}
|
191
core/server/udp_test.go
Normal file
191
core/server/udp_test.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/apernet/hysteria/core/internal/protocol"
|
||||||
|
"go.uber.org/goleak"
|
||||||
|
)
|
||||||
|
|
||||||
|
type echoUDPConnPkt struct {
|
||||||
|
Data []byte
|
||||||
|
Addr string
|
||||||
|
Close bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type echoUDPConn struct {
|
||||||
|
PktCh chan echoUDPConnPkt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
|
||||||
|
pkt := <-c.PktCh
|
||||||
|
if pkt.Close {
|
||||||
|
return 0, "", errors.New("closed")
|
||||||
|
}
|
||||||
|
n := copy(b, pkt.Data)
|
||||||
|
return n, pkt.Addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *echoUDPConn) WriteTo(b []byte, addr string) (int, error) {
|
||||||
|
nb := make([]byte, len(b))
|
||||||
|
copy(nb, b)
|
||||||
|
c.PktCh <- echoUDPConnPkt{
|
||||||
|
Data: nb,
|
||||||
|
Addr: addr,
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *echoUDPConn) Close() error {
|
||||||
|
c.PktCh <- echoUDPConnPkt{
|
||||||
|
Close: true,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpsmMockIO struct {
|
||||||
|
ReceiveCh <-chan *protocol.UDPMessage
|
||||||
|
SendCh chan<- *protocol.UDPMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (io *udpsmMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||||
|
m := <-io.ReceiveCh
|
||||||
|
if m == nil {
|
||||||
|
return nil, errors.New("closed")
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (io *udpsmMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||||
|
nMsg := *msg
|
||||||
|
nMsg.Data = make([]byte, len(msg.Data))
|
||||||
|
copy(nMsg.Data, msg.Data)
|
||||||
|
io.SendCh <- &nMsg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (io *udpsmMockIO) DialUDP(reqAddr string) (UDPConn, error) {
|
||||||
|
return &echoUDPConn{
|
||||||
|
PktCh: make(chan echoUDPConnPkt, 10),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpsmMockEventNew struct {
|
||||||
|
SessionID uint32
|
||||||
|
ReqAddr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpsmMockEventClosed struct {
|
||||||
|
SessionID uint32
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpsmMockEventLogger struct {
|
||||||
|
NewCh chan<- udpsmMockEventNew
|
||||||
|
ClosedCh chan<- udpsmMockEventClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *udpsmMockEventLogger) New(sessionID uint32, reqAddr string) {
|
||||||
|
l.NewCh <- udpsmMockEventNew{sessionID, reqAddr}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *udpsmMockEventLogger) Closed(sessionID uint32, err error) {
|
||||||
|
l.ClosedCh <- udpsmMockEventClosed{sessionID, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPSessionManager(t *testing.T) {
|
||||||
|
msgReceiveCh := make(chan *protocol.UDPMessage, 10)
|
||||||
|
msgSendCh := make(chan *protocol.UDPMessage, 10)
|
||||||
|
io := &udpsmMockIO{
|
||||||
|
ReceiveCh: msgReceiveCh,
|
||||||
|
SendCh: msgSendCh,
|
||||||
|
}
|
||||||
|
eventNewCh := make(chan udpsmMockEventNew, 10)
|
||||||
|
eventClosedCh := make(chan udpsmMockEventClosed, 10)
|
||||||
|
eventLogger := &udpsmMockEventLogger{
|
||||||
|
NewCh: eventNewCh,
|
||||||
|
ClosedCh: eventClosedCh,
|
||||||
|
}
|
||||||
|
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
|
||||||
|
go sm.Run()
|
||||||
|
|
||||||
|
ms := []*protocol.UDPMessage{
|
||||||
|
{
|
||||||
|
SessionID: 1234,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: "example.com:5353",
|
||||||
|
Data: []byte("hello"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: 5678,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: "example.com:9999",
|
||||||
|
Data: []byte("goodbye"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: 1234,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: "example.com:5353",
|
||||||
|
Data: []byte(" world"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SessionID: 5678,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: "example.com:9999",
|
||||||
|
Data: []byte(" girl"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, m := range ms {
|
||||||
|
msgReceiveCh <- m
|
||||||
|
}
|
||||||
|
// New event order should be consistent
|
||||||
|
newEvent := <-eventNewCh
|
||||||
|
if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" {
|
||||||
|
t.Error("unexpected new event value")
|
||||||
|
}
|
||||||
|
newEvent = <-eventNewCh
|
||||||
|
if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" {
|
||||||
|
t.Error("unexpected new event value")
|
||||||
|
}
|
||||||
|
// Message order is not guaranteed
|
||||||
|
msgMap := make(map[string]bool)
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
msg := <-msgSendCh
|
||||||
|
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
|
||||||
|
}
|
||||||
|
if !(msgMap["1234:example.com:5353:hello"] &&
|
||||||
|
msgMap["5678:example.com:9999:goodbye"] &&
|
||||||
|
msgMap["1234:example.com:5353: world"] &&
|
||||||
|
msgMap["5678:example.com:9999: girl"]) {
|
||||||
|
t.Error("unexpected message value")
|
||||||
|
}
|
||||||
|
// Timeout check
|
||||||
|
startTime := time.Now()
|
||||||
|
closedMap := make(map[uint32]bool)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
closedEvent := <-eventClosedCh
|
||||||
|
closedMap[closedEvent.SessionID] = true
|
||||||
|
}
|
||||||
|
if !(closedMap[1234] && closedMap[5678]) {
|
||||||
|
t.Error("unexpected closed event value", closedMap)
|
||||||
|
}
|
||||||
|
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
|
||||||
|
t.Error("unexpected timeout duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Goroutine leak check
|
||||||
|
msgReceiveCh <- nil
|
||||||
|
time.Sleep(1 * time.Second) // Wait for internal routines to exit
|
||||||
|
goleak.VerifyNone(t)
|
||||||
|
}
|
|
@ -75,7 +75,7 @@ func (a *PluggableOutboundAdapter) DialTCP(reqAddr string) (net.Conn, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *PluggableOutboundAdapter) ListenUDP() (server.UDPConn, error) {
|
func (a *PluggableOutboundAdapter) DialUDP() (server.UDPConn, error) {
|
||||||
conn, err := a.PluggableOutbound.ListenUDP()
|
conn, err := a.PluggableOutbound.ListenUDP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -68,10 +68,10 @@ func TestPluggableOutboundAdapter(t *testing.T) {
|
||||||
if err != errWrongAddr {
|
if err != errWrongAddr {
|
||||||
t.Fatal("DialTCP with wrong addr should fail, got", err)
|
t.Fatal("DialTCP with wrong addr should fail, got", err)
|
||||||
}
|
}
|
||||||
// ListenUDP
|
// DialUDP
|
||||||
uConn, err := adapter.ListenUDP()
|
uConn, err := adapter.DialUDP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("ListenUDP failed", err)
|
t.Fatal("DialUDP failed", err)
|
||||||
}
|
}
|
||||||
// ReadFrom
|
// ReadFrom
|
||||||
b := make([]byte, 10)
|
b := make([]byte, 10)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue