From fa3d997246b6747b7b6f432b0b3c4f217d76b05f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Oct 2023 12:00:00 +0800 Subject: [PATCH] Migrate legacy hysteria protocol --- hysteria/client.go | 365 ++++++++++++++ hysteria/client_paclet.go | 47 ++ {hysteria2 => hysteria}/congestion/brutal.go | 0 {hysteria2 => hysteria}/congestion/pacer.go | 0 hysteria/packet.go | 475 +++++++++++++++++++ hysteria/protocol.go | 253 ++++++++++ hysteria/service.go | 376 +++++++++++++++ hysteria/service_packet.go | 43 ++ hysteria/xplus.go | 118 +++++ hysteria2/client.go | 26 +- hysteria2/salamander.go | 20 +- hysteria2/service.go | 30 +- tuic/service.go | 6 +- 13 files changed, 1714 insertions(+), 45 deletions(-) create mode 100644 hysteria/client.go create mode 100644 hysteria/client_paclet.go rename {hysteria2 => hysteria}/congestion/brutal.go (100%) rename {hysteria2 => hysteria}/congestion/pacer.go (100%) create mode 100644 hysteria/packet.go create mode 100644 hysteria/protocol.go create mode 100644 hysteria/service.go create mode 100644 hysteria/service_packet.go create mode 100644 hysteria/xplus.go diff --git a/hysteria/client.go b/hysteria/client.go new file mode 100644 index 0000000..b599922 --- /dev/null +++ b/hysteria/client.go @@ -0,0 +1,365 @@ +package hysteria + +import ( + "context" + "io" + "math" + "net" + "os" + "runtime" + "sync" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-quic" + hyCC "github.com/sagernet/sing-quic/hysteria/congestion" + "github.com/sagernet/sing/common/baderror" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/debug" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" +) + +type ClientOptions struct { + Context context.Context + Dialer N.Dialer + Logger logger.Logger + BrutalDebug bool + ServerAddress M.Socksaddr + SendBPS uint64 + ReceiveBPS uint64 + XPlusPassword string + Password string + TLSConfig aTLS.Config + UDPDisabled bool + + // Legacy options + + ConnReceiveWindow uint64 + StreamReceiveWindow uint64 + DisableMTUDiscovery bool +} + +type Client struct { + ctx context.Context + dialer N.Dialer + logger logger.Logger + brutalDebug bool + serverAddr M.Socksaddr + sendBPS uint64 + receiveBPS uint64 + xplusPassword string + password string + tlsConfig aTLS.Config + quicConfig *quic.Config + udpDisabled bool + + connAccess sync.RWMutex + conn *clientQUICConnection +} + +func NewClient(options ClientOptions) (*Client, error) { + quicConfig := &quic.Config{ + DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), + EnableDatagrams: true, + InitialStreamReceiveWindow: DefaultStreamReceiveWindow, + MaxStreamReceiveWindow: DefaultStreamReceiveWindow, + InitialConnectionReceiveWindow: DefaultConnReceiveWindow, + MaxConnectionReceiveWindow: DefaultConnReceiveWindow, + MaxIdleTimeout: DefaultMaxIdleTimeout, + KeepAlivePeriod: DefaultKeepAlivePeriod, + } + if options.StreamReceiveWindow != 0 { + quicConfig.InitialStreamReceiveWindow = options.StreamReceiveWindow + quicConfig.MaxStreamReceiveWindow = options.StreamReceiveWindow + } + if options.ConnReceiveWindow != 0 { + quicConfig.InitialConnectionReceiveWindow = options.ConnReceiveWindow + quicConfig.MaxConnectionReceiveWindow = options.ConnReceiveWindow + } + if options.DisableMTUDiscovery { + quicConfig.DisablePathMTUDiscovery = true + } + if len(options.TLSConfig.NextProtos()) == 0 { + options.TLSConfig.SetNextProtos([]string{DefaultALPN}) + } + if options.SendBPS == 0 { + return nil, E.New("missing upload speed") + } else if options.SendBPS < MinSpeedBPS { + return nil, E.New("invalid upload speed") + } + if options.ReceiveBPS == 0 { + return nil, E.New("missing download speed") + } else if options.ReceiveBPS < MinSpeedBPS { + return nil, E.New("invalid download speed") + } + return &Client{ + ctx: options.Context, + dialer: options.Dialer, + logger: options.Logger, + brutalDebug: options.BrutalDebug, + serverAddr: options.ServerAddress, + sendBPS: options.SendBPS, + receiveBPS: options.ReceiveBPS, + xplusPassword: options.XPlusPassword, + password: options.Password, + tlsConfig: options.TLSConfig, + quicConfig: quicConfig, + udpDisabled: options.UDPDisabled, + }, nil +} + +func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { + conn := c.conn + if conn != nil && conn.active() { + return conn, nil + } + c.connAccess.Lock() + defer c.connAccess.Unlock() + conn = c.conn + if conn != nil && conn.active() { + return conn, nil + } + conn, err := c.offerNew(ctx) + if err != nil { + return nil, err + } + return conn, nil +} + +func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { + udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr) + if err != nil { + return nil, err + } + var packetConn net.PacketConn + packetConn = bufio.NewUnbindPacketConn(udpConn) + if c.xplusPassword != "" { + packetConn = NewXPlusPacketConn(packetConn, []byte(c.xplusPassword)) + } + quicConn, err := qtls.Dial(c.ctx, packetConn, udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) + if err != nil { + udpConn.Close() + return nil, err + } + controlStream, err := quicConn.OpenStreamSync(ctx) + if err != nil { + packetConn.Close() + return nil, err + } + err = WriteClientHello(controlStream, ClientHello{ + SendBPS: c.sendBPS, + RecvBPS: c.receiveBPS, + Auth: c.password, + }) + if err != nil { + packetConn.Close() + return nil, err + } + serverHello, err := ReadServerHello(controlStream) + if err != nil { + packetConn.Close() + return nil, err + } + if !serverHello.OK { + packetConn.Close() + return nil, E.New("remote error: ", serverHello.Message) + } + quicConn.SetCongestionControl(hyCC.NewBrutalSender(uint64(math.Min(float64(serverHello.RecvBPS), float64(c.sendBPS))), c.brutalDebug, c.logger)) + conn := &clientQUICConnection{ + quicConn: quicConn, + rawConn: udpConn, + connDone: make(chan struct{}), + udpDisabled: !quicConn.ConnectionState().SupportsDatagrams, + udpConnMap: make(map[uint32]*udpPacketConn), + } + if !c.udpDisabled { + go c.loopMessages(conn) + } + c.conn = conn + return conn, nil +} + +func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + conn, err := c.offer(ctx) + if err != nil { + return nil, err + } + stream, err := conn.quicConn.OpenStream() + if err != nil { + return nil, err + } + return &clientConn{ + Stream: stream, + destination: destination, + }, nil +} + +func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if c.udpDisabled { + return nil, os.ErrInvalid + } + conn, err := c.offer(ctx) + if err != nil { + return nil, err + } + if conn.udpDisabled { + return nil, E.New("UDP disabled by server") + } + stream, err := conn.quicConn.OpenStream() + if err != nil { + return nil, err + } + buffer := WriteClientRequest(ClientRequest{ + UDP: true, + Host: destination.AddrString(), + Port: destination.Port, + }, nil) + _, err = stream.Write(buffer.Bytes()) + buffer.Release() + if err != nil { + stream.Close() + return nil, err + } + response, err := ReadServerResponse(stream) + if err != nil { + stream.Close() + return nil, err + } + if !response.OK { + stream.Close() + return nil, E.New("remote error: ", response.Message) + } + clientPacketConn := newUDPPacketConn(c.ctx, conn.quicConn, func() { + stream.CancelRead(0) + stream.Close() + conn.udpAccess.Lock() + delete(conn.udpConnMap, response.UDPSessionID) + conn.udpAccess.Unlock() + }) + conn.udpAccess.Lock() + if debug.Enabled { + if _, connExists := conn.udpConnMap[response.UDPSessionID]; connExists { + stream.Close() + return nil, E.New("udp session id duplicated") + } + } + conn.udpConnMap[response.UDPSessionID] = clientPacketConn + conn.udpAccess.Unlock() + clientPacketConn.sessionID = response.UDPSessionID + go func() { + holdBuffer := make([]byte, 1024) + for { + _, hErr := stream.Read(holdBuffer) + if hErr != nil { + break + } + } + clientPacketConn.closeWithError(E.Cause(net.ErrClosed, "hold stream closed")) + }() + return clientPacketConn, nil +} + +func (c *Client) CloseWithError(err error) error { + conn := c.conn + if conn != nil { + conn.closeWithError(err) + } + return nil +} + +type clientQUICConnection struct { + quicConn quic.Connection + rawConn io.Closer + closeOnce sync.Once + connDone chan struct{} + connErr error + udpDisabled bool + udpAccess sync.RWMutex + udpConnMap map[uint32]*udpPacketConn +} + +func (c *clientQUICConnection) active() bool { + select { + case <-c.quicConn.Context().Done(): + return false + default: + } + select { + case <-c.connDone: + return false + default: + } + return true +} + +func (c *clientQUICConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + c.connErr = err + close(c.connDone) + c.quicConn.CloseWithError(0, "") + }) +} + +type clientConn struct { + quic.Stream + destination M.Socksaddr + requestWritten bool + responseRead bool +} + +func (c *clientConn) NeedHandshake() bool { + return !c.requestWritten +} + +func (c *clientConn) Read(p []byte) (n int, err error) { + if c.responseRead { + n, err = c.Stream.Read(p) + return n, baderror.WrapQUIC(err) + } + response, err := ReadServerResponse(c.Stream) + if err != nil { + return 0, baderror.WrapQUIC(err) + } + if !response.OK { + err = E.New("remote error: ", response.Message) + return + } + c.responseRead = true + n, err = c.Stream.Read(p) + return n, baderror.WrapQUIC(err) +} + +func (c *clientConn) Write(p []byte) (n int, err error) { + if !c.requestWritten { + buffer := WriteClientRequest(ClientRequest{ + UDP: false, + Host: c.destination.AddrString(), + Port: c.destination.Port, + }, p) + defer buffer.Release() + _, err = c.Stream.Write(buffer.Bytes()) + if err != nil { + return + } + c.requestWritten = true + return len(p), nil + } + n, err = c.Stream.Write(p) + return n, baderror.WrapQUIC(err) +} + +func (c *clientConn) LocalAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *clientConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *clientConn) Close() error { + c.Stream.CancelRead(0) + return c.Stream.Close() +} diff --git a/hysteria/client_paclet.go b/hysteria/client_paclet.go new file mode 100644 index 0000000..a77e9ff --- /dev/null +++ b/hysteria/client_paclet.go @@ -0,0 +1,47 @@ +package hysteria + +import E "github.com/sagernet/sing/common/exceptions" + +func (c *Client) loopMessages(conn *clientQUICConnection) { + for { + message, err := conn.quicConn.ReceiveMessage(c.ctx) + if err != nil { + conn.closeWithError(E.Cause(err, "receive message")) + return + } + go func() { + hErr := c.handleMessage(conn, message) + if hErr != nil { + conn.closeWithError(E.Cause(hErr, "handle message")) + } + }() + } +} + +func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { + message := allocMessage() + err := decodeUDPMessage(message, data) + if err != nil { + message.release() + return E.Cause(err, "decode UDP message") + } + conn.handleUDPMessage(message) + return nil +} + +func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { + c.udpAccess.RLock() + udpConn, loaded := c.udpConnMap[message.sessionID] + c.udpAccess.RUnlock() + if !loaded { + message.releaseMessage() + return + } + select { + case <-udpConn.ctx.Done(): + message.releaseMessage() + return + default: + } + udpConn.inputPacket(message) +} diff --git a/hysteria2/congestion/brutal.go b/hysteria/congestion/brutal.go similarity index 100% rename from hysteria2/congestion/brutal.go rename to hysteria/congestion/brutal.go diff --git a/hysteria2/congestion/pacer.go b/hysteria/congestion/pacer.go similarity index 100% rename from hysteria2/congestion/pacer.go rename to hysteria/congestion/pacer.go diff --git a/hysteria/packet.go b/hysteria/packet.go new file mode 100644 index 0000000..b742c4a --- /dev/null +++ b/hysteria/packet.go @@ -0,0 +1,475 @@ +package hysteria + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "math" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/cache" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/rw" +) + +var udpMessagePool = sync.Pool{ + New: func() interface{} { + return new(udpMessage) + }, +} + +func allocMessage() *udpMessage { + message := udpMessagePool.Get().(*udpMessage) + message.referenced = true + return message +} + +func releaseMessages(messages []*udpMessage) { + for _, message := range messages { + if message != nil { + message.release() + } + } +} + +type udpMessage struct { + sessionID uint32 + packetID uint16 + fragmentID uint8 + fragmentTotal uint8 + host string + port uint16 + data *buf.Buffer + referenced bool +} + +func (m *udpMessage) release() { + if !m.referenced { + return + } + *m = udpMessage{} + udpMessagePool.Put(m) +} + +func (m *udpMessage) releaseMessage() { + m.data.Release() + m.release() +} + +func (m *udpMessage) pack() *buf.Buffer { + buffer := buf.NewSize(m.headerSize() + m.data.Len()) + common.Must( + binary.Write(buffer, binary.BigEndian, m.sessionID), + binary.Write(buffer, binary.BigEndian, uint16(len(m.host))), + common.Error(buffer.WriteString(m.host)), + binary.Write(buffer, binary.BigEndian, m.port), + binary.Write(buffer, binary.BigEndian, m.packetID), + binary.Write(buffer, binary.BigEndian, m.fragmentID), + binary.Write(buffer, binary.BigEndian, m.fragmentTotal), + binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())), + common.Error(buffer.Write(m.data.Bytes())), + ) + return buffer +} + +func (m *udpMessage) headerSize() int { + return 14 + len(m.host) +} + +func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { + if message.data.Len() <= maxPacketSize { + return []*udpMessage{message} + } + var fragments []*udpMessage + originPacket := message.data.Bytes() + udpMTU := maxPacketSize - message.headerSize() + for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { + fragment := allocMessage() + *fragment = *message + if remaining > udpMTU { + fragment.data = buf.As(originPacket[:udpMTU]) + originPacket = originPacket[udpMTU:] + } else { + fragment.data = buf.As(originPacket) + originPacket = nil + } + fragments = append(fragments, fragment) + } + fragmentTotal := uint16(len(fragments)) + for index, fragment := range fragments { + fragment.fragmentID = uint8(index) + fragment.fragmentTotal = uint8(fragmentTotal) + /*if index > 0 { + fragment.destination = "" + // not work in hysteria + }*/ + } + return fragments +} + +type udpPacketConn struct { + ctx context.Context + cancel common.ContextCancelCauseFunc + sessionID uint32 + quicConn quic.Connection + data chan *udpMessage + udpMTU int + udpMTUTime time.Time + packetId atomic.Uint32 + closeOnce sync.Once + defragger *udpDefragger + onDestroy func() +} + +func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { + ctx, cancel := common.ContextWithCancelCause(ctx) + return &udpPacketConn{ + ctx: ctx, + cancel: cancel, + quicConn: quicConn, + data: make(chan *udpMessage, 64), + defragger: newUDPDefragger(), + onDestroy: onDestroy, + } +} + +func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case p := <-c.data: + buffer = p.data + destination = M.ParseSocksaddrHostPort(p.host, p.port) + p.release() + return + case <-c.ctx.Done(): + return nil, M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = buffer.ReadOnceFrom(p.data) + destination = M.ParseSocksaddrHostPort(p.host, p.port) + p.releaseMessage() + return + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case p := <-c.data: + _, err = newBuffer().ReadOnceFrom(p.data) + destination = M.ParseSocksaddrHostPort(p.host, p.port) + p.releaseMessage() + return + case <-c.ctx.Done(): + return M.Socksaddr{}, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + select { + case pkt := <-c.data: + n = copy(p, pkt.data.Bytes()) + destination := M.ParseSocksaddrHostPort(pkt.host, pkt.port) + if destination.IsFqdn() { + addr = destination + } else { + addr = destination.UDPAddr() + } + pkt.releaseMessage() + return n, addr, nil + case <-c.ctx.Done(): + return 0, nil, io.ErrClosedPipe + } +} + +func (c *udpPacketConn) needFragment() bool { + nowTime := time.Now() + if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { + c.udpMTUTime = nowTime + return true + } + return false +} + +func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + defer buffer.Release() + select { + case <-c.ctx.Done(): + return net.ErrClosed + default: + } + if buffer.Len() > 0xffff { + return quic.ErrMessageTooLarge(0xffff) + } + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } + message := allocMessage() + *message = udpMessage{ + sessionID: c.sessionID, + packetID: uint16(packetId), + fragmentTotal: 1, + host: destination.AddrString(), + port: destination.Port, + data: buffer, + } + defer message.releaseMessage() + var err error + if c.needFragment() && buffer.Len() > c.udpMTU { + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + } else { + err = c.writePacket(message) + } + if err == nil { + return nil + } + var tooLargeErr quic.ErrMessageTooLarge + if !errors.As(err, &tooLargeErr) { + return err + } + c.udpMTU = int(tooLargeErr) + c.udpMTUTime = time.Now() + return c.writePackets(fragUDPMessage(message, c.udpMTU)) +} + +func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + select { + case <-c.ctx.Done(): + return 0, net.ErrClosed + default: + } + if len(p) > 0xffff { + return 0, quic.ErrMessageTooLarge(0xffff) + } + packetId := c.packetId.Add(1) + if packetId > math.MaxUint16 { + c.packetId.Store(0) + packetId = 0 + } + message := allocMessage() + destination := M.SocksaddrFromNet(addr) + *message = udpMessage{ + sessionID: c.sessionID, + packetID: uint16(packetId), + fragmentTotal: 1, + host: destination.AddrString(), + port: destination.Port, + data: buf.As(p), + } + if c.needFragment() && len(p) > c.udpMTU { + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + if err == nil { + return len(p), nil + } + } else { + err = c.writePacket(message) + } + if err == nil { + return len(p), nil + } + var tooLargeErr quic.ErrMessageTooLarge + if !errors.As(err, &tooLargeErr) { + return + } + c.udpMTU = int(tooLargeErr) + c.udpMTUTime = time.Now() + err = c.writePackets(fragUDPMessage(message, c.udpMTU)) + if err == nil { + return len(p), nil + } + return +} + +func (c *udpPacketConn) inputPacket(message *udpMessage) { + if message.fragmentTotal <= 1 { + select { + case c.data <- message: + default: + } + } else { + newMessage := c.defragger.feed(message) + if newMessage != nil { + select { + case c.data <- newMessage: + default: + } + } + } +} + +func (c *udpPacketConn) writePackets(messages []*udpMessage) error { + defer releaseMessages(messages) + for _, message := range messages { + err := c.writePacket(message) + if err != nil { + return err + } + } + return nil +} + +func (c *udpPacketConn) writePacket(message *udpMessage) error { + buffer := message.pack() + defer buffer.Release() + return c.quicConn.SendMessage(buffer.Bytes()) +} + +func (c *udpPacketConn) Close() error { + c.closeWithError(os.ErrClosed) + return nil +} + +func (c *udpPacketConn) closeWithError(err error) { + c.closeOnce.Do(func() { + c.cancel(err) + c.onDestroy() + }) +} + +func (c *udpPacketConn) LocalAddr() net.Addr { + return c.quicConn.LocalAddr() +} + +func (c *udpPacketConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *udpPacketConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +type udpDefragger struct { + packetMap *cache.LruCache[uint16, *packetItem] +} + +func newUDPDefragger() *udpDefragger { + return &udpDefragger{ + packetMap: cache.New( + cache.WithAge[uint16, *packetItem](10), + cache.WithUpdateAgeOnGet[uint16, *packetItem](), + cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { + releaseMessages(value.messages) + }), + ), + } +} + +type packetItem struct { + access sync.Mutex + messages []*udpMessage + count uint8 +} + +func (d *udpDefragger) feed(m *udpMessage) *udpMessage { + if m.fragmentTotal <= 1 { + return m + } + if m.fragmentID >= m.fragmentTotal { + return nil + } + item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) + item.access.Lock() + defer item.access.Unlock() + if int(m.fragmentTotal) != len(item.messages) { + releaseMessages(item.messages) + item.messages = make([]*udpMessage, m.fragmentTotal) + item.count = 1 + item.messages[m.fragmentID] = m + return nil + } + if item.messages[m.fragmentID] != nil { + return nil + } + item.messages[m.fragmentID] = m + item.count++ + if int(item.count) != len(item.messages) { + return nil + } + newMessage := allocMessage() + newMessage.sessionID = m.sessionID + newMessage.packetID = m.packetID + newMessage.host = item.messages[0].host + newMessage.port = item.messages[0].port + var finalLength int + for _, message := range item.messages { + finalLength += message.data.Len() + } + if finalLength > 0 { + newMessage.data = buf.NewSize(finalLength) + for _, message := range item.messages { + newMessage.data.Write(message.data.Bytes()) + message.releaseMessage() + } + item.messages = nil + return newMessage + } + item.messages = nil + return nil +} + +func newPacketItem() *packetItem { + return new(packetItem) +} + +func decodeUDPMessage(message *udpMessage, data []byte) error { + reader := bytes.NewReader(data) + err := binary.Read(reader, binary.BigEndian, &message.sessionID) + if err != nil { + return err + } + var hostLen uint16 + err = binary.Read(reader, binary.BigEndian, &hostLen) + if err != nil { + return err + } + message.host, err = rw.ReadString(reader, int(hostLen)) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.port) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.packetID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) + if err != nil { + return err + } + var dataLen uint16 + err = binary.Read(reader, binary.BigEndian, &dataLen) + if err != nil { + return err + } + if reader.Len() != int(dataLen) { + return E.New("invalid data length") + } + message.data = buf.As(data[len(data)-reader.Len():]) + return nil +} diff --git a/hysteria/protocol.go b/hysteria/protocol.go new file mode 100644 index 0000000..c0243df --- /dev/null +++ b/hysteria/protocol.go @@ -0,0 +1,253 @@ +package hysteria + +import ( + "encoding/binary" + "io" + "time" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" +) + +const ( + MbpsToBps = 125000 + MinSpeedBPS = 16384 + DefaultALPN = "hysteria" + DefaultStreamReceiveWindow = 8388608 // 8MB + DefaultConnReceiveWindow = DefaultStreamReceiveWindow * 5 / 2 // 20MB + DefaultMaxIdleTimeout = 30 * time.Second + DefaultKeepAlivePeriod = 10 * time.Second +) + +const ( + ProtocolVersion = 3 + ProtocolTimeout = 10 * time.Second + ErrorCodeGeneric = 0 + ErrorCodeProtocolError = 1 + ErrorCodeAuthError = 2 +) + +type ClientHello struct { + SendBPS uint64 + RecvBPS uint64 + Auth string +} + +func WriteClientHello(stream io.Writer, hello ClientHello) error { + var requestLen int + requestLen += 1 // version + requestLen += 8 // sendBPS + requestLen += 8 // recvBPS + requestLen += 2 // auth len + requestLen += len(hello.Auth) + request := buf.NewSize(requestLen) + defer request.Release() + common.Must( + request.WriteByte(ProtocolVersion), + binary.Write(request, binary.BigEndian, hello.SendBPS), + binary.Write(request, binary.BigEndian, hello.RecvBPS), + binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))), + common.Error(request.WriteString(hello.Auth)), + ) + return common.Error(stream.Write(request.Bytes())) +} + +func ReadClientHello(reader io.Reader) (*ClientHello, error) { + var version uint8 + err := binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return nil, err + } + if version != ProtocolVersion { + return nil, E.New("unsupported client version: ", version) + } + var clientHello ClientHello + err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS) + if err != nil { + return nil, err + } + err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS) + if err != nil { + return nil, err + } + if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 { + return nil, E.New("invalid rate from client") + } + var authLen uint16 + err = binary.Read(reader, binary.BigEndian, &authLen) + if err != nil { + return nil, err + } + clientHello.Auth, err = rw.ReadString(reader, int(authLen)) + if err != nil { + return nil, err + } + return &clientHello, nil +} + +type ServerHello struct { + OK bool + SendBPS uint64 + RecvBPS uint64 + Message string +} + +func ReadServerHello(stream io.Reader) (*ServerHello, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // message len + response := buf.NewSize(responseLen) + defer response.Release() + _, err := response.ReadFullFrom(stream, responseLen) + if err != nil { + return nil, err + } + var serverHello ServerHello + serverHello.OK = response.Byte(0) == 1 + serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9)) + serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17)) + messageLen := binary.BigEndian.Uint16(response.Range(17, 19)) + if messageLen == 0 { + return &serverHello, nil + } + message := make([]byte, messageLen) + _, err = io.ReadFull(stream, message) + if err != nil { + return nil, err + } + serverHello.Message = string(message) + return &serverHello, nil +} + +func WriteServerHello(stream io.Writer, hello ServerHello) error { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // message len + responseLen += len(hello.Message) + response := buf.NewSize(responseLen) + defer response.Release() + if hello.OK { + common.Must(response.WriteByte(1)) + } else { + common.Must(response.WriteByte(0)) + } + common.Must( + binary.Write(response, binary.BigEndian, hello.SendBPS), + binary.Write(response, binary.BigEndian, hello.RecvBPS), + binary.Write(response, binary.BigEndian, uint16(len(hello.Message))), + common.Error(response.WriteString(hello.Message)), + ) + return common.Error(stream.Write(response.Bytes())) +} + +type ClientRequest struct { + UDP bool + Host string + Port uint16 +} + +func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { + var clientRequest ClientRequest + err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP) + if err != nil { + return nil, err + } + var hostLen uint16 + err = binary.Read(stream, binary.BigEndian, &hostLen) + if err != nil { + return nil, err + } + host := make([]byte, hostLen) + _, err = io.ReadFull(stream, host) + if err != nil { + return nil, err + } + clientRequest.Host = string(host) + err = binary.Read(stream, binary.BigEndian, &clientRequest.Port) + if err != nil { + return nil, err + } + return &clientRequest, nil +} + +func WriteClientRequest(request ClientRequest, payload []byte) *buf.Buffer { + var requestLen int + requestLen += 1 // udp + requestLen += 2 // host len + requestLen += len(request.Host) + requestLen += 2 // port + buffer := buf.NewSize(requestLen + len(payload)) + if request.UDP { + common.Must(buffer.WriteByte(1)) + } else { + common.Must(buffer.WriteByte(0)) + } + common.Must( + binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))), + common.Error(buffer.WriteString(request.Host)), + binary.Write(buffer, binary.BigEndian, request.Port), + common.Error(buffer.Write(payload)), + ) + return buffer +} + +type ServerResponse struct { + OK bool + UDPSessionID uint32 + Message string +} + +func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 4 // udp session id + responseLen += 2 // message len + response := buf.NewSize(responseLen) + defer response.Release() + _, err := response.ReadFullFrom(stream, responseLen) + if err != nil { + return nil, err + } + var serverResponse ServerResponse + serverResponse.OK = response.Byte(0) == 1 + serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5)) + messageLen := binary.BigEndian.Uint16(response.Range(5, 7)) + if messageLen == 0 { + return &serverResponse, nil + } + message := make([]byte, messageLen) + _, err = io.ReadFull(stream, message) + if err != nil { + return nil, err + } + serverResponse.Message = string(message) + return &serverResponse, nil +} + +func WriteServerResponse(stream quic.Stream, response ServerResponse) error { + var responseLen int + responseLen += 1 // ok + responseLen += 4 // udp session id + responseLen += 2 // message len + responseLen += len(response.Message) + buffer := buf.NewSize(responseLen) + defer buffer.Release() + if response.OK { + common.Must(buffer.WriteByte(1)) + } else { + common.Must(buffer.WriteByte(0)) + } + common.Must( + binary.Write(buffer, binary.BigEndian, response.UDPSessionID), + binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))), + common.Error(buffer.WriteString(response.Message)), + ) + return common.Error(stream.Write(buffer.Bytes())) +} diff --git a/hysteria/service.go b/hysteria/service.go new file mode 100644 index 0000000..92ada68 --- /dev/null +++ b/hysteria/service.go @@ -0,0 +1,376 @@ +package hysteria + +import ( + "context" + "errors" + "io" + "math" + "net" + "os" + "runtime" + "sync" + + "github.com/sagernet/quic-go" + "github.com/sagernet/sing-quic" + hyCC "github.com/sagernet/sing-quic/hysteria/congestion" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/baderror" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" +) + +type ServiceOptions struct { + Context context.Context + Logger logger.Logger + BrutalDebug bool + SendBPS uint64 + ReceiveBPS uint64 + XPlusPassword string + TLSConfig aTLS.ServerConfig + UDPDisabled bool + Handler ServerHandler + + // Legacy options + + ConnReceiveWindow uint64 + StreamReceiveWindow uint64 + MaxIncomingStreams int64 + DisableMTUDiscovery bool +} + +type ServerHandler interface { + N.TCPConnectionHandler + N.UDPConnectionHandler +} + +type Service[U comparable] struct { + ctx context.Context + logger logger.Logger + brutalDebug bool + sendBPS uint64 + receiveBPS uint64 + xplusPassword string + tlsConfig aTLS.ServerConfig + quicConfig *quic.Config + userMap map[string]U + udpDisabled bool + handler ServerHandler + quicListener io.Closer +} + +func NewService[U comparable](options ServiceOptions) (*Service[U], error) { + quicConfig := &quic.Config{ + DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), + EnableDatagrams: !options.UDPDisabled, + MaxIncomingStreams: 1 << 60, + InitialStreamReceiveWindow: DefaultStreamReceiveWindow, + MaxStreamReceiveWindow: DefaultStreamReceiveWindow, + InitialConnectionReceiveWindow: DefaultConnReceiveWindow, + MaxConnectionReceiveWindow: DefaultConnReceiveWindow, + MaxIdleTimeout: DefaultMaxIdleTimeout, + KeepAlivePeriod: DefaultKeepAlivePeriod, + } + if options.StreamReceiveWindow != 0 { + quicConfig.InitialStreamReceiveWindow = options.StreamReceiveWindow + quicConfig.MaxStreamReceiveWindow = options.StreamReceiveWindow + } + if options.ConnReceiveWindow != 0 { + quicConfig.InitialConnectionReceiveWindow = options.ConnReceiveWindow + quicConfig.MaxConnectionReceiveWindow = options.ConnReceiveWindow + } + if options.MaxIncomingStreams > 0 { + quicConfig.MaxIncomingStreams = int64(options.MaxIncomingStreams) + } + if options.DisableMTUDiscovery { + quicConfig.DisablePathMTUDiscovery = true + } + if len(options.TLSConfig.NextProtos()) == 0 { + options.TLSConfig.SetNextProtos([]string{DefaultALPN}) + } + if options.SendBPS == 0 { + return nil, E.New("missing upload speed configuration") + } + if options.ReceiveBPS == 0 { + return nil, E.New("missing download speed configuration") + } + return &Service[U]{ + ctx: options.Context, + logger: options.Logger, + brutalDebug: options.BrutalDebug, + sendBPS: options.SendBPS, + receiveBPS: options.ReceiveBPS, + xplusPassword: options.XPlusPassword, + tlsConfig: options.TLSConfig, + quicConfig: quicConfig, + userMap: make(map[string]U), + handler: options.Handler, + udpDisabled: options.UDPDisabled, + }, nil +} + +func (s *Service[U]) UpdateUsers(userList []U, passwordList []string) { + userMap := make(map[string]U) + for i, user := range userList { + userMap[passwordList[i]] = user + } + s.userMap = userMap +} + +func (s *Service[U]) Start(conn net.PacketConn) error { + if s.xplusPassword != "" { + conn = NewXPlusPacketConn(conn, []byte(s.xplusPassword)) + } + listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) + if err != nil { + return err + } + s.quicListener = listener + go s.loopConnections(listener) + return nil +} + +func (s *Service[U]) Close() error { + return common.Close( + s.quicListener, + ) +} + +func (s *Service[U]) loopConnections(listener qtls.Listener) { + for { + connection, err := listener.Accept(s.ctx) + if err != nil { + if E.IsClosedOrCanceled(err) || errors.Is(err, quic.ErrServerClosed) { + s.logger.Debug(E.Cause(err, "listener closed")) + } else { + s.logger.Error(E.Cause(err, "listener closed")) + } + return + } + session := &serverSession[U]{ + Service: s, + ctx: s.ctx, + quicConn: connection, + source: M.SocksaddrFromNet(connection.RemoteAddr()), + connDone: make(chan struct{}), + udpConnMap: make(map[uint32]*udpPacketConn), + } + go session.handleConnection() + } +} + +type serverSession[U comparable] struct { + *Service[U] + ctx context.Context + quicConn quic.Connection + source M.Socksaddr + connAccess sync.Mutex + connDone chan struct{} + connErr error + authUser U + udpAccess sync.RWMutex + udpConnMap map[uint32]*udpPacketConn + udpSessionID uint32 +} + +func (s *serverSession[U]) handleConnection() { + ctx, cancel := context.WithTimeout(s.ctx, ProtocolTimeout) + controlStream, err := s.quicConn.AcceptStream(ctx) + cancel() + if err != nil { + s.closeWithError0(ErrorCodeProtocolError, err) + return + } + clientHello, err := ReadClientHello(controlStream) + if err != nil { + s.closeWithError0(ErrorCodeProtocolError, E.Cause(err, "read client hello")) + return + } + user, loaded := s.userMap[clientHello.Auth] + if !loaded { + WriteServerHello(controlStream, ServerHello{ + OK: false, + Message: "Wrong password", + }) + s.closeWithError0(ErrorCodeAuthError, E.Cause(err, "authentication failed, auth_str=", clientHello.Auth)) + return + } + err = WriteServerHello(controlStream, ServerHello{ + OK: true, + SendBPS: s.sendBPS, + RecvBPS: s.receiveBPS, + }) + if err != nil { + s.closeWithError(err) + return + } + s.authUser = user + s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(uint64(math.Min(float64(s.sendBPS), float64(clientHello.RecvBPS))), s.brutalDebug, s.logger)) + if !s.udpDisabled { + go s.loopMessages() + } + s.loopStreams() +} + +func (s *serverSession[U]) loopStreams() { + for { + stream, err := s.quicConn.AcceptStream(s.ctx) + if err != nil { + return + } + go func() { + err = s.handleStream(stream) + if err != nil { + stream.CancelRead(0) + stream.Close() + s.logger.Error(E.Cause(err, "handle stream request")) + } + }() + } +} + +func (s *serverSession[U]) handleStream(stream quic.Stream) error { + request, err := ReadClientRequest(stream) + if err != nil { + return E.New("read TCP request") + } + ctx := auth.ContextWithUser(s.ctx, s.authUser) + if !request.UDP { + _ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{ + Source: s.source, + Destination: M.ParseSocksaddrHostPort(request.Host, request.Port), + }) + } else { + if s.udpDisabled { + return WriteServerResponse(stream, ServerResponse{ + OK: false, + Message: "UDP disabled by server", + }) + } + var sessionID uint32 + udpConn := newUDPPacketConn(ctx, s.quicConn, func() { + stream.CancelRead(0) + stream.Close() + s.udpAccess.Lock() + delete(s.udpConnMap, sessionID) + s.udpAccess.Unlock() + }) + s.udpAccess.Lock() + s.udpSessionID++ + sessionID = s.udpSessionID + udpConn.sessionID = sessionID + s.udpConnMap[sessionID] = udpConn + s.udpAccess.Unlock() + err = WriteServerResponse(stream, ServerResponse{ + OK: true, + UDPSessionID: sessionID, + }) + if err != nil { + udpConn.closeWithError(E.Cause(err, "write server response")) + return err + } + go s.handler.NewPacketConnection(udpConn.ctx, udpConn, M.Metadata{ + Source: s.source, + Destination: M.ParseSocksaddrHostPort(request.Host, request.Port), + }) + holdBuffer := make([]byte, 1024) + for { + _, hErr := stream.Read(holdBuffer) + if hErr != nil { + break + } + } + udpConn.closeWithError(E.Cause(net.ErrClosed, "hold stream closed")) + } + return nil +} + +func (s *serverSession[U]) closeWithError(err error) { + s.closeWithError0(ErrorCodeGeneric, err) +} + +func (s *serverSession[U]) closeWithError0(errorCode int, err error) { + s.connAccess.Lock() + defer s.connAccess.Unlock() + select { + case <-s.connDone: + return + default: + s.connErr = err + close(s.connDone) + } + if E.IsClosedOrCanceled(err) { + s.logger.Debug(E.Cause(err, "connection failed")) + } else { + s.logger.Error(E.Cause(err, "connection failed")) + } + switch errorCode { + case ErrorCodeProtocolError: + _ = s.quicConn.CloseWithError(quic.ApplicationErrorCode(errorCode), "protocol error") + case ErrorCodeAuthError: + _ = s.quicConn.CloseWithError(quic.ApplicationErrorCode(errorCode), "auth error") + default: + _ = s.quicConn.CloseWithError(quic.ApplicationErrorCode(errorCode), "") + } +} + +type serverConn struct { + quic.Stream + responseWritten bool +} + +func (c *serverConn) HandshakeFailure(err error) error { + if c.responseWritten { + return os.ErrClosed + } + c.responseWritten = true + return WriteServerResponse(c.Stream, ServerResponse{ + OK: false, + Message: err.Error(), + }) +} + +func (c *serverConn) HandshakeSuccess() error { + if c.responseWritten { + return nil + } + c.responseWritten = true + return WriteServerResponse(c.Stream, ServerResponse{ + OK: true, + }) +} + +func (c *serverConn) Read(p []byte) (n int, err error) { + n, err = c.Stream.Read(p) + return n, baderror.WrapQUIC(err) +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + if !c.responseWritten { + c.responseWritten = true + err = WriteServerResponse(c.Stream, ServerResponse{ + OK: true, + }) + if err != nil { + return 0, baderror.WrapQUIC(err) + } + } + n, err = c.Stream.Write(p) + return n, baderror.WrapQUIC(err) +} + +func (c *serverConn) LocalAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *serverConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *serverConn) Close() error { + c.Stream.CancelRead(0) + return c.Stream.Close() +} diff --git a/hysteria/service_packet.go b/hysteria/service_packet.go new file mode 100644 index 0000000..a1fbab2 --- /dev/null +++ b/hysteria/service_packet.go @@ -0,0 +1,43 @@ +package hysteria + +import ( + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +func (s *serverSession[U]) loopMessages() { + for { + message, err := s.quicConn.ReceiveMessage(s.ctx) + if err != nil { + s.closeWithError(E.Cause(err, "receive message")) + return + } + hErr := s.handleMessage(message) + if hErr != nil { + s.closeWithError(E.Cause(hErr, "handle message")) + return + } + } +} + +func (s *serverSession[U]) handleMessage(data []byte) error { + message := allocMessage() + err := decodeUDPMessage(message, data) + if err != nil { + message.release() + return E.Cause(err, "decode UDP message") + } + return s.handleUDPMessage(message) +} + +func (s *serverSession[U]) handleUDPMessage(message *udpMessage) error { + s.udpAccess.RLock() + udpConn, loaded := s.udpConnMap[message.sessionID] + s.udpAccess.RUnlock() + if !loaded || common.Done(udpConn.ctx) { + message.release() + return E.New("unknown session iD: ", message.sessionID) + } + udpConn.inputPacket(message) + return nil +} diff --git a/hysteria/xplus.go b/hysteria/xplus.go new file mode 100644 index 0000000..14e0eaa --- /dev/null +++ b/hysteria/xplus.go @@ -0,0 +1,118 @@ +package hysteria + +import ( + "crypto/sha256" + "math/rand" + "net" + "sync" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const xplusSaltLen = 16 + +func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn { + vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn) + if isVectorised { + return &VectorisedXPlusConn{ + XPlusPacketConn: XPlusPacketConn{ + PacketConn: conn, + key: key, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + }, + writer: vectorisedWriter, + } + } else { + return &XPlusPacketConn{ + PacketConn: conn, + key: key, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } + } +} + +type XPlusPacketConn struct { + net.PacketConn + key []byte + randAccess sync.Mutex + rand *rand.Rand +} + +func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } else if n < xplusSaltLen { + n = 0 + return + } + key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...)) + for i := range p[xplusSaltLen:] { + p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size] + } + n -= xplusSaltLen + return +} + +func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + // can't use unsafe buffer on WriteTo + buffer := buf.NewSize(len(p) + xplusSaltLen) + defer buffer.Release() + salt := buffer.Extend(xplusSaltLen) + c.randAccess.Lock() + _, _ = c.rand.Read(salt) + c.randAccess.Unlock() + key := sha256.Sum256(append(c.key, salt...)) + for i := range p { + common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size])) + } + return c.PacketConn.WriteTo(buffer.Bytes(), addr) +} + +func (c *XPlusPacketConn) Upstream() any { + return c.PacketConn +} + +type VectorisedXPlusConn struct { + XPlusPacketConn + writer N.VectorisedPacketWriter +} + +func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + header := buf.NewSize(xplusSaltLen) + defer header.Release() + salt := header.Extend(xplusSaltLen) + c.randAccess.Lock() + _, _ = c.rand.Read(salt) + c.randAccess.Unlock() + key := sha256.Sum256(append(c.key, salt...)) + for i := range p { + p[i] ^= key[i%sha256.Size] + } + return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr)) +} + +func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { + header := buf.NewSize(xplusSaltLen) + defer header.Release() + salt := header.Extend(xplusSaltLen) + c.randAccess.Lock() + _, _ = c.rand.Read(salt) + c.randAccess.Unlock() + key := sha256.Sum256(append(c.key, salt...)) + var index int + for _, buffer := range buffers { + data := buffer.Bytes() + for i := range data { + data[i] ^= key[index%sha256.Size] + index++ + } + } + buffers = append([]*buf.Buffer{header}, buffers...) + return c.writer.WriteVectorisedPacket(buffers, destination) +} diff --git a/hysteria2/client.go b/hysteria2/client.go index 29b05d2..df3dd00 100644 --- a/hysteria2/client.go +++ b/hysteria2/client.go @@ -16,7 +16,8 @@ import ( "github.com/sagernet/sing-quic" congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1" congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2" - hyCC "github.com/sagernet/sing-quic/hysteria2/congestion" + "github.com/sagernet/sing-quic/hysteria" + hyCC "github.com/sagernet/sing-quic/hysteria/congestion" "github.com/sagernet/sing-quic/hysteria2/internal/protocol" "github.com/sagernet/sing/common/baderror" "github.com/sagernet/sing/common/bufio" @@ -28,13 +29,6 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -const ( - defaultStreamReceiveWindow = 8388608 // 8MB - defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB - defaultMaxIdleTimeout = 30 * time.Second - defaultKeepAlivePeriod = 10 * time.Second -) - type ClientOptions struct { Context context.Context Dialer N.Dialer @@ -70,13 +64,13 @@ type Client struct { func NewClient(options ClientOptions) (*Client, error) { quicConfig := &quic.Config{ DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), - EnableDatagrams: true, - InitialStreamReceiveWindow: defaultStreamReceiveWindow, - MaxStreamReceiveWindow: defaultStreamReceiveWindow, - InitialConnectionReceiveWindow: defaultConnReceiveWindow, - MaxConnectionReceiveWindow: defaultConnReceiveWindow, - MaxIdleTimeout: defaultMaxIdleTimeout, - KeepAlivePeriod: defaultKeepAlivePeriod, + EnableDatagrams: !options.UDPDisabled, + InitialStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, + MaxStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, + InitialConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, + MaxConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, + MaxIdleTimeout: hysteria.DefaultMaxIdleTimeout, + KeepAlivePeriod: hysteria.DefaultKeepAlivePeriod, } return &Client{ ctx: options.Context, @@ -176,7 +170,7 @@ func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { quicConn: quicConn, rawConn: udpConn, connDone: make(chan struct{}), - udpDisabled: c.udpDisabled || !authResponse.UDPEnabled, + udpDisabled: !authResponse.UDPEnabled, udpConnMap: make(map[uint32]*udpPacketConn), } if !c.udpDisabled { diff --git a/hysteria2/salamander.go b/hysteria2/salamander.go index 4b4e0f9..c057d1f 100644 --- a/hysteria2/salamander.go +++ b/hysteria2/salamander.go @@ -16,7 +16,7 @@ const salamanderSaltLen = 8 const ObfsTypeSalamander = "salamander" -type Salamander struct { +type SalamanderPacketConn struct { net.PacketConn password []byte } @@ -24,22 +24,22 @@ type Salamander struct { func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn { writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn) if isVectorised { - return &VectorisedSalamander{ - Salamander: Salamander{ + return &VectorisedSalamanderPacketConn{ + SalamanderPacketConn: SalamanderPacketConn{ PacketConn: conn, password: password, }, writer: writer, } } else { - return &Salamander{ + return &SalamanderPacketConn{ PacketConn: conn, password: password, } } } -func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func (s *SalamanderPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = s.PacketConn.ReadFrom(p) if err != nil { return @@ -54,7 +54,7 @@ func (s *Salamander) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return n - salamanderSaltLen, addr, nil } -func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { +func (s *SalamanderPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { buffer := buf.NewSize(len(p) + salamanderSaltLen) defer buffer.Release() buffer.WriteRandom(salamanderSaltLen) @@ -69,12 +69,12 @@ func (s *Salamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { return len(p), nil } -type VectorisedSalamander struct { - Salamander +type VectorisedSalamanderPacketConn struct { + SalamanderPacketConn writer N.VectorisedPacketWriter } -func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err error) { +func (s *VectorisedSalamanderPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { buffer := buf.NewSize(salamanderSaltLen) buffer.WriteRandom(salamanderSaltLen) key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) @@ -88,7 +88,7 @@ func (s *VectorisedSalamander) WriteTo(p []byte, addr net.Addr) (n int, err erro return len(p), nil } -func (s *VectorisedSalamander) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { +func (s *VectorisedSalamanderPacketConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { header := buf.NewSize(salamanderSaltLen) defer header.Release() header.WriteRandom(salamanderSaltLen) diff --git a/hysteria2/service.go b/hysteria2/service.go index 7c3866d..4e477bc 100644 --- a/hysteria2/service.go +++ b/hysteria2/service.go @@ -2,6 +2,7 @@ package hysteria2 import ( "context" + "errors" "io" "net" "net/http" @@ -16,13 +17,13 @@ import ( "github.com/sagernet/sing-quic" congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1" congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2" - hyCC "github.com/sagernet/sing-quic/hysteria2/congestion" + "github.com/sagernet/sing-quic/hysteria" + hyCC "github.com/sagernet/sing-quic/hysteria/congestion" "github.com/sagernet/sing-quic/hysteria2/internal/protocol" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/baderror" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -71,12 +72,12 @@ func NewService[U comparable](options ServiceOptions) (*Service[U], error) { DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), EnableDatagrams: !options.UDPDisabled, MaxIncomingStreams: 1 << 60, - InitialStreamReceiveWindow: defaultStreamReceiveWindow, - MaxStreamReceiveWindow: defaultStreamReceiveWindow, - InitialConnectionReceiveWindow: defaultConnReceiveWindow, - MaxConnectionReceiveWindow: defaultConnReceiveWindow, - MaxIdleTimeout: defaultMaxIdleTimeout, - KeepAlivePeriod: defaultKeepAlivePeriod, + InitialStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, + MaxStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, + InitialConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, + MaxConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, + MaxIdleTimeout: hysteria.DefaultMaxIdleTimeout, + KeepAlivePeriod: hysteria.DefaultKeepAlivePeriod, } if options.MasqueradeHandler == nil { options.MasqueradeHandler = http.NotFoundHandler() @@ -133,7 +134,7 @@ func (s *Service[U]) loopConnections(listener qtls.Listener) { for { connection, err := listener.Accept(s.ctx) if err != nil { - if E.IsClosedOrCanceled(err) { + if E.IsClosedOrCanceled(err) || errors.Is(err, quic.ErrServerClosed) { s.logger.Debug(E.Cause(err, "listener closed")) } else { s.logger.Error(E.Cause(err, "listener closed")) @@ -195,14 +196,11 @@ func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.authUser = user s.authenticated = true if !s.ignoreClientBandwidth && request.Rx > 0 { - var sendBps uint64 - if s.sendBPS > 0 && s.sendBPS < request.Rx { - sendBps = s.sendBPS - } else { - sendBps = request.Rx + rx := request.Rx + if s.sendBPS > 0 && rx > s.sendBPS { + rx = s.sendBPS } - format.ToString(1024 * 1024) - s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(sendBps, s.brutalDebug, s.logger)) + s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(rx, s.brutalDebug, s.logger)) } else { timeFunc := ntp.TimeFuncFromContext(s.ctx) if timeFunc == nil { diff --git a/tuic/service.go b/tuic/service.go index ecd3391..e8a191a 100644 --- a/tuic/service.go +++ b/tuic/service.go @@ -4,10 +4,10 @@ import ( "bytes" "context" "encoding/binary" + "errors" "io" "net" "runtime" - "strings" "sync" "time" @@ -113,7 +113,7 @@ func (s *Service[U]) Start(conn net.PacketConn) error { for { connection, hErr := listener.Accept(s.ctx) if hErr != nil { - if E.IsClosedOrCanceled(hErr) { + if E.IsClosedOrCanceled(hErr) || errors.Is(hErr, quic.ErrServerClosed) { s.logger.Debug(E.Cause(hErr, "listener closed")) } else { s.logger.Error(E.Cause(hErr, "listener closed")) @@ -133,7 +133,7 @@ func (s *Service[U]) Start(conn net.PacketConn) error { for { connection, hErr := listener.Accept(s.ctx) if hErr != nil { - if strings.Contains(hErr.Error(), "server closed") { + if E.IsClosedOrCanceled(hErr) || errors.Is(hErr, quic.ErrServerClosed) { s.logger.Debug(E.Cause(hErr, "listener closed")) } else { s.logger.Error(E.Cause(hErr, "listener closed"))