mirror of
https://github.com/apernet/hysteria.git
synced 2025-04-03 20:47:38 +03:00
test: add tests for certloader
This commit is contained in:
parent
bcf830c29a
commit
667b08ec3e
4 changed files with 336 additions and 0 deletions
139
app/internal/utils/certloader_test.go
Normal file
139
app/internal/utils/certloader_test.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
testListen = "127.82.39.147:12947"
|
||||
testCAFile = "./testcerts/ca"
|
||||
testCertFile = "./testcerts/cert"
|
||||
testKeyFile = "./testcerts/key"
|
||||
)
|
||||
|
||||
func TestCertificateLoaderPathError(t *testing.T) {
|
||||
assert.NoError(t, os.RemoveAll(testCertFile))
|
||||
assert.NoError(t, os.RemoveAll(testKeyFile))
|
||||
loader := LocalCertificateLoader{
|
||||
CertFile: testCertFile,
|
||||
KeyFile: testKeyFile,
|
||||
SNIGuard: SNIGuardStrict,
|
||||
}
|
||||
err := loader.InitializeCache()
|
||||
var pathErr *os.PathError
|
||||
assert.ErrorAs(t, err, &pathErr)
|
||||
}
|
||||
|
||||
func TestCertificateLoaderFullChain(t *testing.T) {
|
||||
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
|
||||
|
||||
loader := LocalCertificateLoader{
|
||||
CertFile: testCertFile,
|
||||
KeyFile: testKeyFile,
|
||||
SNIGuard: SNIGuardStrict,
|
||||
}
|
||||
assert.NoError(t, loader.InitializeCache())
|
||||
|
||||
lis, err := tls.Listen("tcp", testListen, &tls.Config{
|
||||
GetCertificate: loader.GetCertificate,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer lis.Close()
|
||||
go http.Serve(lis, nil)
|
||||
|
||||
assert.Error(t, runTestTLSClient("unmatched-sni.example.com"))
|
||||
assert.Error(t, runTestTLSClient(""))
|
||||
assert.NoError(t, runTestTLSClient("example.com"))
|
||||
}
|
||||
|
||||
func TestCertificateLoaderNoSAN(t *testing.T) {
|
||||
assert.NoError(t, generateTestCertificate(nil, "selfsign"))
|
||||
|
||||
loader := LocalCertificateLoader{
|
||||
CertFile: testCertFile,
|
||||
KeyFile: testKeyFile,
|
||||
SNIGuard: SNIGuardDNSSAN,
|
||||
}
|
||||
assert.NoError(t, loader.InitializeCache())
|
||||
|
||||
lis, err := tls.Listen("tcp", testListen, &tls.Config{
|
||||
GetCertificate: loader.GetCertificate,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer lis.Close()
|
||||
go http.Serve(lis, nil)
|
||||
|
||||
assert.NoError(t, runTestTLSClient(""))
|
||||
}
|
||||
|
||||
func TestCertificateLoaderReplaceCertificate(t *testing.T) {
|
||||
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
|
||||
|
||||
loader := LocalCertificateLoader{
|
||||
CertFile: testCertFile,
|
||||
KeyFile: testKeyFile,
|
||||
SNIGuard: SNIGuardStrict,
|
||||
}
|
||||
assert.NoError(t, loader.InitializeCache())
|
||||
|
||||
lis, err := tls.Listen("tcp", testListen, &tls.Config{
|
||||
GetCertificate: loader.GetCertificate,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer lis.Close()
|
||||
go http.Serve(lis, nil)
|
||||
|
||||
assert.NoError(t, runTestTLSClient("example.com"))
|
||||
assert.Error(t, runTestTLSClient("2.example.com"))
|
||||
|
||||
assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain"))
|
||||
|
||||
assert.Error(t, runTestTLSClient("example.com"))
|
||||
assert.NoError(t, runTestTLSClient("2.example.com"))
|
||||
}
|
||||
|
||||
func generateTestCertificate(dnssan []string, certType string) error {
|
||||
args := []string{
|
||||
"certloader_test_gencert.py",
|
||||
"--ca", testCAFile,
|
||||
"--cert", testCertFile,
|
||||
"--key", testKeyFile,
|
||||
"--type", certType,
|
||||
}
|
||||
if len(dnssan) > 0 {
|
||||
args = append(args, "--dnssan", strings.Join(dnssan, ","))
|
||||
}
|
||||
cmd := exec.Command("python3", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("Failed to generate test certificate: %s", out)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runTestTLSClient(sni string) error {
|
||||
args := []string{
|
||||
"certloader_test_tlsclient.py",
|
||||
"--server", testListen,
|
||||
"--ca", testCAFile,
|
||||
}
|
||||
if sni != "" {
|
||||
args = append(args, "--sni", sni)
|
||||
}
|
||||
cmd := exec.Command("python3", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("Failed to run test TLS client: %s", out)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
134
app/internal/utils/certloader_test_gencert.py
Normal file
134
app/internal/utils/certloader_test_gencert.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
import argparse
|
||||
import datetime
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
||||
|
||||
|
||||
def create_key():
|
||||
return ec.generate_private_key(ec.SECP256R1())
|
||||
|
||||
|
||||
def create_certificate(cert_type, subject, issuer, private_key, public_key, dns_san=None):
|
||||
serial_number = x509.random_serial_number()
|
||||
not_valid_before = datetime.datetime.now(datetime.UTC)
|
||||
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
||||
|
||||
subject_name = x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, subject.get('C', 'ZZ')),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject.get('O', 'No Organization')),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, subject.get('CN', 'No CommonName')),
|
||||
])
|
||||
issuer_name = x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, issuer.get('C', 'ZZ')),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, issuer.get('O', 'No Organization')),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, issuer.get('CN', 'No CommonName')),
|
||||
])
|
||||
builder = x509.CertificateBuilder()
|
||||
builder = builder.subject_name(subject_name)
|
||||
builder = builder.issuer_name(issuer_name)
|
||||
builder = builder.public_key(public_key)
|
||||
builder = builder.serial_number(serial_number)
|
||||
builder = builder.not_valid_before(not_valid_before)
|
||||
builder = builder.not_valid_after(not_valid_after)
|
||||
if cert_type == 'root':
|
||||
builder = builder.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=None), critical=True
|
||||
)
|
||||
elif cert_type == 'intermediate':
|
||||
builder = builder.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||
)
|
||||
elif cert_type == 'leaf':
|
||||
builder = builder.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None), critical=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid cert_type: {cert_type}')
|
||||
if dns_san:
|
||||
builder = builder.add_extension(
|
||||
x509.SubjectAlternativeName([x509.DNSName(d) for d in dns_san.split(',')]),
|
||||
critical=False
|
||||
)
|
||||
return builder.sign(private_key=private_key, algorithm=hashes.SHA256())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate HTTPS server certificate.')
|
||||
parser.add_argument('--ca', required=True,
|
||||
help='Path to write the X509 CA certificate in PEM format')
|
||||
parser.add_argument('--cert', required=True,
|
||||
help='Path to write the X509 certificate in PEM format')
|
||||
parser.add_argument('--key', required=True,
|
||||
help='Path to write the private key in PEM format')
|
||||
parser.add_argument('--dnssan', required=False, default=None,
|
||||
help='Comma-separated list of DNS SANs')
|
||||
parser.add_argument('--type', required=True, choices=['selfsign', 'fullchain'],
|
||||
help='Type of certificate to generate')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
key = create_key()
|
||||
public_key = key.public_key()
|
||||
|
||||
if args.type == 'selfsign':
|
||||
subject = {"C": "ZZ", "O": "Certificate", "CN": "Certificate"}
|
||||
cert = create_certificate(
|
||||
cert_type='root',
|
||||
subject=subject,
|
||||
issuer=subject,
|
||||
private_key=key,
|
||||
public_key=public_key,
|
||||
dns_san=args.dnssan)
|
||||
with open(args.ca, 'wb') as f:
|
||||
f.write(cert.public_bytes(Encoding.PEM))
|
||||
with open(args.cert, 'wb') as f:
|
||||
f.write(cert.public_bytes(Encoding.PEM))
|
||||
with open(args.key, 'wb') as f:
|
||||
f.write(
|
||||
key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()))
|
||||
|
||||
elif args.type == 'fullchain':
|
||||
ca_key = create_key()
|
||||
ca_public_key = ca_key.public_key()
|
||||
ca_subject = {"C": "ZZ", "O": "Root CA", "CN": "Root CA"}
|
||||
ca_cert = create_certificate(
|
||||
cert_type='root',
|
||||
subject=ca_subject,
|
||||
issuer=ca_subject,
|
||||
private_key=ca_key,
|
||||
public_key=ca_public_key)
|
||||
|
||||
intermediate_key = create_key()
|
||||
intermediate_public_key = intermediate_key.public_key()
|
||||
intermediate_subject = {"C": "ZZ", "O": "Intermediate CA", "CN": "Intermediate CA"}
|
||||
intermediate_cert = create_certificate(
|
||||
cert_type='intermediate',
|
||||
subject=intermediate_subject,
|
||||
issuer=ca_subject,
|
||||
private_key=ca_key,
|
||||
public_key=intermediate_public_key)
|
||||
|
||||
leaf_subject = {"C": "ZZ", "O": "Leaf Certificate", "CN": "Leaf Certificate"}
|
||||
cert = create_certificate(
|
||||
cert_type='leaf',
|
||||
subject=leaf_subject,
|
||||
issuer=intermediate_subject,
|
||||
private_key=intermediate_key,
|
||||
public_key=public_key,
|
||||
dns_san=args.dnssan)
|
||||
|
||||
with open(args.ca, 'wb') as f:
|
||||
f.write(ca_cert.public_bytes(Encoding.PEM))
|
||||
with open(args.cert, 'wb') as f:
|
||||
f.write(cert.public_bytes(Encoding.PEM))
|
||||
f.write(intermediate_cert.public_bytes(Encoding.PEM))
|
||||
with open(args.key, 'wb') as f:
|
||||
f.write(
|
||||
key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
60
app/internal/utils/certloader_test_tlsclient.py
Normal file
60
app/internal/utils/certloader_test_tlsclient.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import argparse
|
||||
import ssl
|
||||
import socket
|
||||
import sys
|
||||
|
||||
|
||||
def check_tls(server, ca_cert, sni, alpn):
|
||||
try:
|
||||
host, port = server.split(":")
|
||||
port = int(port)
|
||||
|
||||
if ca_cert:
|
||||
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=ca_cert)
|
||||
context.check_hostname = sni is not None
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
if alpn:
|
||||
context.set_alpn_protocols([p for p in alpn.split(",")])
|
||||
|
||||
with socket.create_connection((host, port)) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=sni) as ssock:
|
||||
# Verify handshake and certificate
|
||||
print(f'Connected to {ssock.version()} using {ssock.cipher()}')
|
||||
print(f'Server certificate validated and details: {ssock.getpeercert()}')
|
||||
print("OK")
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test TLS Server")
|
||||
parser.add_argument("--server", required=True,
|
||||
help="Server address to test (e.g., 127.1.2.3:8443)")
|
||||
parser.add_argument("--ca", required=False, default=None,
|
||||
help="CA certificate file used to validate the server certificate"
|
||||
"Omit to use insecure connection")
|
||||
parser.add_argument("--sni", required=False, default=None,
|
||||
help="SNI to send in ClientHello")
|
||||
parser.add_argument("--alpn", required=False, default='h2',
|
||||
help="ALPN to send in ClientHello")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
exit_status = check_tls(
|
||||
server=args.server,
|
||||
ca_cert=args.ca,
|
||||
sni=args.sni,
|
||||
alpn=args.alpn)
|
||||
|
||||
sys.exit(exit_status)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
app/internal/utils/testcerts/.gitignore
vendored
Normal file
3
app/internal/utils/testcerts/.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
# This directory is used for certificate generation in certloader_test.go
|
||||
/*
|
||||
!/.gitignore
|
Loading…
Add table
Add a link
Reference in a new issue