hysteria/core/client/udp.go

185 lines
3.3 KiB
Go

package client
import (
"errors"
"io"
"math/rand"
"sync"
"github.com/apernet/quic-go"
coreErrs "github.com/apernet/hysteria/core/v2/errors"
"github.com/apernet/hysteria/core/v2/internal/frag"
"github.com/apernet/hysteria/core/v2/internal/protocol"
)
const (
udpMessageChanSize = 1024
)
type udpIO interface {
ReceiveMessage() (*protocol.UDPMessage, error)
SendMessage([]byte, *protocol.UDPMessage) error
}
type udpConn struct {
ID uint32
D *frag.Defragger
ReceiveCh chan *protocol.UDPMessage
SendBuf []byte
SendFunc func([]byte, *protocol.UDPMessage) error
CloseFunc func()
Closed bool
}
func (u *udpConn) Receive() ([]byte, string, error) {
for {
msg := <-u.ReceiveCh
if msg == nil {
// Closed
return nil, "", io.EOF
}
dfMsg := u.D.Feed(msg)
if dfMsg == nil {
// Incomplete message, wait for more
continue
}
return dfMsg.Data, dfMsg.Addr, nil
}
}
// Send is not thread-safe, as it uses a shared SendBuf.
func (u *udpConn) Send(data []byte, addr string) error {
// Try no frag first
msg := &protocol.UDPMessage{
SessionID: u.ID,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: addr,
Data: data,
}
err := u.SendFunc(u.SendBuf, msg)
var errTooLarge *quic.DatagramTooLargeError
if errors.As(err, &errTooLarge) {
// Message too large, try fragmentation
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge.MaxDataLen))
for _, fMsg := range fMsgs {
err := u.SendFunc(u.SendBuf, &fMsg)
if err != nil {
return err
}
}
return nil
} else {
return err
}
}
func (u *udpConn) Close() error {
u.CloseFunc()
return nil
}
type udpSessionManager struct {
io udpIO
mutex sync.RWMutex
m map[uint32]*udpConn
nextID uint32
closed bool
}
func newUDPSessionManager(io udpIO) *udpSessionManager {
m := &udpSessionManager{
io: io,
m: make(map[uint32]*udpConn),
nextID: 1,
}
go m.run()
return m
}
func (m *udpSessionManager) run() error {
defer m.closeCleanup()
for {
msg, err := m.io.ReceiveMessage()
if err != nil {
return err
}
m.feed(msg)
}
}
func (m *udpSessionManager) closeCleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, conn := range m.m {
m.close(conn)
}
m.closed = true
}
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
m.mutex.RLock()
defer m.mutex.RUnlock()
conn, ok := m.m[msg.SessionID]
if !ok {
// Ignore message from unknown session
return
}
select {
case conn.ReceiveCh <- msg:
// OK
default:
// Channel full, drop the message
}
}
// NewUDP creates a new UDP session.
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closed {
return nil, coreErrs.ClosedError{}
}
id := m.nextID
m.nextID++
conn := &udpConn{
ID: id,
D: &frag.Defragger{},
ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize),
SendBuf: make([]byte, protocol.MaxUDPSize),
SendFunc: m.io.SendMessage,
}
conn.CloseFunc = func() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.close(conn)
}
m.m[id] = conn
return conn, nil
}
func (m *udpSessionManager) close(conn *udpConn) {
if !conn.Closed {
conn.Closed = true
close(conn.ReceiveCh)
delete(m.m, conn.ID)
}
}
func (m *udpSessionManager) Count() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.m)
}