package core import ( "bytes" "context" "crypto/tls" "errors" "fmt" "math/rand" "net" "strconv" "sync" "time" "github.com/HyNetwork/hysteria/pkg/obfs" "github.com/HyNetwork/hysteria/pkg/pmtud_fix" "github.com/HyNetwork/hysteria/pkg/transport" "github.com/HyNetwork/hysteria/pkg/utils" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/congestion" "github.com/lunixbochs/struc" ) var ErrClosed = errors.New("closed") type CongestionFactory func(refBPS uint64) congestion.CongestionControl type Client struct { transport *transport.ClientTransport serverAddr string protocol string sendBPS, recvBPS uint64 auth []byte congestionFactory CongestionFactory obfuscator obfs.Obfuscator tlsConfig *tls.Config quicConfig *quic.Config quicSession quic.Connection reconnectMutex sync.Mutex closed bool udpDisabled bool udpSessionMutex sync.RWMutex udpSessionMap map[uint32]chan *udpMessage nextUDPSessionID uint32 udpDefragger defragger udpControlStreamMux sync.Mutex udpControlStream quic.Stream } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator obfs.Obfuscator, ) (*Client, error) { quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery c := &Client{ transport: transport, serverAddr: serverAddr, protocol: protocol, sendBPS: sendBPS, recvBPS: recvBPS, auth: auth, congestionFactory: congestionFactory, obfuscator: obfuscator, tlsConfig: tlsConfig, quicConfig: quicConfig, } if err := c.connectToServer(); err != nil { return nil, err } return c, nil } func (c *Client) connectToServer() error { qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator) if err != nil { return err } // Control stream ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) stream, err := qs.OpenStreamSync(ctx) ctxCancel() if err != nil { _ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error") return err } ok, msg, err := c.handleControlStream(qs, stream) if err != nil { _ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error") return err } if !ok { _ = qs.CloseWithError(closeErrorCodeAuth, "auth error") return fmt.Errorf("auth error: %s", msg) } // All good c.udpSessionMap = make(map[uint32]chan *udpMessage) go c.handleMessage(qs) c.quicSession = qs return nil } func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bool, string, error) { // Send protocol version _, err := stream.Write([]byte{protocolVersion}) if err != nil { return false, "", err } // Send client hello err = struc.Pack(stream, &clientHello{ Rate: transmissionRate{ SendBPS: c.sendBPS, RecvBPS: c.recvBPS, }, Auth: c.auth, }) if err != nil { return false, "", err } // Receive server hello var sh serverHello err = struc.Unpack(stream, &sh) if err != nil { return false, "", err } ok := sh.Status != serverHelloStatusFailed c.udpDisabled = sh.Status == serverHelloStatusTCPOnly // Set the congestion accordingly if ok && c.congestionFactory != nil { qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS)) } return ok, sh.Message, nil } func (c *Client) handleMessage(qs quic.Connection) { for { msg, err := qs.ReceiveMessage() if err != nil { break } var udpMsg udpMessage err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg) if err != nil { continue } dfMsg := c.udpDefragger.Feed(udpMsg) if dfMsg == nil { continue } c.udpSessionMutex.RLock() ch, ok := c.udpSessionMap[dfMsg.SessionID] if ok { select { case ch <- dfMsg: // OK default: // Silently drop the message when the channel is full } } c.udpSessionMutex.RUnlock() } } func (c *Client) openStreamWithReconnect() (quic.Connection, quic.Stream, error) { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() if c.closed { return nil, nil, ErrClosed } stream, err := c.quicSession.OpenStream() if err == nil { // All good return c.quicSession, &wrappedQUICStream{stream}, nil } // Something is wrong if nErr, ok := err.(net.Error); ok && nErr.Temporary() { // Temporary error, just return return nil, nil, err } // Permanent error, need to reconnect if err := c.connectToServer(); err != nil { // Still error, oops return nil, nil, err } // We are not going to try again even if it still fails the second time stream, err = c.quicSession.OpenStream() return c.quicSession, &wrappedQUICStream{stream}, err } func (c *Client) DialTCP(addr string) (net.Conn, error) { host, port, err := utils.SplitHostPort(addr) if err != nil { return nil, err } session, stream, err := c.openStreamWithReconnect() if err != nil { return nil, err } // Send request err = struc.Pack(stream, &clientRequest{ Type: clientRequestTypeTCP, Host: host, Port: port, }) if err != nil { _ = stream.Close() return nil, err } // Read response var sr serverResponse err = struc.Unpack(stream, &sr) if err != nil { _ = stream.Close() return nil, err } if !sr.OK { _ = stream.Close() return nil, fmt.Errorf("connection rejected: %s", sr.Message) } return &quicConn{ Orig: stream, PseudoLocalAddr: session.LocalAddr(), PseudoRemoteAddr: session.RemoteAddr(), }, nil } func (c *Client) obtainsUDPControlStream() (quic.Stream, error) { c.udpControlStreamMux.Lock() defer c.udpControlStreamMux.Unlock() if c.udpControlStream != nil { return c.udpControlStream, nil } _, stream, err := c.openStreamWithReconnect() if err != nil { return nil, err } err = struc.Pack(stream, &clientRequest{ Type: clientRequestTypeUDPControl, }) if err != nil { _ = stream.Close() return nil, err } c.udpControlStream = stream return stream, err } func (c *Client) DialUDP() (UDPConn, error) { if c.udpDisabled { return nil, errors.New("UDP is disabled by server side") } // Create a session in the map c.udpSessionMutex.Lock() sessionID := c.nextUDPSessionID c.nextUDPSessionID++ nCh := make(chan *udpMessage, 1024) // Store the current session map for CloseFunc below // to ensures that we are adding and removing sessions on the same map, // as reconnecting will reassign the map sessionMap := c.udpSessionMap sessionMap[sessionID] = nCh c.udpSessionMutex.Unlock() pktConn := &quicPktConn{ Session: c.quicSession, CloseFunc: func() { c.udpSessionMutex.Lock() if ch, ok := sessionMap[sessionID]; ok { close(ch) delete(sessionMap, sessionID) } c.udpSessionMutex.Unlock() // tell server to release this session go func() { udpControlStream, err := c.obtainsUDPControlStream() if err != nil { return } c.udpControlStreamMux.Lock() defer c.udpControlStreamMux.Unlock() err = struc.Pack(udpControlStream, &udpControlRequest{ SessionID: sessionID, Operation: udpControlRequestOperationReleaseSession, }) if err != nil { return } }() }, UDPSessionID: sessionID, MsgCh: nCh, } return pktConn, nil } func (c *Client) Close() error { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() c.udpControlStreamMux.Lock() if c.udpControlStream != nil { _ = c.udpControlStream.Close() } c.udpControlStreamMux.Unlock() err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "") c.closed = true return err } type quicConn struct { Orig quic.Stream PseudoLocalAddr net.Addr PseudoRemoteAddr net.Addr } func (w *quicConn) Read(b []byte) (n int, err error) { return w.Orig.Read(b) } func (w *quicConn) Write(b []byte) (n int, err error) { return w.Orig.Write(b) } func (w *quicConn) Close() error { return w.Orig.Close() } func (w *quicConn) LocalAddr() net.Addr { return w.PseudoLocalAddr } func (w *quicConn) RemoteAddr() net.Addr { return w.PseudoRemoteAddr } func (w *quicConn) SetDeadline(t time.Time) error { return w.Orig.SetDeadline(t) } func (w *quicConn) SetReadDeadline(t time.Time) error { return w.Orig.SetReadDeadline(t) } func (w *quicConn) SetWriteDeadline(t time.Time) error { return w.Orig.SetWriteDeadline(t) } type UDPConn interface { ReadFrom() ([]byte, string, error) WriteTo([]byte, string) error Close() error } type quicPktConn struct { Session quic.Connection CloseFunc func() UDPSessionID uint32 MsgCh <-chan *udpMessage } func (c *quicPktConn) ReadFrom() ([]byte, string, error) { msg := <-c.MsgCh if msg == nil { // Closed return nil, "", ErrClosed } return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil } func (c *quicPktConn) WriteTo(p []byte, addr string) error { host, port, err := utils.SplitHostPort(addr) if err != nil { return err } msg := udpMessage{ SessionID: c.UDPSessionID, Host: host, Port: port, FragCount: 1, Data: p, } // try no frag first var msgBuf bytes.Buffer _ = struc.Pack(&msgBuf, &msg) err = c.Session.SendMessage(msgBuf.Bytes()) if err != nil { if errSize, ok := err.(quic.ErrMessageToLarge); ok { // need to frag msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 fragMsgs := fragUDPMessage(msg, int(errSize)) for _, fragMsg := range fragMsgs { msgBuf.Reset() _ = struc.Pack(&msgBuf, &fragMsg) err = c.Session.SendMessage(msgBuf.Bytes()) if err != nil { return err } } return nil } else { // some other error return err } } else { return nil } } func (c *quicPktConn) Close() error { c.CloseFunc() return nil }