mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
feat(wip): udp rework server side
This commit is contained in:
parent
6245f83262
commit
a2fbcc6507
17 changed files with 554 additions and 513 deletions
|
@ -400,7 +400,7 @@ func (c *udpConn) Receive() ([]byte, string, error) {
|
|||
// Send is not thread-safe as it uses a shared send buffer for now.
|
||||
func (c *udpConn) Send(data []byte, addr string) error {
|
||||
// Try no frag first
|
||||
msg := protocol.UDPMessage{
|
||||
msg := &protocol.UDPMessage{
|
||||
SessionID: c.SessionID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
|
|
|
@ -4,6 +4,7 @@ go 1.20
|
|||
|
||||
require (
|
||||
github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000
|
||||
go.uber.org/goleak v1.2.1
|
||||
golang.org/x/time v0.3.0
|
||||
)
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
||||
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
|
||||
|
|
|
@ -4,9 +4,9 @@ import (
|
|||
"github.com/apernet/hysteria/core/internal/protocol"
|
||||
)
|
||||
|
||||
func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
||||
func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
||||
if m.Size() <= maxSize {
|
||||
return []protocol.UDPMessage{m}
|
||||
return []protocol.UDPMessage{*m}
|
||||
}
|
||||
fullPayload := m.Data
|
||||
maxPayloadSize := maxSize - m.HeaderSize()
|
||||
|
@ -19,7 +19,7 @@ func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
|
|||
if payloadSize > maxPayloadSize {
|
||||
payloadSize = maxPayloadSize
|
||||
}
|
||||
frag := m
|
||||
frag := *m
|
||||
frag.FragID = fragID
|
||||
frag.FragCount = fragCount
|
||||
frag.Data = fullPayload[off : off+payloadSize]
|
||||
|
|
|
@ -124,7 +124,7 @@ func TestFragUDPMessage(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := FragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
|
||||
if got := FragUDPMessage(&tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("FragUDPMessage() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestClientNoServer(t *testing.T) {
|
|||
// Try UDP
|
||||
_, err = c.ListenUDP()
|
||||
if !errors.As(err, &cErr) {
|
||||
t.Fatal("expected connect error from ListenUDP")
|
||||
t.Fatal("expected connect error from DialUDP")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ func TestClientServerBadAuth(t *testing.T) {
|
|||
// Try UDP
|
||||
_, err = c.ListenUDP()
|
||||
if !errors.As(err, &aErr) {
|
||||
t.Fatal("expected auth error from ListenUDP")
|
||||
t.Fatal("expected auth error from DialUDP")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,31 +9,34 @@ const (
|
|||
URLHost = "hysteria"
|
||||
URLPath = "/auth"
|
||||
|
||||
HeaderAuth = "Hysteria-Auth"
|
||||
HeaderCCRX = "Hysteria-CC-RX"
|
||||
HeaderPadding = "Hysteria-Padding"
|
||||
RequestHeaderAuth = "Hysteria-Auth"
|
||||
ResponseHeaderUDPEnabled = "Hysteria-UDP"
|
||||
CommonHeaderCCRX = "Hysteria-CC-RX"
|
||||
CommonHeaderPadding = "Hysteria-Padding"
|
||||
|
||||
StatusAuthOK = 233
|
||||
)
|
||||
|
||||
func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) {
|
||||
auth = h.Get(HeaderAuth)
|
||||
rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64)
|
||||
auth = h.Get(RequestHeaderAuth)
|
||||
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
|
||||
return
|
||||
}
|
||||
|
||||
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
|
||||
h.Set(HeaderAuth, auth)
|
||||
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
||||
h.Set(HeaderPadding, authRequestPadding.String())
|
||||
h.Set(RequestHeaderAuth, auth)
|
||||
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10))
|
||||
h.Set(CommonHeaderPadding, authRequestPadding.String())
|
||||
}
|
||||
|
||||
func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
|
||||
rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64)
|
||||
func AuthResponseDataFromHeader(h http.Header) (udp bool, rx uint64) {
|
||||
udp, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
|
||||
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
|
||||
return
|
||||
}
|
||||
|
||||
func AuthResponseDataToHeader(h http.Header, rx uint64) {
|
||||
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
|
||||
h.Set(HeaderPadding, authResponsePadding.String())
|
||||
func AuthResponseDataToHeader(h http.Header, udp bool, rx uint64) {
|
||||
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(udp))
|
||||
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10))
|
||||
h.Set(CommonHeaderPadding, authResponsePadding.String())
|
||||
}
|
||||
|
|
|
@ -28,6 +28,4 @@ var (
|
|||
authResponsePadding = padding{Min: 256, Max: 2048}
|
||||
tcpRequestPadding = padding{Min: 64, Max: 512}
|
||||
tcpResponsePadding = padding{Min: 128, Max: 1024}
|
||||
udpRequestPadding = padding{Min: 64, Max: 512}
|
||||
udpResponsePadding = padding{Min: 128, Max: 1024}
|
||||
)
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
|
||||
const (
|
||||
FrameTypeTCPRequest = 0x401
|
||||
FrameTypeUDPRequest = 0x402
|
||||
|
||||
// Max length values are for preventing DoS attacks
|
||||
|
||||
|
@ -148,113 +147,6 @@ func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// UDPRequest format:
|
||||
// 0x402 (QUIC varint)
|
||||
// Padding length (QUIC varint)
|
||||
// Padding (bytes)
|
||||
|
||||
func ReadUDPRequest(r io.Reader) error {
|
||||
bReader := quicvarint.NewReader(r)
|
||||
paddingLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if paddingLen > MaxPaddingLength {
|
||||
return errors.ProtocolError{Message: "invalid padding length"}
|
||||
}
|
||||
if paddingLen > 0 {
|
||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WriteUDPRequest(w io.Writer) error {
|
||||
padding := udpRequestPadding.String()
|
||||
paddingLen := len(padding)
|
||||
sz := int(quicvarint.Len(FrameTypeUDPRequest)) +
|
||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||
buf := make([]byte, sz)
|
||||
i := varintPut(buf, FrameTypeUDPRequest)
|
||||
i += varintPut(buf[i:], uint64(paddingLen))
|
||||
copy(buf[i:], padding)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// UDPResponse format:
|
||||
// Status (byte, 0=ok, 1=error)
|
||||
// Session ID (uint32 BE)
|
||||
// Message length (QUIC varint)
|
||||
// Message (bytes)
|
||||
// Padding length (QUIC varint)
|
||||
// Padding (bytes)
|
||||
|
||||
func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
|
||||
var status [1]byte
|
||||
if _, err := io.ReadFull(r, status[:]); err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
var sessionID uint32
|
||||
if err := binary.Read(r, binary.BigEndian, &sessionID); err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
bReader := quicvarint.NewReader(r)
|
||||
msgLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
if msgLen > MaxMessageLength {
|
||||
return false, 0, "", errors.ProtocolError{Message: "invalid message length"}
|
||||
}
|
||||
var msgBuf []byte
|
||||
// No message is fine
|
||||
if msgLen > 0 {
|
||||
msgBuf = make([]byte, msgLen)
|
||||
_, err = io.ReadFull(r, msgBuf)
|
||||
if err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
}
|
||||
paddingLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
if paddingLen > MaxPaddingLength {
|
||||
return false, 0, "", errors.ProtocolError{Message: "invalid padding length"}
|
||||
}
|
||||
if paddingLen > 0 {
|
||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||
if err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
}
|
||||
return status[0] == 0, sessionID, string(msgBuf), nil
|
||||
}
|
||||
|
||||
func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error {
|
||||
padding := udpResponsePadding.String()
|
||||
paddingLen := len(padding)
|
||||
msgLen := len(msg)
|
||||
sz := 1 + 4 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
|
||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||
buf := make([]byte, sz)
|
||||
if ok {
|
||||
buf[0] = 0
|
||||
} else {
|
||||
buf[0] = 1
|
||||
}
|
||||
binary.BigEndian.PutUint32(buf[1:], sessionID)
|
||||
i := varintPut(buf[5:], uint64(msgLen))
|
||||
i += copy(buf[5+i:], msg)
|
||||
i += varintPut(buf[5+i:], uint64(paddingLen))
|
||||
copy(buf[5+i:], padding)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// UDPMessage format:
|
||||
// Session ID (uint32 BE)
|
||||
// Packet ID (uint16 BE)
|
||||
|
|
|
@ -315,179 +315,3 @@ func TestWriteTCPResponse(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadUDPRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal no padding",
|
||||
data: []byte("\x00\x00"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normal with padding",
|
||||
data: []byte("\x02gg"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "incomplete 1",
|
||||
data: []byte("\x0bhoho"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := bytes.NewReader(tt.data)
|
||||
if err := ReadUDPRequest(r); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteUDPRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantW string // Just a prefix, we don't care about the padding
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
wantW: "\x44\x02",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &bytes.Buffer{}
|
||||
err := WriteUDPRequest(w)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
||||
t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadUDPResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
want bool
|
||||
want1 uint32
|
||||
want2 string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal ok no padding",
|
||||
data: []byte("\x00\x00\x00\x00\x33\x0bhello world\x00"),
|
||||
want: true,
|
||||
want1: 51,
|
||||
want2: "hello world",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normal error with padding",
|
||||
data: []byte("\x01\x00\x00\x33\x33\x06stop!!\x05xxxxx"),
|
||||
want: false,
|
||||
want1: 13107,
|
||||
want2: "stop!!",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normal ok no message with padding",
|
||||
data: []byte("\x00\x00\x00\x00\x33\x00\x05xxxxx"),
|
||||
want: true,
|
||||
want1: 51,
|
||||
want2: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "incomplete 1",
|
||||
data: []byte("\x00\x00\x06"),
|
||||
want: false,
|
||||
want1: 0,
|
||||
want2: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete 2",
|
||||
data: []byte("\x01\x00\x01\x02\x03\x05jesus\x05x"),
|
||||
want: false,
|
||||
want1: 0,
|
||||
want2: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := bytes.NewReader(tt.data)
|
||||
got, got1, got2, err := ReadUDPResponse(r)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ReadUDPResponse() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("ReadUDPResponse() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
if got2 != tt.want2 {
|
||||
t.Errorf("ReadUDPResponse() got2 = %v, want %v", got2, tt.want2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteUDPResponse(t *testing.T) {
|
||||
type args struct {
|
||||
ok bool
|
||||
sessionID uint32
|
||||
msg string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantW string // Just a prefix, we don't care about the padding
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal ok",
|
||||
args: args{ok: true, sessionID: 6, msg: "hello world"},
|
||||
wantW: "\x00\x00\x00\x00\x06\x0bhello world",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normal error",
|
||||
args: args{ok: false, sessionID: 7, msg: "stop!!"},
|
||||
wantW: "\x01\x00\x00\x00\x07\x06stop!!",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{ok: true, sessionID: 0, msg: ""},
|
||||
wantW: "\x00\x00\x00\x00\x00\x00",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &bytes.Buffer{}
|
||||
err := WriteUDPResponse(w, tt.args.ok, tt.args.sessionID, tt.args.msg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
|
||||
t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
24
core/internal/utils/atomic.go
Normal file
24
core/internal/utils/atomic.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AtomicTime struct {
|
||||
v atomic.Value
|
||||
}
|
||||
|
||||
func NewAtomicTime(t time.Time) *AtomicTime {
|
||||
a := &AtomicTime{}
|
||||
a.Set(t)
|
||||
return a
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Set(new time.Time) {
|
||||
t.v.Store(new)
|
||||
}
|
||||
|
||||
func (t *AtomicTime) Get() time.Time {
|
||||
return t.v.Load().(time.Time)
|
||||
}
|
|
@ -103,9 +103,13 @@ type QUICConfig struct {
|
|||
}
|
||||
|
||||
// Outbound provides the implementation of how the server should connect to remote servers.
|
||||
// Even though it's called DialUDP, outbound implementations do not necessarily have to
|
||||
// return a "connected" UDP socket that can only send and receive from reqAddr. It's the
|
||||
// address of the first packet to be sent.
|
||||
// It's perfectly fine to have a "full-cone" implementation for UDP.
|
||||
type Outbound interface {
|
||||
DialTCP(reqAddr string) (net.Conn, error)
|
||||
ListenUDP() (UDPConn, error)
|
||||
DialUDP(reqAddr string) (UDPConn, error)
|
||||
}
|
||||
|
||||
// UDPConn is like net.PacketConn, but uses string for addresses.
|
||||
|
@ -125,7 +129,7 @@ func (o *defaultOutbound) DialTCP(reqAddr string) (net.Conn, error) {
|
|||
return defaultOutboundDialer.Dial("tcp", reqAddr)
|
||||
}
|
||||
|
||||
func (o *defaultOutbound) ListenUDP() (UDPConn, error) {
|
||||
func (o *defaultOutbound) DialUDP(reqAddr string) (UDPConn, error) {
|
||||
conn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -171,7 +175,7 @@ type EventLogger interface {
|
|||
Disconnect(addr net.Addr, id string, err error)
|
||||
TCPRequest(addr net.Addr, id, reqAddr string)
|
||||
TCPError(addr net.Addr, id, reqAddr string, err error)
|
||||
UDPRequest(addr net.Addr, id string, sessionID uint32)
|
||||
UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string)
|
||||
UDPError(addr net.Addr, id string, sessionID uint32, err error)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,14 +3,11 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/core/internal/congestion"
|
||||
"github.com/apernet/hysteria/core/internal/frag"
|
||||
"github.com/apernet/hysteria/core/internal/protocol"
|
||||
"github.com/apernet/hysteria/core/internal/utils"
|
||||
|
||||
|
@ -21,6 +18,8 @@ import (
|
|||
const (
|
||||
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
||||
closeErrCodeTrafficLimitReached = 0x107 // HTTP3 ErrCodeExcessiveLoad
|
||||
|
||||
udpSessionIdleTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
|
@ -101,90 +100,21 @@ type h3sHandler struct {
|
|||
authID string
|
||||
|
||||
udpOnce sync.Once
|
||||
udpSM udpSessionManager
|
||||
udpSM *udpSessionManager // Only set after authentication
|
||||
}
|
||||
|
||||
func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler {
|
||||
return &h3sHandler{
|
||||
config: config,
|
||||
conn: conn,
|
||||
udpSM: udpSessionManager{
|
||||
listenFunc: config.Outbound.ListenUDP,
|
||||
m: make(map[uint32]*udpSessionEntry),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type udpSessionEntry struct {
|
||||
Conn UDPConn
|
||||
D *frag.Defragger
|
||||
Closed bool
|
||||
}
|
||||
|
||||
type udpSessionManager struct {
|
||||
listenFunc func() (UDPConn, error)
|
||||
mutex sync.RWMutex
|
||||
m map[uint32]*udpSessionEntry
|
||||
nextID uint32
|
||||
}
|
||||
|
||||
// Add returns the session ID, the UDP connection and a function to close the UDP connection & delete the session.
|
||||
func (m *udpSessionManager) Add() (uint32, UDPConn, func(), error) {
|
||||
conn, err := m.listenFunc()
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
id := m.nextID
|
||||
m.nextID++
|
||||
entry := &udpSessionEntry{
|
||||
Conn: conn,
|
||||
D: &frag.Defragger{},
|
||||
Closed: false,
|
||||
}
|
||||
m.m[id] = entry
|
||||
|
||||
return id, conn, func() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if entry.Closed {
|
||||
// Already closed
|
||||
return
|
||||
}
|
||||
entry.Closed = true
|
||||
_ = conn.Close()
|
||||
delete(m.m, id)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Feed feeds a UDP message to the session manager.
|
||||
// If the message itself is a complete message, or it's the last fragment of a message,
|
||||
// it will be sent to the UDP connection.
|
||||
// The function will then return the number of bytes sent and any error occurred.
|
||||
func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) (int, error) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
entry, ok := m.m[msg.SessionID]
|
||||
if !ok {
|
||||
// No such session, drop the message
|
||||
return 0, nil
|
||||
}
|
||||
dfMsg := entry.D.Feed(msg)
|
||||
if dfMsg == nil {
|
||||
// Not a complete message yet
|
||||
return 0, nil
|
||||
}
|
||||
return entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
}
|
||||
|
||||
func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
|
||||
if h.authenticated {
|
||||
// Already authenticated
|
||||
protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
|
||||
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx)
|
||||
w.WriteHeader(protocol.StatusAuthOK)
|
||||
return
|
||||
}
|
||||
|
@ -204,18 +134,23 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
h.conn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
|
||||
}
|
||||
// Auth OK, send response
|
||||
protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
|
||||
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx)
|
||||
w.WriteHeader(protocol.StatusAuthOK)
|
||||
// Call event logger
|
||||
if h.config.EventLogger != nil {
|
||||
h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx)
|
||||
}
|
||||
// Start UDP loop if UDP is not disabled
|
||||
// Initialize UDP session manager (if UDP is enabled)
|
||||
// We use sync.Once to make sure that only one goroutine is started,
|
||||
// as ServeHTTP may be called by multiple goroutines simultaneously
|
||||
if !h.config.DisableUDP {
|
||||
h.udpOnce.Do(func() {
|
||||
go h.udpLoop()
|
||||
sm := newUDPSessionManager(
|
||||
&udpsmIO{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
|
||||
&udpsmEventLogger{h.conn, id, h.config.EventLogger},
|
||||
udpSessionIdleTimeout)
|
||||
h.udpSM = sm
|
||||
go sm.Run()
|
||||
})
|
||||
}
|
||||
} else {
|
||||
|
@ -240,9 +175,6 @@ func (h *h3sHandler) ProxyStreamHijacker(ft http3.FrameType, conn quic.Connectio
|
|||
case protocol.FrameTypeTCPRequest:
|
||||
go h.handleTCPRequest(stream)
|
||||
return true, nil
|
||||
case protocol.FrameTypeUDPRequest:
|
||||
go h.handleUDPRequest(stream)
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
|
@ -290,125 +222,6 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
|
||||
if h.config.DisableUDP {
|
||||
// UDP is disabled, send error message and close the stream
|
||||
_ = protocol.WriteUDPResponse(stream, false, 0, "UDP is disabled on this server")
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
// Read request
|
||||
err := protocol.ReadUDPRequest(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
// Add to session manager
|
||||
sessionID, conn, connCloseFunc, err := h.udpSM.Add()
|
||||
if err != nil {
|
||||
_ = protocol.WriteUDPResponse(stream, false, 0, err.Error())
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
// Send response
|
||||
_ = protocol.WriteUDPResponse(stream, true, sessionID, "")
|
||||
// Call event logger
|
||||
if h.config.EventLogger != nil {
|
||||
h.config.EventLogger.UDPRequest(h.conn.RemoteAddr(), h.authID, sessionID)
|
||||
}
|
||||
|
||||
// client <- remote direction
|
||||
go func() {
|
||||
udpBuf := make([]byte, protocol.MaxUDPSize)
|
||||
msgBuf := make([]byte, protocol.MaxUDPSize)
|
||||
for {
|
||||
udpN, rAddr, err := conn.ReadFrom(udpBuf)
|
||||
if err != nil {
|
||||
connCloseFunc()
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
if h.config.TrafficLogger != nil {
|
||||
ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
|
||||
if !ok {
|
||||
// TrafficLogger requested to disconnect the client
|
||||
_ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
// Try no frag first
|
||||
msg := protocol.UDPMessage{
|
||||
SessionID: sessionID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: rAddr,
|
||||
Data: udpBuf[:udpN],
|
||||
}
|
||||
msgN := msg.Serialize(msgBuf)
|
||||
if msgN < 0 {
|
||||
// Message even larger than MaxUDPSize, drop it
|
||||
continue
|
||||
}
|
||||
sendErr := h.conn.SendMessage(msgBuf[:msgN])
|
||||
var errTooLarge quic.ErrMessageTooLarge
|
||||
if errors.As(sendErr, &errTooLarge) {
|
||||
// Message too large, try fragmentation
|
||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||
for _, fMsg := range fMsgs {
|
||||
msgN = fMsg.Serialize(msgBuf)
|
||||
_ = h.conn.SendMessage(msgBuf[:msgN])
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Hold (drain) the stream until the client closes it.
|
||||
// Closing the stream is the signal to stop the UDP session.
|
||||
_, err = io.Copy(io.Discard, stream)
|
||||
// Call event logger
|
||||
if h.config.EventLogger != nil {
|
||||
h.config.EventLogger.UDPError(h.conn.RemoteAddr(), h.authID, sessionID, err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
connCloseFunc()
|
||||
_ = stream.Close()
|
||||
}
|
||||
|
||||
func (h *h3sHandler) udpLoop() {
|
||||
for {
|
||||
msg, err := h.conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ok := h.handleUDPMessage(msg)
|
||||
if !ok {
|
||||
// TrafficLogger requested to disconnect the client
|
||||
_ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// client -> remote direction
|
||||
// Returns a bool indicating whether the receiving loop should continue
|
||||
func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) {
|
||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
if h.config.TrafficLogger != nil {
|
||||
ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
_, _ = h.udpSM.Feed(udpMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if h.config.MasqHandler != nil {
|
||||
h.config.MasqHandler.ServeHTTP(w, r)
|
||||
|
@ -417,3 +230,74 @@ func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
|
|||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// udpsmIO is the IO implementation for udpSessionManager with TrafficLogger support
|
||||
type udpsmIO struct {
|
||||
Conn quic.Connection
|
||||
AuthID string
|
||||
TrafficLogger TrafficLogger
|
||||
Outbound Outbound
|
||||
}
|
||||
|
||||
func (io *udpsmIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
for {
|
||||
msg, err := io.Conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
// Connection error, this will stop the session manager
|
||||
return nil, err
|
||||
}
|
||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||
if err != nil {
|
||||
// Invalid message, this is fine - just wait for the next
|
||||
continue
|
||||
}
|
||||
if io.TrafficLogger != nil {
|
||||
ok := io.TrafficLogger.Log(io.AuthID, uint64(len(udpMsg.Data)), 0)
|
||||
if !ok {
|
||||
// TrafficLogger requested to disconnect the client
|
||||
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
||||
return nil, errDisconnect
|
||||
}
|
||||
}
|
||||
return udpMsg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (io *udpsmIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||
if io.TrafficLogger != nil {
|
||||
ok := io.TrafficLogger.Log(io.AuthID, 0, uint64(len(msg.Data)))
|
||||
if !ok {
|
||||
// TrafficLogger requested to disconnect the client
|
||||
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
|
||||
return errDisconnect
|
||||
}
|
||||
}
|
||||
msgN := msg.Serialize(buf)
|
||||
if msgN < 0 {
|
||||
// Message larger than buffer, silent drop
|
||||
return nil
|
||||
}
|
||||
return io.Conn.SendMessage(buf[:msgN])
|
||||
}
|
||||
|
||||
func (io *udpsmIO) DialUDP(reqAddr string) (UDPConn, error) {
|
||||
return io.Outbound.DialUDP(reqAddr)
|
||||
}
|
||||
|
||||
type udpsmEventLogger struct {
|
||||
Conn quic.Connection
|
||||
AuthID string
|
||||
EventLogger EventLogger
|
||||
}
|
||||
|
||||
func (l *udpsmEventLogger) New(sessionID uint32, reqAddr string) {
|
||||
if l.EventLogger != nil {
|
||||
l.EventLogger.UDPRequest(l.Conn.RemoteAddr(), l.AuthID, sessionID, reqAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *udpsmEventLogger) Closed(sessionID uint32, err error) {
|
||||
if l.EventLogger != nil {
|
||||
l.EventLogger.UDPError(l.Conn.RemoteAddr(), l.AuthID, sessionID, err)
|
||||
}
|
||||
}
|
||||
|
|
218
core/server/udp.go
Normal file
218
core/server/udp.go
Normal file
|
@ -0,0 +1,218 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
||||
"github.com/apernet/hysteria/core/internal/frag"
|
||||
"github.com/apernet/hysteria/core/internal/protocol"
|
||||
"github.com/apernet/hysteria/core/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
idleCleanupInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
type udpSessionManagerIO interface {
|
||||
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||
SendMessage([]byte, *protocol.UDPMessage) error
|
||||
DialUDP(reqAddr string) (UDPConn, error)
|
||||
}
|
||||
|
||||
type udpSessionManagerEventLogger interface {
|
||||
New(sessionID uint32, reqAddr string)
|
||||
Closed(sessionID uint32, err error)
|
||||
}
|
||||
|
||||
type udpSessionEntry struct {
|
||||
ID uint32
|
||||
Conn UDPConn
|
||||
D *frag.Defragger
|
||||
Last *utils.AtomicTime
|
||||
Closed bool
|
||||
}
|
||||
|
||||
// Feed feeds a UDP message to the session.
|
||||
// If the message itself is a complete message, or it completes a fragmented message,
|
||||
// the message is written to the session's UDP connection, and the number of bytes
|
||||
// written is returned.
|
||||
// Otherwise, 0 and nil are returned.
|
||||
func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
|
||||
e.Last.Set(time.Now())
|
||||
dfMsg := e.D.Feed(msg)
|
||||
if dfMsg == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
|
||||
}
|
||||
|
||||
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
|
||||
// and sends using the provided io.
|
||||
// Exit and returns error when either the underlying UDP connection returns
|
||||
// error (e.g. closed), or the provided io returns error when sending.
|
||||
func (e *udpSessionEntry) ReceiveLoop(io udpSessionManagerIO) error {
|
||||
udpBuf := make([]byte, protocol.MaxUDPSize)
|
||||
msgBuf := make([]byte, protocol.MaxUDPSize)
|
||||
for {
|
||||
udpN, rAddr, err := e.Conn.ReadFrom(udpBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.Last.Set(time.Now())
|
||||
|
||||
msg := &protocol.UDPMessage{
|
||||
SessionID: e.ID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: rAddr,
|
||||
Data: udpBuf[:udpN],
|
||||
}
|
||||
err = sendMessageAutoFrag(io, msgBuf, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendMessageAutoFrag tries to send a UDP message as a whole first,
|
||||
// but if it fails due to quic.ErrMessageTooLarge, it tries again by
|
||||
// fragmenting the message.
|
||||
func sendMessageAutoFrag(io udpSessionManagerIO, buf []byte, msg *protocol.UDPMessage) error {
|
||||
err := io.SendMessage(buf, msg)
|
||||
var errTooLarge quic.ErrMessageTooLarge
|
||||
if errors.As(err, &errTooLarge) {
|
||||
// Message too large, try fragmentation
|
||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||
for _, fMsg := range fMsgs {
|
||||
err := io.SendMessage(buf, &fMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// udpSessionManager manages the lifecycle of UDP sessions.
|
||||
// Each UDP session is identified by a SessionID, and corresponds to a UDP connection.
|
||||
// A UDP session is created when a UDP message with a new SessionID is received.
|
||||
// Similar to standard NAT, a UDP session is destroyed when no UDP message is received
|
||||
// for a certain period of time (specified by idleTimeout).
|
||||
type udpSessionManager struct {
|
||||
io udpSessionManagerIO
|
||||
eventLogger udpSessionManagerEventLogger
|
||||
idleTimeout time.Duration
|
||||
|
||||
mutex sync.Mutex
|
||||
m map[uint32]*udpSessionEntry
|
||||
nextID uint32
|
||||
}
|
||||
|
||||
func newUDPSessionManager(
|
||||
io udpSessionManagerIO,
|
||||
eventLogger udpSessionManagerEventLogger,
|
||||
idleTimeout time.Duration,
|
||||
) *udpSessionManager {
|
||||
return &udpSessionManager{
|
||||
io: io,
|
||||
eventLogger: eventLogger,
|
||||
idleTimeout: idleTimeout,
|
||||
m: make(map[uint32]*udpSessionEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Run runs the session manager main loop.
|
||||
// Exit and returns error when the underlying io returns error (e.g. closed).
|
||||
func (m *udpSessionManager) Run() error {
|
||||
stopCh := make(chan struct{})
|
||||
go m.idleCleanupLoop(stopCh)
|
||||
defer close(stopCh)
|
||||
defer m.cleanup(false)
|
||||
|
||||
for {
|
||||
msg, err := m.io.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.feed(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(idleCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanup(true)
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) cleanup(idleOnly bool) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for sessionID, entry := range m.m {
|
||||
if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout {
|
||||
entry.Closed = true
|
||||
_ = entry.Conn.Close()
|
||||
m.eventLogger.Closed(sessionID, nil)
|
||||
delete(m.m, sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
||||
m.mutex.Lock()
|
||||
|
||||
entry := m.m[msg.SessionID]
|
||||
if entry == nil {
|
||||
// New session
|
||||
m.eventLogger.New(msg.SessionID, msg.Addr)
|
||||
conn, err := m.io.DialUDP(msg.Addr)
|
||||
if err != nil {
|
||||
m.mutex.Unlock()
|
||||
m.eventLogger.Closed(msg.SessionID, err)
|
||||
return
|
||||
}
|
||||
entry = &udpSessionEntry{
|
||||
ID: msg.SessionID,
|
||||
Conn: conn,
|
||||
D: &frag.Defragger{},
|
||||
Last: utils.NewAtomicTime(time.Now()),
|
||||
}
|
||||
// Start the receive loop for this session
|
||||
go func() {
|
||||
err := entry.ReceiveLoop(m.io)
|
||||
// Receive loop stopped, remove the session
|
||||
m.mutex.Lock()
|
||||
if !entry.Closed {
|
||||
entry.Closed = true
|
||||
_ = entry.Conn.Close()
|
||||
m.eventLogger.Closed(entry.ID, err)
|
||||
delete(m.m, entry.ID)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}()
|
||||
m.m[msg.SessionID] = entry
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
// Feed the message to the session
|
||||
// Feed (send) errors are ignored for now,
|
||||
// as some are temporary (e.g. invalid address)
|
||||
_, _ = entry.Feed(msg)
|
||||
}
|
191
core/server/udp_test.go
Normal file
191
core/server/udp_test.go
Normal file
|
@ -0,0 +1,191 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/core/internal/protocol"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
type echoUDPConnPkt struct {
|
||||
Data []byte
|
||||
Addr string
|
||||
Close bool
|
||||
}
|
||||
|
||||
type echoUDPConn struct {
|
||||
PktCh chan echoUDPConnPkt
|
||||
}
|
||||
|
||||
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
|
||||
pkt := <-c.PktCh
|
||||
if pkt.Close {
|
||||
return 0, "", errors.New("closed")
|
||||
}
|
||||
n := copy(b, pkt.Data)
|
||||
return n, pkt.Addr, nil
|
||||
}
|
||||
|
||||
func (c *echoUDPConn) WriteTo(b []byte, addr string) (int, error) {
|
||||
nb := make([]byte, len(b))
|
||||
copy(nb, b)
|
||||
c.PktCh <- echoUDPConnPkt{
|
||||
Data: nb,
|
||||
Addr: addr,
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *echoUDPConn) Close() error {
|
||||
c.PktCh <- echoUDPConnPkt{
|
||||
Close: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpsmMockIO struct {
|
||||
ReceiveCh <-chan *protocol.UDPMessage
|
||||
SendCh chan<- *protocol.UDPMessage
|
||||
}
|
||||
|
||||
func (io *udpsmMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
m := <-io.ReceiveCh
|
||||
if m == nil {
|
||||
return nil, errors.New("closed")
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (io *udpsmMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||
nMsg := *msg
|
||||
nMsg.Data = make([]byte, len(msg.Data))
|
||||
copy(nMsg.Data, msg.Data)
|
||||
io.SendCh <- &nMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (io *udpsmMockIO) DialUDP(reqAddr string) (UDPConn, error) {
|
||||
return &echoUDPConn{
|
||||
PktCh: make(chan echoUDPConnPkt, 10),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type udpsmMockEventNew struct {
|
||||
SessionID uint32
|
||||
ReqAddr string
|
||||
}
|
||||
|
||||
type udpsmMockEventClosed struct {
|
||||
SessionID uint32
|
||||
Err error
|
||||
}
|
||||
|
||||
type udpsmMockEventLogger struct {
|
||||
NewCh chan<- udpsmMockEventNew
|
||||
ClosedCh chan<- udpsmMockEventClosed
|
||||
}
|
||||
|
||||
func (l *udpsmMockEventLogger) New(sessionID uint32, reqAddr string) {
|
||||
l.NewCh <- udpsmMockEventNew{sessionID, reqAddr}
|
||||
}
|
||||
|
||||
func (l *udpsmMockEventLogger) Closed(sessionID uint32, err error) {
|
||||
l.ClosedCh <- udpsmMockEventClosed{sessionID, err}
|
||||
}
|
||||
|
||||
func TestUDPSessionManager(t *testing.T) {
|
||||
msgReceiveCh := make(chan *protocol.UDPMessage, 10)
|
||||
msgSendCh := make(chan *protocol.UDPMessage, 10)
|
||||
io := &udpsmMockIO{
|
||||
ReceiveCh: msgReceiveCh,
|
||||
SendCh: msgSendCh,
|
||||
}
|
||||
eventNewCh := make(chan udpsmMockEventNew, 10)
|
||||
eventClosedCh := make(chan udpsmMockEventClosed, 10)
|
||||
eventLogger := &udpsmMockEventLogger{
|
||||
NewCh: eventNewCh,
|
||||
ClosedCh: eventClosedCh,
|
||||
}
|
||||
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
|
||||
go sm.Run()
|
||||
|
||||
ms := []*protocol.UDPMessage{
|
||||
{
|
||||
SessionID: 1234,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: "example.com:5353",
|
||||
Data: []byte("hello"),
|
||||
},
|
||||
{
|
||||
SessionID: 5678,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: "example.com:9999",
|
||||
Data: []byte("goodbye"),
|
||||
},
|
||||
{
|
||||
SessionID: 1234,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: "example.com:5353",
|
||||
Data: []byte(" world"),
|
||||
},
|
||||
{
|
||||
SessionID: 5678,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: "example.com:9999",
|
||||
Data: []byte(" girl"),
|
||||
},
|
||||
}
|
||||
for _, m := range ms {
|
||||
msgReceiveCh <- m
|
||||
}
|
||||
// New event order should be consistent
|
||||
newEvent := <-eventNewCh
|
||||
if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" {
|
||||
t.Error("unexpected new event value")
|
||||
}
|
||||
newEvent = <-eventNewCh
|
||||
if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" {
|
||||
t.Error("unexpected new event value")
|
||||
}
|
||||
// Message order is not guaranteed
|
||||
msgMap := make(map[string]bool)
|
||||
for i := 0; i < 4; i++ {
|
||||
msg := <-msgSendCh
|
||||
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
|
||||
}
|
||||
if !(msgMap["1234:example.com:5353:hello"] &&
|
||||
msgMap["5678:example.com:9999:goodbye"] &&
|
||||
msgMap["1234:example.com:5353: world"] &&
|
||||
msgMap["5678:example.com:9999: girl"]) {
|
||||
t.Error("unexpected message value")
|
||||
}
|
||||
// Timeout check
|
||||
startTime := time.Now()
|
||||
closedMap := make(map[uint32]bool)
|
||||
for i := 0; i < 2; i++ {
|
||||
closedEvent := <-eventClosedCh
|
||||
closedMap[closedEvent.SessionID] = true
|
||||
}
|
||||
if !(closedMap[1234] && closedMap[5678]) {
|
||||
t.Error("unexpected closed event value", closedMap)
|
||||
}
|
||||
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
|
||||
t.Error("unexpected timeout duration")
|
||||
}
|
||||
|
||||
// Goroutine leak check
|
||||
msgReceiveCh <- nil
|
||||
time.Sleep(1 * time.Second) // Wait for internal routines to exit
|
||||
goleak.VerifyNone(t)
|
||||
}
|
|
@ -75,7 +75,7 @@ func (a *PluggableOutboundAdapter) DialTCP(reqAddr string) (net.Conn, error) {
|
|||
})
|
||||
}
|
||||
|
||||
func (a *PluggableOutboundAdapter) ListenUDP() (server.UDPConn, error) {
|
||||
func (a *PluggableOutboundAdapter) DialUDP() (server.UDPConn, error) {
|
||||
conn, err := a.PluggableOutbound.ListenUDP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -68,10 +68,10 @@ func TestPluggableOutboundAdapter(t *testing.T) {
|
|||
if err != errWrongAddr {
|
||||
t.Fatal("DialTCP with wrong addr should fail, got", err)
|
||||
}
|
||||
// ListenUDP
|
||||
uConn, err := adapter.ListenUDP()
|
||||
// DialUDP
|
||||
uConn, err := adapter.DialUDP()
|
||||
if err != nil {
|
||||
t.Fatal("ListenUDP failed", err)
|
||||
t.Fatal("DialUDP failed", err)
|
||||
}
|
||||
// ReadFrom
|
||||
b := make([]byte, 10)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue