hysteria/core/server/udp_test.go

261 lines
6.1 KiB
Go

package server
import (
"errors"
"fmt"
"testing"
"time"
"github.com/apernet/hysteria/core/internal/protocol"
"go.uber.org/goleak"
)
var (
errUDPBlocked = errors.New("blocked")
errUDPClosed = errors.New("closed")
)
type echoUDPConnPkt struct {
Data []byte
Addr string
Close bool
}
type echoUDPConn struct {
PktCh chan echoUDPConnPkt
}
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
pkt := <-c.PktCh
if pkt.Close {
return 0, "", errUDPClosed
}
n := copy(b, pkt.Data)
return n, pkt.Addr, nil
}
func (c *echoUDPConn) WriteTo(b []byte, addr string) (int, error) {
nb := make([]byte, len(b))
copy(nb, b)
c.PktCh <- echoUDPConnPkt{
Data: nb,
Addr: addr,
}
return len(b), nil
}
func (c *echoUDPConn) Close() error {
c.PktCh <- echoUDPConnPkt{
Close: true,
}
return nil
}
type udpMockIO struct {
ReceiveCh <-chan *protocol.UDPMessage
SendCh chan<- *protocol.UDPMessage
UDPClose bool // ReadFrom() returns error immediately
BlockUDP bool // Block UDP connection creation
}
func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
m := <-io.ReceiveCh
if m == nil {
return nil, errUDPClosed
}
return m, nil
}
func (io *udpMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
nMsg := *msg
nMsg.Data = make([]byte, len(msg.Data))
copy(nMsg.Data, msg.Data)
io.SendCh <- &nMsg
return nil
}
func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) {
if io.BlockUDP {
return nil, errUDPBlocked
}
conn := &echoUDPConn{
PktCh: make(chan echoUDPConnPkt, 10),
}
if io.UDPClose {
conn.PktCh <- echoUDPConnPkt{
Close: true,
}
}
return conn, nil
}
type udpMockEventNew struct {
SessionID uint32
ReqAddr string
}
type udpMockEventClose struct {
SessionID uint32
Err error
}
type udpMockEventLogger struct {
NewCh chan<- udpMockEventNew
CloseCh chan<- udpMockEventClose
}
func (l *udpMockEventLogger) New(sessionID uint32, reqAddr string) {
l.NewCh <- udpMockEventNew{sessionID, reqAddr}
}
func (l *udpMockEventLogger) Close(sessionID uint32, err error) {
l.CloseCh <- udpMockEventClose{sessionID, err}
}
func TestUDPSessionManager(t *testing.T) {
msgReceiveCh := make(chan *protocol.UDPMessage, 10)
msgSendCh := make(chan *protocol.UDPMessage, 10)
io := &udpMockIO{
ReceiveCh: msgReceiveCh,
SendCh: msgSendCh,
}
eventNewCh := make(chan udpMockEventNew, 10)
eventCloseCh := make(chan udpMockEventClose, 10)
eventLogger := &udpMockEventLogger{
NewCh: eventNewCh,
CloseCh: eventCloseCh,
}
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
go sm.Run()
t.Run("session creation & timeout", func(t *testing.T) {
ms := []*protocol.UDPMessage{
{
SessionID: 1234,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:5353",
Data: []byte("hello"),
},
{
SessionID: 5678,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:9999",
Data: []byte("goodbye"),
},
{
SessionID: 1234,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:5353",
Data: []byte(" world"),
},
{
SessionID: 5678,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "example.com:9999",
Data: []byte(" girl"),
},
}
for _, m := range ms {
msgReceiveCh <- m
}
// New event order should be consistent with message order
for i := 0; i < 2; i++ {
newEvent := <-eventNewCh
if newEvent.SessionID != ms[i].SessionID || newEvent.ReqAddr != ms[i].Addr {
t.Errorf("unexpected new event value: %d:%s", newEvent.SessionID, newEvent.ReqAddr)
}
}
// Message order is not guaranteed so we use a map
msgMap := make(map[string]bool)
for i := 0; i < 4; i++ {
msg := <-msgSendCh
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
}
for _, m := range ms {
key := fmt.Sprintf("%d:%s:%s", m.SessionID, m.Addr, string(m.Data))
if !msgMap[key] {
t.Errorf("missing message: %s", key)
}
}
// Timeout check
startTime := time.Now()
closeMap := make(map[uint32]bool)
for i := 0; i < 2; i++ {
closeEvent := <-eventCloseCh
closeMap[closeEvent.SessionID] = true
}
for i := 0; i < 2; i++ {
if !closeMap[ms[i].SessionID] {
t.Errorf("missing close event: %d", ms[i].SessionID)
}
}
dur := time.Since(startTime)
if dur < 2*time.Second || dur > 4*time.Second {
t.Errorf("unexpected timeout duration: %s", dur)
}
})
t.Run("UDP connection close", func(t *testing.T) {
// Close UDP connection immediately after creation
io.UDPClose = true
m := &protocol.UDPMessage{
SessionID: 8888,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "mygod.org:1514",
Data: []byte("goodnight"),
}
msgReceiveCh <- m
// Should have both new and close events immediately
newEvent := <-eventNewCh
if newEvent.SessionID != m.SessionID || newEvent.ReqAddr != m.Addr {
t.Errorf("unexpected new event value: %d:%s", newEvent.SessionID, newEvent.ReqAddr)
}
closeEvent := <-eventCloseCh
if closeEvent.SessionID != m.SessionID || closeEvent.Err != errUDPClosed {
t.Errorf("unexpected close event value: %d:%s", closeEvent.SessionID, closeEvent.Err)
}
})
t.Run("UDP IO failure", func(t *testing.T) {
// Block UDP connection creation
io.BlockUDP = true
m := &protocol.UDPMessage{
SessionID: 9999,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "xxx.net:12450",
Data: []byte("nope"),
}
msgReceiveCh <- m
// Should have both new and close events immediately
newEvent := <-eventNewCh
if newEvent.SessionID != m.SessionID || newEvent.ReqAddr != m.Addr {
t.Errorf("unexpected new event value: %d:%s", newEvent.SessionID, newEvent.ReqAddr)
}
closeEvent := <-eventCloseCh
if closeEvent.SessionID != m.SessionID || closeEvent.Err != errUDPBlocked {
t.Errorf("unexpected close event value: %d:%s", closeEvent.SessionID, closeEvent.Err)
}
})
// Leak checks
msgReceiveCh <- nil
time.Sleep(1 * time.Second) // Give some time for the goroutines to exit
if sm.Count() != 0 {
t.Error("session count should be 0")
}
goleak.VerifyNone(t)
}