hysteria/internal/core/client.go

192 lines
4.8 KiB
Go

package core
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils"
"io"
"net"
"sync"
"sync/atomic"
)
var (
ErrClosed = errors.New("client closed")
)
type Client struct {
inboundBytes, outboundBytes uint64 // atomic
reconnectMutex sync.Mutex
closed bool
quicSession quic.Session
serverAddr string
username, password string
tlsConfig *tls.Config
quicConfig *quic.Config
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
obfuscator Obfuscator
}
func NewClient(serverAddr string, username string, password string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) {
c := &Client{
serverAddr: serverAddr,
username: username,
password: password,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
sendBPS: sendBPS,
recvBPS: recvBPS,
congestionFactory: congestionFactory,
obfuscator: obfuscator,
}
if err := c.connectToServer(); err != nil {
return nil, err
}
return c, nil
}
func (c *Client) Dial(packet bool, addr string) (io.ReadWriteCloser, error) {
stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
req := &ClientConnectRequest{Address: addr}
if packet {
req.Type = ConnectionType_Packet
} else {
req.Type = ConnectionType_Stream
}
err = writeClientConnectRequest(stream, req)
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
resp, err := readServerConnectResponse(stream)
if err != nil {
_ = stream.Close()
return nil, err
}
if resp.Result != ConnectResult_CONN_SUCCESS {
_ = stream.Close()
return nil, fmt.Errorf("server rejected the connection %s (msg: %s)",
resp.Result.String(), resp.Message)
}
if packet {
return &utils.PacketReadWriteCloser{Orig: stream}, nil
} else {
return stream, nil
}
}
func (c *Client) Stats() (uint64, uint64) {
return atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes)
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic")
c.closed = true
return err
}
func (c *Client) connectToServer() error {
serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
return err
}
packetConn, err := net.ListenPacket("udp", "")
if err != nil {
return err
}
if c.obfuscator != nil {
// Wrap PacketConn with obfuscator
packetConn = &obfsPacketConn{
Orig: packetConn,
Obfuscator: c.obfuscator,
}
}
qs, err := quic.Dial(packetConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig)
if err != nil {
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := qs.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
return err
}
result, msg, err := c.handleControlStream(qs, ctlStream)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return err
}
if result != AuthResult_AUTH_SUCCESS {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure")
return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg)
}
// All good
c.quicSession = qs
return nil
}
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthResult, string, error) {
err := writeClientAuthRequest(stream, &ClientAuthRequest{
Credential: &Credential{
Username: c.username,
Password: c.password,
},
Speed: &Speed{
SendBps: c.sendBPS,
ReceiveBps: c.recvBPS,
},
})
if err != nil {
return 0, "", err
}
// Response
resp, err := readServerAuthResponse(stream)
if err != nil {
return 0, "", err
}
// Set the congestion accordingly
if resp.Result == AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
qs.SetCongestion(c.congestionFactory(resp.Speed.ReceiveBps))
}
return resp.Result, resp.Message, nil
}
func (c *Client) openStreamWithReconnect() (quic.Stream, error) {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
if c.closed {
return nil, ErrClosed
}
stream, err := c.quicSession.OpenStream()
if err == nil {
// All good
return stream, nil
}
// Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return
return nil, err
}
// Permanent error, need to reconnect
if err := c.connectToServer(); err != nil {
// Still error, oops
return nil, err
}
// We are not going to try again even if it still fails the second time
return c.quicSession.OpenStream()
}