sing-quic/hysteria/client.go
2024-02-02 14:24:25 +08:00

366 lines
8.9 KiB
Go

package hysteria
import (
"context"
"io"
"math"
"net"
"os"
"runtime"
"sync"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-quic"
hyCC "github.com/sagernet/sing-quic/hysteria/congestion"
"github.com/sagernet/sing/common/baderror"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
)
type ClientOptions struct {
Context context.Context
Dialer N.Dialer
Logger logger.Logger
BrutalDebug bool
ServerAddress M.Socksaddr
SendBPS uint64
ReceiveBPS uint64
XPlusPassword string
Password string
TLSConfig aTLS.Config
UDPDisabled bool
// Legacy options
ConnReceiveWindow uint64
StreamReceiveWindow uint64
DisableMTUDiscovery bool
}
type Client struct {
ctx context.Context
dialer N.Dialer
logger logger.Logger
brutalDebug bool
serverAddr M.Socksaddr
sendBPS uint64
receiveBPS uint64
xplusPassword string
password string
tlsConfig aTLS.Config
quicConfig *quic.Config
udpDisabled bool
connAccess sync.RWMutex
conn *clientQUICConnection
}
func NewClient(options ClientOptions) (*Client, error) {
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
EnableDatagrams: true,
InitialStreamReceiveWindow: DefaultStreamReceiveWindow,
MaxStreamReceiveWindow: DefaultStreamReceiveWindow,
InitialConnectionReceiveWindow: DefaultConnReceiveWindow,
MaxConnectionReceiveWindow: DefaultConnReceiveWindow,
MaxIdleTimeout: DefaultMaxIdleTimeout,
KeepAlivePeriod: DefaultKeepAlivePeriod,
}
if options.StreamReceiveWindow != 0 {
quicConfig.InitialStreamReceiveWindow = options.StreamReceiveWindow
quicConfig.MaxStreamReceiveWindow = options.StreamReceiveWindow
}
if options.ConnReceiveWindow != 0 {
quicConfig.InitialConnectionReceiveWindow = options.ConnReceiveWindow
quicConfig.MaxConnectionReceiveWindow = options.ConnReceiveWindow
}
if options.DisableMTUDiscovery {
quicConfig.DisablePathMTUDiscovery = true
}
if len(options.TLSConfig.NextProtos()) == 0 {
options.TLSConfig.SetNextProtos([]string{DefaultALPN})
}
if options.SendBPS == 0 {
return nil, E.New("missing upload speed")
} else if options.SendBPS < MinSpeedBPS {
return nil, E.New("invalid upload speed")
}
if options.ReceiveBPS == 0 {
return nil, E.New("missing download speed")
} else if options.ReceiveBPS < MinSpeedBPS {
return nil, E.New("invalid download speed")
}
return &Client{
ctx: options.Context,
dialer: options.Dialer,
logger: options.Logger,
brutalDebug: options.BrutalDebug,
serverAddr: options.ServerAddress,
sendBPS: options.SendBPS,
receiveBPS: options.ReceiveBPS,
xplusPassword: options.XPlusPassword,
password: options.Password,
tlsConfig: options.TLSConfig,
quicConfig: quicConfig,
udpDisabled: options.UDPDisabled,
}, nil
}
func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
conn := c.conn
if conn != nil && conn.active() {
return conn, nil
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
conn = c.conn
if conn != nil && conn.active() {
return conn, nil
}
conn, err := c.offerNew(ctx)
if err != nil {
return nil, err
}
return conn, nil
}
func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
if err != nil {
return nil, err
}
var packetConn net.PacketConn
packetConn = bufio.NewUnbindPacketConn(udpConn)
if c.xplusPassword != "" {
packetConn = NewXPlusPacketConn(packetConn, []byte(c.xplusPassword))
}
quicConn, err := qtls.Dial(c.ctx, packetConn, udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
if err != nil {
udpConn.Close()
return nil, err
}
controlStream, err := quicConn.OpenStreamSync(ctx)
if err != nil {
packetConn.Close()
return nil, err
}
err = WriteClientHello(controlStream, ClientHello{
SendBPS: c.sendBPS,
RecvBPS: c.receiveBPS,
Auth: c.password,
})
if err != nil {
packetConn.Close()
return nil, err
}
serverHello, err := ReadServerHello(controlStream)
if err != nil {
packetConn.Close()
return nil, err
}
if !serverHello.OK {
packetConn.Close()
return nil, E.New("remote error: ", serverHello.Message)
}
quicConn.SetCongestionControl(hyCC.NewBrutalSender(uint64(math.Min(float64(serverHello.RecvBPS), float64(c.sendBPS))), c.brutalDebug, c.logger))
conn := &clientQUICConnection{
quicConn: quicConn,
rawConn: udpConn,
connDone: make(chan struct{}),
udpDisabled: !quicConn.ConnectionState().SupportsDatagrams,
udpConnMap: make(map[uint32]*udpPacketConn),
}
if !c.udpDisabled {
go c.loopMessages(conn)
}
c.conn = conn
return conn, nil
}
func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
conn, err := c.offer(ctx)
if err != nil {
return nil, err
}
stream, err := conn.quicConn.OpenStream()
if err != nil {
return nil, err
}
return &clientConn{
Stream: stream,
destination: destination,
}, nil
}
func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if c.udpDisabled {
return nil, os.ErrInvalid
}
conn, err := c.offer(ctx)
if err != nil {
return nil, err
}
if conn.udpDisabled {
return nil, E.New("UDP disabled by server")
}
stream, err := conn.quicConn.OpenStream()
if err != nil {
return nil, err
}
buffer := WriteClientRequest(ClientRequest{
UDP: true,
Host: destination.AddrString(),
Port: destination.Port,
}, nil)
_, err = stream.Write(buffer.Bytes())
buffer.Release()
if err != nil {
stream.Close()
return nil, err
}
response, err := ReadServerResponse(stream)
if err != nil {
stream.Close()
return nil, err
}
if !response.OK {
stream.Close()
return nil, E.New("remote error: ", response.Message)
}
clientPacketConn := newUDPPacketConn(c.ctx, conn.quicConn, func() {
stream.CancelRead(0)
stream.Close()
conn.udpAccess.Lock()
delete(conn.udpConnMap, response.UDPSessionID)
conn.udpAccess.Unlock()
})
conn.udpAccess.Lock()
if debug.Enabled {
if _, connExists := conn.udpConnMap[response.UDPSessionID]; connExists {
stream.Close()
return nil, E.New("udp session id duplicated")
}
}
conn.udpConnMap[response.UDPSessionID] = clientPacketConn
conn.udpAccess.Unlock()
clientPacketConn.sessionID = response.UDPSessionID
go func() {
holdBuffer := make([]byte, 1024)
for {
_, hErr := stream.Read(holdBuffer)
if hErr != nil {
break
}
}
clientPacketConn.closeWithError(E.Cause(net.ErrClosed, "hold stream closed"))
}()
return clientPacketConn, nil
}
func (c *Client) CloseWithError(err error) error {
conn := c.conn
if conn != nil {
conn.closeWithError(err)
}
return nil
}
type clientQUICConnection struct {
quicConn quic.Connection
rawConn io.Closer
closeOnce sync.Once
connDone chan struct{}
connErr error
udpDisabled bool
udpAccess sync.RWMutex
udpConnMap map[uint32]*udpPacketConn
}
func (c *clientQUICConnection) active() bool {
select {
case <-c.quicConn.Context().Done():
return false
default:
}
select {
case <-c.connDone:
return false
default:
}
return true
}
func (c *clientQUICConnection) closeWithError(err error) {
c.closeOnce.Do(func() {
c.connErr = err
close(c.connDone)
_ = c.quicConn.CloseWithError(0, "")
_ = c.rawConn.Close()
})
}
type clientConn struct {
quic.Stream
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func (c *clientConn) NeedHandshake() bool {
return !c.requestWritten
}
func (c *clientConn) Read(p []byte) (n int, err error) {
if c.responseRead {
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}
response, err := ReadServerResponse(c.Stream)
if err != nil {
return 0, baderror.WrapQUIC(err)
}
if !response.OK {
err = E.New("remote error: ", response.Message)
return
}
c.responseRead = true
n, err = c.Stream.Read(p)
return n, baderror.WrapQUIC(err)
}
func (c *clientConn) Write(p []byte) (n int, err error) {
if !c.requestWritten {
buffer := WriteClientRequest(ClientRequest{
UDP: false,
Host: c.destination.AddrString(),
Port: c.destination.Port,
}, p)
defer buffer.Release()
_, err = c.Stream.Write(buffer.Bytes())
if err != nil {
return
}
c.requestWritten = true
return len(p), nil
}
n, err = c.Stream.Write(p)
return n, baderror.WrapQUIC(err)
}
func (c *clientConn) LocalAddr() net.Addr {
return M.Socksaddr{}
}
func (c *clientConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}
func (c *clientConn) Close() error {
c.Stream.CancelRead(0)
return c.Stream.Close()
}