package utils import ( "crypto/tls" "log" "net/http" "os" "os/exec" "strings" "testing" "github.com/stretchr/testify/assert" ) const ( testListen = "127.82.39.147:12947" testCAFile = "./testcerts/ca" testCertFile = "./testcerts/cert" testKeyFile = "./testcerts/key" ) func TestCertificateLoaderPathError(t *testing.T) { assert.NoError(t, os.RemoveAll(testCertFile)) assert.NoError(t, os.RemoveAll(testKeyFile)) loader := LocalCertificateLoader{ CertFile: testCertFile, KeyFile: testKeyFile, SNIGuard: SNIGuardStrict, } err := loader.InitializeCache() var pathErr *os.PathError assert.ErrorAs(t, err, &pathErr) } func TestCertificateLoaderFullChain(t *testing.T) { assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) loader := LocalCertificateLoader{ CertFile: testCertFile, KeyFile: testKeyFile, SNIGuard: SNIGuardStrict, } assert.NoError(t, loader.InitializeCache()) lis, err := tls.Listen("tcp", testListen, &tls.Config{ GetCertificate: loader.GetCertificate, }) assert.NoError(t, err) defer lis.Close() go http.Serve(lis, nil) assert.Error(t, runTestTLSClient("unmatched-sni.example.com")) assert.Error(t, runTestTLSClient("")) assert.NoError(t, runTestTLSClient("example.com")) } func TestCertificateLoaderNoSAN(t *testing.T) { assert.NoError(t, generateTestCertificate(nil, "selfsign")) loader := LocalCertificateLoader{ CertFile: testCertFile, KeyFile: testKeyFile, SNIGuard: SNIGuardDNSSAN, } assert.NoError(t, loader.InitializeCache()) lis, err := tls.Listen("tcp", testListen, &tls.Config{ GetCertificate: loader.GetCertificate, }) assert.NoError(t, err) defer lis.Close() go http.Serve(lis, nil) assert.NoError(t, runTestTLSClient("")) } func TestCertificateLoaderReplaceCertificate(t *testing.T) { assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) loader := LocalCertificateLoader{ CertFile: testCertFile, KeyFile: testKeyFile, SNIGuard: SNIGuardStrict, } assert.NoError(t, loader.InitializeCache()) lis, err := tls.Listen("tcp", testListen, &tls.Config{ GetCertificate: loader.GetCertificate, }) assert.NoError(t, err) defer lis.Close() go http.Serve(lis, nil) assert.NoError(t, runTestTLSClient("example.com")) assert.Error(t, runTestTLSClient("2.example.com")) assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain")) assert.Error(t, runTestTLSClient("example.com")) assert.NoError(t, runTestTLSClient("2.example.com")) } func generateTestCertificate(dnssan []string, certType string) error { args := []string{ "certloader_test_gencert.py", "--ca", testCAFile, "--cert", testCertFile, "--key", testKeyFile, "--type", certType, } if len(dnssan) > 0 { args = append(args, "--dnssan", strings.Join(dnssan, ",")) } cmd := exec.Command("python", args...) out, err := cmd.CombinedOutput() if err != nil { log.Printf("Failed to generate test certificate: %s", out) return err } return nil } func runTestTLSClient(sni string) error { args := []string{ "certloader_test_tlsclient.py", "--server", testListen, "--ca", testCAFile, } if sni != "" { args = append(args, "--sni", sni) } cmd := exec.Command("python", args...) out, err := cmd.CombinedOutput() if err != nil { log.Printf("Failed to run test TLS client: %s", out) return err } return nil }