From 955a8a7470f56fea92e7b3d9e8fae5812c3685e7 Mon Sep 17 00:00:00 2001 From: Haruue Icymoon Date: Sun, 28 Aug 2022 17:55:40 +0800 Subject: [PATCH] feat: 1xRTT UDP relay this commit makes changes on the hysteria protocol, although old clients still be able to connect to a newer server, newer clients will fail if they connect to older servers and trying to submit a udp request, so the protocolVersion should be bumped if this commit finally get merged. this commit changes the way to handle hyClient.DialUDP(). in the past, the hysteria client asks the server to create the sessionID in every call to hyClient.DialUDP(), which requires a extra RTT to wait the server reply. to avoid this extra RTT, the hysteria client just generates and manages the sessionID by theirselves. the server checks the sessionID sent from clients in every udpMessage, and open & initiate a new udp session for every sessionID it not recognized. the way to release udp sessions is also changed in this commit, as every udp session no longer maintains a quic stream, now the client will open a dedicated quic stream to notify the server to release specified udp session. this also changes the behavior of "max_conn_client" in the server config. this commit can be a partial fix for #348, #352 and #414. --- pkg/core/client.go | 105 +++++++++++++-------- pkg/core/protocol.go | 27 +++++- pkg/core/server.go | 9 +- pkg/core/server_client.go | 190 ++++++++++++++++++++++++++------------ 4 files changed, 230 insertions(+), 101 deletions(-) diff --git a/pkg/core/client.go b/pkg/core/client.go index adc8548..63d6427 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -41,9 +41,14 @@ type Client struct { reconnectMutex sync.Mutex closed bool - udpSessionMutex sync.RWMutex - udpSessionMap map[uint32]chan *udpMessage - udpDefragger defragger + udpDisabled bool + udpSessionMutex sync.RWMutex + udpSessionMap map[uint32]chan *udpMessage + nextUDPSessionID uint32 + udpDefragger defragger + + udpControlStreamMux sync.Mutex + udpControlStream quic.Stream } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, @@ -121,11 +126,13 @@ func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bo if err != nil { return false, "", err } + ok := sh.Status != serverHelloStatusFailed + c.udpDisabled = sh.Status == serverHelloStatusTCPOnly // Set the congestion accordingly - if sh.OK && c.congestionFactory != nil { + if ok && c.congestionFactory != nil { qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS)) } - return sh.OK, sh.Message, nil + return ok, sh.Message, nil } func (c *Client) handleMessage(qs quic.Connection) { @@ -194,7 +201,7 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) { } // Send request err = struc.Pack(stream, &clientRequest{ - UDP: false, + Type: clientRequestTypeTCP, Host: host, Port: port, }) @@ -220,62 +227,91 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) { }, nil } -func (c *Client) DialUDP() (UDPConn, error) { - session, stream, err := c.openStreamWithReconnect() +func (c *Client) obtainsUDPControlStream() (quic.Stream, error) { + c.udpControlStreamMux.Lock() + defer c.udpControlStreamMux.Unlock() + + if c.udpControlStream != nil { + return c.udpControlStream, nil + } + + _, stream, err := c.openStreamWithReconnect() if err != nil { return nil, err } - // Send request + err = struc.Pack(stream, &clientRequest{ - UDP: true, + Type: clientRequestTypeUDPControl, }) if err != nil { _ = stream.Close() return nil, err } - // Read response - var sr serverResponse - err = struc.Unpack(stream, &sr) - if err != nil { - _ = stream.Close() - return nil, err - } - if !sr.OK { - _ = stream.Close() - return nil, fmt.Errorf("connection rejected: %s", sr.Message) + + c.udpControlStream = stream + return stream, err +} + +func (c *Client) DialUDP() (UDPConn, error) { + if c.udpDisabled { + return nil, errors.New("UDP is disabled by server side") } // Create a session in the map c.udpSessionMutex.Lock() + sessionID := c.nextUDPSessionID + c.nextUDPSessionID++ nCh := make(chan *udpMessage, 1024) // Store the current session map for CloseFunc below // to ensures that we are adding and removing sessions on the same map, // as reconnecting will reassign the map sessionMap := c.udpSessionMap - sessionMap[sr.UDPSessionID] = nCh + sessionMap[sessionID] = nCh c.udpSessionMutex.Unlock() pktConn := &quicPktConn{ - Session: session, - Stream: stream, + Session: c.quicSession, CloseFunc: func() { c.udpSessionMutex.Lock() - if ch, ok := sessionMap[sr.UDPSessionID]; ok { + if ch, ok := sessionMap[sessionID]; ok { close(ch) - delete(sessionMap, sr.UDPSessionID) + delete(sessionMap, sessionID) } c.udpSessionMutex.Unlock() + + // tell server to release this session + go func() { + udpControlStream, err := c.obtainsUDPControlStream() + if err != nil { + return + } + + c.udpControlStreamMux.Lock() + defer c.udpControlStreamMux.Unlock() + + err = struc.Pack(udpControlStream, &udpControlRequest{ + SessionID: sessionID, + Operation: udpControlRequestOperationReleaseSession, + }) + if err != nil { + return + } + }() }, - UDPSessionID: sr.UDPSessionID, + UDPSessionID: sessionID, MsgCh: nCh, } - go pktConn.Hold() return pktConn, nil } func (c *Client) Close() error { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() + c.udpControlStreamMux.Lock() + if c.udpControlStream != nil { + _ = c.udpControlStream.Close() + } + c.udpControlStreamMux.Unlock() err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "") c.closed = true return err @@ -327,24 +363,11 @@ type UDPConn interface { type quicPktConn struct { Session quic.Connection - Stream quic.Stream CloseFunc func() UDPSessionID uint32 MsgCh <-chan *udpMessage } -func (c *quicPktConn) Hold() { - // Hold the stream until it's closed - buf := make([]byte, 1024) - for { - _, err := c.Stream.Read(buf) - if err != nil { - break - } - } - _ = c.Close() -} - func (c *quicPktConn) ReadFrom() ([]byte, string, error) { msg := <-c.MsgCh if msg == nil { @@ -395,5 +418,5 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error { func (c *quicPktConn) Close() error { c.CloseFunc() - return c.Stream.Close() + return nil } diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index 7fd64d7..17a0b57 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -25,15 +25,27 @@ type clientHello struct { Auth []byte } +const ( + serverHelloStatusFailed = uint8(0) + serverHelloStatusOK = uint8(1) + serverHelloStatusTCPOnly = uint8(2) +) + type serverHello struct { - OK bool + Status uint8 Rate transmissionRate MessageLen uint16 `struc:"sizeof=Message"` Message string } +const ( + clientRequestTypeTCP = uint8(0) + clientRequestTypeUDPLegacy = uint8(1) + clientRequestTypeUDPControl = uint8(2) +) + type clientRequest struct { - UDP bool + Type uint8 HostLen uint16 `struc:"sizeof=Host"` Host string Port uint16 @@ -74,3 +86,14 @@ type udpMessageV2 struct { DataLen uint16 `struc:"sizeof=Data"` Data []byte } + +const ( + udpControlRequestOperationReleaseSession = uint8(1) +) + +type udpControlRequest struct { + SessionID uint32 + Operation uint8 + DataLen uint16 `struc:"sizeof=Data"` + Data []byte +} diff --git a/pkg/core/server.go b/pkg/core/server.go index b0c4156..9b85094 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -159,9 +159,16 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([] } // Auth ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) + status := serverHelloStatusFailed + if ok { + status = serverHelloStatusOK + if s.disableUDP { + status = serverHelloStatusTCPOnly + } + } // Response err = struc.Pack(stream, &serverHello{ - OK: ok, + Status: status, Rate: transmissionRate{ SendBPS: serverSendBPS, RecvBPS: serverRecvBPS, diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index e157ea5..1ac443b 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "encoding/base64" + "errors" + "math" "math/rand" "net" "strconv" @@ -108,18 +110,25 @@ func (c *serverClient) handleStream(stream quic.Stream) { if err != nil { return } - if !req.UDP { + switch req.Type { + case clientRequestTypeTCP: // TCP connection c.handleTCP(stream, req.Host, req.Port) - } else if !c.DisableUDP { - // UDP connection - c.handleUDP(stream) - } else { - // UDP disabled - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: "UDP disabled", - }) + case clientRequestTypeUDPLegacy: + if !c.DisableUDP { + // UDP connection + c.handleUDPLegacy(stream) + } else { + // UDP disabled + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "UDP disabled", + }) + } + case clientRequestTypeUDPControl: + if !c.DisableUDP { + c.handleUDPControlStream(stream) + } } } @@ -153,6 +162,10 @@ func (c *serverClient) handleMessage(msg []byte) { c.udpSessionMutex.RLock() conn, ok := c.udpSessionMap[dfMsg.SessionID] c.udpSessionMutex.RUnlock() + if !ok { + conn = c.handleUDPSessionCreate(udpMsg) + ok = conn != nil + } if ok { // Session found, send the message action, arg := acl.ActionDirect, "" @@ -304,7 +317,32 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) } -func (c *serverClient) handleUDP(stream quic.Stream) { +func (c *serverClient) handleUDPSessionCreate(udpMsg udpMessage) transport.PUDPConn { + conn, err := c.Transport.ListenUDP() + if err != nil { + c.CUDPErrorFunc(c.ClientAddr, c.Auth, udpMsg.SessionID, err) + return nil + } + + c.udpSessionMutex.Lock() + if origConn, ok := c.udpSessionMap[udpMsg.SessionID]; ok { + _ = origConn.Close() + } + c.udpSessionMap[udpMsg.SessionID] = conn + // effect same as udpMsg.SessionID >= c.nextUDPSessionID, but allows wrapping around + if udpMsg.SessionID-c.nextUDPSessionID < math.MaxUint32/2 { + c.nextUDPSessionID = udpMsg.SessionID + 1 + } + c.udpSessionMutex.Unlock() + + c.CUDPRequestFunc(c.ClientAddr, c.Auth, udpMsg.SessionID) + + go c.reverseRelayUDP(udpMsg.SessionID, conn) + + return conn +} + +func (c *serverClient) handleUDPLegacy(stream quic.Stream) { // Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it conn, err := c.Transport.ListenUDP() if err != nil { @@ -320,6 +358,9 @@ func (c *serverClient) handleUDP(stream quic.Stream) { var id uint32 c.udpSessionMutex.Lock() id = c.nextUDPSessionID + if origConn, ok := c.udpSessionMap[id]; ok { + _ = origConn.Close() + } c.udpSessionMap[id] = conn c.nextUDPSessionID += 1 c.udpSessionMutex.Unlock() @@ -335,52 +376,7 @@ func (c *serverClient) handleUDP(stream quic.Stream) { // Receive UDP packets, send them to the client go func() { - buf := make([]byte, udpBufferSize) - for { - n, rAddr, err := conn.ReadFromUDP(buf) - if n > 0 { - var msgBuf bytes.Buffer - if c.V2 { - msg := udpMessageV2{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - Data: buf[:n], - } - _ = struc.Pack(&msgBuf, &msg) - _ = c.CS.SendMessage(msgBuf.Bytes()) - } else { - msg := udpMessage{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - FragCount: 1, - Data: buf[:n], - } - // try no frag first - _ = struc.Pack(&msgBuf, &msg) - sendErr := c.CS.SendMessage(msgBuf.Bytes()) - if sendErr != nil { - if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok { - // need to frag - msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 - fragMsgs := fragUDPMessage(msg, int(errSize)) - for _, fragMsg := range fragMsgs { - msgBuf.Reset() - _ = struc.Pack(&msgBuf, &fragMsg) - _ = c.CS.SendMessage(msgBuf.Bytes()) - } - } - } - } - if c.DownCounter != nil { - c.DownCounter.Add(float64(n)) - } - } - if err != nil { - break - } - } + c.reverseRelayUDP(id, conn) _ = stream.Close() }() @@ -399,3 +395,83 @@ func (c *serverClient) handleUDP(stream quic.Stream) { delete(c.udpSessionMap, id) c.udpSessionMutex.Unlock() } + +func (c *serverClient) reverseRelayUDP(sessionID uint32, conn transport.PUDPConn) { + buf := make([]byte, udpBufferSize) + for { + n, rAddr, err := conn.ReadFromUDP(buf) + if n > 0 { + var msgBuf bytes.Buffer + if c.V2 { + msg := udpMessageV2{ + SessionID: sessionID, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + Data: buf[:n], + } + _ = struc.Pack(&msgBuf, &msg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } else { + msg := udpMessage{ + SessionID: sessionID, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + FragCount: 1, + Data: buf[:n], + } + // try no frag first + _ = struc.Pack(&msgBuf, &msg) + sendErr := c.CS.SendMessage(msgBuf.Bytes()) + if sendErr != nil { + if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } + } + } + } + if c.DownCounter != nil { + c.DownCounter.Add(float64(n)) + } + } + if err != nil { + break + } + } +} + +func (c *serverClient) handleUDPControlStream(stream quic.Stream) { + for { + var request udpControlRequest + err := struc.Unpack(stream, &request) + if err != nil { + break + } + switch request.Operation { + case udpControlRequestOperationReleaseSession: + c.udpSessionMutex.Lock() + conn, ok := c.udpSessionMap[request.SessionID] + if ok { + _ = conn.Close() + delete(c.udpSessionMap, request.SessionID) + } + c.udpSessionMutex.Unlock() + if ok { + c.CUDPErrorFunc(c.ClientAddr, c.Auth, request.SessionID, errors.New("UDP session released by client")) + } + } + } + // Clear all udp session if the control stream is closed + c.udpSessionMutex.Lock() + for sid, conn := range c.udpSessionMap { + _ = conn.Close() + delete(c.udpSessionMap, sid) + c.CUDPErrorFunc(c.ClientAddr, c.Auth, sid, errors.New("UDP session released for control stream closed")) + } + c.udpSessionMutex.Unlock() +}