start path MTU discovery when the handshake completes

This commit is contained in:
Marten Seemann 2021-01-31 14:27:25 +08:00
parent cb1eab22de
commit ac87292e87
8 changed files with 208 additions and 28 deletions

View file

@ -29,7 +29,10 @@ var _ = Describe("Packetization", func() {
server, err = quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{AcceptToken: func(net.Addr, *quic.Token) bool { return true }}),
getQuicConfig(&quic.Config{
AcceptToken: func(net.Addr, *quic.Token) bool { return true },
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
@ -64,7 +67,7 @@ var _ = Describe("Packetization", func() {
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -130,7 +130,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
@ -159,7 +159,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
)
Expect(err).ToNot(HaveOccurred())
strIn, err := sess.AcceptStream(context.Background())
@ -200,7 +200,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
@ -218,7 +218,11 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, Tracer: newTracer(func() logging.ConnectionTracer { return tr })}),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
Tracer: newTracer(func() logging.ConnectionTracer { return tr }),
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
@ -246,7 +250,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
@ -276,7 +280,7 @@ var _ = Describe("Timeout tests", func() {
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout}),
getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
@ -317,7 +321,7 @@ var _ = Describe("Timeout tests", func() {
server, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
@ -345,8 +349,9 @@ var _ = Describe("Timeout tests", func() {
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
KeepAlive: true,
MaxIdleTimeout: idleTimeout,
KeepAlive: true,
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
@ -417,7 +422,7 @@ var _ = Describe("Timeout tests", func() {
ln, err := quic.Listen(
&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
getTLSConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
Expect(err).ToNot(HaveOccurred())
@ -434,8 +439,9 @@ var _ = Describe("Timeout tests", func() {
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
DisablePathMTUDiscovery: true,
}),
)
if err != nil {
@ -467,9 +473,10 @@ var _ = Describe("Timeout tests", func() {
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
KeepAlive: true,
HandshakeIdleTimeout: handshakeTimeout,
MaxIdleTimeout: handshakeTimeout,
KeepAlive: true,
DisablePathMTUDiscovery: true,
}),
)
Expect(err).ToNot(HaveOccurred())
@ -494,7 +501,7 @@ var _ = Describe("Timeout tests", func() {
ln.Addr(),
"localhost",
getTLSClientConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)
if err != nil {
clientErrChan <- err

View file

@ -0,0 +1,80 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: mtu_discoverer.go
// Package quic is a generated GoMock package.
package quic
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockMtuDiscoverer is a mock of MtuDiscoverer interface.
type MockMtuDiscoverer struct {
ctrl *gomock.Controller
recorder *MockMtuDiscovererMockRecorder
}
// MockMtuDiscovererMockRecorder is the mock recorder for MockMtuDiscoverer.
type MockMtuDiscovererMockRecorder struct {
mock *MockMtuDiscoverer
}
// NewMockMtuDiscoverer creates a new mock instance.
func NewMockMtuDiscoverer(ctrl *gomock.Controller) *MockMtuDiscoverer {
mock := &MockMtuDiscoverer{ctrl: ctrl}
mock.recorder = &MockMtuDiscovererMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMtuDiscoverer) EXPECT() *MockMtuDiscovererMockRecorder {
return m.recorder
}
// GetPing mocks base method.
func (m *MockMtuDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPing")
ret0, _ := ret[0].(ackhandler.Frame)
ret1, _ := ret[1].(protocol.ByteCount)
return ret0, ret1
}
// GetPing indicates an expected call of GetPing.
func (mr *MockMtuDiscovererMockRecorder) GetPing() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMtuDiscoverer)(nil).GetPing))
}
// NextProbeTime mocks base method.
func (m *MockMtuDiscoverer) NextProbeTime() time.Time {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextProbeTime")
ret0, _ := ret[0].(time.Time)
return ret0
}
// NextProbeTime indicates an expected call of NextProbeTime.
func (mr *MockMtuDiscovererMockRecorder) NextProbeTime() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextProbeTime", reflect.TypeOf((*MockMtuDiscoverer)(nil).NextProbeTime))
}
// ShouldSendProbe mocks base method.
func (m *MockMtuDiscoverer) ShouldSendProbe(now time.Time) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ShouldSendProbe", now)
ret0, _ := ret[0].(bool)
return ret0
}
// ShouldSendProbe indicates an expected call of ShouldSendProbe.
func (mr *MockMtuDiscovererMockRecorder) ShouldSendProbe(now interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMtuDiscoverer)(nil).ShouldSendProbe), now)
}

View file

@ -8,6 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
qerr "github.com/lucas-clemente/quic-go/internal/qerr"
wire "github.com/lucas-clemente/quic-go/internal/wire"
@ -108,6 +109,21 @@ func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0)
}
// PackMTUProbePacket mocks base method.
func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size)
ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PackMTUProbePacket indicates an expected call of PackMTUProbePacket.
func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size)
}
// PackPacket mocks base method.
func (m *MockPacker) PackPacket() (*packedPacket, error) {
m.ctrl.T.Helper()
@ -123,6 +139,18 @@ func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket))
}
// SetMaxPacketSize mocks base method.
func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetMaxPacketSize", arg0)
}
// SetMaxPacketSize indicates an expected call of SetMaxPacketSize.
func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0)
}
// SetToken mocks base method.
func (m *MockPacker) SetToken(arg0 []byte) {
m.ctrl.T.Helper()

View file

@ -15,6 +15,7 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager"
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker"
//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer"
//go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/lucas-clemente/quic-go mtuDiscoverer"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession"
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler"

View file

@ -22,6 +22,9 @@ type packer interface {
MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error)
PackConnectionClose(*qerr.QuicError) (*coalescedPacket, error)
SetMaxPacketSize(protocol.ByteCount)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error)
HandleTransportParameters(*wire.TransportParameters)
SetToken([]byte)
}

View file

@ -155,9 +155,10 @@ type session struct {
tokenStoreKey string // only set for the client
tokenGenerator *handshake.TokenGenerator // only set for the server
unpacker unpacker
frameParser wire.FrameParser
packer packer
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
@ -731,6 +732,11 @@ func (s *session) maybeResetTimer() {
} else {
deadline = s.idleTimeoutStartTime().Add(s.idleTimeout)
}
if !s.config.DisablePathMTUDiscovery {
if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() {
deadline = utils.MinTime(deadline, probeTime)
}
}
}
if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() {
@ -761,6 +767,23 @@ func (s *session) handleHandshakeComplete() {
s.connIDManager.SetHandshakeComplete()
s.connIDGenerator.SetHandshakeComplete()
if !s.config.DisablePathMTUDiscovery {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize)
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
getMaxPacketSize(s.conn.RemoteAddr()),
maxPacketSize,
func(size protocol.ByteCount) {
s.sentPacketHandler.SetMaxDatagramSize(size)
s.packer.SetMaxPacketSize(size)
},
)
}
if s.perspective == protocol.PerspectiveServer {
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
@ -1584,7 +1607,7 @@ func (s *session) maybeSendAckOnlyPacket() error {
if packet == nil {
return nil
}
s.sendPackedPacket(packet)
s.sendPackedPacket(packet, time.Now())
return nil
}
@ -1626,7 +1649,7 @@ func (s *session) sendProbePacket(encLevel protocol.EncryptionLevel) error {
if packet == nil || packet.packetContents == nil {
return fmt.Errorf("session BUG: couldn't pack %s probe packet", encLevel)
}
s.sendPackedPacket(packet)
s.sendPackedPacket(packet, time.Now())
return nil
}
@ -1636,8 +1659,8 @@ func (s *session) sendPacket() (bool, error) {
}
s.windowUpdateQueue.QueueAll()
now := time.Now()
if !s.handshakeConfirmed {
now := time.Now()
packet, err := s.packer.PackCoalescedPacket()
if err != nil || packet == nil {
return false, err
@ -1653,16 +1676,23 @@ func (s *session) sendPacket() (bool, error) {
s.sendQueue.Send(packet.buffer)
return true, nil
}
if !s.config.DisablePathMTUDiscovery && s.handshakeComplete && s.mtuDiscoverer.ShouldSendProbe(now) {
packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing())
if err != nil {
return false, err
}
s.sendPackedPacket(packet, now)
return true, nil
}
packet, err := s.packer.PackPacket()
if err != nil || packet == nil {
return false, err
}
s.sendPackedPacket(packet)
s.sendPackedPacket(packet, now)
return true, nil
}
func (s *session) sendPackedPacket(packet *packedPacket) {
now := time.Now()
func (s *session) sendPackedPacket(packet *packedPacket, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
}

View file

@ -102,7 +102,7 @@ var _ = Describe("Session", func() {
destConnID,
srcConnID,
protocol.StatelessResetToken{},
populateServerConfig(&Config{}),
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
nil, // tls.Config
tokenGenerator,
false,
@ -1692,6 +1692,34 @@ var _ = Describe("Session", func() {
sess.scheduleSending() // no packet will get sent
time.Sleep(50 * time.Millisecond)
})
It("sends a Path MTU probe packet", func() {
mtuDiscoverer := NewMockMtuDiscoverer(mockCtrl)
sess.mtuDiscoverer = mtuDiscoverer
sess.config.DisablePathMTUDiscovery = false
sph.EXPECT().SentPacket(gomock.Any())
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny)
sph.EXPECT().SendMode().Return(ackhandler.SendNone)
written := make(chan struct{}, 1)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} })
gomock.InOrder(
mtuDiscoverer.EXPECT().NextProbeTime(),
mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true),
mtuDiscoverer.EXPECT().NextProbeTime(),
)
ping := ackhandler.Frame{Frame: &wire.PingFrame{}}
mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234))
packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
sess.run()
}()
sess.scheduleSending()
Eventually(written).Should(Receive())
})
})
Context("scheduling sending", func() {