From 667b08ec3e49e1a19a30046293bdf576768e8979 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 17:25:31 +0800 Subject: [PATCH] test: add tests for certloader --- app/internal/utils/certloader_test.go | 139 ++++++++++++++++++ app/internal/utils/certloader_test_gencert.py | 134 +++++++++++++++++ .../utils/certloader_test_tlsclient.py | 60 ++++++++ app/internal/utils/testcerts/.gitignore | 3 + 4 files changed, 336 insertions(+) create mode 100644 app/internal/utils/certloader_test.go create mode 100644 app/internal/utils/certloader_test_gencert.py create mode 100644 app/internal/utils/certloader_test_tlsclient.py create mode 100644 app/internal/utils/testcerts/.gitignore diff --git a/app/internal/utils/certloader_test.go b/app/internal/utils/certloader_test.go new file mode 100644 index 0000000..7c5875c --- /dev/null +++ b/app/internal/utils/certloader_test.go @@ -0,0 +1,139 @@ +package utils + +import ( + "crypto/tls" + "log" + "net/http" + "os" + "os/exec" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + testListen = "127.82.39.147:12947" + testCAFile = "./testcerts/ca" + testCertFile = "./testcerts/cert" + testKeyFile = "./testcerts/key" +) + +func TestCertificateLoaderPathError(t *testing.T) { + assert.NoError(t, os.RemoveAll(testCertFile)) + assert.NoError(t, os.RemoveAll(testKeyFile)) + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + err := loader.InitializeCache() + var pathErr *os.PathError + assert.ErrorAs(t, err, &pathErr) +} + +func TestCertificateLoaderFullChain(t *testing.T) { + assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.Error(t, runTestTLSClient("unmatched-sni.example.com")) + assert.Error(t, runTestTLSClient("")) + assert.NoError(t, runTestTLSClient("example.com")) +} + +func TestCertificateLoaderNoSAN(t *testing.T) { + assert.NoError(t, generateTestCertificate(nil, "selfsign")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardDNSSAN, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.NoError(t, runTestTLSClient("")) +} + +func TestCertificateLoaderReplaceCertificate(t *testing.T) { + assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.NoError(t, runTestTLSClient("example.com")) + assert.Error(t, runTestTLSClient("2.example.com")) + + assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain")) + + assert.Error(t, runTestTLSClient("example.com")) + assert.NoError(t, runTestTLSClient("2.example.com")) +} + +func generateTestCertificate(dnssan []string, certType string) error { + args := []string{ + "certloader_test_gencert.py", + "--ca", testCAFile, + "--cert", testCertFile, + "--key", testKeyFile, + "--type", certType, + } + if len(dnssan) > 0 { + args = append(args, "--dnssan", strings.Join(dnssan, ",")) + } + cmd := exec.Command("python3", args...) + out, err := cmd.CombinedOutput() + if err != nil { + log.Printf("Failed to generate test certificate: %s", out) + return err + } + return nil +} + +func runTestTLSClient(sni string) error { + args := []string{ + "certloader_test_tlsclient.py", + "--server", testListen, + "--ca", testCAFile, + } + if sni != "" { + args = append(args, "--sni", sni) + } + cmd := exec.Command("python3", args...) + out, err := cmd.CombinedOutput() + if err != nil { + log.Printf("Failed to run test TLS client: %s", out) + return err + } + return nil +} diff --git a/app/internal/utils/certloader_test_gencert.py b/app/internal/utils/certloader_test_gencert.py new file mode 100644 index 0000000..d4d5695 --- /dev/null +++ b/app/internal/utils/certloader_test_gencert.py @@ -0,0 +1,134 @@ +import argparse +import datetime +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption + + +def create_key(): + return ec.generate_private_key(ec.SECP256R1()) + + +def create_certificate(cert_type, subject, issuer, private_key, public_key, dns_san=None): + serial_number = x509.random_serial_number() + not_valid_before = datetime.datetime.now(datetime.UTC) + not_valid_after = not_valid_before + datetime.timedelta(days=365) + + subject_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, subject.get('C', 'ZZ')), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject.get('O', 'No Organization')), + x509.NameAttribute(NameOID.COMMON_NAME, subject.get('CN', 'No CommonName')), + ]) + issuer_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, issuer.get('C', 'ZZ')), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, issuer.get('O', 'No Organization')), + x509.NameAttribute(NameOID.COMMON_NAME, issuer.get('CN', 'No CommonName')), + ]) + builder = x509.CertificateBuilder() + builder = builder.subject_name(subject_name) + builder = builder.issuer_name(issuer_name) + builder = builder.public_key(public_key) + builder = builder.serial_number(serial_number) + builder = builder.not_valid_before(not_valid_before) + builder = builder.not_valid_after(not_valid_after) + if cert_type == 'root': + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + elif cert_type == 'intermediate': + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=0), critical=True + ) + elif cert_type == 'leaf': + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True + ) + else: + raise ValueError(f'Invalid cert_type: {cert_type}') + if dns_san: + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName(d) for d in dns_san.split(',')]), + critical=False + ) + return builder.sign(private_key=private_key, algorithm=hashes.SHA256()) + + +def main(): + parser = argparse.ArgumentParser(description='Generate HTTPS server certificate.') + parser.add_argument('--ca', required=True, + help='Path to write the X509 CA certificate in PEM format') + parser.add_argument('--cert', required=True, + help='Path to write the X509 certificate in PEM format') + parser.add_argument('--key', required=True, + help='Path to write the private key in PEM format') + parser.add_argument('--dnssan', required=False, default=None, + help='Comma-separated list of DNS SANs') + parser.add_argument('--type', required=True, choices=['selfsign', 'fullchain'], + help='Type of certificate to generate') + + args = parser.parse_args() + + key = create_key() + public_key = key.public_key() + + if args.type == 'selfsign': + subject = {"C": "ZZ", "O": "Certificate", "CN": "Certificate"} + cert = create_certificate( + cert_type='root', + subject=subject, + issuer=subject, + private_key=key, + public_key=public_key, + dns_san=args.dnssan) + with open(args.ca, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + with open(args.cert, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + with open(args.key, 'wb') as f: + f.write( + key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption())) + + elif args.type == 'fullchain': + ca_key = create_key() + ca_public_key = ca_key.public_key() + ca_subject = {"C": "ZZ", "O": "Root CA", "CN": "Root CA"} + ca_cert = create_certificate( + cert_type='root', + subject=ca_subject, + issuer=ca_subject, + private_key=ca_key, + public_key=ca_public_key) + + intermediate_key = create_key() + intermediate_public_key = intermediate_key.public_key() + intermediate_subject = {"C": "ZZ", "O": "Intermediate CA", "CN": "Intermediate CA"} + intermediate_cert = create_certificate( + cert_type='intermediate', + subject=intermediate_subject, + issuer=ca_subject, + private_key=ca_key, + public_key=intermediate_public_key) + + leaf_subject = {"C": "ZZ", "O": "Leaf Certificate", "CN": "Leaf Certificate"} + cert = create_certificate( + cert_type='leaf', + subject=leaf_subject, + issuer=intermediate_subject, + private_key=intermediate_key, + public_key=public_key, + dns_san=args.dnssan) + + with open(args.ca, 'wb') as f: + f.write(ca_cert.public_bytes(Encoding.PEM)) + with open(args.cert, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + f.write(intermediate_cert.public_bytes(Encoding.PEM)) + with open(args.key, 'wb') as f: + f.write( + key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption())) + + +if __name__ == "__main__": + main() diff --git a/app/internal/utils/certloader_test_tlsclient.py b/app/internal/utils/certloader_test_tlsclient.py new file mode 100644 index 0000000..3b7efd6 --- /dev/null +++ b/app/internal/utils/certloader_test_tlsclient.py @@ -0,0 +1,60 @@ +import argparse +import ssl +import socket +import sys + + +def check_tls(server, ca_cert, sni, alpn): + try: + host, port = server.split(":") + port = int(port) + + if ca_cert: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=ca_cert) + context.check_hostname = sni is not None + context.verify_mode = ssl.CERT_REQUIRED + else: + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + if alpn: + context.set_alpn_protocols([p for p in alpn.split(",")]) + + with socket.create_connection((host, port)) as sock: + with context.wrap_socket(sock, server_hostname=sni) as ssock: + # Verify handshake and certificate + print(f'Connected to {ssock.version()} using {ssock.cipher()}') + print(f'Server certificate validated and details: {ssock.getpeercert()}') + print("OK") + return 0 + except Exception as e: + print(f"Error: {e}") + return 1 + + +def main(): + parser = argparse.ArgumentParser(description="Test TLS Server") + parser.add_argument("--server", required=True, + help="Server address to test (e.g., 127.1.2.3:8443)") + parser.add_argument("--ca", required=False, default=None, + help="CA certificate file used to validate the server certificate" + "Omit to use insecure connection") + parser.add_argument("--sni", required=False, default=None, + help="SNI to send in ClientHello") + parser.add_argument("--alpn", required=False, default='h2', + help="ALPN to send in ClientHello") + + args = parser.parse_args() + + exit_status = check_tls( + server=args.server, + ca_cert=args.ca, + sni=args.sni, + alpn=args.alpn) + + sys.exit(exit_status) + + +if __name__ == "__main__": + main() diff --git a/app/internal/utils/testcerts/.gitignore b/app/internal/utils/testcerts/.gitignore new file mode 100644 index 0000000..082821a --- /dev/null +++ b/app/internal/utils/testcerts/.gitignore @@ -0,0 +1,3 @@ +# This directory is used for certificate generation in certloader_test.go +/* +!/.gitignore