mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-01 19:27:35 +03:00
327 lines
8.2 KiB
Go
327 lines
8.2 KiB
Go
package self_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/x509"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"runtime/pprof"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
tls "github.com/refraction-networking/utls"
|
|
|
|
quic "github.com/refraction-networking/uquic"
|
|
"github.com/refraction-networking/uquic/integrationtests/tools"
|
|
"github.com/refraction-networking/uquic/internal/protocol"
|
|
"github.com/refraction-networking/uquic/internal/utils"
|
|
"github.com/refraction-networking/uquic/internal/wire"
|
|
"github.com/refraction-networking/uquic/logging"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
const alpn = tools.ALPN
|
|
|
|
const (
|
|
dataLen = 500 * 1024 // 500 KB
|
|
dataLenLong = 50 * 1024 * 1024 // 50 MB
|
|
)
|
|
|
|
var (
|
|
// PRData contains dataLen bytes of pseudo-random data.
|
|
PRData = GeneratePRData(dataLen)
|
|
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
|
PRDataLong = GeneratePRData(dataLenLong)
|
|
)
|
|
|
|
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
|
func GeneratePRData(l int) []byte {
|
|
res := make([]byte, l)
|
|
seed := uint64(1)
|
|
for i := 0; i < l; i++ {
|
|
seed = seed * 48271 % 2147483647
|
|
res[i] = byte(seed)
|
|
}
|
|
return res
|
|
}
|
|
|
|
const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
|
|
|
|
type syncedBuffer struct {
|
|
mutex sync.Mutex
|
|
|
|
*bytes.Buffer
|
|
}
|
|
|
|
func (b *syncedBuffer) Write(p []byte) (int, error) {
|
|
b.mutex.Lock()
|
|
n, err := b.Buffer.Write(p)
|
|
b.mutex.Unlock()
|
|
return n, err
|
|
}
|
|
|
|
func (b *syncedBuffer) Bytes() []byte {
|
|
b.mutex.Lock()
|
|
p := b.Buffer.Bytes()
|
|
b.mutex.Unlock()
|
|
return p
|
|
}
|
|
|
|
func (b *syncedBuffer) Reset() {
|
|
b.mutex.Lock()
|
|
b.Buffer.Reset()
|
|
b.mutex.Unlock()
|
|
}
|
|
|
|
var (
|
|
logFileName string // the log file set in the ginkgo flags
|
|
logBufOnce sync.Once
|
|
logBuf *syncedBuffer
|
|
versionParam string
|
|
|
|
enableQlog bool
|
|
|
|
version quic.Version
|
|
tlsConfig *tls.Config
|
|
tlsConfigLongChain *tls.Config
|
|
tlsClientConfig *tls.Config
|
|
tlsClientConfigWithoutServerName *tls.Config
|
|
)
|
|
|
|
// read the logfile command line flag
|
|
// to set call ginkgo -- -logfile=log.txt
|
|
func init() {
|
|
flag.StringVar(&logFileName, "logfile", "", "log file")
|
|
flag.StringVar(&versionParam, "version", "1", "QUIC version")
|
|
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
|
|
|
|
ca, caPrivateKey, err := tools.GenerateCA()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
tlsConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{{
|
|
Certificate: [][]byte{leafCert.Raw},
|
|
PrivateKey: leafPrivateKey,
|
|
}},
|
|
NextProtos: []string{alpn},
|
|
}
|
|
tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
tlsConfigLongChain = tlsConfLongChain
|
|
|
|
root := x509.NewCertPool()
|
|
root.AddCert(ca)
|
|
tlsClientConfig = &tls.Config{
|
|
ServerName: "localhost",
|
|
RootCAs: root,
|
|
NextProtos: []string{alpn},
|
|
}
|
|
tlsClientConfigWithoutServerName = &tls.Config{
|
|
RootCAs: root,
|
|
NextProtos: []string{alpn},
|
|
}
|
|
}
|
|
|
|
var _ = BeforeSuite(func() {
|
|
switch versionParam {
|
|
case "1":
|
|
version = quic.Version1
|
|
case "2":
|
|
version = quic.Version2
|
|
default:
|
|
Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
|
|
}
|
|
fmt.Printf("Using QUIC version: %s\n", version)
|
|
protocol.SupportedVersions = []quic.Version{version}
|
|
})
|
|
|
|
func getTLSConfig() *tls.Config {
|
|
return tlsConfig.Clone()
|
|
}
|
|
|
|
func getTLSConfigWithLongCertChain() *tls.Config {
|
|
return tlsConfigLongChain.Clone()
|
|
}
|
|
|
|
func getTLSClientConfig() *tls.Config {
|
|
return tlsClientConfig.Clone()
|
|
}
|
|
|
|
func getTLSClientConfigWithoutServerName() *tls.Config {
|
|
return tlsClientConfigWithoutServerName.Clone()
|
|
}
|
|
|
|
func getQuicConfig(conf *quic.Config) *quic.Config {
|
|
if conf == nil {
|
|
conf = &quic.Config{}
|
|
} else {
|
|
conf = conf.Clone()
|
|
}
|
|
if !enableQlog {
|
|
return conf
|
|
}
|
|
if conf.Tracer == nil {
|
|
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
|
return logging.NewMultiplexedConnectionTracer(
|
|
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
|
|
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
|
|
&logging.ConnectionTracer{},
|
|
)
|
|
}
|
|
return conf
|
|
}
|
|
origTracer := conf.Tracer
|
|
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
|
return logging.NewMultiplexedConnectionTracer(
|
|
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
|
|
origTracer(ctx, p, connID),
|
|
)
|
|
}
|
|
return conf
|
|
}
|
|
|
|
func addTracer(tr *quic.Transport) {
|
|
if !enableQlog {
|
|
return
|
|
}
|
|
if tr.Tracer == nil {
|
|
tr.Tracer = logging.NewMultiplexedTracer(
|
|
tools.QlogTracer(GinkgoWriter),
|
|
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
|
|
&logging.Tracer{},
|
|
)
|
|
return
|
|
}
|
|
origTracer := tr.Tracer
|
|
tr.Tracer = logging.NewMultiplexedTracer(
|
|
tools.QlogTracer(GinkgoWriter),
|
|
origTracer,
|
|
)
|
|
}
|
|
|
|
var _ = BeforeEach(func() {
|
|
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
|
|
|
|
if debugLog() {
|
|
logBufOnce.Do(func() {
|
|
logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
|
|
})
|
|
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
|
|
log.SetOutput(logBuf)
|
|
}
|
|
})
|
|
|
|
func areHandshakesRunning() bool {
|
|
var b bytes.Buffer
|
|
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
return strings.Contains(b.String(), "RunHandshake")
|
|
}
|
|
|
|
func areTransportsRunning() bool {
|
|
var b bytes.Buffer
|
|
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
|
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
|
}
|
|
|
|
var _ = AfterEach(func() {
|
|
Expect(areHandshakesRunning()).To(BeFalse())
|
|
Eventually(areTransportsRunning).Should(BeFalse())
|
|
|
|
if debugLog() {
|
|
logFile, err := os.Create(logFileName)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
logFile.Write(logBuf.Bytes())
|
|
logFile.Close()
|
|
logBuf.Reset()
|
|
}
|
|
})
|
|
|
|
// Debug says if this test is being logged
|
|
func debugLog() bool {
|
|
return len(logFileName) > 0
|
|
}
|
|
|
|
func scaleDuration(d time.Duration) time.Duration {
|
|
scaleFactor := 1
|
|
if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
|
|
scaleFactor = f
|
|
}
|
|
Expect(scaleFactor).ToNot(BeZero())
|
|
return time.Duration(scaleFactor) * d
|
|
}
|
|
|
|
func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
|
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer }
|
|
}
|
|
|
|
type packet struct {
|
|
time time.Time
|
|
hdr *logging.ExtendedHeader
|
|
frames []logging.Frame
|
|
}
|
|
|
|
type shortHeaderPacket struct {
|
|
time time.Time
|
|
hdr *logging.ShortHeader
|
|
frames []logging.Frame
|
|
}
|
|
|
|
type packetCounter struct {
|
|
closed chan struct{}
|
|
sentShortHdr, rcvdShortHdr []shortHeaderPacket
|
|
rcvdLongHdr []packet
|
|
}
|
|
|
|
func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket {
|
|
<-t.closed
|
|
return t.sentShortHdr
|
|
}
|
|
|
|
func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
|
|
<-t.closed
|
|
return t.rcvdLongHdr
|
|
}
|
|
|
|
func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
|
|
<-t.closed
|
|
return t.rcvdShortHdr
|
|
}
|
|
|
|
func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) {
|
|
c := &packetCounter{closed: make(chan struct{})}
|
|
return c, &logging.ConnectionTracer{
|
|
ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
|
|
c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
|
|
},
|
|
ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
|
|
c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
|
|
},
|
|
SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
|
|
if ack != nil {
|
|
frames = append(frames, ack)
|
|
}
|
|
c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
|
|
},
|
|
Close: func() { close(c.closed) },
|
|
}
|
|
}
|
|
|
|
func TestSelf(t *testing.T) {
|
|
RegisterFailHandler(Fail)
|
|
RunSpecs(t, "Self integration tests")
|
|
}
|