mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-04 21:17:47 +03:00
feat(wip): udp rework client side
This commit is contained in:
parent
f142a24047
commit
cbedb27f0f
7 changed files with 391 additions and 391 deletions
|
@ -3,18 +3,13 @@ package client
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
coreErrs "github.com/apernet/hysteria/core/errors"
|
||||
"github.com/apernet/hysteria/core/internal/congestion"
|
||||
"github.com/apernet/hysteria/core/internal/frag"
|
||||
"github.com/apernet/hysteria/core/internal/protocol"
|
||||
"github.com/apernet/hysteria/core/internal/utils"
|
||||
|
||||
|
@ -23,8 +18,6 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
udpMessageChanSize = 1024
|
||||
|
||||
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
||||
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
|
||||
)
|
||||
|
@ -48,94 +41,25 @@ func NewClient(config *Config) (Client, error) {
|
|||
c := &clientImpl{
|
||||
config: config,
|
||||
}
|
||||
c.conn = &autoReconnectConn{
|
||||
Connect: c.connect,
|
||||
if err := c.connect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
config *Config
|
||||
conn *autoReconnectConn
|
||||
|
||||
udpSM udpSessionManager
|
||||
pktConn net.PacketConn
|
||||
conn quic.Connection
|
||||
|
||||
udpSM *udpSessionManager
|
||||
}
|
||||
|
||||
type udpSessionEntry struct {
|
||||
Ch chan *protocol.UDPMessage
|
||||
D *frag.Defragger
|
||||
Closed bool
|
||||
}
|
||||
|
||||
type udpSessionManager struct {
|
||||
mutex sync.RWMutex
|
||||
m map[uint32]*udpSessionEntry
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) Init() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.m = make(map[uint32]*udpSessionEntry)
|
||||
}
|
||||
|
||||
// Add returns both a channel for receiving messages and a function to close the channel & delete the session.
|
||||
func (m *udpSessionManager) Add(id uint32) (<-chan *protocol.UDPMessage, func()) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Important: make sure we add and delete the channel in the same map,
|
||||
// as the map may be replaced by Init() at any time.
|
||||
currentM := m.m
|
||||
|
||||
entry := &udpSessionEntry{
|
||||
Ch: make(chan *protocol.UDPMessage, udpMessageChanSize),
|
||||
D: &frag.Defragger{},
|
||||
Closed: false,
|
||||
}
|
||||
currentM[id] = entry
|
||||
|
||||
return entry.Ch, func() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if entry.Closed {
|
||||
// Double close a channel will panic,
|
||||
// so we need a flag to make sure we only close it once.
|
||||
return
|
||||
}
|
||||
entry.Closed = true
|
||||
close(entry.Ch)
|
||||
delete(currentM, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
entry, ok := m.m[msg.SessionID]
|
||||
if !ok {
|
||||
// No such session, drop the message
|
||||
return
|
||||
}
|
||||
dfMsg := entry.D.Feed(msg)
|
||||
if dfMsg == nil {
|
||||
// Not a complete message yet
|
||||
return
|
||||
}
|
||||
select {
|
||||
case entry.Ch <- dfMsg:
|
||||
// OK
|
||||
default:
|
||||
// Channel is full, drop the message
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientImpl) connect() (quic.Connection, func(), error) {
|
||||
// Use a new packet conn for each connection,
|
||||
// remember to close it after the QUIC connection is closed.
|
||||
func (c *clientImpl) connect() error {
|
||||
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
// Convert config to TLS config & QUIC config
|
||||
tlsConfig := &tls.Config{
|
||||
|
@ -185,15 +109,15 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) {
|
|||
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
||||
}
|
||||
_ = pktConn.Close()
|
||||
return nil, nil, &coreErrs.ConnectError{Err: err}
|
||||
return &coreErrs.ConnectError{Err: err}
|
||||
}
|
||||
if resp.StatusCode != protocol.StatusAuthOK {
|
||||
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
||||
_ = pktConn.Close()
|
||||
return nil, nil, &coreErrs.AuthError{StatusCode: resp.StatusCode}
|
||||
return &coreErrs.AuthError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
// Auth OK
|
||||
serverRx := protocol.AuthResponseDataFromHeader(resp.Header)
|
||||
udpEnabled, serverRx := protocol.AuthResponseDataFromHeader(resp.Header)
|
||||
// actualTx = min(serverRx, clientTx)
|
||||
actualTx := serverRx
|
||||
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
|
||||
|
@ -205,46 +129,20 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) {
|
|||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
c.udpSM.Init()
|
||||
go c.udpLoop(conn)
|
||||
|
||||
return conn, func() {
|
||||
_ = conn.CloseWithError(closeErrCodeOK, "")
|
||||
_ = pktConn.Close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) udpLoop(conn quic.Connection) {
|
||||
for {
|
||||
msg, err := conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return
|
||||
c.pktConn = pktConn
|
||||
c.conn = conn
|
||||
if udpEnabled {
|
||||
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
|
||||
go func() {
|
||||
c.udpSM.Run()
|
||||
// TODO: Mark connection as closed
|
||||
}()
|
||||
}
|
||||
c.handleUDPMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// client <- remote direction
|
||||
func (c *clientImpl) handleUDPMessage(msg []byte) {
|
||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.udpSM.Feed(udpMsg)
|
||||
}
|
||||
|
||||
// openStream wraps the stream with QStream, which handles Close() properly
|
||||
func (c *clientImpl) openStream() (quic.Connection, quic.Stream, error) {
|
||||
qc, stream, err := c.conn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return qc, &utils.QStream{Stream: stream}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
||||
qc, stream, err := c.openStream()
|
||||
stream, err := c.openStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -260,8 +158,8 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
|||
// to the first Read() call.
|
||||
return &tcpConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: qc.LocalAddr(),
|
||||
PseudoRemoteAddr: qc.RemoteAddr(),
|
||||
PseudoLocalAddr: c.conn.LocalAddr(),
|
||||
PseudoRemoteAddr: c.conn.RemoteAddr(),
|
||||
Established: false,
|
||||
}, nil
|
||||
}
|
||||
|
@ -277,49 +175,23 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
|||
}
|
||||
return &tcpConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: qc.LocalAddr(),
|
||||
PseudoRemoteAddr: qc.RemoteAddr(),
|
||||
PseudoLocalAddr: c.conn.LocalAddr(),
|
||||
PseudoRemoteAddr: c.conn.RemoteAddr(),
|
||||
Established: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) ListenUDP() (HyUDPConn, error) {
|
||||
qc, stream, err := c.openStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if c.udpSM == nil {
|
||||
return nil, coreErrs.DialError{Message: "UDP not enabled"}
|
||||
}
|
||||
// Send request
|
||||
err = protocol.WriteUDPRequest(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
ok, sessionID, msg, err := protocol.ReadUDPResponse(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
_ = stream.Close()
|
||||
return nil, coreErrs.DialError{Message: msg}
|
||||
}
|
||||
|
||||
ch, closeFunc := c.udpSM.Add(sessionID)
|
||||
uc := &udpConn{
|
||||
QC: qc,
|
||||
Stream: stream,
|
||||
SessionID: sessionID,
|
||||
Ch: ch,
|
||||
CloseFunc: closeFunc,
|
||||
SendBuf: make([]byte, protocol.MaxUDPSize),
|
||||
}
|
||||
go uc.Hold()
|
||||
return uc, nil
|
||||
return c.udpSM.NewUDP()
|
||||
}
|
||||
|
||||
func (c *clientImpl) Close() error {
|
||||
return c.conn.Close()
|
||||
_ = c.conn.CloseWithError(closeErrCodeOK, "")
|
||||
_ = c.pktConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type tcpConn struct {
|
||||
|
@ -372,72 +244,40 @@ func (c *tcpConn) SetWriteDeadline(t time.Time) error {
|
|||
return c.Orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
QC quic.Connection
|
||||
Stream quic.Stream
|
||||
SessionID uint32
|
||||
Ch <-chan *protocol.UDPMessage
|
||||
CloseFunc func()
|
||||
SendBuf []byte
|
||||
type udpIOImpl struct {
|
||||
Conn quic.Connection
|
||||
}
|
||||
|
||||
func (c *udpConn) Hold() {
|
||||
// Hold (drain) the stream until someone closes it.
|
||||
// Closing the stream is the signal to stop the UDP session.
|
||||
_, _ = io.Copy(io.Discard, c.Stream)
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *udpConn) Receive() ([]byte, string, error) {
|
||||
msg := <-c.Ch
|
||||
if msg == nil {
|
||||
// Closed
|
||||
return nil, "", io.EOF
|
||||
}
|
||||
return msg.Data, msg.Addr, nil
|
||||
}
|
||||
|
||||
// Send is not thread-safe as it uses a shared send buffer for now.
|
||||
func (c *udpConn) Send(data []byte, addr string) error {
|
||||
// Try no frag first
|
||||
msg := &protocol.UDPMessage{
|
||||
SessionID: c.SessionID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: addr,
|
||||
Data: data,
|
||||
}
|
||||
n := msg.Serialize(c.SendBuf)
|
||||
if n < 0 {
|
||||
// Message even larger than MaxUDPSize, drop it
|
||||
// Maybe we should return an error in the future?
|
||||
return nil
|
||||
}
|
||||
sendErr := c.QC.SendMessage(c.SendBuf[:n])
|
||||
if sendErr == nil {
|
||||
// All good
|
||||
return nil
|
||||
}
|
||||
var errTooLarge quic.ErrMessageTooLarge
|
||||
if errors.As(sendErr, &errTooLarge) {
|
||||
// Message too large, try fragmentation
|
||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||
for _, fMsg := range fMsgs {
|
||||
n = fMsg.Serialize(c.SendBuf)
|
||||
err := c.QC.SendMessage(c.SendBuf[:n])
|
||||
func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
for {
|
||||
msg, err := io.Conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
// Connection error, this will stop the session manager
|
||||
return nil, err
|
||||
}
|
||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||
if err != nil {
|
||||
// Invalid message, this is fine - just wait for the next
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
return udpMsg, nil
|
||||
}
|
||||
// Other error
|
||||
return sendErr
|
||||
}
|
||||
|
||||
func (c *udpConn) Close() error {
|
||||
c.CloseFunc()
|
||||
return c.Stream.Close()
|
||||
func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||
msgN := msg.Serialize(buf)
|
||||
if msgN < 0 {
|
||||
// Message larger than buffer, silent drop
|
||||
return nil
|
||||
}
|
||||
return io.Conn.SendMessage(buf[:msgN])
|
||||
}
|
||||
|
||||
// openStream wraps the stream with QStream, which handles Close() properly
|
||||
func (c *clientImpl) openStream() (quic.Stream, error) {
|
||||
stream, err := c.conn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &utils.QStream{Stream: stream}, nil
|
||||
}
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// autoReconnectConn is a wrapper of quic.Connection that automatically reconnects
|
||||
// when a non-temporary error (usually a timeout) occurs.
|
||||
type autoReconnectConn struct {
|
||||
// Connect is called whenever a new QUIC connection is needed.
|
||||
// It should return a new QUIC connection, a function to close the connection
|
||||
// (and potentially other underlying resources), and an error if one occurred.
|
||||
Connect func() (quic.Connection, func(), error)
|
||||
|
||||
conn quic.Connection
|
||||
closeFunc func()
|
||||
connMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *autoReconnectConn) OpenStream() (quic.Connection, quic.Stream, error) {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
// First time?
|
||||
if c.conn == nil {
|
||||
conn, closeFunc, err := c.Connect()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c.conn = conn
|
||||
c.closeFunc = closeFunc
|
||||
}
|
||||
stream, err := c.conn.OpenStream()
|
||||
if err == nil {
|
||||
// All is good
|
||||
return c.conn, stream, nil
|
||||
} else if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
|
||||
// Temporary error, just pass the error to the caller
|
||||
return nil, nil, err
|
||||
} else {
|
||||
// Permanent error
|
||||
// Close the previous connection,
|
||||
// reconnect and try again (only once)
|
||||
c.closeFunc()
|
||||
conn, closeFunc, err := c.Connect()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c.conn = conn
|
||||
c.closeFunc = closeFunc
|
||||
stream, err = c.conn.OpenStream()
|
||||
return c.conn, stream, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *autoReconnectConn) Close() error {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
c.closeFunc()
|
||||
c.conn = nil
|
||||
c.closeFunc = nil
|
||||
return nil
|
||||
}
|
177
core/client/udp.go
Normal file
177
core/client/udp.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"sync"
|
||||
|
||||
"github.com/apernet/hysteria/core/internal/frag"
|
||||
"github.com/apernet/hysteria/core/internal/protocol"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
const (
|
||||
udpMessageChanSize = 1024
|
||||
)
|
||||
|
||||
type udpIO interface {
|
||||
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||
SendMessage([]byte, *protocol.UDPMessage) error
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
ID uint32
|
||||
D *frag.Defragger
|
||||
ReceiveCh chan *protocol.UDPMessage
|
||||
SendBuf []byte
|
||||
SendFunc func([]byte, *protocol.UDPMessage) error
|
||||
CloseFunc func()
|
||||
Closed bool
|
||||
}
|
||||
|
||||
func (u *udpConn) Receive() ([]byte, string, error) {
|
||||
for {
|
||||
msg := <-u.ReceiveCh
|
||||
if msg == nil {
|
||||
// Closed
|
||||
return nil, "", io.EOF
|
||||
}
|
||||
dfMsg := u.D.Feed(msg)
|
||||
if dfMsg == nil {
|
||||
// Incomplete message, wait for more
|
||||
continue
|
||||
}
|
||||
return dfMsg.Data, dfMsg.Addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Send is not thread-safe, as it uses a shared SendBuf.
|
||||
func (u *udpConn) Send(data []byte, addr string) error {
|
||||
// Try no frag first
|
||||
msg := &protocol.UDPMessage{
|
||||
SessionID: u.ID,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: addr,
|
||||
Data: data,
|
||||
}
|
||||
err := u.SendFunc(u.SendBuf, msg)
|
||||
var errTooLarge quic.ErrMessageTooLarge
|
||||
if errors.As(err, &errTooLarge) {
|
||||
// Message too large, try fragmentation
|
||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||
for _, fMsg := range fMsgs {
|
||||
err := u.SendFunc(u.SendBuf, &fMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (u *udpConn) Close() error {
|
||||
u.CloseFunc()
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpSessionManager struct {
|
||||
io udpIO
|
||||
|
||||
mutex sync.Mutex
|
||||
m map[uint32]*udpConn
|
||||
nextID uint32
|
||||
}
|
||||
|
||||
func newUDPSessionManager(io udpIO) *udpSessionManager {
|
||||
return &udpSessionManager{
|
||||
io: io,
|
||||
m: make(map[uint32]*udpConn),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Run runs the session manager main loop.
|
||||
// Exit and returns error when the underlying io returns error (e.g. closed).
|
||||
func (m *udpSessionManager) Run() error {
|
||||
defer m.cleanup()
|
||||
|
||||
for {
|
||||
msg, err := m.io.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.feed(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) cleanup() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, conn := range m.m {
|
||||
m.close(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
conn, ok := m.m[msg.SessionID]
|
||||
if !ok {
|
||||
// Ignore message from unknown session
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case conn.ReceiveCh <- msg:
|
||||
// OK
|
||||
default:
|
||||
// Channel full, drop the message
|
||||
}
|
||||
}
|
||||
|
||||
// NewUDP creates a new UDP session.
|
||||
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
id := m.nextID
|
||||
m.nextID++
|
||||
|
||||
conn := &udpConn{
|
||||
ID: id,
|
||||
D: &frag.Defragger{},
|
||||
ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize),
|
||||
SendBuf: make([]byte, protocol.MaxUDPSize),
|
||||
SendFunc: m.io.SendMessage,
|
||||
}
|
||||
conn.CloseFunc = func() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if !conn.Closed {
|
||||
m.close(conn)
|
||||
}
|
||||
}
|
||||
m.m[id] = conn
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) close(conn *udpConn) {
|
||||
conn.Closed = true
|
||||
close(conn.ReceiveCh)
|
||||
delete(m.m, conn.ID)
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) Count() int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return len(m.m)
|
||||
}
|
|
@ -106,25 +106,4 @@ func TestServerMasquerade(t *testing.T) {
|
|||
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
|
||||
t.Fatal("expected timeout, got", err)
|
||||
}
|
||||
|
||||
// Try UDP request
|
||||
udpStream, err := conn.OpenStream()
|
||||
if err != nil {
|
||||
t.Fatal("error opening stream:", err)
|
||||
}
|
||||
defer udpStream.Close()
|
||||
err = protocol.WriteUDPRequest(udpStream)
|
||||
if err != nil {
|
||||
t.Fatal("error sending request:", err)
|
||||
}
|
||||
|
||||
// We should receive nothing
|
||||
_ = udpStream.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err = udpStream.Read(buf)
|
||||
if n != 0 {
|
||||
t.Fatal("expected no response, got", n)
|
||||
}
|
||||
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
|
||||
t.Fatal("expected timeout, got", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -287,7 +287,7 @@ func (l *channelEventLogger) TCPError(addr net.Addr, id, reqAddr string, err err
|
|||
}
|
||||
}
|
||||
|
||||
func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) {
|
||||
func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string) {
|
||||
if l.UDPRequestEventCh != nil {
|
||||
l.UDPRequestEventCh <- udpRequestEvent{
|
||||
Addr: addr,
|
||||
|
|
|
@ -113,7 +113,6 @@ type udpSessionManager struct {
|
|||
|
||||
mutex sync.Mutex
|
||||
m map[uint32]*udpSessionEntry
|
||||
nextID uint32
|
||||
}
|
||||
|
||||
func newUDPSessionManager(io udpIO, eventLogger udpEventLogger, idleTimeout time.Duration) *udpSessionManager {
|
||||
|
@ -212,3 +211,9 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||
// as some are temporary (e.g. invalid address)
|
||||
_, _ = entry.Feed(msg)
|
||||
}
|
||||
|
||||
func (m *udpSessionManager) Count() int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return len(m.m)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,11 @@ import (
|
|||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
var (
|
||||
errUDPBlocked = errors.New("blocked")
|
||||
errUDPClosed = errors.New("closed")
|
||||
)
|
||||
|
||||
type echoUDPConnPkt struct {
|
||||
Data []byte
|
||||
Addr string
|
||||
|
@ -23,7 +28,7 @@ type echoUDPConn struct {
|
|||
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
|
||||
pkt := <-c.PktCh
|
||||
if pkt.Close {
|
||||
return 0, "", errors.New("closed")
|
||||
return 0, "", errUDPClosed
|
||||
}
|
||||
n := copy(b, pkt.Data)
|
||||
return n, pkt.Addr, nil
|
||||
|
@ -49,12 +54,14 @@ func (c *echoUDPConn) Close() error {
|
|||
type udpMockIO struct {
|
||||
ReceiveCh <-chan *protocol.UDPMessage
|
||||
SendCh chan<- *protocol.UDPMessage
|
||||
UDPClose bool // ReadFrom() returns error immediately
|
||||
BlockUDP bool // Block UDP connection creation
|
||||
}
|
||||
|
||||
func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||
m := <-io.ReceiveCh
|
||||
if m == nil {
|
||||
return nil, errors.New("closed")
|
||||
return nil, errUDPClosed
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
@ -68,9 +75,18 @@ func (io *udpMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
|||
}
|
||||
|
||||
func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) {
|
||||
return &echoUDPConn{
|
||||
if io.BlockUDP {
|
||||
return nil, errUDPBlocked
|
||||
}
|
||||
conn := &echoUDPConn{
|
||||
PktCh: make(chan echoUDPConnPkt, 10),
|
||||
}, nil
|
||||
}
|
||||
if io.UDPClose {
|
||||
conn.PktCh <- echoUDPConnPkt{
|
||||
Close: true,
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type udpMockEventNew struct {
|
||||
|
@ -112,6 +128,7 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
|
||||
go sm.Run()
|
||||
|
||||
t.Run("session creation & timeout", func(t *testing.T) {
|
||||
ms := []*protocol.UDPMessage{
|
||||
{
|
||||
SessionID: 1234,
|
||||
|
@ -183,9 +200,59 @@ func TestUDPSessionManager(t *testing.T) {
|
|||
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
|
||||
t.Error("unexpected timeout duration")
|
||||
}
|
||||
})
|
||||
|
||||
// Goroutine leak check
|
||||
t.Run("UDP connection close", func(t *testing.T) {
|
||||
// Close UDP connection immediately after creation
|
||||
io.UDPClose = true
|
||||
|
||||
msgReceiveCh <- &protocol.UDPMessage{
|
||||
SessionID: 8888,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: "mygod.org:1514",
|
||||
Data: []byte("goodnight"),
|
||||
}
|
||||
// Should have both new and close events immediately
|
||||
newEvent := <-eventNewCh
|
||||
if newEvent.SessionID != 8888 || newEvent.ReqAddr != "mygod.org:1514" {
|
||||
t.Error("unexpected new event value")
|
||||
}
|
||||
closeEvent := <-eventCloseCh
|
||||
if closeEvent.SessionID != 8888 || closeEvent.Err != errUDPClosed {
|
||||
t.Error("unexpected close event value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UDP IO failure", func(t *testing.T) {
|
||||
// Block UDP connection creation
|
||||
io.BlockUDP = true
|
||||
|
||||
msgReceiveCh <- &protocol.UDPMessage{
|
||||
SessionID: 9999,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: "xxx.net:12450",
|
||||
Data: []byte("nope"),
|
||||
}
|
||||
// Should have both new and close events immediately
|
||||
newEvent := <-eventNewCh
|
||||
if newEvent.SessionID != 9999 || newEvent.ReqAddr != "xxx.net:12450" {
|
||||
t.Error("unexpected new event value")
|
||||
}
|
||||
closeEvent := <-eventCloseCh
|
||||
if closeEvent.SessionID != 9999 || closeEvent.Err != errUDPBlocked {
|
||||
t.Error("unexpected close event value")
|
||||
}
|
||||
})
|
||||
|
||||
// Leak checks
|
||||
msgReceiveCh <- nil
|
||||
time.Sleep(1 * time.Second) // Wait for internal routines to exit
|
||||
time.Sleep(1 * time.Second) // Give some time for the goroutines to exit
|
||||
if sm.Count() != 0 {
|
||||
t.Error("session count should be 0")
|
||||
}
|
||||
goleak.VerifyNone(t)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue