diff --git a/example/client/main.go b/example/client/main.go index bec0bc3e..aafc50bc 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -2,12 +2,14 @@ package main import ( "bytes" + "crypto/tls" "flag" "io" "net/http" "sync" "github.com/lucas-clemente/quic-go/h2quic" + "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -26,7 +28,11 @@ func main() { } logger.SetLogTimeFormat("") - roundTripper := &h2quic.RoundTripper{} + roundTripper := &h2quic.RoundTripper{ + TLSClientConfig: &tls.Config{ + RootCAs: testdata.GetRootCA(), + }, + } defer roundTripper.Close() hclient := &http.Client{ Transport: roundTripper, diff --git a/example/main.go b/example/main.go index f008badc..205996f0 100644 --- a/example/main.go +++ b/example/main.go @@ -10,14 +10,13 @@ import ( "log" "mime/multipart" "net/http" - "path" - "runtime" "strings" "sync" _ "net/http/pprof" "github.com/lucas-clemente/quic-go/h2quic" + "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -99,15 +98,6 @@ func init() { }) } -func getBuildDir() string { - _, filename, _, ok := runtime.Caller(0) - if !ok { - panic("Failed to get current frame") - } - - return path.Dir(filename) -} - func main() { // defer profile.Start().Stop() go func() { @@ -118,7 +108,6 @@ func main() { verbose := flag.Bool("v", false, "verbose") bs := binds{} flag.Var(&bs, "bind", "bind to") - certPath := flag.String("certpath", getBuildDir(), "certificate directory") www := flag.String("www", "/var/www", "www data") tcp := flag.Bool("tcp", false, "also listen on TCP") flag.Parse() @@ -132,9 +121,6 @@ func main() { } logger.SetLogTimeFormat("") - certFile := *certPath + "/fullchain.pem" - keyFile := *certPath + "/privkey.pem" - http.Handle("/", http.FileServer(http.Dir(*www))) if len(bs) == 0 { @@ -148,12 +134,13 @@ func main() { go func() { var err error if *tcp { + certFile, keyFile := testdata.GetCertificatePaths() err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil) } else { server := h2quic.Server{ Server: &http.Server{Addr: bCap}, } - err = server.ListenAndServeTLS(certFile, keyFile) + err = server.ListenAndServeTLS(testdata.GetCertificatePaths()) } if err != nil { fmt.Println(err)