mirror of
https://github.com/SagerNet/sing-quic.git
synced 2025-04-03 20:07:39 +03:00
Improve server API
This commit is contained in:
parent
98205e7e79
commit
b55f3531e7
4 changed files with 75 additions and 88 deletions
|
@ -26,7 +26,7 @@ import (
|
|||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
type ServiceOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.Logger
|
||||
SendBPS uint64
|
||||
|
@ -34,23 +34,17 @@ type ServerOptions struct {
|
|||
IgnoreClientBandwidth bool
|
||||
SalamanderPassword string
|
||||
TLSConfig aTLS.ServerConfig
|
||||
Users []User
|
||||
UDPDisabled bool
|
||||
Handler ServerHandler
|
||||
MasqueradeHandler http.Handler
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Password string
|
||||
}
|
||||
|
||||
type ServerHandler interface {
|
||||
N.TCPConnectionHandler
|
||||
N.UDPConnectionHandler
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
type Service[U comparable] struct {
|
||||
ctx context.Context
|
||||
logger logger.Logger
|
||||
sendBPS uint64
|
||||
|
@ -59,14 +53,14 @@ type Server struct {
|
|||
salamanderPassword string
|
||||
tlsConfig aTLS.ServerConfig
|
||||
quicConfig *quic.Config
|
||||
userMap map[string]User
|
||||
userMap map[string]U
|
||||
udpDisabled bool
|
||||
handler ServerHandler
|
||||
masqueradeHandler http.Handler
|
||||
quicListener io.Closer
|
||||
}
|
||||
|
||||
func NewServer(options ServerOptions) (*Server, error) {
|
||||
func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
|
||||
quicConfig := &quic.Config{
|
||||
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
|
||||
EnableDatagrams: !options.UDPDisabled,
|
||||
|
@ -78,17 +72,10 @@ func NewServer(options ServerOptions) (*Server, error) {
|
|||
MaxIdleTimeout: defaultMaxIdleTimeout,
|
||||
KeepAlivePeriod: defaultKeepAlivePeriod,
|
||||
}
|
||||
if len(options.Users) == 0 {
|
||||
return nil, E.New("missing users")
|
||||
}
|
||||
userMap := make(map[string]User)
|
||||
for _, user := range options.Users {
|
||||
userMap[user.Password] = user
|
||||
}
|
||||
if options.MasqueradeHandler == nil {
|
||||
options.MasqueradeHandler = http.NotFoundHandler()
|
||||
}
|
||||
return &Server{
|
||||
return &Service[U]{
|
||||
ctx: options.Context,
|
||||
logger: options.Logger,
|
||||
sendBPS: options.SendBPS,
|
||||
|
@ -97,14 +84,22 @@ func NewServer(options ServerOptions) (*Server, error) {
|
|||
salamanderPassword: options.SalamanderPassword,
|
||||
tlsConfig: options.TLSConfig,
|
||||
quicConfig: quicConfig,
|
||||
userMap: userMap,
|
||||
userMap: make(map[string]U),
|
||||
udpDisabled: options.UDPDisabled,
|
||||
handler: options.Handler,
|
||||
masqueradeHandler: options.MasqueradeHandler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start(conn net.PacketConn) error {
|
||||
func (s *Service[U]) UpdateUsers(userList []U, passwordList []string) {
|
||||
userMap := make(map[string]U)
|
||||
for i, user := range userList {
|
||||
userMap[passwordList[i]] = user
|
||||
}
|
||||
s.userMap = userMap
|
||||
}
|
||||
|
||||
func (s *Service[U]) Start(conn net.PacketConn) error {
|
||||
if s.salamanderPassword != "" {
|
||||
conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
|
||||
}
|
||||
|
@ -121,13 +116,13 @@ func (s *Server) Start(conn net.PacketConn) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
func (s *Service[U]) Close() error {
|
||||
return common.Close(
|
||||
s.quicListener,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) loopConnections(listener qtls.Listener) {
|
||||
func (s *Service[U]) loopConnections(listener qtls.Listener) {
|
||||
for {
|
||||
connection, err := listener.Accept(s.ctx)
|
||||
if err != nil {
|
||||
|
@ -142,9 +137,9 @@ func (s *Server) loopConnections(listener qtls.Listener) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(connection quic.Connection) {
|
||||
session := &serverSession{
|
||||
Server: s,
|
||||
func (s *Service[U]) handleConnection(connection quic.Connection) {
|
||||
session := &serverSession[U]{
|
||||
Service: s,
|
||||
ctx: s.ctx,
|
||||
quicConn: connection,
|
||||
source: M.SocksaddrFromNet(connection.RemoteAddr()),
|
||||
|
@ -159,8 +154,8 @@ func (s *Server) handleConnection(connection quic.Connection) {
|
|||
_ = connection.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
type serverSession struct {
|
||||
*Server
|
||||
type serverSession[U comparable] struct {
|
||||
*Service[U]
|
||||
ctx context.Context
|
||||
quicConn quic.Connection
|
||||
source M.Socksaddr
|
||||
|
@ -168,12 +163,12 @@ type serverSession struct {
|
|||
connDone chan struct{}
|
||||
connErr error
|
||||
authenticated bool
|
||||
authUser *User
|
||||
authUser U
|
||||
udpAccess sync.RWMutex
|
||||
udpConnMap map[uint32]*udpPacketConn
|
||||
}
|
||||
|
||||
func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
|
||||
if s.authenticated {
|
||||
protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
|
||||
|
@ -190,7 +185,7 @@ func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
s.masqueradeHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
s.authUser = &user
|
||||
s.authUser = user
|
||||
s.authenticated = true
|
||||
if !s.ignoreClientBandwidth && request.Rx > 0 {
|
||||
var sendBps uint64
|
||||
|
@ -231,7 +226,7 @@ func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
|
||||
func (s *serverSession[U]) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
|
||||
if !s.authenticated || err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -251,15 +246,12 @@ func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func (s *serverSession) handleStream(stream quic.Stream) error {
|
||||
func (s *serverSession[U]) handleStream(stream quic.Stream) error {
|
||||
destinationString, err := protocol.ReadTCPRequest(stream)
|
||||
if err != nil {
|
||||
return E.New("read TCP request")
|
||||
}
|
||||
ctx := s.ctx
|
||||
if s.authUser.Name != "" {
|
||||
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
|
||||
}
|
||||
ctx := auth.ContextWithUser(s.ctx, s.authUser)
|
||||
_ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{
|
||||
Source: s.source,
|
||||
Destination: M.ParseSocksaddr(destinationString),
|
||||
|
@ -267,7 +259,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *serverSession) closeWithError(err error) {
|
||||
func (s *serverSession[U]) closeWithError(err error) {
|
||||
s.connAccess.Lock()
|
||||
defer s.connAccess.Unlock()
|
||||
select {
|
|
@ -6,7 +6,7 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func (s *serverSession) loopMessages() {
|
||||
func (s *serverSession[U]) loopMessages() {
|
||||
for {
|
||||
message, err := s.quicConn.ReceiveMessage(s.ctx)
|
||||
if err != nil {
|
||||
|
@ -21,7 +21,7 @@ func (s *serverSession) loopMessages() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleMessage(data []byte) error {
|
||||
func (s *serverSession[U]) handleMessage(data []byte) error {
|
||||
message := allocMessage()
|
||||
err := decodeUDPMessage(message, data)
|
||||
if err != nil {
|
||||
|
@ -32,7 +32,7 @@ func (s *serverSession) handleMessage(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *serverSession) handleUDPMessage(message *udpMessage) {
|
||||
func (s *serverSession[U]) handleUDPMessage(message *udpMessage) {
|
||||
s.udpAccess.RLock()
|
||||
udpConn, loaded := s.udpConnMap[message.sessionID]
|
||||
s.udpAccess.RUnlock()
|
|
@ -25,44 +25,38 @@ import (
|
|||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
)
|
||||
|
||||
type ServerOptions struct {
|
||||
type ServiceOptions struct {
|
||||
Context context.Context
|
||||
Logger logger.Logger
|
||||
TLSConfig aTLS.ServerConfig
|
||||
Users []User
|
||||
CongestionControl string
|
||||
AuthTimeout time.Duration
|
||||
ZeroRTTHandshake bool
|
||||
Heartbeat time.Duration
|
||||
Handler ServerHandler
|
||||
Handler ServiceHandler
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
UUID [16]byte
|
||||
Password string
|
||||
}
|
||||
|
||||
type ServerHandler interface {
|
||||
type ServiceHandler interface {
|
||||
N.TCPConnectionHandler
|
||||
N.UDPConnectionHandler
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
type Service[U comparable] struct {
|
||||
ctx context.Context
|
||||
logger logger.Logger
|
||||
tlsConfig aTLS.ServerConfig
|
||||
heartbeat time.Duration
|
||||
quicConfig *quic.Config
|
||||
userMap map[[16]byte]User
|
||||
userMap map[[16]byte]U
|
||||
passwordMap map[U]string
|
||||
congestionControl string
|
||||
authTimeout time.Duration
|
||||
handler ServerHandler
|
||||
handler ServiceHandler
|
||||
|
||||
quicListener io.Closer
|
||||
}
|
||||
|
||||
func NewServer(options ServerOptions) (*Server, error) {
|
||||
func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
|
||||
if options.AuthTimeout == 0 {
|
||||
options.AuthTimeout = 3 * time.Second
|
||||
}
|
||||
|
@ -84,27 +78,31 @@ func NewServer(options ServerOptions) (*Server, error) {
|
|||
default:
|
||||
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
|
||||
}
|
||||
if len(options.Users) == 0 {
|
||||
return nil, E.New("missing users")
|
||||
}
|
||||
userMap := make(map[[16]byte]User)
|
||||
for _, user := range options.Users {
|
||||
userMap[user.UUID] = user
|
||||
}
|
||||
return &Server{
|
||||
return &Service[U]{
|
||||
ctx: options.Context,
|
||||
logger: options.Logger,
|
||||
tlsConfig: options.TLSConfig,
|
||||
heartbeat: options.Heartbeat,
|
||||
quicConfig: quicConfig,
|
||||
userMap: userMap,
|
||||
userMap: make(map[[16]byte]U),
|
||||
congestionControl: options.CongestionControl,
|
||||
authTimeout: options.AuthTimeout,
|
||||
handler: options.Handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start(conn net.PacketConn) error {
|
||||
func (s *Service[U]) UpdateUsers(userList []U, uuidList [][16]byte, passwordList []string) {
|
||||
userMap := make(map[[16]byte]U)
|
||||
passwordMap := make(map[U]string)
|
||||
for index := range userList {
|
||||
userMap[uuidList[index]] = userList[index]
|
||||
passwordMap[userList[index]] = passwordList[index]
|
||||
}
|
||||
s.userMap = userMap
|
||||
s.passwordMap = passwordMap
|
||||
}
|
||||
|
||||
func (s *Service[U]) Start(conn net.PacketConn) error {
|
||||
if !s.quicConfig.Allow0RTT {
|
||||
listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
|
||||
if err != nil {
|
||||
|
@ -149,16 +147,16 @@ func (s *Server) Start(conn net.PacketConn) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
func (s *Service[U]) Close() error {
|
||||
return common.Close(
|
||||
s.quicListener,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(connection quic.Connection) {
|
||||
func (s *Service[U]) handleConnection(connection quic.Connection) {
|
||||
setCongestion(s.ctx, connection, s.congestionControl)
|
||||
session := &serverSession{
|
||||
Server: s,
|
||||
session := &serverSession[U]{
|
||||
Service: s,
|
||||
ctx: s.ctx,
|
||||
quicConn: connection,
|
||||
source: M.SocksaddrFromNet(connection.RemoteAddr()),
|
||||
|
@ -169,8 +167,8 @@ func (s *Server) handleConnection(connection quic.Connection) {
|
|||
session.handle()
|
||||
}
|
||||
|
||||
type serverSession struct {
|
||||
*Server
|
||||
type serverSession[U comparable] struct {
|
||||
*Service[U]
|
||||
ctx context.Context
|
||||
quicConn quic.Connection
|
||||
source M.Socksaddr
|
||||
|
@ -178,12 +176,12 @@ type serverSession struct {
|
|||
connDone chan struct{}
|
||||
connErr error
|
||||
authDone chan struct{}
|
||||
authUser *User
|
||||
authUser U
|
||||
udpAccess sync.RWMutex
|
||||
udpConnMap map[uint16]*udpPacketConn
|
||||
}
|
||||
|
||||
func (s *serverSession) handle() {
|
||||
func (s *serverSession[U]) handle() {
|
||||
if s.ctx.Done() != nil {
|
||||
go func() {
|
||||
select {
|
||||
|
@ -200,7 +198,7 @@ func (s *serverSession) handle() {
|
|||
go s.loopHeartbeats()
|
||||
}
|
||||
|
||||
func (s *serverSession) loopUniStreams() {
|
||||
func (s *serverSession[U]) loopUniStreams() {
|
||||
for {
|
||||
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
|
||||
if err != nil {
|
||||
|
@ -215,7 +213,7 @@ func (s *serverSession) loopUniStreams() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
|
||||
func (s *serverSession[U]) handleUniStream(stream quic.ReceiveStream) error {
|
||||
defer stream.CancelRead(0)
|
||||
buffer := buf.New()
|
||||
defer buffer.Release()
|
||||
|
@ -248,14 +246,14 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
|
|||
return E.New("authentication: unknown user ", userUUID)
|
||||
}
|
||||
handshakeState := s.quicConn.ConnectionState()
|
||||
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
|
||||
tuicToken, err := handshakeState.ExportKeyingMaterial(string(userUUID[:]), []byte(s.passwordMap[user]), 32)
|
||||
if err != nil {
|
||||
return E.Cause(err, "authentication: export keying material")
|
||||
}
|
||||
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
|
||||
return E.New("authentication: token mismatch")
|
||||
}
|
||||
s.authUser = &user
|
||||
s.authUser = user
|
||||
close(s.authDone)
|
||||
return nil
|
||||
case CommandPacket:
|
||||
|
@ -301,7 +299,7 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleAuthTimeout() {
|
||||
func (s *serverSession[U]) handleAuthTimeout() {
|
||||
select {
|
||||
case <-s.connDone:
|
||||
case <-s.authDone:
|
||||
|
@ -310,7 +308,7 @@ func (s *serverSession) handleAuthTimeout() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) loopStreams() {
|
||||
func (s *serverSession[U]) loopStreams() {
|
||||
for {
|
||||
stream, err := s.quicConn.AcceptStream(s.ctx)
|
||||
if err != nil {
|
||||
|
@ -327,7 +325,7 @@ func (s *serverSession) loopStreams() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleStream(stream quic.Stream) error {
|
||||
func (s *serverSession[U]) handleStream(stream quic.Stream) error {
|
||||
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
|
||||
defer buffer.Release()
|
||||
_, err := buffer.ReadAtLeastFrom(stream, 2)
|
||||
|
@ -360,10 +358,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error {
|
|||
} else {
|
||||
conn = bufio.NewCachedConn(conn, buffer)
|
||||
}
|
||||
ctx := s.ctx
|
||||
if s.authUser.Name != "" {
|
||||
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
|
||||
}
|
||||
ctx := auth.ContextWithUser(s.ctx, s.authUser)
|
||||
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
|
||||
Source: s.source,
|
||||
Destination: destination,
|
||||
|
@ -371,7 +366,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *serverSession) loopHeartbeats() {
|
||||
func (s *serverSession[U]) loopHeartbeats() {
|
||||
ticker := time.NewTicker(s.heartbeat)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
|
@ -387,7 +382,7 @@ func (s *serverSession) loopHeartbeats() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) closeWithError(err error) {
|
||||
func (s *serverSession[U]) closeWithError(err error) {
|
||||
s.connAccess.Lock()
|
||||
defer s.connAccess.Unlock()
|
||||
select {
|
|
@ -6,7 +6,7 @@ import (
|
|||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
func (s *serverSession) loopMessages() {
|
||||
func (s *serverSession[U]) loopMessages() {
|
||||
select {
|
||||
case <-s.connDone:
|
||||
return
|
||||
|
@ -26,7 +26,7 @@ func (s *serverSession) loopMessages() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleMessage(data []byte) error {
|
||||
func (s *serverSession[U]) handleMessage(data []byte) error {
|
||||
if len(data) < 2 {
|
||||
return E.New("invalid message")
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ func (s *serverSession) handleMessage(data []byte) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
|
||||
func (s *serverSession[U]) handleUDPMessage(message *udpMessage, udpStream bool) {
|
||||
s.udpAccess.RLock()
|
||||
udpConn, loaded := s.udpConnMap[message.sessionID]
|
||||
s.udpAccess.RUnlock()
|
Loading…
Add table
Add a link
Reference in a new issue