mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 04:27:39 +03:00
feat: 1xRTT UDP relay
this commit makes changes on the hysteria protocol, although old clients still be able to connect to a newer server, newer clients will fail if they connect to older servers and trying to submit a udp request, so the protocolVersion should be bumped if this commit finally get merged. this commit changes the way to handle hyClient.DialUDP(). in the past, the hysteria client asks the server to create the sessionID in every call to hyClient.DialUDP(), which requires a extra RTT to wait the server reply. to avoid this extra RTT, the hysteria client just generates and manages the sessionID by theirselves. the server checks the sessionID sent from clients in every udpMessage, and open & initiate a new udp session for every sessionID it not recognized. the way to release udp sessions is also changed in this commit, as every udp session no longer maintains a quic stream, now the client will open a dedicated quic stream to notify the server to release specified udp session. this also changes the behavior of "max_conn_client" in the server config. this commit can be a partial fix for #348, #352 and #414.
This commit is contained in:
parent
f7de18fd43
commit
955a8a7470
4 changed files with 230 additions and 101 deletions
|
@ -41,9 +41,14 @@ type Client struct {
|
|||
reconnectMutex sync.Mutex
|
||||
closed bool
|
||||
|
||||
udpSessionMutex sync.RWMutex
|
||||
udpSessionMap map[uint32]chan *udpMessage
|
||||
udpDefragger defragger
|
||||
udpDisabled bool
|
||||
udpSessionMutex sync.RWMutex
|
||||
udpSessionMap map[uint32]chan *udpMessage
|
||||
nextUDPSessionID uint32
|
||||
udpDefragger defragger
|
||||
|
||||
udpControlStreamMux sync.Mutex
|
||||
udpControlStream quic.Stream
|
||||
}
|
||||
|
||||
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
|
@ -121,11 +126,13 @@ func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bo
|
|||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
ok := sh.Status != serverHelloStatusFailed
|
||||
c.udpDisabled = sh.Status == serverHelloStatusTCPOnly
|
||||
// Set the congestion accordingly
|
||||
if sh.OK && c.congestionFactory != nil {
|
||||
if ok && c.congestionFactory != nil {
|
||||
qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
|
||||
}
|
||||
return sh.OK, sh.Message, nil
|
||||
return ok, sh.Message, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleMessage(qs quic.Connection) {
|
||||
|
@ -194,7 +201,7 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
|||
}
|
||||
// Send request
|
||||
err = struc.Pack(stream, &clientRequest{
|
||||
UDP: false,
|
||||
Type: clientRequestTypeTCP,
|
||||
Host: host,
|
||||
Port: port,
|
||||
})
|
||||
|
@ -220,62 +227,91 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialUDP() (UDPConn, error) {
|
||||
session, stream, err := c.openStreamWithReconnect()
|
||||
func (c *Client) obtainsUDPControlStream() (quic.Stream, error) {
|
||||
c.udpControlStreamMux.Lock()
|
||||
defer c.udpControlStreamMux.Unlock()
|
||||
|
||||
if c.udpControlStream != nil {
|
||||
return c.udpControlStream, nil
|
||||
}
|
||||
|
||||
_, stream, err := c.openStreamWithReconnect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send request
|
||||
|
||||
err = struc.Pack(stream, &clientRequest{
|
||||
UDP: true,
|
||||
Type: clientRequestTypeUDPControl,
|
||||
})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
var sr serverResponse
|
||||
err = struc.Unpack(stream, &sr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !sr.OK {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
|
||||
|
||||
c.udpControlStream = stream
|
||||
return stream, err
|
||||
}
|
||||
|
||||
func (c *Client) DialUDP() (UDPConn, error) {
|
||||
if c.udpDisabled {
|
||||
return nil, errors.New("UDP is disabled by server side")
|
||||
}
|
||||
|
||||
// Create a session in the map
|
||||
c.udpSessionMutex.Lock()
|
||||
sessionID := c.nextUDPSessionID
|
||||
c.nextUDPSessionID++
|
||||
nCh := make(chan *udpMessage, 1024)
|
||||
// Store the current session map for CloseFunc below
|
||||
// to ensures that we are adding and removing sessions on the same map,
|
||||
// as reconnecting will reassign the map
|
||||
sessionMap := c.udpSessionMap
|
||||
sessionMap[sr.UDPSessionID] = nCh
|
||||
sessionMap[sessionID] = nCh
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
||||
pktConn := &quicPktConn{
|
||||
Session: session,
|
||||
Stream: stream,
|
||||
Session: c.quicSession,
|
||||
CloseFunc: func() {
|
||||
c.udpSessionMutex.Lock()
|
||||
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
|
||||
if ch, ok := sessionMap[sessionID]; ok {
|
||||
close(ch)
|
||||
delete(sessionMap, sr.UDPSessionID)
|
||||
delete(sessionMap, sessionID)
|
||||
}
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
||||
// tell server to release this session
|
||||
go func() {
|
||||
udpControlStream, err := c.obtainsUDPControlStream()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.udpControlStreamMux.Lock()
|
||||
defer c.udpControlStreamMux.Unlock()
|
||||
|
||||
err = struc.Pack(udpControlStream, &udpControlRequest{
|
||||
SessionID: sessionID,
|
||||
Operation: udpControlRequestOperationReleaseSession,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
},
|
||||
UDPSessionID: sr.UDPSessionID,
|
||||
UDPSessionID: sessionID,
|
||||
MsgCh: nCh,
|
||||
}
|
||||
go pktConn.Hold()
|
||||
return pktConn, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
c.udpControlStreamMux.Lock()
|
||||
if c.udpControlStream != nil {
|
||||
_ = c.udpControlStream.Close()
|
||||
}
|
||||
c.udpControlStreamMux.Unlock()
|
||||
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
|
||||
c.closed = true
|
||||
return err
|
||||
|
@ -327,24 +363,11 @@ type UDPConn interface {
|
|||
|
||||
type quicPktConn struct {
|
||||
Session quic.Connection
|
||||
Stream quic.Stream
|
||||
CloseFunc func()
|
||||
UDPSessionID uint32
|
||||
MsgCh <-chan *udpMessage
|
||||
}
|
||||
|
||||
func (c *quicPktConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.Stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
|
||||
msg := <-c.MsgCh
|
||||
if msg == nil {
|
||||
|
@ -395,5 +418,5 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error {
|
|||
|
||||
func (c *quicPktConn) Close() error {
|
||||
c.CloseFunc()
|
||||
return c.Stream.Close()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -25,15 +25,27 @@ type clientHello struct {
|
|||
Auth []byte
|
||||
}
|
||||
|
||||
const (
|
||||
serverHelloStatusFailed = uint8(0)
|
||||
serverHelloStatusOK = uint8(1)
|
||||
serverHelloStatusTCPOnly = uint8(2)
|
||||
)
|
||||
|
||||
type serverHello struct {
|
||||
OK bool
|
||||
Status uint8
|
||||
Rate transmissionRate
|
||||
MessageLen uint16 `struc:"sizeof=Message"`
|
||||
Message string
|
||||
}
|
||||
|
||||
const (
|
||||
clientRequestTypeTCP = uint8(0)
|
||||
clientRequestTypeUDPLegacy = uint8(1)
|
||||
clientRequestTypeUDPControl = uint8(2)
|
||||
)
|
||||
|
||||
type clientRequest struct {
|
||||
UDP bool
|
||||
Type uint8
|
||||
HostLen uint16 `struc:"sizeof=Host"`
|
||||
Host string
|
||||
Port uint16
|
||||
|
@ -74,3 +86,14 @@ type udpMessageV2 struct {
|
|||
DataLen uint16 `struc:"sizeof=Data"`
|
||||
Data []byte
|
||||
}
|
||||
|
||||
const (
|
||||
udpControlRequestOperationReleaseSession = uint8(1)
|
||||
)
|
||||
|
||||
type udpControlRequest struct {
|
||||
SessionID uint32
|
||||
Operation uint8
|
||||
DataLen uint16 `struc:"sizeof=Data"`
|
||||
Data []byte
|
||||
}
|
||||
|
|
|
@ -159,9 +159,16 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([]
|
|||
}
|
||||
// Auth
|
||||
ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS)
|
||||
status := serverHelloStatusFailed
|
||||
if ok {
|
||||
status = serverHelloStatusOK
|
||||
if s.disableUDP {
|
||||
status = serverHelloStatusTCPOnly
|
||||
}
|
||||
}
|
||||
// Response
|
||||
err = struc.Pack(stream, &serverHello{
|
||||
OK: ok,
|
||||
Status: status,
|
||||
Rate: transmissionRate{
|
||||
SendBPS: serverSendBPS,
|
||||
RecvBPS: serverRecvBPS,
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -108,18 +110,25 @@ func (c *serverClient) handleStream(stream quic.Stream) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !req.UDP {
|
||||
switch req.Type {
|
||||
case clientRequestTypeTCP:
|
||||
// TCP connection
|
||||
c.handleTCP(stream, req.Host, req.Port)
|
||||
} else if !c.DisableUDP {
|
||||
// UDP connection
|
||||
c.handleUDP(stream)
|
||||
} else {
|
||||
// UDP disabled
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "UDP disabled",
|
||||
})
|
||||
case clientRequestTypeUDPLegacy:
|
||||
if !c.DisableUDP {
|
||||
// UDP connection
|
||||
c.handleUDPLegacy(stream)
|
||||
} else {
|
||||
// UDP disabled
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "UDP disabled",
|
||||
})
|
||||
}
|
||||
case clientRequestTypeUDPControl:
|
||||
if !c.DisableUDP {
|
||||
c.handleUDPControlStream(stream)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -153,6 +162,10 @@ func (c *serverClient) handleMessage(msg []byte) {
|
|||
c.udpSessionMutex.RLock()
|
||||
conn, ok := c.udpSessionMap[dfMsg.SessionID]
|
||||
c.udpSessionMutex.RUnlock()
|
||||
if !ok {
|
||||
conn = c.handleUDPSessionCreate(udpMsg)
|
||||
ok = conn != nil
|
||||
}
|
||||
if ok {
|
||||
// Session found, send the message
|
||||
action, arg := acl.ActionDirect, ""
|
||||
|
@ -304,7 +317,32 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
|
|||
c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err)
|
||||
}
|
||||
|
||||
func (c *serverClient) handleUDP(stream quic.Stream) {
|
||||
func (c *serverClient) handleUDPSessionCreate(udpMsg udpMessage) transport.PUDPConn {
|
||||
conn, err := c.Transport.ListenUDP()
|
||||
if err != nil {
|
||||
c.CUDPErrorFunc(c.ClientAddr, c.Auth, udpMsg.SessionID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.udpSessionMutex.Lock()
|
||||
if origConn, ok := c.udpSessionMap[udpMsg.SessionID]; ok {
|
||||
_ = origConn.Close()
|
||||
}
|
||||
c.udpSessionMap[udpMsg.SessionID] = conn
|
||||
// effect same as udpMsg.SessionID >= c.nextUDPSessionID, but allows wrapping around
|
||||
if udpMsg.SessionID-c.nextUDPSessionID < math.MaxUint32/2 {
|
||||
c.nextUDPSessionID = udpMsg.SessionID + 1
|
||||
}
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
||||
c.CUDPRequestFunc(c.ClientAddr, c.Auth, udpMsg.SessionID)
|
||||
|
||||
go c.reverseRelayUDP(udpMsg.SessionID, conn)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (c *serverClient) handleUDPLegacy(stream quic.Stream) {
|
||||
// Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it
|
||||
conn, err := c.Transport.ListenUDP()
|
||||
if err != nil {
|
||||
|
@ -320,6 +358,9 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
|
|||
var id uint32
|
||||
c.udpSessionMutex.Lock()
|
||||
id = c.nextUDPSessionID
|
||||
if origConn, ok := c.udpSessionMap[id]; ok {
|
||||
_ = origConn.Close()
|
||||
}
|
||||
c.udpSessionMap[id] = conn
|
||||
c.nextUDPSessionID += 1
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
@ -335,52 +376,7 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
|
|||
|
||||
// Receive UDP packets, send them to the client
|
||||
go func() {
|
||||
buf := make([]byte, udpBufferSize)
|
||||
for {
|
||||
n, rAddr, err := conn.ReadFromUDP(buf)
|
||||
if n > 0 {
|
||||
var msgBuf bytes.Buffer
|
||||
if c.V2 {
|
||||
msg := udpMessageV2{
|
||||
SessionID: id,
|
||||
Host: rAddr.IP.String(),
|
||||
Port: uint16(rAddr.Port),
|
||||
Data: buf[:n],
|
||||
}
|
||||
_ = struc.Pack(&msgBuf, &msg)
|
||||
_ = c.CS.SendMessage(msgBuf.Bytes())
|
||||
} else {
|
||||
msg := udpMessage{
|
||||
SessionID: id,
|
||||
Host: rAddr.IP.String(),
|
||||
Port: uint16(rAddr.Port),
|
||||
FragCount: 1,
|
||||
Data: buf[:n],
|
||||
}
|
||||
// try no frag first
|
||||
_ = struc.Pack(&msgBuf, &msg)
|
||||
sendErr := c.CS.SendMessage(msgBuf.Bytes())
|
||||
if sendErr != nil {
|
||||
if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok {
|
||||
// need to frag
|
||||
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||
fragMsgs := fragUDPMessage(msg, int(errSize))
|
||||
for _, fragMsg := range fragMsgs {
|
||||
msgBuf.Reset()
|
||||
_ = struc.Pack(&msgBuf, &fragMsg)
|
||||
_ = c.CS.SendMessage(msgBuf.Bytes())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.DownCounter != nil {
|
||||
c.DownCounter.Add(float64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.reverseRelayUDP(id, conn)
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
|
@ -399,3 +395,83 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
|
|||
delete(c.udpSessionMap, id)
|
||||
c.udpSessionMutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *serverClient) reverseRelayUDP(sessionID uint32, conn transport.PUDPConn) {
|
||||
buf := make([]byte, udpBufferSize)
|
||||
for {
|
||||
n, rAddr, err := conn.ReadFromUDP(buf)
|
||||
if n > 0 {
|
||||
var msgBuf bytes.Buffer
|
||||
if c.V2 {
|
||||
msg := udpMessageV2{
|
||||
SessionID: sessionID,
|
||||
Host: rAddr.IP.String(),
|
||||
Port: uint16(rAddr.Port),
|
||||
Data: buf[:n],
|
||||
}
|
||||
_ = struc.Pack(&msgBuf, &msg)
|
||||
_ = c.CS.SendMessage(msgBuf.Bytes())
|
||||
} else {
|
||||
msg := udpMessage{
|
||||
SessionID: sessionID,
|
||||
Host: rAddr.IP.String(),
|
||||
Port: uint16(rAddr.Port),
|
||||
FragCount: 1,
|
||||
Data: buf[:n],
|
||||
}
|
||||
// try no frag first
|
||||
_ = struc.Pack(&msgBuf, &msg)
|
||||
sendErr := c.CS.SendMessage(msgBuf.Bytes())
|
||||
if sendErr != nil {
|
||||
if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok {
|
||||
// need to frag
|
||||
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||
fragMsgs := fragUDPMessage(msg, int(errSize))
|
||||
for _, fragMsg := range fragMsgs {
|
||||
msgBuf.Reset()
|
||||
_ = struc.Pack(&msgBuf, &fragMsg)
|
||||
_ = c.CS.SendMessage(msgBuf.Bytes())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.DownCounter != nil {
|
||||
c.DownCounter.Add(float64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverClient) handleUDPControlStream(stream quic.Stream) {
|
||||
for {
|
||||
var request udpControlRequest
|
||||
err := struc.Unpack(stream, &request)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
switch request.Operation {
|
||||
case udpControlRequestOperationReleaseSession:
|
||||
c.udpSessionMutex.Lock()
|
||||
conn, ok := c.udpSessionMap[request.SessionID]
|
||||
if ok {
|
||||
_ = conn.Close()
|
||||
delete(c.udpSessionMap, request.SessionID)
|
||||
}
|
||||
c.udpSessionMutex.Unlock()
|
||||
if ok {
|
||||
c.CUDPErrorFunc(c.ClientAddr, c.Auth, request.SessionID, errors.New("UDP session released by client"))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clear all udp session if the control stream is closed
|
||||
c.udpSessionMutex.Lock()
|
||||
for sid, conn := range c.udpSessionMap {
|
||||
_ = conn.Close()
|
||||
delete(c.udpSessionMap, sid)
|
||||
c.CUDPErrorFunc(c.ClientAddr, c.Auth, sid, errors.New("UDP session released for control stream closed"))
|
||||
}
|
||||
c.udpSessionMutex.Unlock()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue