feat(wip): udp rework server side

This commit is contained in:
Toby 2023-07-23 11:42:52 -07:00
parent 6245f83262
commit a2fbcc6507
17 changed files with 554 additions and 513 deletions

View file

@ -400,7 +400,7 @@ func (c *udpConn) Receive() ([]byte, string, error) {
// Send is not thread-safe as it uses a shared send buffer for now.
func (c *udpConn) Send(data []byte, addr string) error {
// Try no frag first
msg := protocol.UDPMessage{
msg := &protocol.UDPMessage{
SessionID: c.SessionID,
PacketID: 0,
FragID: 0,

View file

@ -4,6 +4,7 @@ go 1.20
require (
github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000
go.uber.org/goleak v1.2.1
golang.org/x/time v0.3.0
)

View file

@ -38,6 +38,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=

View file

@ -4,9 +4,9 @@ import (
"github.com/apernet/hysteria/core/internal/protocol"
)
func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
if m.Size() <= maxSize {
return []protocol.UDPMessage{m}
return []protocol.UDPMessage{*m}
}
fullPayload := m.Data
maxPayloadSize := maxSize - m.HeaderSize()
@ -19,7 +19,7 @@ func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage {
if payloadSize > maxPayloadSize {
payloadSize = maxPayloadSize
}
frag := m
frag := *m
frag.FragID = fragID
frag.FragCount = fragCount
frag.Data = fullPayload[off : off+payloadSize]

View file

@ -124,7 +124,7 @@ func TestFragUDPMessage(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := FragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
if got := FragUDPMessage(&tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) {
t.Errorf("FragUDPMessage() = %v, want %v", got, tt.want)
}
})

View file

@ -36,7 +36,7 @@ func TestClientNoServer(t *testing.T) {
// Try UDP
_, err = c.ListenUDP()
if !errors.As(err, &cErr) {
t.Fatal("expected connect error from ListenUDP")
t.Fatal("expected connect error from DialUDP")
}
}
@ -86,7 +86,7 @@ func TestClientServerBadAuth(t *testing.T) {
// Try UDP
_, err = c.ListenUDP()
if !errors.As(err, &aErr) {
t.Fatal("expected auth error from ListenUDP")
t.Fatal("expected auth error from DialUDP")
}
}

View file

@ -9,31 +9,34 @@ const (
URLHost = "hysteria"
URLPath = "/auth"
HeaderAuth = "Hysteria-Auth"
HeaderCCRX = "Hysteria-CC-RX"
HeaderPadding = "Hysteria-Padding"
RequestHeaderAuth = "Hysteria-Auth"
ResponseHeaderUDPEnabled = "Hysteria-UDP"
CommonHeaderCCRX = "Hysteria-CC-RX"
CommonHeaderPadding = "Hysteria-Padding"
StatusAuthOK = 233
)
func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) {
auth = h.Get(HeaderAuth)
rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64)
auth = h.Get(RequestHeaderAuth)
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
return
}
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) {
h.Set(HeaderAuth, auth)
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
h.Set(HeaderPadding, authRequestPadding.String())
h.Set(RequestHeaderAuth, auth)
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10))
h.Set(CommonHeaderPadding, authRequestPadding.String())
}
func AuthResponseDataFromHeader(h http.Header) (rx uint64) {
rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64)
func AuthResponseDataFromHeader(h http.Header) (udp bool, rx uint64) {
udp, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
return
}
func AuthResponseDataToHeader(h http.Header, rx uint64) {
h.Set(HeaderCCRX, strconv.FormatUint(rx, 10))
h.Set(HeaderPadding, authResponsePadding.String())
func AuthResponseDataToHeader(h http.Header, udp bool, rx uint64) {
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(udp))
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10))
h.Set(CommonHeaderPadding, authResponsePadding.String())
}

View file

@ -28,6 +28,4 @@ var (
authResponsePadding = padding{Min: 256, Max: 2048}
tcpRequestPadding = padding{Min: 64, Max: 512}
tcpResponsePadding = padding{Min: 128, Max: 1024}
udpRequestPadding = padding{Min: 64, Max: 512}
udpResponsePadding = padding{Min: 128, Max: 1024}
)

View file

@ -13,7 +13,6 @@ import (
const (
FrameTypeTCPRequest = 0x401
FrameTypeUDPRequest = 0x402
// Max length values are for preventing DoS attacks
@ -148,113 +147,6 @@ func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
return err
}
// UDPRequest format:
// 0x402 (QUIC varint)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadUDPRequest(r io.Reader) error {
bReader := quicvarint.NewReader(r)
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return err
}
if paddingLen > MaxPaddingLength {
return errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return err
}
}
return nil
}
func WriteUDPRequest(w io.Writer) error {
padding := udpRequestPadding.String()
paddingLen := len(padding)
sz := int(quicvarint.Len(FrameTypeUDPRequest)) +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
i := varintPut(buf, FrameTypeUDPRequest)
i += varintPut(buf[i:], uint64(paddingLen))
copy(buf[i:], padding)
_, err := w.Write(buf)
return err
}
// UDPResponse format:
// Status (byte, 0=ok, 1=error)
// Session ID (uint32 BE)
// Message length (QUIC varint)
// Message (bytes)
// Padding length (QUIC varint)
// Padding (bytes)
func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) {
var status [1]byte
if _, err := io.ReadFull(r, status[:]); err != nil {
return false, 0, "", err
}
var sessionID uint32
if err := binary.Read(r, binary.BigEndian, &sessionID); err != nil {
return false, 0, "", err
}
bReader := quicvarint.NewReader(r)
msgLen, err := quicvarint.Read(bReader)
if err != nil {
return false, 0, "", err
}
if msgLen > MaxMessageLength {
return false, 0, "", errors.ProtocolError{Message: "invalid message length"}
}
var msgBuf []byte
// No message is fine
if msgLen > 0 {
msgBuf = make([]byte, msgLen)
_, err = io.ReadFull(r, msgBuf)
if err != nil {
return false, 0, "", err
}
}
paddingLen, err := quicvarint.Read(bReader)
if err != nil {
return false, 0, "", err
}
if paddingLen > MaxPaddingLength {
return false, 0, "", errors.ProtocolError{Message: "invalid padding length"}
}
if paddingLen > 0 {
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return false, 0, "", err
}
}
return status[0] == 0, sessionID, string(msgBuf), nil
}
func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error {
padding := udpResponsePadding.String()
paddingLen := len(padding)
msgLen := len(msg)
sz := 1 + 4 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
buf := make([]byte, sz)
if ok {
buf[0] = 0
} else {
buf[0] = 1
}
binary.BigEndian.PutUint32(buf[1:], sessionID)
i := varintPut(buf[5:], uint64(msgLen))
i += copy(buf[5+i:], msg)
i += varintPut(buf[5+i:], uint64(paddingLen))
copy(buf[5+i:], padding)
_, err := w.Write(buf)
return err
}
// UDPMessage format:
// Session ID (uint32 BE)
// Packet ID (uint16 BE)

View file

@ -315,179 +315,3 @@ func TestWriteTCPResponse(t *testing.T) {
})
}
}
func TestReadUDPRequest(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{
name: "normal no padding",
data: []byte("\x00\x00"),
wantErr: false,
},
{
name: "normal with padding",
data: []byte("\x02gg"),
wantErr: false,
},
{
name: "incomplete 1",
data: []byte("\x0bhoho"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
if err := ReadUDPRequest(r); (err != nil) != tt.wantErr {
t.Errorf("ReadUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestWriteUDPRequest(t *testing.T) {
tests := []struct {
name string
wantW string // Just a prefix, we don't care about the padding
wantErr bool
}{
{
name: "normal",
wantW: "\x44\x02",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := WriteUDPRequest(w)
if (err != nil) != tt.wantErr {
t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}
func TestReadUDPResponse(t *testing.T) {
tests := []struct {
name string
data []byte
want bool
want1 uint32
want2 string
wantErr bool
}{
{
name: "normal ok no padding",
data: []byte("\x00\x00\x00\x00\x33\x0bhello world\x00"),
want: true,
want1: 51,
want2: "hello world",
wantErr: false,
},
{
name: "normal error with padding",
data: []byte("\x01\x00\x00\x33\x33\x06stop!!\x05xxxxx"),
want: false,
want1: 13107,
want2: "stop!!",
wantErr: false,
},
{
name: "normal ok no message with padding",
data: []byte("\x00\x00\x00\x00\x33\x00\x05xxxxx"),
want: true,
want1: 51,
want2: "",
wantErr: false,
},
{
name: "incomplete 1",
data: []byte("\x00\x00\x06"),
want: false,
want1: 0,
want2: "",
wantErr: true,
},
{
name: "incomplete 2",
data: []byte("\x01\x00\x01\x02\x03\x05jesus\x05x"),
want: false,
want1: 0,
want2: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := bytes.NewReader(tt.data)
got, got1, got2, err := ReadUDPResponse(r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ReadUDPResponse() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("ReadUDPResponse() got1 = %v, want %v", got1, tt.want1)
}
if got2 != tt.want2 {
t.Errorf("ReadUDPResponse() got2 = %v, want %v", got2, tt.want2)
}
})
}
}
func TestWriteUDPResponse(t *testing.T) {
type args struct {
ok bool
sessionID uint32
msg string
}
tests := []struct {
name string
args args
wantW string // Just a prefix, we don't care about the padding
wantErr bool
}{
{
name: "normal ok",
args: args{ok: true, sessionID: 6, msg: "hello world"},
wantW: "\x00\x00\x00\x00\x06\x0bhello world",
wantErr: false,
},
{
name: "normal error",
args: args{ok: false, sessionID: 7, msg: "stop!!"},
wantW: "\x01\x00\x00\x00\x07\x06stop!!",
wantErr: false,
},
{
name: "empty",
args: args{ok: true, sessionID: 0, msg: ""},
wantW: "\x00\x00\x00\x00\x00\x00",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &bytes.Buffer{}
err := WriteUDPResponse(w, tt.args.ok, tt.args.sessionID, tt.args.msg)
if (err != nil) != tt.wantErr {
t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) {
t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW)
}
})
}
}

View file

@ -0,0 +1,24 @@
package utils
import (
"sync/atomic"
"time"
)
type AtomicTime struct {
v atomic.Value
}
func NewAtomicTime(t time.Time) *AtomicTime {
a := &AtomicTime{}
a.Set(t)
return a
}
func (t *AtomicTime) Set(new time.Time) {
t.v.Store(new)
}
func (t *AtomicTime) Get() time.Time {
return t.v.Load().(time.Time)
}

View file

@ -103,9 +103,13 @@ type QUICConfig struct {
}
// Outbound provides the implementation of how the server should connect to remote servers.
// Even though it's called DialUDP, outbound implementations do not necessarily have to
// return a "connected" UDP socket that can only send and receive from reqAddr. It's the
// address of the first packet to be sent.
// It's perfectly fine to have a "full-cone" implementation for UDP.
type Outbound interface {
DialTCP(reqAddr string) (net.Conn, error)
ListenUDP() (UDPConn, error)
DialUDP(reqAddr string) (UDPConn, error)
}
// UDPConn is like net.PacketConn, but uses string for addresses.
@ -125,7 +129,7 @@ func (o *defaultOutbound) DialTCP(reqAddr string) (net.Conn, error) {
return defaultOutboundDialer.Dial("tcp", reqAddr)
}
func (o *defaultOutbound) ListenUDP() (UDPConn, error) {
func (o *defaultOutbound) DialUDP(reqAddr string) (UDPConn, error) {
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
@ -171,7 +175,7 @@ type EventLogger interface {
Disconnect(addr net.Addr, id string, err error)
TCPRequest(addr net.Addr, id, reqAddr string)
TCPError(addr net.Addr, id, reqAddr string, err error)
UDPRequest(addr net.Addr, id string, sessionID uint32)
UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string)
UDPError(addr net.Addr, id string, sessionID uint32, err error)
}

View file

@ -3,14 +3,11 @@ package server
import (
"context"
"crypto/tls"
"errors"
"io"
"math/rand"
"net/http"
"sync"
"time"
"github.com/apernet/hysteria/core/internal/congestion"
"github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils"
@ -21,6 +18,8 @@ import (
const (
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
closeErrCodeTrafficLimitReached = 0x107 // HTTP3 ErrCodeExcessiveLoad
udpSessionIdleTimeout = 60 * time.Second
)
type Server interface {
@ -101,90 +100,21 @@ type h3sHandler struct {
authID string
udpOnce sync.Once
udpSM udpSessionManager
udpSM *udpSessionManager // Only set after authentication
}
func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler {
return &h3sHandler{
config: config,
conn: conn,
udpSM: udpSessionManager{
listenFunc: config.Outbound.ListenUDP,
m: make(map[uint32]*udpSessionEntry),
},
}
}
type udpSessionEntry struct {
Conn UDPConn
D *frag.Defragger
Closed bool
}
type udpSessionManager struct {
listenFunc func() (UDPConn, error)
mutex sync.RWMutex
m map[uint32]*udpSessionEntry
nextID uint32
}
// Add returns the session ID, the UDP connection and a function to close the UDP connection & delete the session.
func (m *udpSessionManager) Add() (uint32, UDPConn, func(), error) {
conn, err := m.listenFunc()
if err != nil {
return 0, nil, nil, err
}
m.mutex.Lock()
defer m.mutex.Unlock()
id := m.nextID
m.nextID++
entry := &udpSessionEntry{
Conn: conn,
D: &frag.Defragger{},
Closed: false,
}
m.m[id] = entry
return id, conn, func() {
m.mutex.Lock()
defer m.mutex.Unlock()
if entry.Closed {
// Already closed
return
}
entry.Closed = true
_ = conn.Close()
delete(m.m, id)
}, nil
}
// Feed feeds a UDP message to the session manager.
// If the message itself is a complete message, or it's the last fragment of a message,
// it will be sent to the UDP connection.
// The function will then return the number of bytes sent and any error occurred.
func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) (int, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
entry, ok := m.m[msg.SessionID]
if !ok {
// No such session, drop the message
return 0, nil
}
dfMsg := entry.D.Feed(msg)
if dfMsg == nil {
// Not a complete message yet
return 0, nil
}
return entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
}
func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
if h.authenticated {
// Already authenticated
protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx)
w.WriteHeader(protocol.StatusAuthOK)
return
}
@ -204,18 +134,23 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.conn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
}
// Auth OK, send response
protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx)
w.WriteHeader(protocol.StatusAuthOK)
// Call event logger
if h.config.EventLogger != nil {
h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx)
}
// Start UDP loop if UDP is not disabled
// Initialize UDP session manager (if UDP is enabled)
// We use sync.Once to make sure that only one goroutine is started,
// as ServeHTTP may be called by multiple goroutines simultaneously
if !h.config.DisableUDP {
h.udpOnce.Do(func() {
go h.udpLoop()
sm := newUDPSessionManager(
&udpsmIO{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
&udpsmEventLogger{h.conn, id, h.config.EventLogger},
udpSessionIdleTimeout)
h.udpSM = sm
go sm.Run()
})
}
} else {
@ -240,9 +175,6 @@ func (h *h3sHandler) ProxyStreamHijacker(ft http3.FrameType, conn quic.Connectio
case protocol.FrameTypeTCPRequest:
go h.handleTCPRequest(stream)
return true, nil
case protocol.FrameTypeUDPRequest:
go h.handleUDPRequest(stream)
return true, nil
default:
return false, nil
}
@ -290,125 +222,6 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
}
}
func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
if h.config.DisableUDP {
// UDP is disabled, send error message and close the stream
_ = protocol.WriteUDPResponse(stream, false, 0, "UDP is disabled on this server")
_ = stream.Close()
return
}
// Read request
err := protocol.ReadUDPRequest(stream)
if err != nil {
_ = stream.Close()
return
}
// Add to session manager
sessionID, conn, connCloseFunc, err := h.udpSM.Add()
if err != nil {
_ = protocol.WriteUDPResponse(stream, false, 0, err.Error())
_ = stream.Close()
return
}
// Send response
_ = protocol.WriteUDPResponse(stream, true, sessionID, "")
// Call event logger
if h.config.EventLogger != nil {
h.config.EventLogger.UDPRequest(h.conn.RemoteAddr(), h.authID, sessionID)
}
// client <- remote direction
go func() {
udpBuf := make([]byte, protocol.MaxUDPSize)
msgBuf := make([]byte, protocol.MaxUDPSize)
for {
udpN, rAddr, err := conn.ReadFrom(udpBuf)
if err != nil {
connCloseFunc()
_ = stream.Close()
return
}
if h.config.TrafficLogger != nil {
ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN))
if !ok {
// TrafficLogger requested to disconnect the client
_ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
return
}
}
// Try no frag first
msg := protocol.UDPMessage{
SessionID: sessionID,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: rAddr,
Data: udpBuf[:udpN],
}
msgN := msg.Serialize(msgBuf)
if msgN < 0 {
// Message even larger than MaxUDPSize, drop it
continue
}
sendErr := h.conn.SendMessage(msgBuf[:msgN])
var errTooLarge quic.ErrMessageTooLarge
if errors.As(sendErr, &errTooLarge) {
// Message too large, try fragmentation
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
for _, fMsg := range fMsgs {
msgN = fMsg.Serialize(msgBuf)
_ = h.conn.SendMessage(msgBuf[:msgN])
}
}
}
}()
// Hold (drain) the stream until the client closes it.
// Closing the stream is the signal to stop the UDP session.
_, err = io.Copy(io.Discard, stream)
// Call event logger
if h.config.EventLogger != nil {
h.config.EventLogger.UDPError(h.conn.RemoteAddr(), h.authID, sessionID, err)
}
// Cleanup
connCloseFunc()
_ = stream.Close()
}
func (h *h3sHandler) udpLoop() {
for {
msg, err := h.conn.ReceiveMessage()
if err != nil {
return
}
ok := h.handleUDPMessage(msg)
if !ok {
// TrafficLogger requested to disconnect the client
_ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
return
}
}
}
// client -> remote direction
// Returns a bool indicating whether the receiving loop should continue
func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) {
udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil {
return true
}
if h.config.TrafficLogger != nil {
ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0)
if !ok {
return false
}
}
_, _ = h.udpSM.Feed(udpMsg)
return true
}
func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
if h.config.MasqHandler != nil {
h.config.MasqHandler.ServeHTTP(w, r)
@ -417,3 +230,74 @@ func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}
}
// udpsmIO is the IO implementation for udpSessionManager with TrafficLogger support
type udpsmIO struct {
Conn quic.Connection
AuthID string
TrafficLogger TrafficLogger
Outbound Outbound
}
func (io *udpsmIO) ReceiveMessage() (*protocol.UDPMessage, error) {
for {
msg, err := io.Conn.ReceiveMessage()
if err != nil {
// Connection error, this will stop the session manager
return nil, err
}
udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil {
// Invalid message, this is fine - just wait for the next
continue
}
if io.TrafficLogger != nil {
ok := io.TrafficLogger.Log(io.AuthID, uint64(len(udpMsg.Data)), 0)
if !ok {
// TrafficLogger requested to disconnect the client
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
return nil, errDisconnect
}
}
return udpMsg, nil
}
}
func (io *udpsmIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
if io.TrafficLogger != nil {
ok := io.TrafficLogger.Log(io.AuthID, 0, uint64(len(msg.Data)))
if !ok {
// TrafficLogger requested to disconnect the client
_ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "")
return errDisconnect
}
}
msgN := msg.Serialize(buf)
if msgN < 0 {
// Message larger than buffer, silent drop
return nil
}
return io.Conn.SendMessage(buf[:msgN])
}
func (io *udpsmIO) DialUDP(reqAddr string) (UDPConn, error) {
return io.Outbound.DialUDP(reqAddr)
}
type udpsmEventLogger struct {
Conn quic.Connection
AuthID string
EventLogger EventLogger
}
func (l *udpsmEventLogger) New(sessionID uint32, reqAddr string) {
if l.EventLogger != nil {
l.EventLogger.UDPRequest(l.Conn.RemoteAddr(), l.AuthID, sessionID, reqAddr)
}
}
func (l *udpsmEventLogger) Closed(sessionID uint32, err error) {
if l.EventLogger != nil {
l.EventLogger.UDPError(l.Conn.RemoteAddr(), l.AuthID, sessionID, err)
}
}

218
core/server/udp.go Normal file
View file

@ -0,0 +1,218 @@
package server
import (
"errors"
"math/rand"
"sync"
"time"
"github.com/quic-go/quic-go"
"github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils"
)
const (
idleCleanupInterval = 1 * time.Second
)
type udpSessionManagerIO interface {
ReceiveMessage() (*protocol.UDPMessage, error)
SendMessage([]byte, *protocol.UDPMessage) error
DialUDP(reqAddr string) (UDPConn, error)
}
type udpSessionManagerEventLogger interface {
New(sessionID uint32, reqAddr string)
Closed(sessionID uint32, err error)
}
type udpSessionEntry struct {
ID uint32
Conn UDPConn
D *frag.Defragger
Last *utils.AtomicTime
Closed bool
}
// Feed feeds a UDP message to the session.
// If the message itself is a complete message, or it completes a fragmented message,
// the message is written to the session's UDP connection, and the number of bytes
// written is returned.
// Otherwise, 0 and nil are returned.
func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
e.Last.Set(time.Now())
dfMsg := e.D.Feed(msg)
if dfMsg == nil {
return 0, nil
}
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
}
// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
// and sends using the provided io.
// Exit and returns error when either the underlying UDP connection returns
// error (e.g. closed), or the provided io returns error when sending.
func (e *udpSessionEntry) ReceiveLoop(io udpSessionManagerIO) error {
udpBuf := make([]byte, protocol.MaxUDPSize)
msgBuf := make([]byte, protocol.MaxUDPSize)
for {
udpN, rAddr, err := e.Conn.ReadFrom(udpBuf)
if err != nil {
return err
}
e.Last.Set(time.Now())
msg := &protocol.UDPMessage{
SessionID: e.ID,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: rAddr,
Data: udpBuf[:udpN],
}
err = sendMessageAutoFrag(io, msgBuf, msg)
if err != nil {
return err
}
}
}
// sendMessageAutoFrag tries to send a UDP message as a whole first,
// but if it fails due to quic.ErrMessageTooLarge, it tries again by
// fragmenting the message.
func sendMessageAutoFrag(io udpSessionManagerIO, buf []byte, msg *protocol.UDPMessage) error {
err := io.SendMessage(buf, msg)
var errTooLarge quic.ErrMessageTooLarge
if errors.As(err, &errTooLarge) {
// Message too large, try fragmentation
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
for _, fMsg := range fMsgs {
err := io.SendMessage(buf, &fMsg)
if err != nil {
return err
}
}
return nil
} else {
return err
}
}
// udpSessionManager manages the lifecycle of UDP sessions.
// Each UDP session is identified by a SessionID, and corresponds to a UDP connection.
// A UDP session is created when a UDP message with a new SessionID is received.
// Similar to standard NAT, a UDP session is destroyed when no UDP message is received
// for a certain period of time (specified by idleTimeout).
type udpSessionManager struct {
io udpSessionManagerIO
eventLogger udpSessionManagerEventLogger
idleTimeout time.Duration
mutex sync.Mutex
m map[uint32]*udpSessionEntry
nextID uint32
}
func newUDPSessionManager(
io udpSessionManagerIO,
eventLogger udpSessionManagerEventLogger,
idleTimeout time.Duration,
) *udpSessionManager {
return &udpSessionManager{
io: io,
eventLogger: eventLogger,
idleTimeout: idleTimeout,
m: make(map[uint32]*udpSessionEntry),
}
}
// Run runs the session manager main loop.
// Exit and returns error when the underlying io returns error (e.g. closed).
func (m *udpSessionManager) Run() error {
stopCh := make(chan struct{})
go m.idleCleanupLoop(stopCh)
defer close(stopCh)
defer m.cleanup(false)
for {
msg, err := m.io.ReceiveMessage()
if err != nil {
return err
}
m.feed(msg)
}
}
func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) {
ticker := time.NewTicker(idleCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanup(true)
case <-stopCh:
return
}
}
}
func (m *udpSessionManager) cleanup(idleOnly bool) {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
for sessionID, entry := range m.m {
if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout {
entry.Closed = true
_ = entry.Conn.Close()
m.eventLogger.Closed(sessionID, nil)
delete(m.m, sessionID)
}
}
}
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
m.mutex.Lock()
entry := m.m[msg.SessionID]
if entry == nil {
// New session
m.eventLogger.New(msg.SessionID, msg.Addr)
conn, err := m.io.DialUDP(msg.Addr)
if err != nil {
m.mutex.Unlock()
m.eventLogger.Closed(msg.SessionID, err)
return
}
entry = &udpSessionEntry{
ID: msg.SessionID,
Conn: conn,
D: &frag.Defragger{},
Last: utils.NewAtomicTime(time.Now()),
}
// Start the receive loop for this session
go func() {
err := entry.ReceiveLoop(m.io)
// Receive loop stopped, remove the session
m.mutex.Lock()
if !entry.Closed {
entry.Closed = true
_ = entry.Conn.Close()
m.eventLogger.Closed(entry.ID, err)
delete(m.m, entry.ID)
}
m.mutex.Unlock()
}()
m.m[msg.SessionID] = entry
}
m.mutex.Unlock()
// Feed the message to the session
// Feed (send) errors are ignored for now,
// as some are temporary (e.g. invalid address)
_, _ = entry.Feed(msg)
}

191
core/server/udp_test.go Normal file
View file

@ -0,0 +1,191 @@
package server
import (
"errors"
"fmt"
"testing"
"time"
"github.com/apernet/hysteria/core/internal/protocol"
"go.uber.org/goleak"
)
type echoUDPConnPkt struct {
Data []byte
Addr string
Close bool
}
type echoUDPConn struct {
PktCh chan echoUDPConnPkt
}
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
pkt := <-c.PktCh
if pkt.Close {
return 0, "", errors.New("closed")
}
n := copy(b, pkt.Data)
return n, pkt.Addr, nil
}
func (c *echoUDPConn) WriteTo(b []byte, addr string) (int, error) {
nb := make([]byte, len(b))
copy(nb, b)
c.PktCh <- echoUDPConnPkt{
Data: nb,
Addr: addr,
}
return len(b), nil
}
func (c *echoUDPConn) Close() error {
c.PktCh <- echoUDPConnPkt{
Close: true,
}
return nil
}
type udpsmMockIO struct {
ReceiveCh <-chan *protocol.UDPMessage
SendCh chan<- *protocol.UDPMessage
}
func (io *udpsmMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
m := <-io.ReceiveCh
if m == nil {
return nil, errors.New("closed")
}
return m, nil
}
func (io *udpsmMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
nMsg := *msg
nMsg.Data = make([]byte, len(msg.Data))
copy(nMsg.Data, msg.Data)
io.SendCh <- &nMsg
return nil
}
func (io *udpsmMockIO) DialUDP(reqAddr string) (UDPConn, error) {
return &echoUDPConn{
PktCh: make(chan echoUDPConnPkt, 10),
}, nil
}
type udpsmMockEventNew struct {
SessionID uint32
ReqAddr string
}
type udpsmMockEventClosed struct {
SessionID uint32
Err error
}
type udpsmMockEventLogger struct {
NewCh chan<- udpsmMockEventNew
ClosedCh chan<- udpsmMockEventClosed
}
func (l *udpsmMockEventLogger) New(sessionID uint32, reqAddr string) {
l.NewCh <- udpsmMockEventNew{sessionID, reqAddr}
}
func (l *udpsmMockEventLogger) Closed(sessionID uint32, err error) {
l.ClosedCh <- udpsmMockEventClosed{sessionID, err}
}
func TestUDPSessionManager(t *testing.T) {
msgReceiveCh := make(chan *protocol.UDPMessage, 10)
msgSendCh := make(chan *protocol.UDPMessage, 10)
io := &udpsmMockIO{
ReceiveCh: msgReceiveCh,
SendCh: msgSendCh,
}
eventNewCh := make(chan udpsmMockEventNew, 10)
eventClosedCh := make(chan udpsmMockEventClosed, 10)
eventLogger := &udpsmMockEventLogger{
NewCh: eventNewCh,
ClosedCh: eventClosedCh,
}
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
go sm.Run()
ms := []*protocol.UDPMessage{
{
SessionID: 1234,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:5353",
Data: []byte("hello"),
},
{
SessionID: 5678,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:9999",
Data: []byte("goodbye"),
},
{
SessionID: 1234,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:5353",
Data: []byte(" world"),
},
{
SessionID: 5678,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:9999",
Data: []byte(" girl"),
},
}
for _, m := range ms {
msgReceiveCh <- m
}
// New event order should be consistent
newEvent := <-eventNewCh
if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" {
t.Error("unexpected new event value")
}
newEvent = <-eventNewCh
if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" {
t.Error("unexpected new event value")
}
// Message order is not guaranteed
msgMap := make(map[string]bool)
for i := 0; i < 4; i++ {
msg := <-msgSendCh
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
}
if !(msgMap["1234:example.com:5353:hello"] &&
msgMap["5678:example.com:9999:goodbye"] &&
msgMap["1234:example.com:5353: world"] &&
msgMap["5678:example.com:9999: girl"]) {
t.Error("unexpected message value")
}
// Timeout check
startTime := time.Now()
closedMap := make(map[uint32]bool)
for i := 0; i < 2; i++ {
closedEvent := <-eventClosedCh
closedMap[closedEvent.SessionID] = true
}
if !(closedMap[1234] && closedMap[5678]) {
t.Error("unexpected closed event value", closedMap)
}
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
t.Error("unexpected timeout duration")
}
// Goroutine leak check
msgReceiveCh <- nil
time.Sleep(1 * time.Second) // Wait for internal routines to exit
goleak.VerifyNone(t)
}

View file

@ -75,7 +75,7 @@ func (a *PluggableOutboundAdapter) DialTCP(reqAddr string) (net.Conn, error) {
})
}
func (a *PluggableOutboundAdapter) ListenUDP() (server.UDPConn, error) {
func (a *PluggableOutboundAdapter) DialUDP() (server.UDPConn, error) {
conn, err := a.PluggableOutbound.ListenUDP()
if err != nil {
return nil, err

View file

@ -68,10 +68,10 @@ func TestPluggableOutboundAdapter(t *testing.T) {
if err != errWrongAddr {
t.Fatal("DialTCP with wrong addr should fail, got", err)
}
// ListenUDP
uConn, err := adapter.ListenUDP()
// DialUDP
uConn, err := adapter.DialUDP()
if err != nil {
t.Fatal("ListenUDP failed", err)
t.Fatal("DialUDP failed", err)
}
// ReadFrom
b := make([]byte, 10)