diff --git a/service.go b/service.go index 9431fe2..7a39088 100644 --- a/service.go +++ b/service.go @@ -9,6 +9,7 @@ import ( "os" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/debug" @@ -20,12 +21,23 @@ import ( ) type ServiceConfig struct { - Version int - Password string - HandshakeServer M.Socksaddr - HandshakeDialer N.Dialer - Handler Handler - Logger logger.ContextLogger + Version int + Password string // for protocol version 2 + Users []User // for protocol version 3 + Handshake HandshakeConfig + HandshakeForServerName map[string]HandshakeConfig // for protocol version 2/3 + Handler Handler + Logger logger.ContextLogger +} + +type User struct { + Name string + Password string +} + +type HandshakeConfig struct { + Server M.Socksaddr + Dialer N.Dialer } type Handler interface { @@ -34,28 +46,39 @@ type Handler interface { } type Service struct { - version int - password string - handshakeServer M.Socksaddr - handshakeDialer N.Dialer - handler Handler - logger logger.ContextLogger + version int + password string + users []User + handshake HandshakeConfig + handshakeForServerName map[string]HandshakeConfig + handler Handler + logger logger.ContextLogger } func NewService(config ServiceConfig) (*Service, error) { service := &Service{ - version: config.Version, - password: config.Password, - handshakeServer: config.HandshakeServer, - handshakeDialer: config.HandshakeDialer, - handler: config.Handler, - logger: config.Logger, + version: config.Version, + password: config.Password, + users: config.Users, + handshake: config.Handshake, + handshakeForServerName: config.HandshakeForServerName, + handler: config.Handler, + logger: config.Logger, } - if !service.handshakeServer.IsValid() || service.handler == nil || service.logger == nil { + + if !service.handshake.Server.IsValid() { + return nil, E.New("missing default handshake information") + } + + if service.handler == nil || service.logger == nil { return nil, os.ErrInvalid } switch config.Version { - case 1, 2, 3: + case 1, 2: + case 3: + if len(service.users) == 0 { + return nil, E.New("missing users") + } default: return nil, E.New("unknown protocol version: ", config.Version) } @@ -63,15 +86,26 @@ func NewService(config ServiceConfig) (*Service, error) { return service, nil } -func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { - handshakeConn, err := s.handshakeDialer.DialContext(ctx, N.NetworkTCP, s.handshakeServer) - if err != nil { - return E.Cause(err, "server handshake") +func (s *Service) selectHandshake(clientHelloFrame *buf.Buffer) HandshakeConfig { + serverName, err := extractServerName(clientHelloFrame.Bytes()) + if err == nil { + if customHandshake, found := s.handshakeForServerName[serverName]; found { + return customHandshake + } } + return s.handshake +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { switch s.version { default: fallthrough case 1: + handshakeConn, err := s.handshake.Dialer.DialContext(ctx, N.NetworkTCP, s.handshake.Server) + if err != nil { + return E.Cause(err, "server handshake") + } + var group task.Group group.Append("client handshake", func(ctx context.Context) error { return copyUntilHandshakeFinished(handshakeConn, conn) @@ -90,10 +124,20 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M s.logger.TraceContext(ctx, "handshake finished") return s.handler.NewConnection(ctx, conn, metadata) case 2: + clientHelloFrame, err := extractFrame(conn) + if err != nil { + return E.Cause(err, "read client handshake") + } + + handshakeConfig := s.selectHandshake(clientHelloFrame) + handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server) + if err != nil { + return E.Cause(err, "server handshake") + } hashConn := newHashWriteConn(conn, s.password) go bufio.Copy(hashConn, handshakeConn) var request *buf.Buffer - request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, conn, hashConn, 2) + request, err = copyUntilHandshakeFinishedV2(ctx, s.logger, handshakeConn, bufio.NewCachedConn(conn, clientHelloFrame), hashConn, 2) if err == nil { s.logger.TraceContext(ctx, "handshake finished") handshakeConn.Close() @@ -106,21 +150,30 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M return err } case 3: - var clientHelloFrame *buf.Buffer - clientHelloFrame, err = extractFrame(conn) + clientHelloFrame, err := extractFrame(conn) if err != nil { return E.Cause(err, "read client handshake") } + + handshakeConfig := s.selectHandshake(clientHelloFrame) + handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server) + if err != nil { + return E.Cause(err, "server handshake") + } + _, err = handshakeConn.Write(clientHelloFrame.Bytes()) if err != nil { clientHelloFrame.Release() return E.Cause(err, "write client handshake") } - err = verifyClientHello(clientHelloFrame.Bytes(), s.password) + user, err := verifyClientHello(clientHelloFrame.Bytes(), s.users) if err != nil { s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed")) return bufio.CopyConn(ctx, conn, handshakeConn) } + if user.Name != "" { + ctx = auth.ContextWithUser(ctx, user.Name) + } s.logger.TraceContext(ctx, "client hello verify success") clientHelloFrame.Release() @@ -152,12 +205,12 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M if debug.Enabled { s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom)) } - hmacWrite := hmac.New(sha1.New, []byte(s.password)) + hmacWrite := hmac.New(sha1.New, []byte(user.Password)) hmacWrite.Write(serverRandom) - hmacAdd := hmac.New(sha1.New, []byte(s.password)) + hmacAdd := hmac.New(sha1.New, []byte(user.Password)) hmacAdd.Write(serverRandom) hmacAdd.Write([]byte("S")) - hmacVerify := hmac.New(sha1.New, []byte(s.password)) + hmacVerify := hmac.New(sha1.New, []byte(user.Password)) hmacVerifyReset := func() { hmacVerify.Reset() hmacVerify.Write(serverRandom) @@ -177,7 +230,7 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.M return cErr }) group.Append("server handshake relay", func(ctx context.Context) error { - cErr := copyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite) + cErr := copyByFrameWithModification(handshakeConn, conn, user.Password, serverRandom, hmacWrite) if E.IsClosedOrCanceled(cErr) && handshakeFinished { return nil } diff --git a/v3_server.go b/v3_server.go index 8dbf95b..f700206 100644 --- a/v3_server.go +++ b/v3_server.go @@ -2,8 +2,10 @@ package shadowtls import ( "bytes" + "context" "crypto/hmac" "crypto/sha1" + "crypto/tls" "encoding/binary" "hash" "io" @@ -32,26 +34,42 @@ func extractFrame(conn net.Conn) (*buf.Buffer, error) { return buffer, err } -func verifyClientHello(frame []byte, password string) error { +func extractServerName(frame []byte) (string, error) { + var hello *tls.ClientHelloInfo + err := tls.Server(bufio.NewReadOnlyConn(bytes.NewReader(frame)), &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + hello = argHello + return nil, nil + }, + }).HandshakeContext(context.Background()) + if hello != nil { + return hello.ServerName, nil + } + return "", err +} + +func verifyClientHello(frame []byte, users []User) (*User, error) { const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize if len(frame) < minLen { - return io.ErrUnexpectedEOF + return nil, io.ErrUnexpectedEOF } else if frame[0] != handshake { - return E.New("unexpected record type") + return nil, E.New("unexpected record type") } else if frame[5] != clientHello { - return E.New("unexpected handshake type") + return nil, E.New("unexpected handshake type") } else if frame[sessionIDLengthIndex] != tlsSessionIDSize { - return E.New("unexpected session id length") + return nil, E.New("unexpected session id length") } - hmacSHA1Hash := hmac.New(sha1.New, []byte(password)) - hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex]) - hmacSHA1Hash.Write(rw.ZeroBytes[:4]) - hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:]) - if !hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) { - return E.New("hmac mismatch") + for _, user := range users { + hmacSHA1Hash := hmac.New(sha1.New, []byte(user.Password)) + hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex]) + hmacSHA1Hash.Write(rw.ZeroBytes[:4]) + hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:]) + if hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) { + return &user, nil + } } - return nil + return nil, E.New("hmac mismatch") } func extractServerRandom(frame []byte) []byte {