diff --git a/pkg/core/client.go b/pkg/core/client.go index adc8548..63d6427 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -41,9 +41,14 @@ type Client struct { reconnectMutex sync.Mutex closed bool - udpSessionMutex sync.RWMutex - udpSessionMap map[uint32]chan *udpMessage - udpDefragger defragger + udpDisabled bool + udpSessionMutex sync.RWMutex + udpSessionMap map[uint32]chan *udpMessage + nextUDPSessionID uint32 + udpDefragger defragger + + udpControlStreamMux sync.Mutex + udpControlStream quic.Stream } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, @@ -121,11 +126,13 @@ func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bo if err != nil { return false, "", err } + ok := sh.Status != serverHelloStatusFailed + c.udpDisabled = sh.Status == serverHelloStatusTCPOnly // Set the congestion accordingly - if sh.OK && c.congestionFactory != nil { + if ok && c.congestionFactory != nil { qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS)) } - return sh.OK, sh.Message, nil + return ok, sh.Message, nil } func (c *Client) handleMessage(qs quic.Connection) { @@ -194,7 +201,7 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) { } // Send request err = struc.Pack(stream, &clientRequest{ - UDP: false, + Type: clientRequestTypeTCP, Host: host, Port: port, }) @@ -220,62 +227,91 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) { }, nil } -func (c *Client) DialUDP() (UDPConn, error) { - session, stream, err := c.openStreamWithReconnect() +func (c *Client) obtainsUDPControlStream() (quic.Stream, error) { + c.udpControlStreamMux.Lock() + defer c.udpControlStreamMux.Unlock() + + if c.udpControlStream != nil { + return c.udpControlStream, nil + } + + _, stream, err := c.openStreamWithReconnect() if err != nil { return nil, err } - // Send request + err = struc.Pack(stream, &clientRequest{ - UDP: true, + Type: clientRequestTypeUDPControl, }) if err != nil { _ = stream.Close() return nil, err } - // Read response - var sr serverResponse - err = struc.Unpack(stream, &sr) - if err != nil { - _ = stream.Close() - return nil, err - } - if !sr.OK { - _ = stream.Close() - return nil, fmt.Errorf("connection rejected: %s", sr.Message) + + c.udpControlStream = stream + return stream, err +} + +func (c *Client) DialUDP() (UDPConn, error) { + if c.udpDisabled { + return nil, errors.New("UDP is disabled by server side") } // Create a session in the map c.udpSessionMutex.Lock() + sessionID := c.nextUDPSessionID + c.nextUDPSessionID++ nCh := make(chan *udpMessage, 1024) // Store the current session map for CloseFunc below // to ensures that we are adding and removing sessions on the same map, // as reconnecting will reassign the map sessionMap := c.udpSessionMap - sessionMap[sr.UDPSessionID] = nCh + sessionMap[sessionID] = nCh c.udpSessionMutex.Unlock() pktConn := &quicPktConn{ - Session: session, - Stream: stream, + Session: c.quicSession, CloseFunc: func() { c.udpSessionMutex.Lock() - if ch, ok := sessionMap[sr.UDPSessionID]; ok { + if ch, ok := sessionMap[sessionID]; ok { close(ch) - delete(sessionMap, sr.UDPSessionID) + delete(sessionMap, sessionID) } c.udpSessionMutex.Unlock() + + // tell server to release this session + go func() { + udpControlStream, err := c.obtainsUDPControlStream() + if err != nil { + return + } + + c.udpControlStreamMux.Lock() + defer c.udpControlStreamMux.Unlock() + + err = struc.Pack(udpControlStream, &udpControlRequest{ + SessionID: sessionID, + Operation: udpControlRequestOperationReleaseSession, + }) + if err != nil { + return + } + }() }, - UDPSessionID: sr.UDPSessionID, + UDPSessionID: sessionID, MsgCh: nCh, } - go pktConn.Hold() return pktConn, nil } func (c *Client) Close() error { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() + c.udpControlStreamMux.Lock() + if c.udpControlStream != nil { + _ = c.udpControlStream.Close() + } + c.udpControlStreamMux.Unlock() err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "") c.closed = true return err @@ -327,24 +363,11 @@ type UDPConn interface { type quicPktConn struct { Session quic.Connection - Stream quic.Stream CloseFunc func() UDPSessionID uint32 MsgCh <-chan *udpMessage } -func (c *quicPktConn) Hold() { - // Hold the stream until it's closed - buf := make([]byte, 1024) - for { - _, err := c.Stream.Read(buf) - if err != nil { - break - } - } - _ = c.Close() -} - func (c *quicPktConn) ReadFrom() ([]byte, string, error) { msg := <-c.MsgCh if msg == nil { @@ -395,5 +418,5 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error { func (c *quicPktConn) Close() error { c.CloseFunc() - return c.Stream.Close() + return nil } diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index 7fd64d7..17a0b57 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -25,15 +25,27 @@ type clientHello struct { Auth []byte } +const ( + serverHelloStatusFailed = uint8(0) + serverHelloStatusOK = uint8(1) + serverHelloStatusTCPOnly = uint8(2) +) + type serverHello struct { - OK bool + Status uint8 Rate transmissionRate MessageLen uint16 `struc:"sizeof=Message"` Message string } +const ( + clientRequestTypeTCP = uint8(0) + clientRequestTypeUDPLegacy = uint8(1) + clientRequestTypeUDPControl = uint8(2) +) + type clientRequest struct { - UDP bool + Type uint8 HostLen uint16 `struc:"sizeof=Host"` Host string Port uint16 @@ -74,3 +86,14 @@ type udpMessageV2 struct { DataLen uint16 `struc:"sizeof=Data"` Data []byte } + +const ( + udpControlRequestOperationReleaseSession = uint8(1) +) + +type udpControlRequest struct { + SessionID uint32 + Operation uint8 + DataLen uint16 `struc:"sizeof=Data"` + Data []byte +} diff --git a/pkg/core/server.go b/pkg/core/server.go index b0c4156..9b85094 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -159,9 +159,16 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([] } // Auth ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) + status := serverHelloStatusFailed + if ok { + status = serverHelloStatusOK + if s.disableUDP { + status = serverHelloStatusTCPOnly + } + } // Response err = struc.Pack(stream, &serverHello{ - OK: ok, + Status: status, Rate: transmissionRate{ SendBPS: serverSendBPS, RecvBPS: serverRecvBPS, diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index e157ea5..1ac443b 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "encoding/base64" + "errors" + "math" "math/rand" "net" "strconv" @@ -108,18 +110,25 @@ func (c *serverClient) handleStream(stream quic.Stream) { if err != nil { return } - if !req.UDP { + switch req.Type { + case clientRequestTypeTCP: // TCP connection c.handleTCP(stream, req.Host, req.Port) - } else if !c.DisableUDP { - // UDP connection - c.handleUDP(stream) - } else { - // UDP disabled - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: "UDP disabled", - }) + case clientRequestTypeUDPLegacy: + if !c.DisableUDP { + // UDP connection + c.handleUDPLegacy(stream) + } else { + // UDP disabled + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "UDP disabled", + }) + } + case clientRequestTypeUDPControl: + if !c.DisableUDP { + c.handleUDPControlStream(stream) + } } } @@ -153,6 +162,10 @@ func (c *serverClient) handleMessage(msg []byte) { c.udpSessionMutex.RLock() conn, ok := c.udpSessionMap[dfMsg.SessionID] c.udpSessionMutex.RUnlock() + if !ok { + conn = c.handleUDPSessionCreate(udpMsg) + ok = conn != nil + } if ok { // Session found, send the message action, arg := acl.ActionDirect, "" @@ -304,7 +317,32 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) } -func (c *serverClient) handleUDP(stream quic.Stream) { +func (c *serverClient) handleUDPSessionCreate(udpMsg udpMessage) transport.PUDPConn { + conn, err := c.Transport.ListenUDP() + if err != nil { + c.CUDPErrorFunc(c.ClientAddr, c.Auth, udpMsg.SessionID, err) + return nil + } + + c.udpSessionMutex.Lock() + if origConn, ok := c.udpSessionMap[udpMsg.SessionID]; ok { + _ = origConn.Close() + } + c.udpSessionMap[udpMsg.SessionID] = conn + // effect same as udpMsg.SessionID >= c.nextUDPSessionID, but allows wrapping around + if udpMsg.SessionID-c.nextUDPSessionID < math.MaxUint32/2 { + c.nextUDPSessionID = udpMsg.SessionID + 1 + } + c.udpSessionMutex.Unlock() + + c.CUDPRequestFunc(c.ClientAddr, c.Auth, udpMsg.SessionID) + + go c.reverseRelayUDP(udpMsg.SessionID, conn) + + return conn +} + +func (c *serverClient) handleUDPLegacy(stream quic.Stream) { // Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it conn, err := c.Transport.ListenUDP() if err != nil { @@ -320,6 +358,9 @@ func (c *serverClient) handleUDP(stream quic.Stream) { var id uint32 c.udpSessionMutex.Lock() id = c.nextUDPSessionID + if origConn, ok := c.udpSessionMap[id]; ok { + _ = origConn.Close() + } c.udpSessionMap[id] = conn c.nextUDPSessionID += 1 c.udpSessionMutex.Unlock() @@ -335,52 +376,7 @@ func (c *serverClient) handleUDP(stream quic.Stream) { // Receive UDP packets, send them to the client go func() { - buf := make([]byte, udpBufferSize) - for { - n, rAddr, err := conn.ReadFromUDP(buf) - if n > 0 { - var msgBuf bytes.Buffer - if c.V2 { - msg := udpMessageV2{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - Data: buf[:n], - } - _ = struc.Pack(&msgBuf, &msg) - _ = c.CS.SendMessage(msgBuf.Bytes()) - } else { - msg := udpMessage{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - FragCount: 1, - Data: buf[:n], - } - // try no frag first - _ = struc.Pack(&msgBuf, &msg) - sendErr := c.CS.SendMessage(msgBuf.Bytes()) - if sendErr != nil { - if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok { - // need to frag - msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 - fragMsgs := fragUDPMessage(msg, int(errSize)) - for _, fragMsg := range fragMsgs { - msgBuf.Reset() - _ = struc.Pack(&msgBuf, &fragMsg) - _ = c.CS.SendMessage(msgBuf.Bytes()) - } - } - } - } - if c.DownCounter != nil { - c.DownCounter.Add(float64(n)) - } - } - if err != nil { - break - } - } + c.reverseRelayUDP(id, conn) _ = stream.Close() }() @@ -399,3 +395,83 @@ func (c *serverClient) handleUDP(stream quic.Stream) { delete(c.udpSessionMap, id) c.udpSessionMutex.Unlock() } + +func (c *serverClient) reverseRelayUDP(sessionID uint32, conn transport.PUDPConn) { + buf := make([]byte, udpBufferSize) + for { + n, rAddr, err := conn.ReadFromUDP(buf) + if n > 0 { + var msgBuf bytes.Buffer + if c.V2 { + msg := udpMessageV2{ + SessionID: sessionID, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + Data: buf[:n], + } + _ = struc.Pack(&msgBuf, &msg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } else { + msg := udpMessage{ + SessionID: sessionID, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + FragCount: 1, + Data: buf[:n], + } + // try no frag first + _ = struc.Pack(&msgBuf, &msg) + sendErr := c.CS.SendMessage(msgBuf.Bytes()) + if sendErr != nil { + if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } + } + } + } + if c.DownCounter != nil { + c.DownCounter.Add(float64(n)) + } + } + if err != nil { + break + } + } +} + +func (c *serverClient) handleUDPControlStream(stream quic.Stream) { + for { + var request udpControlRequest + err := struc.Unpack(stream, &request) + if err != nil { + break + } + switch request.Operation { + case udpControlRequestOperationReleaseSession: + c.udpSessionMutex.Lock() + conn, ok := c.udpSessionMap[request.SessionID] + if ok { + _ = conn.Close() + delete(c.udpSessionMap, request.SessionID) + } + c.udpSessionMutex.Unlock() + if ok { + c.CUDPErrorFunc(c.ClientAddr, c.Auth, request.SessionID, errors.New("UDP session released by client")) + } + } + } + // Clear all udp session if the control stream is closed + c.udpSessionMutex.Lock() + for sid, conn := range c.udpSessionMap { + _ = conn.Close() + delete(c.udpSessionMap, sid) + c.CUDPErrorFunc(c.ClientAddr, c.Auth, sid, errors.New("UDP session released for control stream closed")) + } + c.udpSessionMutex.Unlock() +}