Improve server API

This commit is contained in:
世界 2023-09-15 14:42:01 +08:00
parent 98205e7e79
commit b55f3531e7
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
4 changed files with 75 additions and 88 deletions

View file

@ -26,7 +26,7 @@ import (
aTLS "github.com/sagernet/sing/common/tls"
)
type ServerOptions struct {
type ServiceOptions struct {
Context context.Context
Logger logger.Logger
SendBPS uint64
@ -34,23 +34,17 @@ type ServerOptions struct {
IgnoreClientBandwidth bool
SalamanderPassword string
TLSConfig aTLS.ServerConfig
Users []User
UDPDisabled bool
Handler ServerHandler
MasqueradeHandler http.Handler
}
type User struct {
Name string
Password string
}
type ServerHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
}
type Server struct {
type Service[U comparable] struct {
ctx context.Context
logger logger.Logger
sendBPS uint64
@ -59,14 +53,14 @@ type Server struct {
salamanderPassword string
tlsConfig aTLS.ServerConfig
quicConfig *quic.Config
userMap map[string]User
userMap map[string]U
udpDisabled bool
handler ServerHandler
masqueradeHandler http.Handler
quicListener io.Closer
}
func NewServer(options ServerOptions) (*Server, error) {
func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
quicConfig := &quic.Config{
DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
EnableDatagrams: !options.UDPDisabled,
@ -78,17 +72,10 @@ func NewServer(options ServerOptions) (*Server, error) {
MaxIdleTimeout: defaultMaxIdleTimeout,
KeepAlivePeriod: defaultKeepAlivePeriod,
}
if len(options.Users) == 0 {
return nil, E.New("missing users")
}
userMap := make(map[string]User)
for _, user := range options.Users {
userMap[user.Password] = user
}
if options.MasqueradeHandler == nil {
options.MasqueradeHandler = http.NotFoundHandler()
}
return &Server{
return &Service[U]{
ctx: options.Context,
logger: options.Logger,
sendBPS: options.SendBPS,
@ -97,14 +84,22 @@ func NewServer(options ServerOptions) (*Server, error) {
salamanderPassword: options.SalamanderPassword,
tlsConfig: options.TLSConfig,
quicConfig: quicConfig,
userMap: userMap,
userMap: make(map[string]U),
udpDisabled: options.UDPDisabled,
handler: options.Handler,
masqueradeHandler: options.MasqueradeHandler,
}, nil
}
func (s *Server) Start(conn net.PacketConn) error {
func (s *Service[U]) UpdateUsers(userList []U, passwordList []string) {
userMap := make(map[string]U)
for i, user := range userList {
userMap[passwordList[i]] = user
}
s.userMap = userMap
}
func (s *Service[U]) Start(conn net.PacketConn) error {
if s.salamanderPassword != "" {
conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
}
@ -121,13 +116,13 @@ func (s *Server) Start(conn net.PacketConn) error {
return nil
}
func (s *Server) Close() error {
func (s *Service[U]) Close() error {
return common.Close(
s.quicListener,
)
}
func (s *Server) loopConnections(listener qtls.Listener) {
func (s *Service[U]) loopConnections(listener qtls.Listener) {
for {
connection, err := listener.Accept(s.ctx)
if err != nil {
@ -142,9 +137,9 @@ func (s *Server) loopConnections(listener qtls.Listener) {
}
}
func (s *Server) handleConnection(connection quic.Connection) {
session := &serverSession{
Server: s,
func (s *Service[U]) handleConnection(connection quic.Connection) {
session := &serverSession[U]{
Service: s,
ctx: s.ctx,
quicConn: connection,
source: M.SocksaddrFromNet(connection.RemoteAddr()),
@ -159,8 +154,8 @@ func (s *Server) handleConnection(connection quic.Connection) {
_ = connection.CloseWithError(0, "")
}
type serverSession struct {
*Server
type serverSession[U comparable] struct {
*Service[U]
ctx context.Context
quicConn quic.Connection
source M.Socksaddr
@ -168,12 +163,12 @@ type serverSession struct {
connDone chan struct{}
connErr error
authenticated bool
authUser *User
authUser U
udpAccess sync.RWMutex
udpConnMap map[uint32]*udpPacketConn
}
func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
if s.authenticated {
protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
@ -190,7 +185,7 @@ func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.masqueradeHandler.ServeHTTP(w, r)
return
}
s.authUser = &user
s.authUser = user
s.authenticated = true
if !s.ignoreClientBandwidth && request.Rx > 0 {
var sendBps uint64
@ -231,7 +226,7 @@ func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
func (s *serverSession[U]) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
if !s.authenticated || err != nil {
return false, nil
}
@ -251,15 +246,12 @@ func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic
return true, nil
}
func (s *serverSession) handleStream(stream quic.Stream) error {
func (s *serverSession[U]) handleStream(stream quic.Stream) error {
destinationString, err := protocol.ReadTCPRequest(stream)
if err != nil {
return E.New("read TCP request")
}
ctx := s.ctx
if s.authUser.Name != "" {
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
}
ctx := auth.ContextWithUser(s.ctx, s.authUser)
_ = s.handler.NewConnection(ctx, &serverConn{Stream: stream}, M.Metadata{
Source: s.source,
Destination: M.ParseSocksaddr(destinationString),
@ -267,7 +259,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error {
return nil
}
func (s *serverSession) closeWithError(err error) {
func (s *serverSession[U]) closeWithError(err error) {
s.connAccess.Lock()
defer s.connAccess.Unlock()
select {

View file

@ -6,7 +6,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
)
func (s *serverSession) loopMessages() {
func (s *serverSession[U]) loopMessages() {
for {
message, err := s.quicConn.ReceiveMessage(s.ctx)
if err != nil {
@ -21,7 +21,7 @@ func (s *serverSession) loopMessages() {
}
}
func (s *serverSession) handleMessage(data []byte) error {
func (s *serverSession[U]) handleMessage(data []byte) error {
message := allocMessage()
err := decodeUDPMessage(message, data)
if err != nil {
@ -32,7 +32,7 @@ func (s *serverSession) handleMessage(data []byte) error {
return nil
}
func (s *serverSession) handleUDPMessage(message *udpMessage) {
func (s *serverSession[U]) handleUDPMessage(message *udpMessage) {
s.udpAccess.RLock()
udpConn, loaded := s.udpConnMap[message.sessionID]
s.udpAccess.RUnlock()

View file

@ -25,44 +25,38 @@ import (
aTLS "github.com/sagernet/sing/common/tls"
)
type ServerOptions struct {
type ServiceOptions struct {
Context context.Context
Logger logger.Logger
TLSConfig aTLS.ServerConfig
Users []User
CongestionControl string
AuthTimeout time.Duration
ZeroRTTHandshake bool
Heartbeat time.Duration
Handler ServerHandler
Handler ServiceHandler
}
type User struct {
Name string
UUID [16]byte
Password string
}
type ServerHandler interface {
type ServiceHandler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
}
type Server struct {
type Service[U comparable] struct {
ctx context.Context
logger logger.Logger
tlsConfig aTLS.ServerConfig
heartbeat time.Duration
quicConfig *quic.Config
userMap map[[16]byte]User
userMap map[[16]byte]U
passwordMap map[U]string
congestionControl string
authTimeout time.Duration
handler ServerHandler
handler ServiceHandler
quicListener io.Closer
}
func NewServer(options ServerOptions) (*Server, error) {
func NewService[U comparable](options ServiceOptions) (*Service[U], error) {
if options.AuthTimeout == 0 {
options.AuthTimeout = 3 * time.Second
}
@ -84,27 +78,31 @@ func NewServer(options ServerOptions) (*Server, error) {
default:
return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
}
if len(options.Users) == 0 {
return nil, E.New("missing users")
}
userMap := make(map[[16]byte]User)
for _, user := range options.Users {
userMap[user.UUID] = user
}
return &Server{
return &Service[U]{
ctx: options.Context,
logger: options.Logger,
tlsConfig: options.TLSConfig,
heartbeat: options.Heartbeat,
quicConfig: quicConfig,
userMap: userMap,
userMap: make(map[[16]byte]U),
congestionControl: options.CongestionControl,
authTimeout: options.AuthTimeout,
handler: options.Handler,
}, nil
}
func (s *Server) Start(conn net.PacketConn) error {
func (s *Service[U]) UpdateUsers(userList []U, uuidList [][16]byte, passwordList []string) {
userMap := make(map[[16]byte]U)
passwordMap := make(map[U]string)
for index := range userList {
userMap[uuidList[index]] = userList[index]
passwordMap[userList[index]] = passwordList[index]
}
s.userMap = userMap
s.passwordMap = passwordMap
}
func (s *Service[U]) Start(conn net.PacketConn) error {
if !s.quicConfig.Allow0RTT {
listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
if err != nil {
@ -149,16 +147,16 @@ func (s *Server) Start(conn net.PacketConn) error {
return nil
}
func (s *Server) Close() error {
func (s *Service[U]) Close() error {
return common.Close(
s.quicListener,
)
}
func (s *Server) handleConnection(connection quic.Connection) {
func (s *Service[U]) handleConnection(connection quic.Connection) {
setCongestion(s.ctx, connection, s.congestionControl)
session := &serverSession{
Server: s,
session := &serverSession[U]{
Service: s,
ctx: s.ctx,
quicConn: connection,
source: M.SocksaddrFromNet(connection.RemoteAddr()),
@ -169,8 +167,8 @@ func (s *Server) handleConnection(connection quic.Connection) {
session.handle()
}
type serverSession struct {
*Server
type serverSession[U comparable] struct {
*Service[U]
ctx context.Context
quicConn quic.Connection
source M.Socksaddr
@ -178,12 +176,12 @@ type serverSession struct {
connDone chan struct{}
connErr error
authDone chan struct{}
authUser *User
authUser U
udpAccess sync.RWMutex
udpConnMap map[uint16]*udpPacketConn
}
func (s *serverSession) handle() {
func (s *serverSession[U]) handle() {
if s.ctx.Done() != nil {
go func() {
select {
@ -200,7 +198,7 @@ func (s *serverSession) handle() {
go s.loopHeartbeats()
}
func (s *serverSession) loopUniStreams() {
func (s *serverSession[U]) loopUniStreams() {
for {
uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
if err != nil {
@ -215,7 +213,7 @@ func (s *serverSession) loopUniStreams() {
}
}
func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
func (s *serverSession[U]) handleUniStream(stream quic.ReceiveStream) error {
defer stream.CancelRead(0)
buffer := buf.New()
defer buffer.Release()
@ -248,14 +246,14 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
return E.New("authentication: unknown user ", userUUID)
}
handshakeState := s.quicConn.ConnectionState()
tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
tuicToken, err := handshakeState.ExportKeyingMaterial(string(userUUID[:]), []byte(s.passwordMap[user]), 32)
if err != nil {
return E.Cause(err, "authentication: export keying material")
}
if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
return E.New("authentication: token mismatch")
}
s.authUser = &user
s.authUser = user
close(s.authDone)
return nil
case CommandPacket:
@ -301,7 +299,7 @@ func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
}
}
func (s *serverSession) handleAuthTimeout() {
func (s *serverSession[U]) handleAuthTimeout() {
select {
case <-s.connDone:
case <-s.authDone:
@ -310,7 +308,7 @@ func (s *serverSession) handleAuthTimeout() {
}
}
func (s *serverSession) loopStreams() {
func (s *serverSession[U]) loopStreams() {
for {
stream, err := s.quicConn.AcceptStream(s.ctx)
if err != nil {
@ -327,7 +325,7 @@ func (s *serverSession) loopStreams() {
}
}
func (s *serverSession) handleStream(stream quic.Stream) error {
func (s *serverSession[U]) handleStream(stream quic.Stream) error {
buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
defer buffer.Release()
_, err := buffer.ReadAtLeastFrom(stream, 2)
@ -360,10 +358,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error {
} else {
conn = bufio.NewCachedConn(conn, buffer)
}
ctx := s.ctx
if s.authUser.Name != "" {
ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
}
ctx := auth.ContextWithUser(s.ctx, s.authUser)
_ = s.handler.NewConnection(ctx, conn, M.Metadata{
Source: s.source,
Destination: destination,
@ -371,7 +366,7 @@ func (s *serverSession) handleStream(stream quic.Stream) error {
return nil
}
func (s *serverSession) loopHeartbeats() {
func (s *serverSession[U]) loopHeartbeats() {
ticker := time.NewTicker(s.heartbeat)
defer ticker.Stop()
for {
@ -387,7 +382,7 @@ func (s *serverSession) loopHeartbeats() {
}
}
func (s *serverSession) closeWithError(err error) {
func (s *serverSession[U]) closeWithError(err error) {
s.connAccess.Lock()
defer s.connAccess.Unlock()
select {

View file

@ -6,7 +6,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
)
func (s *serverSession) loopMessages() {
func (s *serverSession[U]) loopMessages() {
select {
case <-s.connDone:
return
@ -26,7 +26,7 @@ func (s *serverSession) loopMessages() {
}
}
func (s *serverSession) handleMessage(data []byte) error {
func (s *serverSession[U]) handleMessage(data []byte) error {
if len(data) < 2 {
return E.New("invalid message")
}
@ -50,7 +50,7 @@ func (s *serverSession) handleMessage(data []byte) error {
}
}
func (s *serverSession) handleUDPMessage(message *udpMessage, udpStream bool) {
func (s *serverSession[U]) handleUDPMessage(message *udpMessage, udpStream bool) {
s.udpAccess.RLock()
udpConn, loaded := s.udpConnMap[message.sessionID]
s.udpAccess.RUnlock()