From b55f3531e703ba957f9916d8d95cd85c159c3a0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 15 Sep 2023 14:42:01 +0800 Subject: [PATCH] Improve server API --- hysteria2/{server.go => service.go} | 66 +++++++------- .../{server_packet.go => service_packet.go} | 6 +- tuic/{server.go => service.go} | 85 +++++++++---------- tuic/{server_packet.go => service_packet.go} | 6 +- 4 files changed, 75 insertions(+), 88 deletions(-) rename hysteria2/{server.go => service.go} (86%) rename hysteria2/{server_packet.go => service_packet.go} (87%) rename tuic/{server.go => service.go} (85%) rename tuic/{server_packet.go => service_packet.go} (89%) diff --git a/hysteria2/server.go b/hysteria2/service.go similarity index 86% rename from hysteria2/server.go rename to hysteria2/service.go index 5908a8a..d626827 100644 --- a/hysteria2/server.go +++ b/hysteria2/service.go @@ -26,7 +26,7 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -type ServerOptions struct { +type ServiceOptions struct { Context context.Context Logger logger.Logger SendBPS uint64 @@ -34,23 +34,17 @@ type ServerOptions struct { IgnoreClientBandwidth bool SalamanderPassword string TLSConfig aTLS.ServerConfig - Users []User UDPDisabled bool Handler ServerHandler MasqueradeHandler http.Handler } -type User struct { - Name string - Password string -} - type ServerHandler interface { N.TCPConnectionHandler N.UDPConnectionHandler } -type Server struct { +type Service[U comparable] struct { ctx context.Context logger logger.Logger sendBPS uint64 @@ -59,14 +53,14 @@ type Server struct { salamanderPassword string tlsConfig aTLS.ServerConfig quicConfig *quic.Config - userMap map[string]User + userMap map[string]U udpDisabled bool handler ServerHandler masqueradeHandler http.Handler quicListener io.Closer } -func NewServer(options ServerOptions) (*Server, error) { +func NewService[U comparable](options ServiceOptions) (*Service[U], error) { quicConfig := &quic.Config{ DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), EnableDatagrams: !options.UDPDisabled, @@ -78,17 +72,10 @@ func NewServer(options ServerOptions) (*Server, error) { MaxIdleTimeout: defaultMaxIdleTimeout, KeepAlivePeriod: defaultKeepAlivePeriod, } - if len(options.Users) == 0 { - return nil, E.New("missing users") - } - userMap := make(map[string]User) - for _, user := range options.Users { - userMap[user.Password] = user - } if options.MasqueradeHandler == nil { options.MasqueradeHandler = http.NotFoundHandler() } - return &Server{ + return &Service[U]{ ctx: options.Context, logger: options.Logger, sendBPS: options.SendBPS, @@ -97,14 +84,22 @@ func NewServer(options ServerOptions) (*Server, error) { salamanderPassword: options.SalamanderPassword, tlsConfig: options.TLSConfig, quicConfig: quicConfig, - userMap: userMap, + userMap: make(map[string]U), udpDisabled: options.UDPDisabled, handler: options.Handler, masqueradeHandler: options.MasqueradeHandler, }, nil } -func (s *Server) Start(conn net.PacketConn) error { +func (s *Service[U]) UpdateUsers(userList []U, passwordList []string) { + userMap := make(map[string]U) + for i, user := range userList { + userMap[passwordList[i]] = user + } + s.userMap = userMap +} + +func (s *Service[U]) Start(conn net.PacketConn) error { if s.salamanderPassword != "" { conn = NewSalamanderConn(conn, []byte(s.salamanderPassword)) } @@ -121,13 +116,13 @@ func (s *Server) Start(conn net.PacketConn) error { return nil } -func (s *Server) Close() error { +func (s *Service[U]) Close() error { return common.Close( s.quicListener, ) } -func (s *Server) loopConnections(listener qtls.Listener) { +func (s *Service[U]) loopConnections(listener qtls.Listener) { for { connection, err := listener.Accept(s.ctx) if err != nil { @@ -142,9 +137,9 @@ func (s *Server) loopConnections(listener qtls.Listener) { } } -func (s *Server) handleConnection(connection quic.Connection) { - session := &serverSession{ - Server: s, +func (s *Service[U]) handleConnection(connection quic.Connection) { + session := &serverSession[U]{ + Service: s, ctx: s.ctx, quicConn: connection, source: M.SocksaddrFromNet(connection.RemoteAddr()), @@ -159,8 +154,8 @@ func (s *Server) handleConnection(connection quic.Connection) { _ = connection.CloseWithError(0, "") } -type serverSession struct { - *Server +type serverSession[U comparable] struct { + *Service[U] ctx context.Context quicConn quic.Connection source M.Socksaddr @@ -168,12 +163,12 @@ type serverSession struct { connDone chan struct{} connErr error authenticated bool - authUser *User + authUser U udpAccess sync.RWMutex udpConnMap map[uint32]*udpPacketConn } -func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { if s.authenticated { protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ @@ -190,7 +185,7 @@ func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.masqueradeHandler.ServeHTTP(w, r) return } - s.authUser = &user + s.authUser = user s.authenticated = true if !s.ignoreClientBandwidth && request.Rx > 0 { var sendBps uint64 @@ -231,7 +226,7 @@ func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) { +func (s *serverSession[U]) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) { if !s.authenticated || err != nil { return false, nil } @@ -251,15 +246,12 @@ func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic return true, nil } -func (s *serverSession) handleStream(stream quic.Stream) error { +func (s *serverSession[U]) handleStream(stream quic.Stream) error { destinationString, err := protocol.ReadTCPRequest(stream) if err != nil { return E.New("read TCP request") } - ctx := s.ctx - if s.authUser.Name != "" { - ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) - } + ctx := auth.ContextWithUser(s.ctx, s.authUser) _ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{ Source: s.source, Destination: M.ParseSocksaddr(destinationString), @@ -267,7 +259,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error { return nil } -func (s *serverSession) closeWithError(err error) { +func (s *serverSession[U]) closeWithError(err error) { s.connAccess.Lock() defer s.connAccess.Unlock() select { diff --git a/hysteria2/server_packet.go b/hysteria2/service_packet.go similarity index 87% rename from hysteria2/server_packet.go rename to hysteria2/service_packet.go index d84b592..d15d619 100644 --- a/hysteria2/server_packet.go +++ b/hysteria2/service_packet.go @@ -6,7 +6,7 @@ import ( M "github.com/sagernet/sing/common/metadata" ) -func (s *serverSession) loopMessages() { +func (s *serverSession[U]) loopMessages() { for { message, err := s.quicConn.ReceiveMessage(s.ctx) if err != nil { @@ -21,7 +21,7 @@ func (s *serverSession) loopMessages() { } } -func (s *serverSession) handleMessage(data []byte) error { +func (s *serverSession[U]) handleMessage(data []byte) error { message := allocMessage() err := decodeUDPMessage(message, data) if err != nil { @@ -32,7 +32,7 @@ func (s *serverSession) handleMessage(data []byte) error { return nil } -func (s *serverSession) handleUDPMessage(message *udpMessage) { +func (s *serverSession[U]) handleUDPMessage(message *udpMessage) { s.udpAccess.RLock() udpConn, loaded := s.udpConnMap[message.sessionID] s.udpAccess.RUnlock() diff --git a/tuic/server.go b/tuic/service.go similarity index 85% rename from tuic/server.go rename to tuic/service.go index 54cc74e..2186307 100644 --- a/tuic/server.go +++ b/tuic/service.go @@ -25,44 +25,38 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -type ServerOptions struct { +type ServiceOptions struct { Context context.Context Logger logger.Logger TLSConfig aTLS.ServerConfig - Users []User CongestionControl string AuthTimeout time.Duration ZeroRTTHandshake bool Heartbeat time.Duration - Handler ServerHandler + Handler ServiceHandler } -type User struct { - Name string - UUID [16]byte - Password string -} - -type ServerHandler interface { +type ServiceHandler interface { N.TCPConnectionHandler N.UDPConnectionHandler } -type Server struct { +type Service[U comparable] struct { ctx context.Context logger logger.Logger tlsConfig aTLS.ServerConfig heartbeat time.Duration quicConfig *quic.Config - userMap map[[16]byte]User + userMap map[[16]byte]U + passwordMap map[U]string congestionControl string authTimeout time.Duration - handler ServerHandler + handler ServiceHandler quicListener io.Closer } -func NewServer(options ServerOptions) (*Server, error) { +func NewService[U comparable](options ServiceOptions) (*Service[U], error) { if options.AuthTimeout == 0 { options.AuthTimeout = 3 * time.Second } @@ -84,27 +78,31 @@ func NewServer(options ServerOptions) (*Server, error) { default: return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) } - if len(options.Users) == 0 { - return nil, E.New("missing users") - } - userMap := make(map[[16]byte]User) - for _, user := range options.Users { - userMap[user.UUID] = user - } - return &Server{ + return &Service[U]{ ctx: options.Context, logger: options.Logger, tlsConfig: options.TLSConfig, heartbeat: options.Heartbeat, quicConfig: quicConfig, - userMap: userMap, + userMap: make(map[[16]byte]U), congestionControl: options.CongestionControl, authTimeout: options.AuthTimeout, handler: options.Handler, }, nil } -func (s *Server) Start(conn net.PacketConn) error { +func (s *Service[U]) UpdateUsers(userList []U, uuidList [][16]byte, passwordList []string) { + userMap := make(map[[16]byte]U) + passwordMap := make(map[U]string) + for index := range userList { + userMap[uuidList[index]] = userList[index] + passwordMap[userList[index]] = passwordList[index] + } + s.userMap = userMap + s.passwordMap = passwordMap +} + +func (s *Service[U]) Start(conn net.PacketConn) error { if !s.quicConfig.Allow0RTT { listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) if err != nil { @@ -149,16 +147,16 @@ func (s *Server) Start(conn net.PacketConn) error { return nil } -func (s *Server) Close() error { +func (s *Service[U]) Close() error { return common.Close( s.quicListener, ) } -func (s *Server) handleConnection(connection quic.Connection) { +func (s *Service[U]) handleConnection(connection quic.Connection) { setCongestion(s.ctx, connection, s.congestionControl) - session := &serverSession{ - Server: s, + session := &serverSession[U]{ + Service: s, ctx: s.ctx, quicConn: connection, source: M.SocksaddrFromNet(connection.RemoteAddr()), @@ -169,8 +167,8 @@ func (s *Server) handleConnection(connection quic.Connection) { session.handle() } -type serverSession struct { - *Server +type serverSession[U comparable] struct { + *Service[U] ctx context.Context quicConn quic.Connection source M.Socksaddr @@ -178,12 +176,12 @@ type serverSession struct { connDone chan struct{} connErr error authDone chan struct{} - authUser *User + authUser U udpAccess sync.RWMutex udpConnMap map[uint16]*udpPacketConn } -func (s *serverSession) handle() { +func (s *serverSession[U]) handle() { if s.ctx.Done() != nil { go func() { select { @@ -200,7 +198,7 @@ func (s *serverSession) handle() { go s.loopHeartbeats() } -func (s *serverSession) loopUniStreams() { +func (s *serverSession[U]) loopUniStreams() { for { uniStream, err := s.quicConn.AcceptUniStream(s.ctx) if err != nil { @@ -215,7 +213,7 @@ func (s *serverSession) loopUniStreams() { } } -func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { +func (s *serverSession[U]) handleUniStream(stream quic.ReceiveStream) error { defer stream.CancelRead(0) buffer := buf.New() defer buffer.Release() @@ -248,14 +246,14 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { return E.New("authentication: unknown user ", userUUID) } handshakeState := s.quicConn.ConnectionState() - tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32) + tuicToken, err := handshakeState.ExportKeyingMaterial(string(userUUID[:]), []byte(s.passwordMap[user]), 32) if err != nil { return E.Cause(err, "authentication: export keying material") } if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) { return E.New("authentication: token mismatch") } - s.authUser = &user + s.authUser = user close(s.authDone) return nil case CommandPacket: @@ -301,7 +299,7 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error { } } -func (s *serverSession) handleAuthTimeout() { +func (s *serverSession[U]) handleAuthTimeout() { select { case <-s.connDone: case <-s.authDone: @@ -310,7 +308,7 @@ func (s *serverSession) handleAuthTimeout() { } } -func (s *serverSession) loopStreams() { +func (s *serverSession[U]) loopStreams() { for { stream, err := s.quicConn.AcceptStream(s.ctx) if err != nil { @@ -327,7 +325,7 @@ func (s *serverSession) loopStreams() { } } -func (s *serverSession) handleStream(stream quic.Stream) error { +func (s *serverSession[U]) handleStream(stream quic.Stream) error { buffer := buf.NewSize(2 + M.MaxSocksaddrLength) defer buffer.Release() _, err := buffer.ReadAtLeastFrom(stream, 2) @@ -360,10 +358,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error { } else { conn = bufio.NewCachedConn(conn, buffer) } - ctx := s.ctx - if s.authUser.Name != "" { - ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) - } + ctx := auth.ContextWithUser(s.ctx, s.authUser) _ = s.handler.NewConnection(ctx, conn, M.Metadata{ Source: s.source, Destination: destination, @@ -371,7 +366,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error { return nil } -func (s *serverSession) loopHeartbeats() { +func (s *serverSession[U]) loopHeartbeats() { ticker := time.NewTicker(s.heartbeat) defer ticker.Stop() for { @@ -387,7 +382,7 @@ func (s *serverSession) loopHeartbeats() { } } -func (s *serverSession) closeWithError(err error) { +func (s *serverSession[U]) closeWithError(err error) { s.connAccess.Lock() defer s.connAccess.Unlock() select { diff --git a/tuic/server_packet.go b/tuic/service_packet.go similarity index 89% rename from tuic/server_packet.go rename to tuic/service_packet.go index d05c7bf..c1bca31 100644 --- a/tuic/server_packet.go +++ b/tuic/service_packet.go @@ -6,7 +6,7 @@ import ( M "github.com/sagernet/sing/common/metadata" ) -func (s *serverSession) loopMessages() { +func (s *serverSession[U]) loopMessages() { select { case <-s.connDone: return @@ -26,7 +26,7 @@ func (s *serverSession) loopMessages() { } } -func (s *serverSession) handleMessage(data []byte) error { +func (s *serverSession[U]) handleMessage(data []byte) error { if len(data) < 2 { return E.New("invalid message") } @@ -50,7 +50,7 @@ func (s *serverSession) handleMessage(data []byte) error { } } -func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) { +func (s *serverSession[U]) handleUDPMessage(message *udpMessage, udpStream bool) { s.udpAccess.RLock() udpConn, loaded := s.udpConnMap[message.sessionID] s.udpAccess.RUnlock()