hysteria/app/internal/utils/certloader_test.go

139 lines
3.3 KiB
Go

package utils
import (
"crypto/tls"
"log"
"net/http"
"os"
"os/exec"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
const (
testListen = "127.82.39.147:12947"
testCAFile = "./testcerts/ca"
testCertFile = "./testcerts/cert"
testKeyFile = "./testcerts/key"
)
func TestCertificateLoaderPathError(t *testing.T) {
assert.NoError(t, os.RemoveAll(testCertFile))
assert.NoError(t, os.RemoveAll(testKeyFile))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardStrict,
}
err := loader.InitializeCache()
var pathErr *os.PathError
assert.ErrorAs(t, err, &pathErr)
}
func TestCertificateLoaderFullChain(t *testing.T) {
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardStrict,
}
assert.NoError(t, loader.InitializeCache())
lis, err := tls.Listen("tcp", testListen, &tls.Config{
GetCertificate: loader.GetCertificate,
})
assert.NoError(t, err)
defer lis.Close()
go http.Serve(lis, nil)
assert.Error(t, runTestTLSClient("unmatched-sni.example.com"))
assert.Error(t, runTestTLSClient(""))
assert.NoError(t, runTestTLSClient("example.com"))
}
func TestCertificateLoaderNoSAN(t *testing.T) {
assert.NoError(t, generateTestCertificate(nil, "selfsign"))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardDNSSAN,
}
assert.NoError(t, loader.InitializeCache())
lis, err := tls.Listen("tcp", testListen, &tls.Config{
GetCertificate: loader.GetCertificate,
})
assert.NoError(t, err)
defer lis.Close()
go http.Serve(lis, nil)
assert.NoError(t, runTestTLSClient(""))
}
func TestCertificateLoaderReplaceCertificate(t *testing.T) {
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
loader := LocalCertificateLoader{
CertFile: testCertFile,
KeyFile: testKeyFile,
SNIGuard: SNIGuardStrict,
}
assert.NoError(t, loader.InitializeCache())
lis, err := tls.Listen("tcp", testListen, &tls.Config{
GetCertificate: loader.GetCertificate,
})
assert.NoError(t, err)
defer lis.Close()
go http.Serve(lis, nil)
assert.NoError(t, runTestTLSClient("example.com"))
assert.Error(t, runTestTLSClient("2.example.com"))
assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain"))
assert.Error(t, runTestTLSClient("example.com"))
assert.NoError(t, runTestTLSClient("2.example.com"))
}
func generateTestCertificate(dnssan []string, certType string) error {
args := []string{
"certloader_test_gencert.py",
"--ca", testCAFile,
"--cert", testCertFile,
"--key", testKeyFile,
"--type", certType,
}
if len(dnssan) > 0 {
args = append(args, "--dnssan", strings.Join(dnssan, ","))
}
cmd := exec.Command("python", args...)
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to generate test certificate: %s", out)
return err
}
return nil
}
func runTestTLSClient(sni string) error {
args := []string{
"certloader_test_tlsclient.py",
"--server", testListen,
"--ca", testCAFile,
}
if sni != "" {
args = append(args, "--sni", sni)
}
cmd := exec.Command("python", args...)
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to run test TLS client: %s", out)
return err
}
return nil
}