package cs import ( "context" "crypto/tls" "errors" "fmt" "net" "github.com/apernet/hysteria/core/congestion" "github.com/apernet/hysteria/core/acl" "github.com/apernet/hysteria/core/pmtud" "github.com/apernet/hysteria/core/transport" "github.com/lunixbochs/struc" "github.com/quic-go/quic-go" ) type ( ConnectFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) DisconnectFunc func(addr net.Addr, auth []byte, err error) TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32) UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error) ) type TrafficCounter interface { Rx(auth string, n int) Tx(auth string, n int) IncConn(auth string) // increase connection count DecConn(auth string) // decrease connection count } type Server struct { transport *transport.ServerTransport sendBPS, recvBPS uint64 disableUDP bool aclEngine *acl.Engine connectFunc ConnectFunc disconnectFunc DisconnectFunc tcpRequestFunc TCPRequestFunc tcpErrorFunc TCPErrorFunc udpRequestFunc UDPRequestFunc udpErrorFunc UDPErrorFunc trafficCounter TrafficCounter pktConn net.PacketConn listener quic.Listener } func NewServer(tlsConfig *tls.Config, quicConfig *quic.Config, pktConn net.PacketConn, transport *transport.ServerTransport, sendBPS uint64, recvBPS uint64, disableUDP bool, aclEngine *acl.Engine, connectFunc ConnectFunc, disconnectFunc DisconnectFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, trafficCounter TrafficCounter, ) (*Server, error) { quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery listener, err := quic.Listen(pktConn, tlsConfig, quicConfig) if err != nil { _ = pktConn.Close() return nil, err } s := &Server{ pktConn: pktConn, listener: listener, transport: transport, sendBPS: sendBPS, recvBPS: recvBPS, disableUDP: disableUDP, aclEngine: aclEngine, connectFunc: connectFunc, disconnectFunc: disconnectFunc, tcpRequestFunc: tcpRequestFunc, tcpErrorFunc: tcpErrorFunc, udpRequestFunc: udpRequestFunc, udpErrorFunc: udpErrorFunc, trafficCounter: trafficCounter, } return s, nil } func (s *Server) Serve() error { for { cc, err := s.listener.Accept(context.Background()) if err != nil { return err } go s.handleClient(cc) } } func (s *Server) Close() error { err := s.listener.Close() _ = s.pktConn.Close() return err } func (s *Server) handleClient(cc quic.Connection) { // Expect the client to create a control stream to send its own information ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) stream, err := cc.AcceptStream(ctx) ctxCancel() if err != nil { _ = qErrorProtocol.Send(cc) return } // Handle the control stream auth, ok, err := s.handleControlStream(cc, stream) if err != nil { _ = qErrorProtocol.Send(cc) return } if !ok { _ = qErrorAuth.Send(cc) return } // Start accepting streams and messages sc := newServerClient(cc, s.transport, auth, s.disableUDP, s.aclEngine, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.trafficCounter) err = sc.Run() _ = qErrorGeneric.Send(cc) s.disconnectFunc(cc.RemoteAddr(), auth, err) } // Auth & negotiate speed func (s *Server) handleControlStream(cc quic.Connection, stream quic.Stream) ([]byte, bool, error) { // Check version vb := make([]byte, 1) _, err := stream.Read(vb) if err != nil { return nil, false, err } if vb[0] != protocolVersion { return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion) } // Parse client hello var ch clientHello err = struc.Unpack(stream, &ch) if err != nil { return nil, false, err } // Speed if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { return nil, false, errors.New("invalid rate from client") } serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS if s.sendBPS > 0 && serverSendBPS > s.sendBPS { serverSendBPS = s.sendBPS } if s.recvBPS > 0 && serverRecvBPS > s.recvBPS { serverRecvBPS = s.recvBPS } // Auth ok, msg := s.connectFunc(cc.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) // Response err = struc.Pack(stream, &serverHello{ OK: ok, Rate: maxRate{ SendBPS: serverSendBPS, RecvBPS: serverRecvBPS, }, Message: msg, }) if err != nil { return nil, false, err } // Set the congestion accordingly if ok { cc.SetCongestionControl(congestion.NewBrutalSender(serverSendBPS)) } return ch.Auth, ok, nil }