move the RTTStats to the utils package

The RTTStats are used by the logging package. In order to instrument the
congestion package, the RTTStats can't be part of that package any more
(to avoid an import loop).
This commit is contained in:
Marten Seemann 2020-07-22 14:18:57 +07:00
parent ce16603a24
commit 741dc28d74
29 changed files with 129 additions and 139 deletions

View file

@ -1,7 +1,6 @@
package ackhandler package ackhandler
import ( import (
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
@ -11,7 +10,7 @@ import (
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler // NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler
func NewAckHandler( func NewAckHandler(
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
pers protocol.Perspective, pers protocol.Perspective,
traceCallback func(quictrace.Event), traceCallback func(quictrace.Event),
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -45,7 +44,7 @@ var _ ReceivedPacketHandler = &receivedPacketHandler{}
func newReceivedPacketHandler( func newReceivedPacketHandler(
sentPackets sentPacketTracker, sentPackets sentPacketTracker,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.VersionNumber,
) ReceivedPacketHandler { ) ReceivedPacketHandler {

View file

@ -5,7 +5,6 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -22,7 +21,7 @@ var _ = Describe("Received Packet Handler", func() {
sentPackets = NewMockSentPacketTracker(mockCtrl) sentPackets = NewMockSentPacketTracker(mockCtrl)
handler = newReceivedPacketHandler( handler = newReceivedPacketHandler(
sentPackets, sentPackets,
&congestion.RTTStats{}, &utils.RTTStats{},
utils.DefaultLogger, utils.DefaultLogger,
protocol.VersionWhatever, protocol.VersionWhatever,
) )

View file

@ -3,7 +3,6 @@ package ackhandler
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -17,7 +16,7 @@ type receivedPacketTracker struct {
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
maxAckDelay time.Duration maxAckDelay time.Duration
rttStats *congestion.RTTStats rttStats *utils.RTTStats
hasNewAck bool // true as soon as we received an ack-eliciting new packet hasNewAck bool // true as soon as we received an ack-eliciting new packet
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
@ -32,7 +31,7 @@ type receivedPacketTracker struct {
} }
func newReceivedPacketTracker( func newReceivedPacketTracker(
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.VersionNumber,
) *receivedPacketTracker { ) *receivedPacketTracker {

View file

@ -3,7 +3,6 @@ package ackhandler
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -15,11 +14,11 @@ import (
var _ = Describe("Received Packet Tracker", func() { var _ = Describe("Received Packet Tracker", func() {
var ( var (
tracker *receivedPacketTracker tracker *receivedPacketTracker
rttStats *congestion.RTTStats rttStats *utils.RTTStats
) )
BeforeEach(func() { BeforeEach(func() {
rttStats = &congestion.RTTStats{} rttStats = &utils.RTTStats{}
tracker = newReceivedPacketTracker(rttStats, utils.DefaultLogger, protocol.VersionWhatever) tracker = newReceivedPacketTracker(rttStats, utils.DefaultLogger, protocol.VersionWhatever)
}) })

View file

@ -69,7 +69,7 @@ type sentPacketHandler struct {
bytesInFlight protocol.ByteCount bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithmWithDebugInfos congestion congestion.SendAlgorithmWithDebugInfos
rttStats *congestion.RTTStats rttStats *utils.RTTStats
// The number of times a PTO has been sent without receiving an ack. // The number of times a PTO has been sent without receiving an ack.
ptoCount uint32 ptoCount uint32
@ -93,7 +93,7 @@ var _ sentPacketTracker = &sentPacketHandler{}
func newSentPacketHandler( func newSentPacketHandler(
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
pers protocol.Perspective, pers protocol.Perspective,
traceCallback func(quictrace.Event), traceCallback func(quictrace.Event),
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,

View file

@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -27,7 +26,7 @@ var _ = Describe("SentPacketHandler", func() {
JustBeforeEach(func() { JustBeforeEach(func() {
lostPackets = nil lostPackets = nil
rttStats := &congestion.RTTStats{} rttStats := &utils.RTTStats{}
handler = newSentPacketHandler(42, rttStats, perspective, nil, nil, utils.DefaultLogger) handler = newSentPacketHandler(42, rttStats, perspective, nil, nil, utils.DefaultLogger)
streamFrame = wire.StreamFrame{ streamFrame = wire.StreamFrame{
StreamID: 5, StreamID: 5,

View file

@ -20,7 +20,7 @@ const (
type cubicSender struct { type cubicSender struct {
hybridSlowStart HybridSlowStart hybridSlowStart HybridSlowStart
rttStats *RTTStats rttStats *utils.RTTStats
cubic *Cubic cubic *Cubic
pacer *pacer pacer *pacer
clock Clock clock Clock
@ -63,11 +63,11 @@ var _ SendAlgorithm = &cubicSender{}
var _ SendAlgorithmWithDebugInfos = &cubicSender{} var _ SendAlgorithmWithDebugInfos = &cubicSender{}
// NewCubicSender makes a new cubic sender // NewCubicSender makes a new cubic sender
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool) *cubicSender { func NewCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool) *cubicSender {
return newCubicSender(clock, rttStats, reno, initialCongestionWindow, maxCongestionWindow) return newCubicSender(clock, rttStats, reno, initialCongestionWindow, maxCongestionWindow)
} }
func newCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) *cubicSender { func newCubicSender(clock Clock, rttStats *utils.RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) *cubicSender {
c := &cubicSender{ c := &cubicSender{
rttStats: rttStats, rttStats: rttStats,
largestSentPacketNumber: protocol.InvalidPacketNumber, largestSentPacketNumber: protocol.InvalidPacketNumber,

View file

@ -31,7 +31,7 @@ var _ = Describe("Cubic Sender", func() {
bytesInFlight protocol.ByteCount bytesInFlight protocol.ByteCount
packetNumber protocol.PacketNumber packetNumber protocol.PacketNumber
ackedPacketNumber protocol.PacketNumber ackedPacketNumber protocol.PacketNumber
rttStats *RTTStats rttStats *utils.RTTStats
) )
BeforeEach(func() { BeforeEach(func() {
@ -39,7 +39,7 @@ var _ = Describe("Cubic Sender", func() {
packetNumber = 1 packetNumber = 1
ackedPacketNumber = 0 ackedPacketNumber = 0
clock = mockClock{} clock = mockClock{}
rttStats = NewRTTStats() rttStats = utils.NewRTTStats()
sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow) sender = newCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*maxDatagramSize, MaxCongestionWindow)
}) })

View file

@ -4,7 +4,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
@ -25,7 +24,7 @@ type baseFlowController struct {
epochStartTime time.Time epochStartTime time.Time
epochStartOffset protocol.ByteCount epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats rttStats *utils.RTTStats
logger utils.Logger logger utils.Logger
} }

View file

@ -5,7 +5,8 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -27,7 +28,7 @@ var _ = Describe("Base Flow controller", func() {
BeforeEach(func() { BeforeEach(func() {
controller = &baseFlowController{} controller = &baseFlowController{}
controller.rttStats = &congestion.RTTStats{} controller.rttStats = &utils.RTTStats{}
}) })
Context("send flow control", func() { Context("send flow control", func() {

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -24,7 +23,7 @@ func NewConnectionFlowController(
receiveWindow protocol.ByteCount, receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(), queueWindowUpdate func(),
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
logger utils.Logger, logger utils.Logger,
) ConnectionFlowController { ) ConnectionFlowController {
return &connectionFlowController{ return &connectionFlowController{

View file

@ -3,7 +3,6 @@ package flowcontrol
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -25,13 +24,13 @@ var _ = Describe("Connection Flow controller", func() {
BeforeEach(func() { BeforeEach(func() {
queuedWindowUpdate = false queuedWindowUpdate = false
controller = &connectionFlowController{} controller = &connectionFlowController{}
controller.rttStats = &congestion.RTTStats{} controller.rttStats = &utils.RTTStats{}
controller.logger = utils.DefaultLogger controller.logger = utils.DefaultLogger
controller.queueWindowUpdate = func() { queuedWindowUpdate = true } controller.queueWindowUpdate = func() { queuedWindowUpdate = true }
}) })
Context("Constructor", func() { Context("Constructor", func() {
rttStats := &congestion.RTTStats{} rttStats := &utils.RTTStats{}
It("sets the send and receive windows", func() { It("sets the send and receive windows", func() {
receiveWindow := protocol.ByteCount(2000) receiveWindow := protocol.ByteCount(2000)

View file

@ -3,7 +3,6 @@ package flowcontrol
import ( import (
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -31,7 +30,7 @@ func NewStreamFlowController(
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount, initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID), queueWindowUpdate func(protocol.StreamID),
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
logger utils.Logger, logger utils.Logger,
) StreamFlowController { ) StreamFlowController {
return &streamFlowController{ return &streamFlowController{

View file

@ -3,7 +3,6 @@ package flowcontrol
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -18,7 +17,7 @@ var _ = Describe("Stream Flow controller", func() {
BeforeEach(func() { BeforeEach(func() {
queuedWindowUpdate = false queuedWindowUpdate = false
rttStats := &congestion.RTTStats{} rttStats := &utils.RTTStats{}
controller = &streamFlowController{ controller = &streamFlowController{
streamID: 10, streamID: 10,
connection: NewConnectionFlowController(1000, 1000, func() {}, rttStats, utils.DefaultLogger).(*connectionFlowController), connection: NewConnectionFlowController(1000, 1000, func() {}, rttStats, utils.DefaultLogger).(*connectionFlowController),
@ -30,7 +29,7 @@ var _ = Describe("Stream Flow controller", func() {
}) })
Context("Constructor", func() { Context("Constructor", func() {
rttStats := &congestion.RTTStats{} rttStats := &utils.RTTStats{}
receiveWindow := protocol.ByteCount(2000) receiveWindow := protocol.ByteCount(2000)
maxReceiveWindow := protocol.ByteCount(3000) maxReceiveWindow := protocol.ByteCount(3000)
sendWindow := protocol.ByteCount(4000) sendWindow := protocol.ByteCount(4000)

View file

@ -10,7 +10,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
@ -92,7 +91,7 @@ type cryptoSetup struct {
// for clients: to see if a ServerHello is a HelloRetryRequest // for clients: to see if a ServerHello is a HelloRetryRequest
writeRecord chan struct{} writeRecord chan struct{}
rttStats *congestion.RTTStats rttStats *utils.RTTStats
tracer logging.ConnectionTracer tracer logging.ConnectionTracer
logger utils.Logger logger utils.Logger
@ -136,7 +135,7 @@ func NewCryptoSetupClient(
runner handshakeRunner, runner handshakeRunner,
tlsConf *tls.Config, tlsConf *tls.Config,
enable0RTT bool, enable0RTT bool,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { ) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
@ -168,7 +167,7 @@ func NewCryptoSetupServer(
runner handshakeRunner, runner handshakeRunner,
tlsConf *tls.Config, tlsConf *tls.Config,
enable0RTT bool, enable0RTT bool,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
) CryptoSetup { ) CryptoSetup {
@ -197,7 +196,7 @@ func newCryptoSetup(
runner handshakeRunner, runner handshakeRunner,
tlsConf *tls.Config, tlsConf *tls.Config,
enable0RTT bool, enable0RTT bool,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
perspective protocol.Perspective, perspective protocol.Perspective,

View file

@ -11,7 +11,6 @@ import (
"math/big" "math/big"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/testdata"
@ -99,7 +98,7 @@ var _ = Describe("Crypto Setup TLS", func() {
NewMockHandshakeRunner(mockCtrl), NewMockHandshakeRunner(mockCtrl),
tlsConf, tlsConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -133,7 +132,7 @@ var _ = Describe("Crypto Setup TLS", func() {
runner, runner,
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -173,7 +172,7 @@ var _ = Describe("Crypto Setup TLS", func() {
runner, runner,
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -216,7 +215,7 @@ var _ = Describe("Crypto Setup TLS", func() {
runner, runner,
serverConf, serverConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -252,7 +251,7 @@ var _ = Describe("Crypto Setup TLS", func() {
NewMockHandshakeRunner(mockCtrl), NewMockHandshakeRunner(mockCtrl),
serverConf, serverConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -287,8 +286,8 @@ var _ = Describe("Crypto Setup TLS", func() {
} }
} }
newRTTStatsWithRTT := func(rtt time.Duration) *congestion.RTTStats { newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
rttStats := &congestion.RTTStats{} rttStats := &utils.RTTStats{}
rttStats.UpdateRTT(rtt, 0, time.Now()) rttStats.UpdateRTT(rtt, 0, time.Now())
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
return rttStats return rttStats
@ -328,7 +327,7 @@ var _ = Describe("Crypto Setup TLS", func() {
handshakeWithTLSConf := func( handshakeWithTLSConf := func(
clientConf, serverConf *tls.Config, clientConf, serverConf *tls.Config,
clientRTTStats, serverRTTStats *congestion.RTTStats, clientRTTStats, serverRTTStats *utils.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters, clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool, enable0RTT bool,
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
@ -399,7 +398,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes", func() { It("handshakes", func() {
_, _, clientErr, _, serverErr := handshakeWithTLSConf( _, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{}, &utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -411,7 +410,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, _, clientErr, _, serverErr := handshakeWithTLSConf( _, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{}, &utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -424,7 +423,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.ClientAuth = qtls.RequireAnyClientCert serverConf.ClientAuth = qtls.RequireAnyClientCert
_, _, clientErr, _, serverErr := handshakeWithTLSConf( _, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{}, &utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -445,7 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() {
runner, runner,
&tls.Config{InsecureSkipVerify: true}, &tls.Config{InsecureSkipVerify: true},
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -487,7 +486,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner, cRunner,
clientConf, clientConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -511,7 +510,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner, sRunner,
serverConf, serverConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -544,7 +543,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner, cRunner,
clientConf, clientConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -564,7 +563,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner, sRunner,
serverConf, serverConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -604,7 +603,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner, cRunner,
clientConf, clientConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -624,7 +623,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner, sRunner,
serverConf, serverConf,
false, false,
&congestion.RTTStats{}, &utils.RTTStats{},
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -661,7 +660,7 @@ var _ = Describe("Crypto Setup TLS", func() {
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{}, clientOrigRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -674,10 +673,10 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &congestion.RTTStats{} clientRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{}, clientRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -702,7 +701,7 @@ var _ = Describe("Crypto Setup TLS", func() {
clientConf.ClientSessionCache = csc clientConf.ClientSessionCache = csc
_, client, clientErr, server, serverErr := handshakeWithTLSConf( _, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{}, &utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -716,7 +715,7 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
_, client, clientErr, server, serverErr = handshakeWithTLSConf( _, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{}, &utils.RTTStats{}, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{}, &wire.TransportParameters{},
false, false,
) )
@ -759,8 +758,8 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &congestion.RTTStats{} clientRTTStats := &utils.RTTStats{}
serverRTTStats := &congestion.RTTStats{} serverRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientRTTStats, serverRTTStats, clientRTTStats, serverRTTStats,
@ -797,7 +796,7 @@ var _ = Describe("Crypto Setup TLS", func() {
const initialMaxData protocol.ByteCount = 1337 const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{}, clientOrigRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData}, &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true, true,
) )
@ -812,10 +811,10 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &congestion.RTTStats{} clientRTTStats := &utils.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{}, clientRTTStats, &utils.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1}, &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1},
true, true,
) )

View file

@ -6,9 +6,9 @@ import (
"time" "time"
"unsafe" "unsafe"
"github.com/marten-seemann/qtls" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/marten-seemann/qtls"
) )
func init() { func init() {
@ -46,7 +46,7 @@ func tlsConfigToQtlsConfig(
c *tls.Config, c *tls.Config,
recordLayer qtls.RecordLayer, recordLayer qtls.RecordLayer,
extHandler tlsExtensionHandler, extHandler tlsExtensionHandler,
rttStats *congestion.RTTStats, rttStats *utils.RTTStats,
getDataForSessionState func() []byte, getDataForSessionState func() []byte,
setDataFromSessionState func([]byte), setDataFromSessionState func([]byte),
accept0RTT func([]byte) bool, accept0RTT func([]byte) bool,

View file

@ -6,7 +6,8 @@ import (
"net" "net"
"unsafe" "unsafe"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -30,19 +31,19 @@ func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not im
var _ = Describe("qtls.Config", func() { var _ = Describe("qtls.Config", func() {
It("sets MinVersion and MaxVersion", func() { It("sets MinVersion and MaxVersion", func() {
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12} tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
}) })
It("works when called with a nil config", func() { It("works when called with a nil config", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf).ToNot(BeNil()) Expect(qtlsConf).ToNot(BeNil())
}) })
It("sets the setter and getter function for TLS extensions", func() { It("sets the setter and getter function for TLS extensions", func() {
extHandler := &mockExtensionHandler{} extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(extHandler.get).To(BeFalse()) Expect(extHandler.get).To(BeFalse())
qtlsConf.GetExtensions(10) qtlsConf.GetExtensions(10)
Expect(extHandler.get).To(BeTrue()) Expect(extHandler.get).To(BeTrue())
@ -53,7 +54,7 @@ var _ = Describe("qtls.Config", func() {
It("sets the Accept0RTT callback", func() { It("sets the Accept0RTT callback", func() {
accept0RTT := func([]byte) bool { return true } accept0RTT := func([]byte) bool { return true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, accept0RTT, nil, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, accept0RTT, nil, false)
Expect(qtlsConf.Accept0RTT).ToNot(BeNil()) Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue()) Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
}) })
@ -61,32 +62,32 @@ var _ = Describe("qtls.Config", func() {
It("sets the Accept0RTT callback", func() { It("sets the Accept0RTT callback", func() {
var called bool var called bool
rejected0RTT := func() { called = true } rejected0RTT := func() { called = true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, rejected0RTT, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, rejected0RTT, false)
Expect(qtlsConf.Rejected0RTT).ToNot(BeNil()) Expect(qtlsConf.Rejected0RTT).ToNot(BeNil())
qtlsConf.Rejected0RTT() qtlsConf.Rejected0RTT()
Expect(called).To(BeTrue()) Expect(called).To(BeTrue())
}) })
It("enables 0-RTT", func() { It("enables 0-RTT", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.Enable0RTT).To(BeFalse()) Expect(qtlsConf.Enable0RTT).To(BeFalse())
Expect(qtlsConf.MaxEarlyData).To(BeZero()) Expect(qtlsConf.MaxEarlyData).To(BeZero())
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, true) qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, true)
Expect(qtlsConf.Enable0RTT).To(BeTrue()) Expect(qtlsConf.Enable0RTT).To(BeTrue())
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff))) Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
}) })
It("initializes such that the session ticket key remains constant", func() { It("initializes such that the session ticket key remains constant", func() {
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey)) Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
}) })
Context("GetConfigForClient callback", func() { Context("GetConfigForClient callback", func() {
It("doesn't set it if absent", func() { It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient).To(BeNil()) Expect(qtlsConf.GetConfigForClient).To(BeNil())
}) })
@ -97,7 +98,7 @@ var _ = Describe("qtls.Config", func() {
}, },
} }
extHandler := &mockExtensionHandler{} extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil()) Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
confForClient, err := qtlsConf.GetConfigForClient(nil) confForClient, err := qtlsConf.GetConfigForClient(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -117,7 +118,7 @@ var _ = Describe("qtls.Config", func() {
return nil, testErr return nil, testErr
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
_, err := qtlsConf.GetConfigForClient(nil) _, err := qtlsConf.GetConfigForClient(nil)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -128,7 +129,7 @@ var _ = Describe("qtls.Config", func() {
return nil, nil return nil, nil
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil()) Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
}) })
}) })
@ -140,7 +141,7 @@ var _ = Describe("qtls.Config", func() {
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsCert, err := qtlsConf.GetCertificate(nil) qtlsCert, err := qtlsConf.GetCertificate(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(qtlsCert).ToNot(BeNil()) Expect(qtlsCert).ToNot(BeNil())
@ -148,7 +149,7 @@ var _ = Describe("qtls.Config", func() {
}) })
It("doesn't set it if absent", func() { It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetCertificate).To(BeNil()) Expect(qtlsConf.GetCertificate).To(BeNil())
}) })
@ -158,7 +159,7 @@ var _ = Describe("qtls.Config", func() {
return nil, errors.New("test") return nil, errors.New("test")
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
_, err := qtlsConf.GetCertificate(nil) _, err := qtlsConf.GetCertificate(nil)
Expect(err).To(MatchError("test")) Expect(err).To(MatchError("test"))
}) })
@ -169,21 +170,21 @@ var _ = Describe("qtls.Config", func() {
return nil, nil return nil, nil
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetCertificate(nil)).To(BeNil()) Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
}) })
}) })
Context("ClientSessionCache", func() { Context("ClientSessionCache", func() {
It("doesn't set if absent", func() { It("doesn't set if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.ClientSessionCache).To(BeNil()) Expect(qtlsConf.ClientSessionCache).To(BeNil())
}) })
It("puts a nil session state", func() { It("puts a nil session state", func() {
csc := NewMockClientSessionCache(mockCtrl) csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc} tlsConf := &tls.Config{ClientSessionCache: csc}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, congestion.NewRTTStats(), nil, nil, nil, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
// put something // put something
csc.EXPECT().Put("foobar", nil) csc.EXPECT().Put("foobar", nil)
qtlsConf.ClientSessionCache.Put("foobar", nil) qtlsConf.ClientSessionCache.Put("foobar", nil)

View file

@ -9,7 +9,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
@ -72,7 +71,7 @@ type updatableAEAD struct {
headerDecrypter headerProtector headerDecrypter headerProtector
headerEncrypter headerProtector headerEncrypter headerProtector
rttStats *congestion.RTTStats rttStats *utils.RTTStats
tracer logging.ConnectionTracer tracer logging.ConnectionTracer
logger utils.Logger logger utils.Logger
@ -84,7 +83,7 @@ type updatableAEAD struct {
var _ ShortHeaderOpener = &updatableAEAD{} var _ ShortHeaderOpener = &updatableAEAD{}
var _ ShortHeaderSealer = &updatableAEAD{} var _ ShortHeaderSealer = &updatableAEAD{}
func newUpdatableAEAD(rttStats *congestion.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD { func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD {
return &updatableAEAD{ return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber, firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,

View file

@ -6,7 +6,6 @@ import (
"os" "os"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
@ -18,7 +17,7 @@ import (
var _ = Describe("Updatable AEAD", func() { var _ = Describe("Updatable AEAD", func() {
It("ChaCha test vector from the draft", func() { It("ChaCha test vector from the draft", func() {
secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
aead := newUpdatableAEAD(&congestion.RTTStats{}, nil, nil) aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil)
chacha := cipherSuites[2] chacha := cipherSuites[2]
Expect(chacha.ID).To(Equal(qtls.TLS_CHACHA20_POLY1305_SHA256)) Expect(chacha.ID).To(Equal(qtls.TLS_CHACHA20_POLY1305_SHA256))
aead.SetWriteKey(chacha, secret) aead.SetWriteKey(chacha, secret)
@ -37,7 +36,7 @@ var _ = Describe("Updatable AEAD", func() {
cs := cipherSuites[i] cs := cipherSuites[i]
Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() { Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() {
getPeers := func(rttStats *congestion.RTTStats) (client, server *updatableAEAD) { getPeers := func(rttStats *utils.RTTStats) (client, server *updatableAEAD) {
trafficSecret1 := make([]byte, 16) trafficSecret1 := make([]byte, 16)
trafficSecret2 := make([]byte, 16) trafficSecret2 := make([]byte, 16)
rand.Read(trafficSecret1) rand.Read(trafficSecret1)
@ -54,7 +53,7 @@ var _ = Describe("Updatable AEAD", func() {
Context("header protection", func() { Context("header protection", func() {
It("encrypts and decrypts the header", func() { It("encrypts and decrypts the header", func() {
server, client := getPeers(&congestion.RTTStats{}) server, client := getPeers(&utils.RTTStats{})
var lastFiveBitsDifferent int var lastFiveBitsDifferent int
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
sample := make([]byte, 16) sample := make([]byte, 16)
@ -77,10 +76,10 @@ var _ = Describe("Updatable AEAD", func() {
Context("message encryption", func() { Context("message encryption", func() {
var msg, ad []byte var msg, ad []byte
var server, client *updatableAEAD var server, client *updatableAEAD
var rttStats *congestion.RTTStats var rttStats *utils.RTTStats
BeforeEach(func() { BeforeEach(func() {
rttStats = &congestion.RTTStats{} rttStats = &utils.RTTStats{}
server, client = getPeers(rttStats) server, client = getPeers(rttStats)
msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
ad = []byte("Donec in velit neque.") ad = []byte("Donec in velit neque.")

View file

@ -5,15 +5,16 @@
package mocks package mocks
import ( import (
net "net" "net"
reflect "reflect" "reflect"
time "time" "time"
gomock "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/utils"
congestion "github.com/lucas-clemente/quic-go/internal/congestion"
protocol "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/golang/mock/gomock"
wire "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/protocol"
logging "github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
) )
// MockConnectionTracer is a mock of ConnectionTracer interface // MockConnectionTracer is a mock of ConnectionTracer interface
@ -256,7 +257,7 @@ func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interfa
} }
// UpdatedMetrics mocks base method // UpdatedMetrics mocks base method
func (m *MockConnectionTracer) UpdatedMetrics(arg0 *congestion.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3)
} }

View file

@ -1,10 +1,9 @@
package congestion package utils
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
const ( const (
@ -56,7 +55,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
if r.SmoothedRTT() == 0 { if r.SmoothedRTT() == 0 {
return 2 * defaultInitialRTT return 2 * defaultInitialRTT
} }
pto := r.SmoothedRTT() + utils.MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity)
if includeMaxAckDelay { if includeMaxAckDelay {
pto += r.MaxAckDelay() pto += r.MaxAckDelay()
} }
@ -65,7 +64,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
// UpdateRTT updates the RTT based on a new sample. // UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 { if sendDelta == InfDuration || sendDelta <= 0 {
return return
} }
@ -91,7 +90,7 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
r.smoothedRTT = sample r.smoothedRTT = sample
r.meanDeviation = sample / 2 r.meanDeviation = sample / 2
} else { } else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
} }
} }
@ -123,6 +122,6 @@ func (r *RTTStats) OnConnectionMigration() {
// is larger. The mean deviation is increased to the most recent deviation if // is larger. The mean deviation is increased to the most recent deviation if
// it's larger. // it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() { func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT)) r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT) r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT)
} }

View file

@ -1,10 +1,10 @@
package congestion package utils
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -113,7 +113,7 @@ var _ = Describe("RTT stats", func() {
badSendDeltas := []time.Duration{ badSendDeltas := []time.Duration{
0, 0,
utils.InfDuration, InfDuration,
-1000 * time.Microsecond, -1000 * time.Microsecond,
} }
// log.StartCapturingLogs(); // log.StartCapturingLogs();

View file

@ -6,7 +6,8 @@ import (
"net" "net"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -49,7 +50,7 @@ type (
ApplicationError = qerr.ErrorCode ApplicationError = qerr.ErrorCode
// The RTTStats contain statistics used by the congestion controller. // The RTTStats contain statistics used by the congestion controller.
RTTStats = congestion.RTTStats RTTStats = utils.RTTStats
) )
const ( const (

View file

@ -5,13 +5,14 @@
package logging package logging
import ( import (
gomock "github.com/golang/mock/gomock" "net"
congestion "github.com/lucas-clemente/quic-go/internal/congestion" "reflect"
protocol "github.com/lucas-clemente/quic-go/internal/protocol" "time"
wire "github.com/lucas-clemente/quic-go/internal/wire"
net "net" "github.com/golang/mock/gomock"
reflect "reflect" "github.com/lucas-clemente/quic-go/internal/protocol"
time "time" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
) )
// MockConnectionTracer is a mock of ConnectionTracer interface // MockConnectionTracer is a mock of ConnectionTracer interface
@ -254,7 +255,7 @@ func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interfa
} }
// UpdatedMetrics mocks base method // UpdatedMetrics mocks base method
func (m *MockConnectionTracer) UpdatedMetrics(arg0 *congestion.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3)
} }

View file

@ -9,7 +9,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
@ -286,7 +287,7 @@ func (t *connectionTracer) DroppedPacket(pt logging.PacketType, size protocol.By
t.mutex.Unlock() t.mutex.Unlock()
} }
func (t *connectionTracer) UpdatedMetrics(rttStats *congestion.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { func (t *connectionTracer) UpdatedMetrics(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) {
m := &metrics{ m := &metrics{
MinRTT: rttStats.MinRTT(), MinRTT: rttStats.MinRTT(),
SmoothedRTT: rttStats.SmoothedRTT(), SmoothedRTT: rttStats.SmoothedRTT(),

View file

@ -10,7 +10,8 @@ import (
"os" "os"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
@ -438,7 +439,7 @@ var _ = Describe("Tracing", func() {
It("records metrics updates", func() { It("records metrics updates", func() {
now := time.Now() now := time.Now()
rttStats := congestion.NewRTTStats() rttStats := utils.NewRTTStats()
rttStats.UpdateRTT(15*time.Millisecond, 0, now) rttStats.UpdateRTT(15*time.Millisecond, 0, now)
rttStats.UpdateRTT(20*time.Millisecond, 0, now) rttStats.UpdateRTT(20*time.Millisecond, 0, now)
rttStats.UpdateRTT(25*time.Millisecond, 0, now) rttStats.UpdateRTT(25*time.Millisecond, 0, now)
@ -472,13 +473,13 @@ var _ = Describe("Tracing", func() {
It("only logs the diff between two metrics updates", func() { It("only logs the diff between two metrics updates", func() {
now := time.Now() now := time.Now()
rttStats := congestion.NewRTTStats() rttStats := utils.NewRTTStats()
rttStats.UpdateRTT(15*time.Millisecond, 0, now) rttStats.UpdateRTT(15*time.Millisecond, 0, now)
rttStats.UpdateRTT(20*time.Millisecond, 0, now) rttStats.UpdateRTT(20*time.Millisecond, 0, now)
rttStats.UpdateRTT(25*time.Millisecond, 0, now) rttStats.UpdateRTT(25*time.Millisecond, 0, now)
Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond))
rttStats2 := congestion.NewRTTStats() rttStats2 := utils.NewRTTStats()
rttStats2.UpdateRTT(15*time.Millisecond, 0, now) rttStats2.UpdateRTT(15*time.Millisecond, 0, now)
rttStats2.UpdateRTT(15*time.Millisecond, 0, now) rttStats2.UpdateRTT(15*time.Millisecond, 0, now)
rttStats2.UpdateRTT(15*time.Millisecond, 0, now) rttStats2.UpdateRTT(15*time.Millisecond, 0, now)

View file

@ -13,7 +13,6 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/logutils" "github.com/lucas-clemente/quic-go/internal/logutils"
@ -144,7 +143,7 @@ type session struct {
connIDManager *connIDManager connIDManager *connIDManager
connIDGenerator *connIDGenerator connIDGenerator *connIDGenerator
rttStats *congestion.RTTStats rttStats *utils.RTTStats
cryptoStreamManager *cryptoStreamManager cryptoStreamManager *cryptoStreamManager
sentPacketHandler ackhandler.SentPacketHandler sentPacketHandler ackhandler.SentPacketHandler
@ -472,7 +471,7 @@ func (s *session) preSetup() {
s.sendQueue = newSendQueue(s.conn) s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue(s.version) s.retransmissionQueue = newRetransmissionQueue(s.version)
s.frameParser = wire.NewFrameParser(s.version) s.frameParser = wire.NewFrameParser(s.version)
s.rttStats = &congestion.RTTStats{} s.rttStats = &utils.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData, protocol.InitialMaxData,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),