remove the params negotiator

This commit is contained in:
Marten Seemann 2017-10-16 08:55:16 +07:00
parent 925a52f032
commit f3e9bf4332
37 changed files with 1013 additions and 1296 deletions

View file

@ -13,32 +13,31 @@ import (
)
type flowControlManager struct {
connParams handshake.ParamsNegotiator
rttStats *congestion.RTTStats
maxReceiveStreamWindow protocol.ByteCount
streamFlowController map[protocol.StreamID]*flowController
connFlowController *flowController
mutex sync.RWMutex
initialStreamSendWindow protocol.ByteCount
}
var _ FlowControlManager = &flowControlManager{}
var errMapAccess = errors.New("Error accessing the flowController map.")
var errMapAccess = errors.New("Error accessing the flowController map")
// NewFlowControlManager creates a new flow control manager
func NewFlowControlManager(
connParams handshake.ParamsNegotiator,
maxReceiveStreamWindow protocol.ByteCount,
maxReceiveConnectionWindow protocol.ByteCount,
rttStats *congestion.RTTStats,
) FlowControlManager {
return &flowControlManager{
connParams: connParams,
rttStats: rttStats,
maxReceiveStreamWindow: maxReceiveStreamWindow,
streamFlowController: make(map[protocol.StreamID]*flowController),
connFlowController: newFlowController(0, false, connParams, protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, rttStats),
connFlowController: newFlowController(0, false, protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, 0, rttStats),
}
}
@ -51,7 +50,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo
if _, ok := f.streamFlowController[streamID]; ok {
return
}
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connParams, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.rttStats)
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.initialStreamSendWindow, f.rttStats)
}
// RemoveStream removes a closed stream from flow control
@ -61,6 +60,17 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
f.mutex.Unlock()
}
func (f *flowControlManager) UpdateTransportParameters(params *handshake.TransportParameters) {
f.mutex.Lock()
defer f.mutex.Unlock()
f.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow)
f.initialStreamSendWindow = params.StreamFlowControlWindow
for _, fc := range f.streamFlowController {
fc.UpdateSendWindow(params.StreamFlowControlWindow)
}
}
// ResetStream should be called when receiving a RstStreamFrame
// it updates the byte offset to the value in the RstStreamFrame
// streamID must not be 0 here
@ -233,7 +243,6 @@ func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset pro
return false, err
}
}
return fc.UpdateSendWindow(offset), nil
}

View file

@ -3,8 +3,9 @@ package flowcontrol
import (
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
@ -15,13 +16,18 @@ var _ = Describe("Flow Control Manager", func() {
var fcm *flowControlManager
BeforeEach(func() {
mockPn := mocks.NewMockParamsNegotiator(mockCtrl)
fcm = NewFlowControlManager(mockPn, protocol.MaxByteCount, protocol.MaxByteCount, &congestion.RTTStats{}).(*flowControlManager)
fcm = NewFlowControlManager(
0x2000, // maxReceiveStreamWindow
0x4000, // maxReceiveConnectionWindow
&congestion.RTTStats{},
).(*flowControlManager)
})
It("creates a connection level flow controller", func() {
Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(0)))
Expect(fcm.streamFlowController).To(BeEmpty())
Expect(fcm.connFlowController.ContributesToConnection()).To(BeFalse())
Expect(fcm.connFlowController.sendWindow).To(BeZero())
Expect(fcm.connFlowController.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x4000)))
})
Context("creating new streams", func() {
@ -31,6 +37,19 @@ var _ = Describe("Flow Control Manager", func() {
fc := fcm.streamFlowController[5]
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.ContributesToConnection()).To(BeFalse())
// the transport parameters have not yet been received. Start with a window of size 0
Expect(fc.sendWindow).To(BeZero())
Expect(fc.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x2000)))
})
It("creates a new stream after it has received transport parameters", func() {
fcm.UpdateTransportParameters(&handshake.TransportParameters{
StreamFlowControlWindow: 0x3000,
})
fcm.NewStream(5, false)
Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5)))
fc := fcm.streamFlowController[5]
Expect(fc.sendWindow).To(Equal(protocol.ByteCount(0x3000)))
})
It("doesn't create a new flow controller if called for an existing stream", func() {
@ -51,6 +70,16 @@ var _ = Describe("Flow Control Manager", func() {
Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(5)))
})
It("updates the send windows for existing streams when receiveing the transport parameters", func() {
fcm.NewStream(5, false)
fcm.UpdateTransportParameters(&handshake.TransportParameters{
StreamFlowControlWindow: 0x3000,
ConnectionFlowControlWindow: 0x6000,
})
Expect(fcm.connFlowController.sendWindow).To(Equal(protocol.ByteCount(0x6000)))
Expect(fcm.streamFlowController[5].sendWindow).To(Equal(protocol.ByteCount(0x3000)))
})
Context("receiving data", func() {
BeforeEach(func() {
fcm.NewStream(1, false)

View file

@ -5,7 +5,6 @@ import (
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -14,7 +13,6 @@ type flowController struct {
streamID protocol.StreamID
contributesToConnection bool // does the stream contribute to connection level flow control
connParams handshake.ParamsNegotiator
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount
@ -36,19 +34,19 @@ var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
func newFlowController(
streamID protocol.StreamID,
contributesToConnection bool,
connParams handshake.ParamsNegotiator,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
rttStats *congestion.RTTStats,
) *flowController {
return &flowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connParams: connParams,
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowIncrement: receiveWindow,
maxReceiveWindowIncrement: maxReceiveWindow,
sendWindow: initialSendWindow,
}
}
@ -56,16 +54,6 @@ func (c *flowController) ContributesToConnection() bool {
return c.contributesToConnection
}
func (c *flowController) getSendWindow() protocol.ByteCount {
if c.sendWindow == 0 {
if c.streamID == 0 {
return c.connParams.GetSendConnectionFlowControlWindow()
}
return c.connParams.GetSendStreamFlowControlWindow()
}
return c.sendWindow
}
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
@ -81,16 +69,11 @@ func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
}
func (c *flowController) SendWindowSize() protocol.ByteCount {
sendWindow := c.getSendWindow()
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
}
return sendWindow - c.bytesSent
}
func (c *flowController) SendWindowOffset() protocol.ByteCount {
return c.getSendWindow()
return c.sendWindow - c.bytesSent
}
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher

View file

@ -4,7 +4,6 @@ import (
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -19,61 +18,28 @@ var _ = Describe("Flow controller", func() {
})
Context("Constructor", func() {
var rttStats *congestion.RTTStats
var mockPn *mocks.MockParamsNegotiator
rttStats := &congestion.RTTStats{}
receiveStreamWindow := protocol.ByteCount(2000)
receiveConnectionWindow := protocol.ByteCount(4000)
maxReceiveStreamWindow := protocol.ByteCount(8000)
maxReceiveConnectionWindow := protocol.ByteCount(9000)
BeforeEach(func() {
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
mockPn.EXPECT().GetSendStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(1000))
mockPn.EXPECT().GetSendConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(3000))
rttStats = &congestion.RTTStats{}
})
It("reads the stream send and receive windows when acting as stream-level flow controller", func() {
fc := newFlowController(5, true, mockPn, receiveStreamWindow, maxReceiveStreamWindow, rttStats)
It("sets the send and receive windows", func() {
receiveWindow := protocol.ByteCount(2000)
maxReceiveWindow := protocol.ByteCount(3000)
sendWindow := protocol.ByteCount(4000)
fc := newFlowController(5, true, receiveWindow, maxReceiveWindow, sendWindow, rttStats)
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.receiveWindow).To(Equal(receiveStreamWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveStreamWindow))
})
It("reads the stream send and receive windows when acting as connection-level flow controller", func() {
fc := newFlowController(0, false, mockPn, receiveConnectionWindow, maxReceiveConnectionWindow, rttStats)
Expect(fc.streamID).To(Equal(protocol.StreamID(0)))
Expect(fc.receiveWindow).To(Equal(receiveConnectionWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveConnectionWindow))
})
It("does not set the stream flow control windows for sending", func() {
fc := newFlowController(5, true, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
Expect(fc.sendWindow).To(BeZero())
})
It("does not set the connection flow control windows for sending", func() {
fc := newFlowController(0, false, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
Expect(fc.sendWindow).To(BeZero())
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow))
Expect(fc.sendWindow).To(Equal(sendWindow))
})
It("says if it contributes to connection-level flow control", func() {
fc := newFlowController(1, false, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
fc := newFlowController(1, false, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
Expect(fc.ContributesToConnection()).To(BeFalse())
fc = newFlowController(5, true, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
fc = newFlowController(5, true, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
Expect(fc.ContributesToConnection()).To(BeTrue())
})
})
Context("send flow control", func() {
var mockPn *mocks.MockParamsNegotiator
BeforeEach(func() {
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
controller.connParams = mockPn
})
It("adds bytes sent", func() {
controller.bytesSent = 5
controller.AddBytesSent(6)
@ -89,14 +55,14 @@ var _ = Describe("Flow controller", func() {
It("gets the offset of the flow control window", func() {
controller.bytesSent = 5
controller.sendWindow = 12
Expect(controller.SendWindowOffset()).To(Equal(protocol.ByteCount(12)))
Expect(controller.sendWindow).To(Equal(protocol.ByteCount(12)))
})
It("updates the size of the flow control window", func() {
controller.bytesSent = 5
updateSuccessful := controller.UpdateSendWindow(15)
Expect(updateSuccessful).To(BeTrue())
Expect(controller.SendWindowOffset()).To(Equal(protocol.ByteCount(15)))
Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15)))
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(15 - 5)))
})
@ -108,36 +74,6 @@ var _ = Describe("Flow controller", func() {
Expect(updateSuccessful).To(BeFalse())
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20)))
})
It("asks the ConnectionParametersManager for the stream flow control window size", func() {
controller.streamID = 5
mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(1000))
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(1000)))
// make sure the value is not cached
mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(2000))
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(2000)))
})
It("stops asking the ConnectionParametersManager for the flow control stream window size once a window update has arrived", func() {
controller.streamID = 5
Expect(controller.UpdateSendWindow(8000))
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(8000)))
})
It("asks the ConnectionParametersManager for the connection flow control window size", func() {
controller.streamID = 0
mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(3000))
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(3000)))
// make sure the value is not cached
mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(5000))
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(5000)))
})
It("stops asking the ConnectionParametersManager for the connection flow control window size once a window update has arrived", func() {
controller.streamID = 0
Expect(controller.UpdateSendWindow(7000))
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(7000)))
})
})
Context("receive flow control", func() {

View file

@ -1,6 +1,7 @@
package flowcontrol
import "github.com/lucas-clemente/quic-go/internal/protocol"
import "github.com/lucas-clemente/quic-go/internal/handshake"
// WindowUpdate provides the data for WindowUpdateFrames.
type WindowUpdate struct {
@ -12,6 +13,7 @@ type WindowUpdate struct {
type FlowControlManager interface {
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
RemoveStream(streamID protocol.StreamID)
UpdateTransportParameters(*handshake.TransportParameters)
// methods needed for receiving data
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error

View file

@ -49,10 +49,11 @@ type cryptoSetupClient struct {
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel
requestConnIDOmission bool
params *paramsNegotiatorGQUIC
params *TransportParameters
}
var _ CryptoSetup = &cryptoSetupClient{}
@ -70,24 +71,24 @@ func NewCryptoSetupClient(
version protocol.VersionNumber,
tlsConfig *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, ParamsNegotiator, error) {
pn := newParamsNegotiatorGQUIC(protocol.PerspectiveClient, version, params)
) (CryptoSetup, error) {
return &cryptoSetupClient{
hostname: hostname,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConfig),
params: pn,
requestConnIDOmission: params.RequestConnectionIDOmission,
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
paramsChan: paramsChan,
aeadChanged: aeadChanged,
negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte),
}, pn, nil
}, nil
}
func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
@ -141,15 +142,21 @@ func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
utils.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
err = h.handleREJMessage(message.Data)
case TagSHLO:
err = h.handleSHLOMessage(message.Data)
default:
return qerr.InvalidCryptoMessageType
if err := h.handleREJMessage(message.Data); err != nil {
return err
}
case TagSHLO:
params, err := h.handleSHLOMessage(message.Data)
if err != nil {
return err
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
default:
return qerr.InvalidCryptoMessageType
}
}
}
@ -215,12 +222,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
return nil
}
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.receivedSecurePacket {
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
}
if sno, ok := cryptoData[TagSNO]; ok {
@ -229,22 +236,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
serverPubs, ok := cryptoData[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
verTag, ok := cryptoData[TagVER]
if !ok {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
}
if !h.validateVersionList(verTag) {
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
nonce := append(h.nonc, h.sno...)
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
if err != nil {
return err
return nil, err
}
leafCert := h.certManager.GetLeafCert()
@ -261,18 +268,14 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
protocol.PerspectiveClient,
)
if err != nil {
return err
return nil, err
}
err = h.params.SetFromMap(cryptoData)
params, err := readHelloMap(cryptoData)
if err != nil {
return qerr.InvalidCryptoMessageParameter
return nil, qerr.InvalidCryptoMessageParameter
}
h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
return nil
return params, nil
}
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
@ -405,10 +408,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
}
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags, err := h.params.GetHelloMap()
if err != nil {
return nil, err
}
tags := h.params.getHelloMap()
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
@ -421,9 +421,6 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
tags[TagVER] = versionTag
if h.requestConnIDOmission {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
if len(h.stk) > 0 {
tags[TagSTK] = h.stk
}

View file

@ -79,6 +79,7 @@ var _ = Describe("Client Crypto Setup", func() {
keyDerivationCalledWith *keyDerivationValues
shloMap map[Tag][]byte
aeadChanged chan protocol.EncryptionLevel
paramsChan chan TransportParameters
)
BeforeEach(func() {
@ -108,13 +109,16 @@ var _ = Describe("Client Crypto Setup", func() {
stream = newMockStream()
certManager = &mockCertManager{}
version := protocol.Version37
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, _, err := NewCryptoSetupClient(
csInt, err := NewCryptoSetupClient(
"hostname",
0,
version,
nil,
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
paramsChan,
aeadChanged,
nil,
)
@ -222,7 +226,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("returns the right error when detecting a downgrade attack", func() {
cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever}
cs.receivedSecurePacket = true
err := cs.handleSHLOMessage(map[Tag][]byte{
_, err := cs.handleSHLOMessage(map[Tag][]byte{
TagPUBS: []byte{0},
TagVER: []byte{0, 1},
})
@ -385,7 +389,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("rejects unencrypted SHLOs", func() {
cs.receivedSecurePacket = false
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
Expect(aeadChanged).ToNot(Receive())
Expect(aeadChanged).ToNot(BeClosed())
@ -393,14 +397,14 @@ var _ = Describe("Client Crypto Setup", func() {
It("rejects SHLOs without a PUBS", func() {
delete(shloMap, TagPUBS)
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
Expect(aeadChanged).ToNot(BeClosed())
})
It("rejects SHLOs without a version list", func() {
delete(shloMap, TagVER)
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
Expect(aeadChanged).ToNot(BeClosed())
})
@ -412,36 +416,58 @@ var _ = Describe("Client Crypto Setup", func() {
b := &bytes.Buffer{}
utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(ver))
shloMap[TagVER] = b.Bytes()
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).ToNot(HaveOccurred())
})
It("reads the server nonce, if set", func() {
shloMap[TagSNO] = []byte("server nonce")
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).ToNot(HaveOccurred())
Expect(cs.sno).To(Equal(shloMap[TagSNO]))
})
It("creates a forwardSecureAEAD", func() {
shloMap[TagSNO] = []byte("server nonce")
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).ToNot(HaveOccurred())
Expect(cs.forwardSecureAEAD).ToNot(BeNil())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).To(BeClosed())
})
It("reads the connection paramaters", func() {
shloMap[TagICSL] = []byte{13, 0, 0, 0} // 13 seconds
err := cs.handleSHLOMessage(shloMap)
params, err := cs.handleSHLOMessage(shloMap)
Expect(err).ToNot(HaveOccurred())
Expect(cs.params.GetRemoteIdleTimeout()).To(Equal(13 * time.Second))
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
})
It("closes the aeadChanged when receiving an SHLO", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream(stream)
Expect(err).ToNot(HaveOccurred())
}()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
Eventually(aeadChanged).Should(BeClosed())
})
It("passes the transport parameters on the channel", func() {
shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba}
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
go func() {
defer GinkgoRecover()
err := cs.HandleCryptoStream(stream)
Expect(err).ToNot(HaveOccurred())
}()
var params TransportParameters
Eventually(paramsChan).Should(Receive(&params))
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d)))
})
It("errors if it can't read a connection parameter", func() {
shloMap[TagICSL] = []byte{3, 0, 0} // 1 byte too short
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).To(MatchError(qerr.InvalidCryptoMessageParameter))
})
})
@ -488,15 +514,14 @@ var _ = Describe("Client Crypto Setup", func() {
})
It("requests to omit the connection ID", func() {
cs.requestConnIDOmission = true
cs.params.OmitConnectionID = true
tags, err := cs.getTags()
Expect(err).ToNot(HaveOccurred())
Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
})
It("adds the tags returned from the connectionParametersManager to the CHLO", func() {
pnTags, err := cs.params.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
pnTags := cs.params.getHelloMap()
Expect(pnTags).ToNot(BeEmpty())
tags, err := cs.getTags()
Expect(err).ToNot(HaveOccurred())
@ -588,7 +613,7 @@ var _ = Describe("Client Crypto Setup", func() {
doSHLO := func() {
cs.receivedSecurePacket = true
err := cs.handleSHLOMessage(shloMap)
_, err := cs.handleSHLOMessage(shloMap)
Expect(err).ToNot(HaveOccurred())
}

View file

@ -40,6 +40,9 @@ type cryptoSetupServer struct {
receivedForwardSecurePacket bool
receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool
paramsChan chan<- TransportParameters
aeadChanged chan<- protocol.EncryptionLevel
keyDerivation QuicCryptoKeyDerivationFunction
@ -47,7 +50,7 @@ type cryptoSetupServer struct {
cryptoStream io.ReadWriter
params *paramsNegotiatorGQUIC
params *TransportParameters
mutex sync.RWMutex
}
@ -72,14 +75,14 @@ func NewCryptoSetup(
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, ParamsNegotiator, error) {
) (CryptoSetup, error) {
stkGenerator, err := NewCookieGenerator()
if err != nil {
return nil, nil, err
return nil, err
}
pn := newParamsNegotiatorGQUIC(protocol.PerspectiveServer, version, params)
return &cryptoSetupServer{
connID: connID,
remoteAddr: remoteAddr,
@ -90,11 +93,12 @@ func NewCryptoSetup(
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
params: pn,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
aeadChanged: aeadChanged,
}, pn, nil
}, nil
}
// HandleCryptoStream reads and writes messages on the crypto stream
@ -163,6 +167,16 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
return false, err
}
params, err := readHelloMap(cryptoData)
if err != nil {
return false, err
}
// blocks until the session has received the parameters
if !h.receivedParams {
h.receivedParams = true
h.paramsChan <- *params
}
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
reply, err = h.handleCHLO(sni, chloData, cryptoData)
@ -418,14 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err
}
if err := h.params.SetFromMap(cryptoData); err != nil {
return nil, err
}
replyMap, err := h.params.GetHelloMap()
if err != nil {
return nil, err
}
replyMap := h.params.getHelloMap()
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {

View file

@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -167,6 +168,7 @@ var _ = Describe("Server Crypto Setup", func() {
scfg *ServerConfig
cs *cryptoSetupServer
stream *mockStream
paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel
nonce32 []byte
versionTag []byte
@ -183,6 +185,8 @@ var _ = Describe("Server Crypto Setup", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
expectedInitialNonceLen = 32
expectedFSNonceLen = 64
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
stream = newMockStream()
kex = &mockKEX{}
@ -197,7 +201,7 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred())
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99}
csInt, _, err := NewCryptoSetup(
csInt, err := NewCryptoSetup(
protocol.ConnectionID(42),
remoteAddr,
version,
@ -205,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() {
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
supportedVersions,
nil,
paramsChan,
aeadChanged,
)
Expect(err).NotTo(HaveOccurred())
@ -285,6 +290,16 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).To(MatchError(ErrNSTPExperiment))
})
It("reads the transport parameters sent by the client", func() {
sourceAddrValid = true
fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0}
_, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), fullCHLO)
Expect(err).ToNot(HaveOccurred())
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
})
It("generates REJ messages", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)

View file

@ -38,52 +38,52 @@ var newMintController = func(conn *mint.Conn) crypto.MintController {
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
tlsConfig *tls.Config,
transportParams *TransportParameters,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) (CryptoSetup, ParamsNegotiator, error) {
) (CryptoSetup, error) {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
if err != nil {
return nil, nil, err
return nil, err
}
params := newParamsNegotiator(protocol.PerspectiveServer, version, transportParams)
return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer,
mintConf: mintConf,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
extensionHandler: newExtensionHandlerServer(params, supportedVersions, version),
}, params, nil
extensionHandler: newExtensionHandlerServer(params, paramsChan, supportedVersions, version),
}, nil
}
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
func NewCryptoSetupTLSClient(
hostname string, // only needed for the client
tlsConfig *tls.Config,
transportParams *TransportParameters,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) (CryptoSetup, ParamsNegotiator, error) {
) (CryptoSetup, error) {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
if err != nil {
return nil, nil, err
return nil, err
}
mintConf.ServerName = hostname
params := newParamsNegotiator(protocol.PerspectiveClient, version, transportParams)
return &cryptoSetupTLS{
perspective: protocol.PerspectiveClient,
mintConf: mintConf,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
extensionHandler: newExtensionHandlerClient(params, initialVersion, supportedVersions, version),
}, params, nil
extensionHandler: newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version),
}, nil
}
func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error {

View file

@ -33,16 +33,19 @@ func mockKeyDerivation(crypto.MintController, protocol.Perspective) (crypto.AEAD
var _ = Describe("TLS Crypto Setup", func() {
var (
cs *cryptoSetupTLS
paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel
mintControllerConstructor = newMintController
)
BeforeEach(func() {
paramsChan = make(chan TransportParameters)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, _, err := NewCryptoSetupTLSServer(
csInt, err := NewCryptoSetupTLSServer(
testdata.GetTLSConfig(),
&TransportParameters{},
paramsChan,
aeadChanged,
nil,
protocol.VersionTLS,

View file

@ -2,7 +2,6 @@ package handshake
import (
"io"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
@ -25,9 +24,3 @@ type CryptoSetup interface {
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
RequestConnectionIDOmission bool
IdleTimeout time.Duration
}

View file

@ -1,111 +0,0 @@
package handshake
import (
"encoding/binary"
"errors"
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type paramsNegotiator struct {
paramsNegotiatorBase
}
var _ ParamsNegotiator = &paramsNegotiator{}
// newParamsNegotiator creates a new connection parameters manager
func newParamsNegotiator(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiator {
h := &paramsNegotiator{}
h.perspective = pers
h.version = v
h.init(params)
return h
}
func (h *paramsNegotiator) SetFromTransportParameters(params []transportParameter) error {
h.mutex.Lock()
defer h.mutex.Unlock()
var foundInitialMaxStreamData bool
var foundInitialMaxData bool
var foundInitialMaxStreamID bool
var foundIdleTimeout bool
for _, p := range params {
switch p.Parameter {
case initialMaxStreamDataParameterID:
foundInitialMaxStreamData = true
if len(p.Value) != 4 {
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
}
h.sendStreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
utils.Debugf("h.sendStreamFlowControlWindow: %#x", h.sendStreamFlowControlWindow)
case initialMaxDataParameterID:
foundInitialMaxData = true
if len(p.Value) != 4 {
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
}
h.sendConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
utils.Debugf("h.sendConnectionFlowControlWindow: %#x", h.sendConnectionFlowControlWindow)
case initialMaxStreamIDParameterID:
foundInitialMaxStreamID = true
if len(p.Value) != 4 {
return fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value))
}
// TODO: handle this value
case idleTimeoutParameterID:
foundIdleTimeout = true
if len(p.Value) != 2 {
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
}
h.setRemoteIdleTimeout(time.Duration(binary.BigEndian.Uint16(p.Value)) * time.Second)
case omitConnectionIDParameterID:
if len(p.Value) != 0 {
return fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
}
h.omitConnectionID = true
}
}
if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) {
return errors.New("missing parameter")
}
return nil
}
func (h *paramsNegotiator) GetTransportParameters() []transportParameter {
initialMaxStreamData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(protocol.ReceiveStreamFlowControlWindow))
initialMaxData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxData, uint32(protocol.ReceiveConnectionFlowControlWindow))
initialMaxStreamID := make([]byte, 4)
// TODO: use a reasonable value here
binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32)
idleTimeout := make([]byte, 2)
binary.BigEndian.PutUint16(idleTimeout, uint16(h.idleTimeout))
maxPacketSize := make([]byte, 2)
binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
params := []transportParameter{
{initialMaxStreamDataParameterID, initialMaxStreamData},
{initialMaxDataParameterID, initialMaxData},
{initialMaxStreamIDParameterID, initialMaxStreamID},
{idleTimeoutParameterID, idleTimeout},
{maxPacketSizeParameterID, maxPacketSize},
}
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.omitConnectionID {
params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}})
}
return params
}
func (h *paramsNegotiator) OmitConnectionID() bool {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.omitConnectionID
}

View file

@ -1,85 +0,0 @@
package handshake
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// The ParamsNegotiator negotiates and stores the connection parameters.
// It can be used for a server as well as a client.
type ParamsNegotiator interface {
GetSendStreamFlowControlWindow() protocol.ByteCount
GetSendConnectionFlowControlWindow() protocol.ByteCount
GetMaxOutgoingStreams() uint32
// get the idle timeout that was sent by the peer
GetRemoteIdleTimeout() time.Duration
// determines if the client requests omission of connection IDs.
OmitConnectionID() bool
}
// For the server:
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
// 2. call GetHelloMap to get the values to send in the SHLO
// For the client:
// 1. call GetHelloMap to get the values to send in a CHLO
// 2. call SetFromMap with the values received in the SHLO
type paramsNegotiatorBase struct {
mutex sync.RWMutex
version protocol.VersionNumber
perspective protocol.Perspective
flowControlNegotiated bool
omitConnectionID bool
requestConnectionIDOmission bool
maxOutgoingStreams uint32
idleTimeout time.Duration
remoteIdleTimeout time.Duration
sendStreamFlowControlWindow protocol.ByteCount
sendConnectionFlowControlWindow protocol.ByteCount
}
func (h *paramsNegotiatorBase) init(params *TransportParameters) {
h.sendStreamFlowControlWindow = protocol.InitialStreamFlowControlWindow // can only be changed by the client
h.sendConnectionFlowControlWindow = protocol.InitialConnectionFlowControlWindow // can only be changed by the client
h.requestConnectionIDOmission = params.RequestConnectionIDOmission
h.idleTimeout = params.IdleTimeout
// use this as a default value. As soon as the client sends its value, this gets updated
h.maxOutgoingStreams = protocol.MaxIncomingStreams
}
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
func (h *paramsNegotiatorBase) GetSendStreamFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.sendStreamFlowControlWindow
}
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
func (h *paramsNegotiatorBase) GetSendConnectionFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.sendConnectionFlowControlWindow
}
func (h *paramsNegotiatorBase) GetMaxOutgoingStreams() uint32 {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.maxOutgoingStreams
}
func (h *paramsNegotiatorBase) setRemoteIdleTimeout(t time.Duration) {
h.remoteIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, t)
}
func (h *paramsNegotiatorBase) GetRemoteIdleTimeout() time.Duration {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.remoteIdleTimeout
}

View file

@ -1,116 +0,0 @@
package handshake
import (
"bytes"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// errMalformedTag is returned when the tag value cannot be read
var (
errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
errFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
)
type paramsNegotiatorGQUIC struct {
paramsNegotiatorBase
}
var _ ParamsNegotiator = &paramsNegotiatorGQUIC{}
// newParamsNegotiatorGQUIC creates a new connection parameters manager
func newParamsNegotiatorGQUIC(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiatorGQUIC {
h := &paramsNegotiatorGQUIC{}
h.perspective = pers
h.version = v
h.init(params)
return h
}
// SetFromMap reads all params.
func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error {
h.mutex.Lock()
defer h.mutex.Unlock()
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return errMalformedTag
}
h.omitConnectionID = (clientValue == 0)
}
if value, ok := params[TagMIDS]; ok {
clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return errMalformedTag
}
h.maxOutgoingStreams = clientValue
}
if value, ok := params[TagICSL]; ok {
clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return errMalformedTag
}
h.setRemoteIdleTimeout(time.Duration(clientValue) * time.Second)
}
if value, ok := params[TagSFCW]; ok {
if h.flowControlNegotiated {
return errFlowControlRenegotiationNotSupported
}
sendStreamFlowControlWindow, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return errMalformedTag
}
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
}
if value, ok := params[TagCFCW]; ok {
if h.flowControlNegotiated {
return errFlowControlRenegotiationNotSupported
}
sendConnectionFlowControlWindow, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return errMalformedTag
}
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
}
_, containsSFCW := params[TagSFCW]
_, containsCFCW := params[TagCFCW]
if containsCFCW || containsSFCW {
h.flowControlNegotiated = true
}
return nil
}
// GetHelloMap gets all parameters needed for the Hello message.
func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) {
sfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(sfcw, uint32(protocol.ReceiveStreamFlowControlWindow))
cfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(cfcw, uint32(protocol.ReceiveConnectionFlowControlWindow))
mids := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(mids, protocol.MaxIncomingStreams)
icsl := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(icsl, uint32(h.idleTimeout/time.Second))
return map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}, nil
}
func (h *paramsNegotiatorGQUIC) OmitConnectionID() bool {
if h.perspective == protocol.PerspectiveClient {
return false
}
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.omitConnectionID
}

View file

@ -1,231 +0,0 @@
package handshake
import (
"encoding/binary"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Params Negotiator (for gQUIC)", func() {
var pn *paramsNegotiatorGQUIC // a connectionParametersManager for a server
var pnClient *paramsNegotiatorGQUIC
idleTimeout := 42 * time.Second
BeforeEach(func() {
pn = newParamsNegotiatorGQUIC(
protocol.PerspectiveServer,
protocol.VersionWhatever,
&TransportParameters{
IdleTimeout: idleTimeout,
},
)
pnClient = newParamsNegotiatorGQUIC(
protocol.PerspectiveClient,
protocol.VersionWhatever,
&TransportParameters{
IdleTimeout: idleTimeout,
},
)
})
Context("SHLO", func() {
BeforeEach(func() {
// these tests should only use the server connectionParametersManager. Make them panic if they don't
pnClient = nil
})
It("returns all parameters necessary for the SHLO", func() {
entryMap, err := pn.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
Expect(entryMap).To(HaveKey(TagICSL))
Expect(entryMap).To(HaveKey(TagMIDS))
})
It("sets the stream-level flow control windows in SHLO", func() {
entryMap, err := pn.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
expected := make([]byte, 4)
binary.LittleEndian.PutUint32(expected, uint32(protocol.ReceiveStreamFlowControlWindow))
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, expected))
})
It("sets the connection-level flow control windows in SHLO", func() {
entryMap, err := pn.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
expected := make([]byte, 4)
binary.LittleEndian.PutUint32(expected, uint32(protocol.ReceiveConnectionFlowControlWindow))
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, expected))
})
It("sets the connection-level flow control windows in SHLO", func() {
pn.idleTimeout = 0xdecafbad * time.Second
entryMap, err := pn.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
Expect(entryMap).To(HaveKey(TagICSL))
Expect(entryMap[TagICSL]).To(Equal([]byte{0xad, 0xfb, 0xca, 0xde}))
})
It("always sends its own value for the maximum incoming dynamic streams in the SHLO", func() {
err := pn.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}})
Expect(err).ToNot(HaveOccurred())
entryMap, err := pn.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
Expect(entryMap[TagMIDS]).To(Equal([]byte{byte(protocol.MaxIncomingStreams), 0, 0, 0}))
})
})
Context("CHLO", func() {
BeforeEach(func() {
// these tests should only use the client connectionParametersManager. Make them panic if they don't
pn = nil
})
It("has the right values", func() {
entryMap, err := pnClient.GetHelloMap()
Expect(err).ToNot(HaveOccurred())
Expect(entryMap).To(HaveKey(TagICSL))
Expect(binary.LittleEndian.Uint32(entryMap[TagICSL])).To(BeEquivalentTo(idleTimeout / time.Second))
Expect(entryMap).To(HaveKey(TagMIDS))
Expect(binary.LittleEndian.Uint32(entryMap[TagMIDS])).To(BeEquivalentTo(protocol.MaxIncomingStreams))
Expect(entryMap).To(HaveKey(TagSFCW))
Expect(binary.LittleEndian.Uint32(entryMap[TagSFCW])).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow))
Expect(entryMap).To(HaveKey(TagCFCW))
Expect(binary.LittleEndian.Uint32(entryMap[TagCFCW])).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow))
})
})
Context("Omitted connection IDs", func() {
It("does not send omitted connection IDs if the TCID tag is missing", func() {
Expect(pn.OmitConnectionID()).To(BeFalse())
})
It("reads the tag for omitted connection IDs", func() {
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
pn.SetFromMap(values)
Expect(pn.OmitConnectionID()).To(BeTrue())
})
It("ignores the TCID tag, as a client", func() {
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
pnClient.SetFromMap(values)
Expect(pnClient.OmitConnectionID()).To(BeFalse())
})
It("errors when given an invalid value", func() {
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
err := pn.SetFromMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
})
Context("flow control", func() {
It("has the correct default flow control windows for sending", func() {
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow))
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow))
Expect(pnClient.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow))
Expect(pnClient.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow))
})
It("sets a new stream-level flow control window for sending", func() {
values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF}}
err := pn.SetFromMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
})
It("does not change the stream-level flow control window when given an invalid value", func() {
values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short
err := pn.SetFromMap(values)
Expect(err).To(MatchError(errMalformedTag))
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow))
})
It("sets a new connection-level flow control window for sending", func() {
values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF}}
err := pn.SetFromMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
})
It("does not change the connection-level flow control window when given an invalid value", func() {
values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short
err := pn.SetFromMap(values)
Expect(err).To(MatchError(errMalformedTag))
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow))
})
It("does not allow renegotiation of flow control parameters", func() {
values := map[Tag][]byte{
TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF},
TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF},
}
err := pn.SetFromMap(values)
Expect(err).ToNot(HaveOccurred())
values = map[Tag][]byte{
TagCFCW: {0x13, 0x37, 0x13, 0x37},
TagSFCW: {0x13, 0x37, 0x13, 0x37},
}
err = pn.SetFromMap(values)
Expect(err).To(MatchError(errFlowControlRenegotiationNotSupported))
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
})
})
Context("idle timeout", func() {
It("sets the remote idle timeout", func() {
values := map[Tag][]byte{
TagICSL: {10, 0, 0, 0},
}
err := pn.SetFromMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetRemoteIdleTimeout()).To(Equal(10 * time.Second))
})
It("doesn't allow values below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
values := map[Tag][]byte{
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
}
err := pn.SetFromMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetRemoteIdleTimeout()).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("errors when given an invalid value", func() {
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
err := pn.SetFromMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
})
Context("max streams per connection", func() {
It("errors when given an invalid max dynamic incoming streams per connection value", func() {
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
err := pn.SetFromMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
Context("outgoing connections", func() {
It("sets the negotiated max streams per connection value", func() {
// this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection
err := pn.SetFromMap(map[Tag][]byte{
TagMIDS: {2, 0, 0, 0},
})
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(2)))
})
It("uses the the MSPC value, if no MIDS is given", func() {
err := pn.SetFromMap(map[Tag][]byte{
TagMIDS: {3, 0, 0, 0},
})
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(3)))
})
})
})
})

View file

@ -1,154 +0,0 @@
package handshake
import (
"encoding/binary"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Params Negotiator (for TLS)", func() {
var params map[transportParameterID][]byte
var pn *paramsNegotiator
paramsMapToList := func(p map[transportParameterID][]byte) []transportParameter {
var list []transportParameter
for id, val := range p {
list = append(list, transportParameter{id, val})
}
return list
}
paramsListToMap := func(l []transportParameter) map[transportParameterID][]byte {
p := make(map[transportParameterID][]byte)
for _, v := range l {
p[v.Parameter] = v.Value
}
return p
}
BeforeEach(func() {
pn = newParamsNegotiator(
protocol.PerspectiveServer,
protocol.VersionWhatever,
&TransportParameters{},
)
params = map[transportParameterID][]byte{
initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44},
initialMaxDataParameterID: []byte{0x22, 0x33, 0x44, 0x55},
initialMaxStreamIDParameterID: []byte{0x33, 0x44, 0x55, 0x66},
idleTimeoutParameterID: []byte{0x13, 0x37},
}
})
Context("getting", func() {
It("creates the parameters list", func() {
pn.idleTimeout = 0xcafe
buf := make([]byte, 4)
values := paramsListToMap(pn.GetTransportParameters())
Expect(values).To(HaveLen(5))
binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveStreamFlowControlWindow))
Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, buf))
binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveConnectionFlowControlWindow))
Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, buf))
Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff}))
Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe}))
Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac
})
It("request ommision of the connection ID", func() {
pn.omitConnectionID = true
values := paramsListToMap(pn.GetTransportParameters())
Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{}))
})
})
Context("setting", func() {
It("reads parameters", func() {
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0x11223344)))
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0x22334455)))
Expect(pn.GetRemoteIdleTimeout()).To(Equal(0x1337 * time.Second))
Expect(pn.OmitConnectionID()).To(BeFalse())
})
It("saves if it should omit the connection ID", func() {
params[omitConnectionIDParameterID] = []byte{}
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).ToNot(HaveOccurred())
Expect(pn.OmitConnectionID()).To(BeTrue())
})
It("rejects the parameters if the initial_max_stream_data is missing", func() {
delete(params, initialMaxStreamDataParameterID)
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("missing parameter"))
})
It("rejects the parameters if the initial_max_data is missing", func() {
delete(params, initialMaxDataParameterID)
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("missing parameter"))
})
It("rejects the parameters if the initial_max_stream_id is missing", func() {
delete(params, initialMaxStreamIDParameterID)
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("missing parameter"))
})
It("rejects the parameters if the idle_timeout is missing", func() {
delete(params, idleTimeoutParameterID)
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("missing parameter"))
})
It("doesn't allow values below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
params[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).ToNot(HaveOccurred())
Expect(pn.GetRemoteIdleTimeout()).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
params[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_data has the wrong length", func() {
params[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_stream_id has the wrong length", func() {
params[initialMaxStreamIDParameterID] = []byte{0x11, 0x22, 0x33, 0x44, 0x55} // should be 4 bytes
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("wrong length for initial_max_stream_id: 5 (expected 4)"))
})
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
params[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
})
It("rejects the parameters if omit_connection_id is non-empty", func() {
params[omitConnectionIDParameterID] = []byte{0} // should be empty
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)"))
})
It("ignores unknown parameters", func() {
params[1337] = []byte{42}
err := pn.SetFromTransportParameters(paramsMapToList(params))
Expect(err).ToNot(HaveOccurred())
})
})
})

View file

@ -3,6 +3,7 @@ package handshake
import (
"errors"
"fmt"
"math"
"github.com/lucas-clemente/quic-go/qerr"
@ -12,7 +13,8 @@ import (
)
type extensionHandlerClient struct {
params *paramsNegotiator
params *TransportParameters
paramsChan chan<- TransportParameters
initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber
@ -21,9 +23,16 @@ type extensionHandlerClient struct {
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
func newExtensionHandlerClient(params *paramsNegotiator, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerClient {
func newExtensionHandlerClient(
params *TransportParameters,
paramsChan chan<- TransportParameters,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) *extensionHandlerClient {
return &extensionHandlerClient{
params: params,
paramsChan: paramsChan,
initialVersion: initialVersion,
supportedVersions: supportedVersions,
version: version,
@ -38,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
data, err := syntax.Marshal(clientHelloTransportParameters{
NegotiatedVersion: uint32(h.version),
InitialVersion: uint32(h.initialVersion),
Parameters: h.params.GetTransportParameters(),
Parameters: h.params.getTransportParameters(),
})
if err != nil {
return err
@ -99,5 +108,12 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
// TODO: return the right error here
return errors.New("server didn't sent stateless_reset_token")
}
return h.params.SetFromTransportParameters(eetp.Parameters)
params, err := readTransportParamters(eetp.Parameters)
if err != nil {
return err
}
// TODO(#878): remove this when implementing the MAX_STREAM_ID frame
params.MaxStreams = math.MaxUint32
h.paramsChan <- *params
return nil
}

View file

@ -12,12 +12,16 @@ import (
)
var _ = Describe("TLS Extension Handler, for the client", func() {
var handler *extensionHandlerClient
var el mint.ExtensionList
var (
handler *extensionHandlerClient
el mint.ExtensionList
paramsChan chan TransportParameters
)
BeforeEach(func() {
pn := &paramsNegotiator{}
handler = newExtensionHandlerClient(pn, protocol.VersionWhatever, nil, protocol.VersionWhatever)
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
paramsChan = make(chan TransportParameters, 1)
handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever)
el = make(mint.ExtensionList, 0)
})
@ -78,7 +82,9 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
addEncryptedExtensionsWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344))
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
})
It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() {

View file

@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"math"
"github.com/lucas-clemente/quic-go/qerr"
@ -13,7 +14,8 @@ import (
)
type extensionHandlerServer struct {
params *paramsNegotiator
params *TransportParameters
paramsChan chan<- TransportParameters
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
@ -21,9 +23,15 @@ type extensionHandlerServer struct {
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
func newExtensionHandlerServer(params *paramsNegotiator, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerServer {
func newExtensionHandlerServer(
params *TransportParameters,
paramsChan chan<- TransportParameters,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) *extensionHandlerServer {
return &extensionHandlerServer{
params: params,
paramsChan: paramsChan,
version: version,
supportedVersions: supportedVersions,
}
@ -35,7 +43,8 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
}
transportParams := append(
h.params.GetTransportParameters(),
h.params.getTransportParameters(),
// TODO(#855): generate a real token
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
)
supportedVersions := make([]uint32, len(h.supportedVersions))
@ -89,5 +98,12 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
return errors.New("client sent a stateless reset token")
}
}
return h.params.SetFromTransportParameters(chtp.Parameters)
params, err := readTransportParamters(chtp.Parameters)
if err != nil {
return err
}
// TODO(#878): remove this when implementing the MAX_STREAM_ID frame
params.MaxStreams = math.MaxUint32
h.paramsChan <- *params
return nil
}

View file

@ -19,12 +19,16 @@ func parameterMapToList(paramMap map[transportParameterID][]byte) []transportPar
}
var _ = Describe("TLS Extension Handler, for the server", func() {
var handler *extensionHandlerServer
var el mint.ExtensionList
var (
handler *extensionHandlerServer
el mint.ExtensionList
paramsChan chan TransportParameters
)
BeforeEach(func() {
pn := &paramsNegotiator{}
handler = newExtensionHandlerServer(pn, nil, protocol.VersionWhatever)
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
paramsChan = make(chan TransportParameters, 1)
handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever)
el = make(mint.ExtensionList, 0)
})
@ -79,7 +83,9 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
addClientHelloWithParameters(parameters)
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred())
Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344))
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
})
It("errors if the ClientHello doesn't contain TransportParameters", func() {

View file

@ -6,25 +6,6 @@ import (
)
var _ = Describe("TLS extension body", func() {
// var server, client mint.AppExtensionHandler
// var el mint.ExtensionList
// BeforeEach(func() {
// server = &extensionHandler{perspective: protocol.PerspectiveServer}
// client = &extensionHandler{perspective: protocol.PerspectiveClient}
// // el = make(mint.ExtensionList, 0)
// // TODO: initialize el with some dummy extensions
// })
// It("writes and reads a ClientHello", func() {
// err := client.Send(mint.HandshakeTypeClientHello, &el)
// Expect(err).ToNot(HaveOccurred())
// ch := &tlsExtensionBody{}
// found := el.Find(ch)
// Expect(found).To(BeTrue())
// err = server.Receive(mint.HandshakeTypeClientHello, &el)
// Expect(err).ToNot(HaveOccurred())
// })
var extBody *tlsExtensionBody
BeforeEach(func() {

View file

@ -0,0 +1,246 @@
package handshake
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Transport Parameters", func() {
Context("for gQUIC", func() {
Context("parsing", func() {
It("sets all values", func() {
values := map[Tag][]byte{
TagSFCW: {0xad, 0xfb, 0xca, 0xde},
TagCFCW: {0xef, 0xbe, 0xad, 0xde},
TagICSL: {0x0d, 0xf0, 0xad, 0xba},
TagMIDS: {0xff, 0x10, 0x00, 0xc0},
}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef)))
Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second))
Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff)))
Expect(params.OmitConnectionID).To(BeFalse())
})
It("reads if the connection ID should be omitted", func() {
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.OmitConnectionID).To(BeTrue())
})
It("doesn't allow idle timeouts below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
values := map[Tag][]byte{
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("errors when given an invalid SFCW value", func() {
values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid CFCW value", func() {
values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid TCID value", func() {
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid ICSL value", func() {
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid MIDS value", func() {
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
})
Context("writing", func() {
It("returns all necessary parameters ", func() {
params := &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xbaaaaaad * time.Second,
MaxStreams: 0x1337,
}
entryMap := params.getHelloMap()
Expect(entryMap).To(HaveLen(4))
Expect(entryMap).ToNot(HaveKey(TagTCID))
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde}))
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde}))
Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba}))
Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0}))
})
It("requests omission of the connection ID", func() {
params := &TransportParameters{OmitConnectionID: true}
entryMap := params.getHelloMap()
Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
})
})
})
Context("for TLS", func() {
paramsMapToList := func(p map[transportParameterID][]byte) []transportParameter {
var list []transportParameter
for id, val := range p {
list = append(list, transportParameter{id, val})
}
return list
}
Context("parsing", func() {
var parameters map[transportParameterID][]byte
BeforeEach(func() {
parameters = map[transportParameterID][]byte{
initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44},
initialMaxDataParameterID: []byte{0x22, 0x33, 0x44, 0x55},
initialMaxStreamIDParameterID: []byte{0x33, 0x44, 0x55, 0x66},
idleTimeoutParameterID: []byte{0x13, 0x37},
}
})
It("reads parameters", func() {
params, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
Expect(params.OmitConnectionID).To(BeFalse())
})
It("saves if it should omit the connection ID", func() {
parameters[omitConnectionIDParameterID] = []byte{}
params, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.OmitConnectionID).To(BeTrue())
})
It("rejects the parameters if the initial_max_stream_data is missing", func() {
delete(parameters, initialMaxStreamDataParameterID)
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("missing parameter"))
})
It("rejects the parameters if the initial_max_data is missing", func() {
delete(parameters, initialMaxDataParameterID)
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("missing parameter"))
})
It("rejects the parameters if the initial_max_stream_id is missing", func() {
delete(parameters, initialMaxStreamIDParameterID)
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("missing parameter"))
})
It("rejects the parameters if the idle_timeout is missing", func() {
delete(parameters, idleTimeoutParameterID)
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("missing parameter"))
})
It("doesn't allow values below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
params, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_data has the wrong length", func() {
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_stream_id has the wrong length", func() {
parameters[initialMaxStreamIDParameterID] = []byte{0x11, 0x22, 0x33, 0x44, 0x55} // should be 4 bytes
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id: 5 (expected 4)"))
})
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
})
It("rejects the parameters if omit_connection_id is non-empty", func() {
parameters[omitConnectionIDParameterID] = []byte{0} // should be empty
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)"))
})
It("ignores unknown parameters", func() {
parameters[1337] = []byte{42}
_, err := readTransportParamters(paramsMapToList(parameters))
Expect(err).ToNot(HaveOccurred())
})
})
Context("writing", func() {
var params *TransportParameters
paramsListToMap := func(l []transportParameter) map[transportParameterID][]byte {
p := make(map[transportParameterID][]byte)
for _, v := range l {
p[v.Parameter] = v.Value
}
return p
}
BeforeEach(func() {
params = &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xcafe,
}
})
It("creates the parameters list", func() {
values := paramsListToMap(params.getTransportParameters())
Expect(values).To(HaveLen(5))
Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, []byte{0xde, 0xad, 0xbe, 0xef}))
Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, []byte{0xde, 0xca, 0xfb, 0xad}))
Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff}))
Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe}))
Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac
})
It("request ommision of the connection ID", func() {
params.OmitConnectionID = true
values := paramsListToMap(params.getTransportParameters())
Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{}))
})
})
})
})

View file

@ -0,0 +1,167 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// errMalformedTag is returned when the tag value cannot be read
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
StreamFlowControlWindow protocol.ByteCount
ConnectionFlowControlWindow protocol.ByteCount
MaxStreams uint32
OmitConnectionID bool
IdleTimeout time.Duration
}
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
params := &TransportParameters{}
if value, ok := tags[TagTCID]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.OmitConnectionID = (v == 0)
}
if value, ok := tags[TagMIDS]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.MaxStreams = v
}
if value, ok := tags[TagICSL]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
}
if value, ok := tags[TagSFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.StreamFlowControlWindow = protocol.ByteCount(v)
}
if value, ok := tags[TagCFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
}
return params, nil
}
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
sfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
cfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
mids := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
icsl := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
tags := map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}
if p.OmitConnectionID {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
return tags
}
// readTransportParameters reads the transport parameters sent in the QUIC TLS extension
func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) {
params := &TransportParameters{}
var foundInitialMaxStreamData bool
var foundInitialMaxData bool
var foundInitialMaxStreamID bool
var foundIdleTimeout bool
for _, p := range paramsList {
switch p.Parameter {
case initialMaxStreamDataParameterID:
foundInitialMaxStreamData = true
if len(p.Value) != 4 {
return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
}
params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
case initialMaxDataParameterID:
foundInitialMaxData = true
if len(p.Value) != 4 {
return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
}
params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
case initialMaxStreamIDParameterID:
foundInitialMaxStreamID = true
if len(p.Value) != 4 {
return nil, fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value))
}
// TODO: handle this value
case idleTimeoutParameterID:
foundIdleTimeout = true
if len(p.Value) != 2 {
return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second)
case omitConnectionIDParameterID:
if len(p.Value) != 0 {
return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
}
params.OmitConnectionID = true
}
}
if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) {
return nil, errors.New("missing parameter")
}
return params, nil
}
// GetTransportParameters gets the parameters needed for the TLS handshake.
func (p *TransportParameters) getTransportParameters() []transportParameter {
initialMaxStreamData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow))
initialMaxData := make([]byte, 4)
binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow))
initialMaxStreamID := make([]byte, 4)
// TODO: use a reasonable value here
binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32)
idleTimeout := make([]byte, 2)
binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout))
maxPacketSize := make([]byte, 2)
binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
params := []transportParameter{
{initialMaxStreamDataParameterID, initialMaxStreamData},
{initialMaxDataParameterID, initialMaxData},
{initialMaxStreamIDParameterID, initialMaxStreamID},
{idleTimeoutParameterID, idleTimeout},
{maxPacketSizeParameterID, maxPacketSize},
}
if p.OmitConnectionID {
params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}})
}
return params
}

View file

@ -0,0 +1,177 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../flowcontrol/interface.go
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockFlowControlManager is a mock of FlowControlManager interface
type MockFlowControlManager struct {
ctrl *gomock.Controller
recorder *MockFlowControlManagerMockRecorder
}
// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager
type MockFlowControlManagerMockRecorder struct {
mock *MockFlowControlManager
}
// NewMockFlowControlManager creates a new mock instance
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
mock := &MockFlowControlManager{ctrl: ctrl}
mock.recorder = &MockFlowControlManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder {
return _m.recorder
}
// NewStream mocks base method
func (_m *MockFlowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) {
_m.ctrl.Call(_m, "NewStream", streamID, contributesToConnectionFlow)
}
// NewStream indicates an expected call of NewStream
func (_mr *MockFlowControlManagerMockRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), arg0, arg1)
}
// RemoveStream mocks base method
func (_m *MockFlowControlManager) RemoveStream(streamID protocol.StreamID) {
_m.ctrl.Call(_m, "RemoveStream", streamID)
}
// RemoveStream indicates an expected call of RemoveStream
func (_mr *MockFlowControlManagerMockRecorder) RemoveStream(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), arg0)
}
// UpdateTransportParameters mocks base method
func (_m *MockFlowControlManager) UpdateTransportParameters(_param0 *handshake.TransportParameters) {
_m.ctrl.Call(_m, "UpdateTransportParameters", _param0)
}
// UpdateTransportParameters indicates an expected call of UpdateTransportParameters
func (_mr *MockFlowControlManagerMockRecorder) UpdateTransportParameters(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateTransportParameters", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateTransportParameters), arg0)
}
// ResetStream mocks base method
func (_m *MockFlowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "ResetStream", streamID, byteOffset)
ret0, _ := ret[0].(error)
return ret0
}
// ResetStream indicates an expected call of ResetStream
func (_mr *MockFlowControlManagerMockRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), arg0, arg1)
}
// UpdateHighestReceived mocks base method
func (_m *MockFlowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "UpdateHighestReceived", streamID, byteOffset)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateHighestReceived indicates an expected call of UpdateHighestReceived
func (_mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), arg0, arg1)
}
// AddBytesRead mocks base method
func (_m *MockFlowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesRead", streamID, n)
ret0, _ := ret[0].(error)
return ret0
}
// AddBytesRead indicates an expected call of AddBytesRead
func (_mr *MockFlowControlManagerMockRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), arg0, arg1)
}
// GetWindowUpdates mocks base method
func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
ret := _m.ctrl.Call(_m, "GetWindowUpdates")
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
return ret0
}
// GetWindowUpdates indicates an expected call of GetWindowUpdates
func (_mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates))
}
// GetReceiveWindow mocks base method
func (_m *MockFlowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "GetReceiveWindow", streamID)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetReceiveWindow indicates an expected call of GetReceiveWindow
func (_mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), arg0)
}
// AddBytesSent mocks base method
func (_m *MockFlowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesSent", streamID, n)
ret0, _ := ret[0].(error)
return ret0
}
// AddBytesSent indicates an expected call of AddBytesSent
func (_mr *MockFlowControlManagerMockRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), arg0, arg1)
}
// SendWindowSize mocks base method
func (_m *MockFlowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "SendWindowSize", streamID)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendWindowSize indicates an expected call of SendWindowSize
func (_mr *MockFlowControlManagerMockRecorder) SendWindowSize(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), arg0)
}
// RemainingConnectionWindowSize mocks base method
func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize
func (_mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize))
}
// UpdateWindow mocks base method
func (_m *MockFlowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
ret := _m.ctrl.Call(_m, "UpdateWindow", streamID, offset)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateWindow indicates an expected call of UpdateWindow
func (_mr *MockFlowControlManagerMockRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateWindow), arg0, arg1)
}

View file

@ -3,6 +3,5 @@ package mocks
// mockgen source mode doesn't properly recognize structs defined in the same package
// so we have to use sed to correct for that
//go:generate sh -c "mockgen -package mocks_fc -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > mocks_fc/flow_control_manager.go"
//go:generate sh -c "mockgen -package mocks -source ../handshake/params_negotiator_base.go > params_negotiator.go"
//go:generate sh -c "mockgen -package mocks -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > flow_control_manager.go"
//go:generate sh -c "goimports -w ."

View file

@ -1,167 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../flowcontrol/interface.go
// Package mocks_fc is a generated GoMock package.
package mocks_fc
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockFlowControlManager is a mock of FlowControlManager interface
type MockFlowControlManager struct {
ctrl *gomock.Controller
recorder *MockFlowControlManagerMockRecorder
}
// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager
type MockFlowControlManagerMockRecorder struct {
mock *MockFlowControlManager
}
// NewMockFlowControlManager creates a new mock instance
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
mock := &MockFlowControlManager{ctrl: ctrl}
mock.recorder = &MockFlowControlManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder {
return m.recorder
}
// NewStream mocks base method
func (m *MockFlowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) {
m.ctrl.Call(m, "NewStream", streamID, contributesToConnectionFlow)
}
// NewStream indicates an expected call of NewStream
func (mr *MockFlowControlManagerMockRecorder) NewStream(streamID, contributesToConnectionFlow interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), streamID, contributesToConnectionFlow)
}
// RemoveStream mocks base method
func (m *MockFlowControlManager) RemoveStream(streamID protocol.StreamID) {
m.ctrl.Call(m, "RemoveStream", streamID)
}
// RemoveStream indicates an expected call of RemoveStream
func (mr *MockFlowControlManagerMockRecorder) RemoveStream(streamID interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), streamID)
}
// ResetStream mocks base method
func (m *MockFlowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
ret := m.ctrl.Call(m, "ResetStream", streamID, byteOffset)
ret0, _ := ret[0].(error)
return ret0
}
// ResetStream indicates an expected call of ResetStream
func (mr *MockFlowControlManagerMockRecorder) ResetStream(streamID, byteOffset interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), streamID, byteOffset)
}
// UpdateHighestReceived mocks base method
func (m *MockFlowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
ret := m.ctrl.Call(m, "UpdateHighestReceived", streamID, byteOffset)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateHighestReceived indicates an expected call of UpdateHighestReceived
func (mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(streamID, byteOffset interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), streamID, byteOffset)
}
// AddBytesRead mocks base method
func (m *MockFlowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
ret := m.ctrl.Call(m, "AddBytesRead", streamID, n)
ret0, _ := ret[0].(error)
return ret0
}
// AddBytesRead indicates an expected call of AddBytesRead
func (mr *MockFlowControlManagerMockRecorder) AddBytesRead(streamID, n interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), streamID, n)
}
// GetWindowUpdates mocks base method
func (m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
ret := m.ctrl.Call(m, "GetWindowUpdates")
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
return ret0
}
// GetWindowUpdates indicates an expected call of GetWindowUpdates
func (mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates))
}
// GetReceiveWindow mocks base method
func (m *MockFlowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
ret := m.ctrl.Call(m, "GetReceiveWindow", streamID)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetReceiveWindow indicates an expected call of GetReceiveWindow
func (mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(streamID interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), streamID)
}
// AddBytesSent mocks base method
func (m *MockFlowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
ret := m.ctrl.Call(m, "AddBytesSent", streamID, n)
ret0, _ := ret[0].(error)
return ret0
}
// AddBytesSent indicates an expected call of AddBytesSent
func (mr *MockFlowControlManagerMockRecorder) AddBytesSent(streamID, n interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), streamID, n)
}
// SendWindowSize mocks base method
func (m *MockFlowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
ret := m.ctrl.Call(m, "SendWindowSize", streamID)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendWindowSize indicates an expected call of SendWindowSize
func (mr *MockFlowControlManagerMockRecorder) SendWindowSize(streamID interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), streamID)
}
// RemainingConnectionWindowSize mocks base method
func (m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
ret := m.ctrl.Call(m, "RemainingConnectionWindowSize")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize
func (mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize))
}
// UpdateWindow mocks base method
func (m *MockFlowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
ret := m.ctrl.Call(m, "UpdateWindow", streamID, offset)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateWindow indicates an expected call of UpdateWindow
func (mr *MockFlowControlManagerMockRecorder) UpdateWindow(streamID, offset interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateWindow), streamID, offset)
}

View file

@ -1,96 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../handshake/params_negotiator_base.go
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockParamsNegotiator is a mock of ParamsNegotiator interface
type MockParamsNegotiator struct {
ctrl *gomock.Controller
recorder *MockParamsNegotiatorMockRecorder
}
// MockParamsNegotiatorMockRecorder is the mock recorder for MockParamsNegotiator
type MockParamsNegotiatorMockRecorder struct {
mock *MockParamsNegotiator
}
// NewMockParamsNegotiator creates a new mock instance
func NewMockParamsNegotiator(ctrl *gomock.Controller) *MockParamsNegotiator {
mock := &MockParamsNegotiator{ctrl: ctrl}
mock.recorder = &MockParamsNegotiatorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockParamsNegotiator) EXPECT() *MockParamsNegotiatorMockRecorder {
return m.recorder
}
// GetSendStreamFlowControlWindow mocks base method
func (m *MockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount {
ret := m.ctrl.Call(m, "GetSendStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// GetSendStreamFlowControlWindow indicates an expected call of GetSendStreamFlowControlWindow
func (mr *MockParamsNegotiatorMockRecorder) GetSendStreamFlowControlWindow() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSendStreamFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendStreamFlowControlWindow))
}
// GetSendConnectionFlowControlWindow mocks base method
func (m *MockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount {
ret := m.ctrl.Call(m, "GetSendConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// GetSendConnectionFlowControlWindow indicates an expected call of GetSendConnectionFlowControlWindow
func (mr *MockParamsNegotiatorMockRecorder) GetSendConnectionFlowControlWindow() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSendConnectionFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendConnectionFlowControlWindow))
}
// GetMaxOutgoingStreams mocks base method
func (m *MockParamsNegotiator) GetMaxOutgoingStreams() uint32 {
ret := m.ctrl.Call(m, "GetMaxOutgoingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
// GetMaxOutgoingStreams indicates an expected call of GetMaxOutgoingStreams
func (mr *MockParamsNegotiatorMockRecorder) GetMaxOutgoingStreams() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxOutgoingStreams))
}
// GetRemoteIdleTimeout mocks base method
func (m *MockParamsNegotiator) GetRemoteIdleTimeout() time.Duration {
ret := m.ctrl.Call(m, "GetRemoteIdleTimeout")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// GetRemoteIdleTimeout indicates an expected call of GetRemoteIdleTimeout
func (mr *MockParamsNegotiatorMockRecorder) GetRemoteIdleTimeout() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRemoteIdleTimeout", reflect.TypeOf((*MockParamsNegotiator)(nil).GetRemoteIdleTimeout))
}
// OmitConnectionID mocks base method
func (m *MockParamsNegotiator) OmitConnectionID() bool {
ret := m.ctrl.Call(m, "OmitConnectionID")
ret0, _ := ret[0].(bool)
return ret0
}
// OmitConnectionID indicates an expected call of OmitConnectionID
func (mr *MockParamsNegotiatorMockRecorder) OmitConnectionID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OmitConnectionID", reflect.TypeOf((*MockParamsNegotiator)(nil).OmitConnectionID))
}

View file

@ -43,12 +43,6 @@ const MaxReceivePacketSize ByteCount = 1452
// Used in QUIC for congestion window computations in bytes.
const DefaultTCPMSS ByteCount = 1460
// InitialStreamFlowControlWindow is the initial stream-level flow control window for sending
const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB
// InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending
const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB
// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have.
const ClientHelloMinimumSize = 1024

View file

@ -25,18 +25,17 @@ type packetPacker struct {
cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator
connParams handshake.ParamsNegotiator
streamFramer *streamFramer
controlFrames []wire.Frame
stopWaiting *wire.StopWaitingFrame
ackFrame *wire.AckFrame
leastUnacked protocol.PacketNumber
omitConnectionID bool
}
func newPacketPacker(connectionID protocol.ConnectionID,
cryptoSetup handshake.CryptoSetup,
connParams handshake.ParamsNegotiator,
streamFramer *streamFramer,
perspective protocol.Perspective,
version protocol.VersionNumber,
@ -44,7 +43,6 @@ func newPacketPacker(connectionID protocol.ConnectionID,
return &packetPacker{
cryptoSetup: cryptoSetup,
connectionID: connectionID,
connParams: connParams,
perspective: perspective,
version: version,
streamFramer: streamFramer,
@ -271,9 +269,11 @@ func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *wire.
ConnectionID: p.connectionID,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
OmitConnectionID: p.connParams.OmitConnectionID(),
}
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
publicHeader.OmitConnectionID = true
}
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
@ -329,3 +329,7 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
p.leastUnacked = leastUnacked
}
func (p *packetPacker) SetOmitConnectionID() {
p.omitConnectionID = true
}

View file

@ -7,7 +7,6 @@ import (
"github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
@ -61,19 +60,15 @@ var _ = Describe("Packet packer", func() {
)
BeforeEach(func() {
mockPn := mocks.NewMockParamsNegotiator(mockCtrl)
mockPn.EXPECT().OmitConnectionID().Return(false).AnyTimes()
cryptoStream = &stream{}
streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer, nil)
streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer)
streamsMap.streams[1] = cryptoStream
streamsMap.openStreams = []protocol.StreamID{1}
streamFramer = newStreamFramer(streamsMap, nil)
packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
connParams: mockPn,
connectionID: 0x1337,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
streamFramer: streamFramer,
@ -234,6 +229,20 @@ var _ = Describe("Packet packer", func() {
Expect(p).ToNot(BeNil())
})
It("it omits the connection ID for forward-secure packets", func() {
ph := packer.getPublicHeader(protocol.EncryptionForwardSecure)
Expect(ph.OmitConnectionID).To(BeFalse())
packer.SetOmitConnectionID()
ph = packer.getPublicHeader(protocol.EncryptionForwardSecure)
Expect(ph.OmitConnectionID).To(BeTrue())
})
It("doesn't omit the connection ID for non-forware-secure packets", func() {
packer.SetOmitConnectionID()
ph := packer.getPublicHeader(protocol.EncryptionSecure)
Expect(ph.OmitConnectionID).To(BeFalse())
})
It("adds the version flag to the public header before the crypto handshake is finished", func() {
packer.perspective = protocol.PerspectiveClient
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure

View file

@ -88,6 +88,8 @@ type session struct {
undecryptablePackets []*receivedPacket
receivedTooManyUndecrytablePacketsTime time.Time
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
paramsChan <-chan handshake.TransportParameters
// this channel is passed to the CryptoSetup and receives the current encryption level
// it is closed as soon as the handshake is complete
aeadChanged <-chan protocol.EncryptionLevel
@ -100,8 +102,6 @@ type session struct {
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
handshakeChan chan<- handshakeEvent
connParams handshake.ParamsNegotiator
lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire
// representation, and sent back in public reset packets
@ -109,6 +109,7 @@ type session struct {
sessionCreationTime time.Time
lastNetworkActivityTime time.Time
remoteIdleTimeout time.Duration
timer *utils.Timer
// keepAlivePingSent stores whether a Ping frame was sent to the peer or not
@ -166,7 +167,9 @@ func (s *session) setup(
negotiatedVersions []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2)
paramsChan := make(chan handshake.TransportParameters)
s.aeadChanged = aeadChanged
s.paramsChan = paramsChan
handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan
s.handshakeCompleteChan = make(chan error, 1)
@ -183,6 +186,9 @@ func (s *session) setup(
s.rttStats = &congestion.RTTStats{}
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
@ -194,15 +200,16 @@ func (s *session) setup(
return s.config.AcceptCookie(clientAddr, cookie)
}
if s.version.UsesTLS() {
s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSServer(
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
tlsConf,
transportParams,
paramsChan,
aeadChanged,
s.config.Versions,
s.version,
)
} else {
s.cryptoSetup, s.connParams, err = newCryptoSetup(
s.cryptoSetup, err = newCryptoSetup(
s.connectionID,
s.conn.RemoteAddr(),
s.version,
@ -210,28 +217,31 @@ func (s *session) setup(
transportParams,
s.config.Versions,
verifySourceAddr,
paramsChan,
aeadChanged,
)
}
} else {
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
if s.version.UsesTLS() {
s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSClient(
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
hostname,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
s.config.Versions,
s.version,
)
} else {
transportParams.RequestConnectionIDOmission = s.config.RequestConnectionIDOmission
s.cryptoSetup, s.connParams, err = newCryptoSetupClient(
s.cryptoSetup, err = newCryptoSetupClient(
hostname,
s.connectionID,
s.version,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
negotiatedVersions,
)
@ -242,16 +252,14 @@ func (s *session) setup(
}
s.flowControlManager = flowcontrol.NewFlowControlManager(
s.connParams,
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective, s.connParams)
s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup,
s.connParams,
s.streamFramer,
s.perspective,
s.version,
@ -318,6 +326,8 @@ runLoop:
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
putPacketBuffer(p.publicHeader.Raw)
case p := <-s.paramsChan:
s.processTransportParameters(&p)
case l, ok := <-aeadChanged:
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
s.handshakeComplete = true
@ -338,7 +348,7 @@ runLoop:
s.sentPacketHandler.OnAlarm()
}
if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.connParams.GetRemoteIdleTimeout()/2 {
if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.remoteIdleTimeout/2 {
// send the PING frame since there is no activity in the session
s.packer.QueueControlFrame(&wire.PingFrame{})
s.keepAlivePingSent = true
@ -379,7 +389,7 @@ func (s *session) Context() context.Context {
func (s *session) maybeResetTimer() {
var deadline time.Time
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
deadline = s.lastNetworkActivityTime.Add(s.connParams.GetRemoteIdleTimeout() / 2)
deadline = s.lastNetworkActivityTime.Add(s.remoteIdleTimeout / 2)
} else {
deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout)
}
@ -613,6 +623,15 @@ func (s *session) handleCloseError(closeErr closeError) error {
return s.sendConnectionClose(quicErr)
}
func (s *session) processTransportParameters(params *handshake.TransportParameters) {
s.remoteIdleTimeout = params.IdleTimeout
s.flowControlManager.UpdateTransportParameters(params)
s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams)
if params.OmitConnectionID {
s.packer.SetOmitConnectionID()
}
}
func (s *session) sendPacket() error {
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())

View file

@ -18,7 +18,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -142,20 +141,6 @@ func areSessionsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*session).run")
}
type mockParamsNegotiator struct{}
var _ handshake.ParamsNegotiator = &mockParamsNegotiator{}
func (m *mockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount {
return protocol.InitialStreamFlowControlWindow
}
func (m *mockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount {
return protocol.InitialConnectionFlowControlWindow
}
func (m *mockParamsNegotiator) GetMaxOutgoingStreams() uint32 { return 100 }
func (m *mockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { return time.Hour }
func (m *mockParamsNegotiator) OmitConnectionID() bool { return false }
var _ = Describe("Session", func() {
var (
sess *session
@ -178,10 +163,11 @@ var _ = Describe("Session", func() {
_ *handshake.TransportParameters,
_ []protocol.VersionNumber,
_ func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters,
aeadChangedP chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) {
) (handshake.CryptoSetup, error) {
aeadChanged = aeadChangedP
return cryptoSetup, &mockParamsNegotiator{}, nil
return cryptoSetup, nil
}
mconn = newMockConnection()
@ -202,8 +188,6 @@ var _ = Describe("Session", func() {
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // 1 stream: the crypto stream
sess.connParams = &mockParamsNegotiator{}
})
AfterEach(func() {
@ -228,10 +212,11 @@ var _ = Describe("Session", func() {
_ *handshake.TransportParameters,
_ []protocol.VersionNumber,
cookieFunc func(net.Addr, *Cookie) bool,
_ chan<- handshake.TransportParameters,
_ chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) {
) (handshake.CryptoSetup, error) {
cookieVerify = cookieFunc
return cryptoSetup, &mockParamsNegotiator{}, nil
return cryptoSetup, nil
}
conf := populateServerConfig(&Config{})
@ -270,6 +255,10 @@ var _ = Describe("Session", func() {
})
Context("when handling stream frames", func() {
BeforeEach(func() {
sess.streamsMap.UpdateMaxStreamLimit(100)
})
It("makes new streams", func() {
sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 5,
@ -464,7 +453,7 @@ var _ = Describe("Session", func() {
It("passes the byte offset to the flow controller", func() {
sess.streamsMap.GetOrOpenStream(5)
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
fcm := mocks.NewMockFlowControlManager(mockCtrl)
sess.flowControlManager = fcm
fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337))
err := sess.handleRstStreamFrame(&wire.RstStreamFrame{
@ -477,7 +466,7 @@ var _ = Describe("Session", func() {
It("returns errors from the flow controller", func() {
testErr := errors.New("flow control violation")
sess.streamsMap.GetOrOpenStream(5)
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
fcm := mocks.NewMockFlowControlManager(mockCtrl)
sess.flowControlManager = fcm
fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)).Return(testErr)
err := sess.handleRstStreamFrame(&wire.RstStreamFrame{
@ -525,6 +514,10 @@ var _ = Describe("Session", func() {
})
Context("handling WINDOW_UPDATE frames", func() {
BeforeEach(func() {
sess.flowControlManager.UpdateTransportParameters(&handshake.TransportParameters{ConnectionFlowControlWindow: 0x1000})
})
It("updates the Flow Control Window of a stream", func() {
_, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
@ -1093,7 +1086,7 @@ var _ = Describe("Session", func() {
It("retransmits a WindowUpdate if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() {
_, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
fcm := mocks.NewMockFlowControlManager(mockCtrl)
sess.flowControlManager = fcm
fcm.EXPECT().GetWindowUpdates()
fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x1000), nil)
@ -1114,7 +1107,7 @@ var _ = Describe("Session", func() {
It("doesn't retransmit WindowUpdates if it already sent a WindowUpdate with a higher ByteOffset", func() {
_, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
fcm := mocks.NewMockFlowControlManager(mockCtrl)
sess.flowControlManager = fcm
fcm.EXPECT().GetWindowUpdates()
fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x2000), nil)
@ -1140,7 +1133,7 @@ var _ = Describe("Session", func() {
err = sess.streamsMap.DeleteClosedStreams()
Expect(err).ToNot(HaveOccurred())
_, err = sess.flowControlManager.SendWindowSize(5)
Expect(err).To(MatchError("Error accessing the flowController map."))
Expect(err).To(MatchError("Error accessing the flowController map"))
sph.retransmissionQueue = []*ackhandler.Packet{{
Frames: []wire.Frame{&wire.WindowUpdateFrame{
StreamID: 5,
@ -1183,6 +1176,11 @@ var _ = Describe("Session", func() {
Context("scheduling sending", func() {
BeforeEach(func() {
sess.processTransportParameters(&handshake.TransportParameters{
StreamFlowControlWindow: protocol.MaxByteCount,
ConnectionFlowControlWindow: protocol.MaxByteCount,
MaxStreams: 1000,
})
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
})
@ -1420,16 +1418,33 @@ var _ = Describe("Session", func() {
close(done)
})
It("process transport parameters received from the peer", func() {
paramsChan := make(chan handshake.TransportParameters)
sess.paramsChan = paramsChan
_, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
go sess.run()
paramsChan <- handshake.TransportParameters{
MaxStreams: 123,
IdleTimeout: 90 * time.Second,
StreamFlowControlWindow: 0x5000,
ConnectionFlowControlWindow: 0x5000,
OmitConnectionID: true,
}
Eventually(func() time.Duration { return sess.remoteIdleTimeout }).Should(Equal(90 * time.Second))
Eventually(func() uint32 { return sess.streamsMap.maxOutgoingStreams }).Should(Equal(uint32(123)))
Eventually(func() (protocol.ByteCount, error) { return sess.flowControlManager.SendWindowSize(5) }).Should(Equal(protocol.ByteCount(0x5000)))
Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue())
Expect(sess.Close(nil)).To(Succeed())
})
Context("keep-alives", func() {
var mockPn *mocks.MockParamsNegotiator
// should be shorter than the local timeout for these tests
// otherwise we'd send a CONNECTION_CLOSE in the tests where we're testing that no PING is sent
remoteIdleTimeout := 20 * time.Second
BeforeEach(func() {
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
mockPn.EXPECT().GetRemoteIdleTimeout().Return(remoteIdleTimeout).AnyTimes()
sess.connParams = mockPn
sess.remoteIdleTimeout = remoteIdleTimeout
})
It("sends a PING", func() {
@ -1523,6 +1538,10 @@ var _ = Describe("Session", func() {
}, 0.5)
Context("getting streams", func() {
BeforeEach(func() {
sess.processTransportParameters(&handshake.TransportParameters{MaxStreams: 1000})
})
It("returns a new stream", func() {
str, err := sess.GetOrOpenStream(11)
Expect(err).ToNot(HaveOccurred())
@ -1653,11 +1672,12 @@ var _ = Describe("Client Session", func() {
_ protocol.VersionNumber,
_ *tls.Config,
_ *handshake.TransportParameters,
_ chan<- handshake.TransportParameters,
aeadChangedP chan<- protocol.EncryptionLevel,
_ []protocol.VersionNumber,
) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) {
) (handshake.CryptoSetup, error) {
aeadChanged = aeadChangedP
return cryptoSetup, &mockParamsNegotiator{}, nil
return cryptoSetup, nil
}
mconn = newMockConnection()

View file

@ -3,7 +3,7 @@ package quic
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
@ -21,7 +21,7 @@ var _ = Describe("Stream Framer", func() {
framer *streamFramer
streamsMap *streamsMap
stream1, stream2 *stream
mockFcm *mocks_fc.MockFlowControlManager
mockFcm *mocks.MockFlowControlManager
)
BeforeEach(func() {
@ -37,11 +37,11 @@ var _ = Describe("Stream Framer", func() {
stream1 = &stream{streamID: id1}
stream2 = &stream{streamID: id2}
streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer, nil)
streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer)
streamsMap.putStream(stream1)
streamsMap.putStream(stream2)
mockFcm = mocks_fc.NewMockFlowControlManager(mockCtrl)
mockFcm = mocks.NewMockFlowControlManager(mockCtrl)
framer = newStreamFramer(streamsMap, mockFcm)
})

View file

@ -9,7 +9,7 @@ import (
"os"
"github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -30,7 +30,7 @@ var _ = Describe("Stream", func() {
resetCalledForStream protocol.StreamID
resetCalledAtOffset protocol.ByteCount
mockFcm *mocks_fc.MockFlowControlManager
mockFcm *mocks.MockFlowControlManager
)
// in the tests for the stream deadlines we set a deadline
@ -58,7 +58,7 @@ var _ = Describe("Stream", func() {
BeforeEach(func() {
onDataCalled = false
resetCalled = false
mockFcm = mocks_fc.NewMockFlowControlManager(mockCtrl)
mockFcm = mocks.NewMockFlowControlManager(mockCtrl)
str = newStream(streamID, onData, onReset, mockFcm)
timeout := scaleDuration(250 * time.Millisecond)

View file

@ -5,7 +5,6 @@ import (
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
@ -14,7 +13,6 @@ import (
type streamsMap struct {
mutex sync.RWMutex
connParams handshake.ParamsNegotiator
perspective protocol.Perspective
streams map[protocol.StreamID]*stream
@ -36,6 +34,7 @@ type streamsMap struct {
numOutgoingStreams uint32
numIncomingStreams uint32
maxIncomingStreams uint32
maxOutgoingStreams uint32
}
type streamLambda func(*stream) (bool, error)
@ -44,7 +43,7 @@ type newStreamLambda func(protocol.StreamID) *stream
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap {
func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective) *streamsMap {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
@ -57,7 +56,6 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC
openStreams: make([]protocol.StreamID, 0),
newStream: newStream,
removeStreamCallback: removeStreamCallback,
connParams: connParams,
maxIncomingStreams: maxIncomingStreams,
}
sm.nextStreamOrErrCond.L = &sm.mutex
@ -66,6 +64,8 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC
if pers == protocol.PerspectiveClient {
sm.nextStream = 1
sm.nextStreamToAccept = 2
// TODO: find a better solution for opening the crypto stream
sm.maxOutgoingStreams = 1 // allow the crypto stream
} else {
sm.nextStream = 2
sm.nextStreamToAccept = 1
@ -159,7 +159,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
func (m *streamsMap) openStreamImpl() (*stream, error) {
id := m.nextStream
if m.numOutgoingStreams >= m.connParams.GetMaxOutgoingStreams() {
if m.numOutgoingStreams >= m.maxOutgoingStreams {
return nil, qerr.TooManyOpenStreams
}
@ -340,3 +340,9 @@ func (m *streamsMap) CloseWithError(err error) {
m.streams[s].Cancel(err)
}
}
func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.maxOutgoingStreams = limit
}

View file

@ -3,7 +3,6 @@ package quic
import (
"errors"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
@ -11,22 +10,16 @@ import (
)
var _ = Describe("Streams Map", func() {
const maxOutgoingStreams = 60
var (
m *streamsMap
mockPn *mocks.MockParamsNegotiator
)
setNewStreamsMap := func(p protocol.Perspective) {
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams))
newStream := func(id protocol.StreamID) *stream {
return newStream(id, func() {}, nil, nil)
}
removeStreamCallback := func(protocol.StreamID) {}
m = newStreamsMap(newStream, removeStreamCallback, p, mockPn)
m = newStreamsMap(newStream, removeStreamCallback, p)
}
AfterEach(func() {
@ -132,7 +125,13 @@ var _ = Describe("Streams Map", func() {
})
Context("server-side streams", func() {
It("doesn't allow opening streams before receiving the transport parameters", func() {
_, err := m.OpenStream()
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
})
It("opens a stream 2 first", func() {
m.UpdateMaxStreamLimit(100)
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
@ -149,6 +148,7 @@ var _ = Describe("Streams Map", func() {
})
It("doesn't reopen an already closed stream", func() {
m.UpdateMaxStreamLimit(100)
str, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
@ -160,6 +160,12 @@ var _ = Describe("Streams Map", func() {
})
Context("counting streams", func() {
const maxOutgoingStreams = 50
BeforeEach(func() {
m.UpdateMaxStreamLimit(maxOutgoingStreams)
})
It("errors when too many streams are opened", func() {
for i := 1; i <= maxOutgoingStreams; i++ {
_, err := m.OpenStream()
@ -190,6 +196,12 @@ var _ = Describe("Streams Map", func() {
})
Context("opening streams synchronously", func() {
const maxOutgoingStreams = 10
BeforeEach(func() {
m.UpdateMaxStreamLimit(maxOutgoingStreams)
})
openMaxNumStreams := func() {
for i := 1; i <= maxOutgoingStreams; i++ {
_, err := m.OpenStream()