test: add tests for certloader

This commit is contained in:
Haruue 2024-08-24 17:25:31 +08:00
parent bcf830c29a
commit 667b08ec3e
No known key found for this signature in database
GPG key ID: F6083B28CBCBC148
4 changed files with 336 additions and 0 deletions

View file

@ -0,0 +1,139 @@
package utils
import (
"crypto/tls"
"log"
"net/http"
"os"
"os/exec"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
const (
testListen = "127.82.39.147:12947"
testCAFile = "./testcerts/ca"
testCertFile = "./testcerts/cert"
testKeyFile = "./testcerts/key"
)
func TestCertificateLoaderPathError(t *testing.T) {
assert.NoError(t, os.RemoveAll(testCertFile))
assert.NoError(t, os.RemoveAll(testKeyFile))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardStrict,
}
err := loader.InitializeCache()
var pathErr *os.PathError
assert.ErrorAs(t, err, &pathErr)
}
func TestCertificateLoaderFullChain(t *testing.T) {
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardStrict,
}
assert.NoError(t, loader.InitializeCache())
lis, err := tls.Listen("tcp", testListen, &tls.Config{
GetCertificate: loader.GetCertificate,
})
assert.NoError(t, err)
defer lis.Close()
go http.Serve(lis, nil)
assert.Error(t, runTestTLSClient("unmatched-sni.example.com"))
assert.Error(t, runTestTLSClient(""))
assert.NoError(t, runTestTLSClient("example.com"))
}
func TestCertificateLoaderNoSAN(t *testing.T) {
assert.NoError(t, generateTestCertificate(nil, "selfsign"))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardDNSSAN,
}
assert.NoError(t, loader.InitializeCache())
lis, err := tls.Listen("tcp", testListen, &tls.Config{
GetCertificate: loader.GetCertificate,
})
assert.NoError(t, err)
defer lis.Close()
go http.Serve(lis, nil)
assert.NoError(t, runTestTLSClient(""))
}
func TestCertificateLoaderReplaceCertificate(t *testing.T) {
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardStrict,
}
assert.NoError(t, loader.InitializeCache())
lis, err := tls.Listen("tcp", testListen, &tls.Config{
GetCertificate: loader.GetCertificate,
})
assert.NoError(t, err)
defer lis.Close()
go http.Serve(lis, nil)
assert.NoError(t, runTestTLSClient("example.com"))
assert.Error(t, runTestTLSClient("2.example.com"))
assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain"))
assert.Error(t, runTestTLSClient("example.com"))
assert.NoError(t, runTestTLSClient("2.example.com"))
}
func generateTestCertificate(dnssan []string, certType string) error {
args := []string{
"certloader_test_gencert.py",
"--ca", testCAFile,
"--cert", testCertFile,
"--key", testKeyFile,
"--type", certType,
}
if len(dnssan) > 0 {
args = append(args, "--dnssan", strings.Join(dnssan, ","))
}
cmd := exec.Command("python3", args...)
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to generate test certificate: %s", out)
return err
}
return nil
}
func runTestTLSClient(sni string) error {
args := []string{
"certloader_test_tlsclient.py",
"--server", testListen,
"--ca", testCAFile,
}
if sni != "" {
args = append(args, "--sni", sni)
}
cmd := exec.Command("python3", args...)
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to run test TLS client: %s", out)
return err
}
return nil
}

View file

@ -0,0 +1,134 @@
import argparse
import datetime
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
def create_key():
return ec.generate_private_key(ec.SECP256R1())
def create_certificate(cert_type, subject, issuer, private_key, public_key, dns_san=None):
serial_number = x509.random_serial_number()
not_valid_before = datetime.datetime.now(datetime.UTC)
not_valid_after = not_valid_before + datetime.timedelta(days=365)
subject_name = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, subject.get('C', 'ZZ')),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject.get('O', 'No Organization')),
x509.NameAttribute(NameOID.COMMON_NAME, subject.get('CN', 'No CommonName')),
])
issuer_name = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, issuer.get('C', 'ZZ')),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, issuer.get('O', 'No Organization')),
x509.NameAttribute(NameOID.COMMON_NAME, issuer.get('CN', 'No CommonName')),
])
builder = x509.CertificateBuilder()
builder = builder.subject_name(subject_name)
builder = builder.issuer_name(issuer_name)
builder = builder.public_key(public_key)
builder = builder.serial_number(serial_number)
builder = builder.not_valid_before(not_valid_before)
builder = builder.not_valid_after(not_valid_after)
if cert_type == 'root':
builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True
)
elif cert_type == 'intermediate':
builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
elif cert_type == 'leaf':
builder = builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=True
)
else:
raise ValueError(f'Invalid cert_type: {cert_type}')
if dns_san:
builder = builder.add_extension(
x509.SubjectAlternativeName([x509.DNSName(d) for d in dns_san.split(',')]),
critical=False
)
return builder.sign(private_key=private_key, algorithm=hashes.SHA256())
def main():
parser = argparse.ArgumentParser(description='Generate HTTPS server certificate.')
parser.add_argument('--ca', required=True,
help='Path to write the X509 CA certificate in PEM format')
parser.add_argument('--cert', required=True,
help='Path to write the X509 certificate in PEM format')
parser.add_argument('--key', required=True,
help='Path to write the private key in PEM format')
parser.add_argument('--dnssan', required=False, default=None,
help='Comma-separated list of DNS SANs')
parser.add_argument('--type', required=True, choices=['selfsign', 'fullchain'],
help='Type of certificate to generate')
args = parser.parse_args()
key = create_key()
public_key = key.public_key()
if args.type == 'selfsign':
subject = {"C": "ZZ", "O": "Certificate", "CN": "Certificate"}
cert = create_certificate(
cert_type='root',
subject=subject,
issuer=subject,
private_key=key,
public_key=public_key,
dns_san=args.dnssan)
with open(args.ca, 'wb') as f:
f.write(cert.public_bytes(Encoding.PEM))
with open(args.cert, 'wb') as f:
f.write(cert.public_bytes(Encoding.PEM))
with open(args.key, 'wb') as f:
f.write(
key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()))
elif args.type == 'fullchain':
ca_key = create_key()
ca_public_key = ca_key.public_key()
ca_subject = {"C": "ZZ", "O": "Root CA", "CN": "Root CA"}
ca_cert = create_certificate(
cert_type='root',
subject=ca_subject,
issuer=ca_subject,
private_key=ca_key,
public_key=ca_public_key)
intermediate_key = create_key()
intermediate_public_key = intermediate_key.public_key()
intermediate_subject = {"C": "ZZ", "O": "Intermediate CA", "CN": "Intermediate CA"}
intermediate_cert = create_certificate(
cert_type='intermediate',
subject=intermediate_subject,
issuer=ca_subject,
private_key=ca_key,
public_key=intermediate_public_key)
leaf_subject = {"C": "ZZ", "O": "Leaf Certificate", "CN": "Leaf Certificate"}
cert = create_certificate(
cert_type='leaf',
subject=leaf_subject,
issuer=intermediate_subject,
private_key=intermediate_key,
public_key=public_key,
dns_san=args.dnssan)
with open(args.ca, 'wb') as f:
f.write(ca_cert.public_bytes(Encoding.PEM))
with open(args.cert, 'wb') as f:
f.write(cert.public_bytes(Encoding.PEM))
f.write(intermediate_cert.public_bytes(Encoding.PEM))
with open(args.key, 'wb') as f:
f.write(
key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,60 @@
import argparse
import ssl
import socket
import sys
def check_tls(server, ca_cert, sni, alpn):
try:
host, port = server.split(":")
port = int(port)
if ca_cert:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=ca_cert)
context.check_hostname = sni is not None
context.verify_mode = ssl.CERT_REQUIRED
else:
context = ssl.create_default_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
if alpn:
context.set_alpn_protocols([p for p in alpn.split(",")])
with socket.create_connection((host, port)) as sock:
with context.wrap_socket(sock, server_hostname=sni) as ssock:
# Verify handshake and certificate
print(f'Connected to {ssock.version()} using {ssock.cipher()}')
print(f'Server certificate validated and details: {ssock.getpeercert()}')
print("OK")
return 0
except Exception as e:
print(f"Error: {e}")
return 1
def main():
parser = argparse.ArgumentParser(description="Test TLS Server")
parser.add_argument("--server", required=True,
help="Server address to test (e.g., 127.1.2.3:8443)")
parser.add_argument("--ca", required=False, default=None,
help="CA certificate file used to validate the server certificate"
"Omit to use insecure connection")
parser.add_argument("--sni", required=False, default=None,
help="SNI to send in ClientHello")
parser.add_argument("--alpn", required=False, default='h2',
help="ALPN to send in ClientHello")
args = parser.parse_args()
exit_status = check_tls(
server=args.server,
ca_cert=args.ca,
sni=args.sni,
alpn=args.alpn)
sys.exit(exit_status)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,3 @@
# This directory is used for certificate generation in certloader_test.go
/*
!/.gitignore