mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
remove the params negotiator
This commit is contained in:
parent
925a52f032
commit
f3e9bf4332
37 changed files with 1013 additions and 1296 deletions
|
@ -13,32 +13,31 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type flowControlManager struct {
|
type flowControlManager struct {
|
||||||
connParams handshake.ParamsNegotiator
|
|
||||||
rttStats *congestion.RTTStats
|
rttStats *congestion.RTTStats
|
||||||
maxReceiveStreamWindow protocol.ByteCount
|
maxReceiveStreamWindow protocol.ByteCount
|
||||||
|
|
||||||
streamFlowController map[protocol.StreamID]*flowController
|
streamFlowController map[protocol.StreamID]*flowController
|
||||||
connFlowController *flowController
|
connFlowController *flowController
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
initialStreamSendWindow protocol.ByteCount
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ FlowControlManager = &flowControlManager{}
|
var _ FlowControlManager = &flowControlManager{}
|
||||||
|
|
||||||
var errMapAccess = errors.New("Error accessing the flowController map.")
|
var errMapAccess = errors.New("Error accessing the flowController map")
|
||||||
|
|
||||||
// NewFlowControlManager creates a new flow control manager
|
// NewFlowControlManager creates a new flow control manager
|
||||||
func NewFlowControlManager(
|
func NewFlowControlManager(
|
||||||
connParams handshake.ParamsNegotiator,
|
|
||||||
maxReceiveStreamWindow protocol.ByteCount,
|
maxReceiveStreamWindow protocol.ByteCount,
|
||||||
maxReceiveConnectionWindow protocol.ByteCount,
|
maxReceiveConnectionWindow protocol.ByteCount,
|
||||||
rttStats *congestion.RTTStats,
|
rttStats *congestion.RTTStats,
|
||||||
) FlowControlManager {
|
) FlowControlManager {
|
||||||
return &flowControlManager{
|
return &flowControlManager{
|
||||||
connParams: connParams,
|
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
maxReceiveStreamWindow: maxReceiveStreamWindow,
|
maxReceiveStreamWindow: maxReceiveStreamWindow,
|
||||||
streamFlowController: make(map[protocol.StreamID]*flowController),
|
streamFlowController: make(map[protocol.StreamID]*flowController),
|
||||||
connFlowController: newFlowController(0, false, connParams, protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, rttStats),
|
connFlowController: newFlowController(0, false, protocol.ReceiveConnectionFlowControlWindow, maxReceiveConnectionWindow, 0, rttStats),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,7 +50,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo
|
||||||
if _, ok := f.streamFlowController[streamID]; ok {
|
if _, ok := f.streamFlowController[streamID]; ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connParams, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.rttStats)
|
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, protocol.ReceiveStreamFlowControlWindow, f.maxReceiveStreamWindow, f.initialStreamSendWindow, f.rttStats)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveStream removes a closed stream from flow control
|
// RemoveStream removes a closed stream from flow control
|
||||||
|
@ -61,6 +60,17 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
|
||||||
f.mutex.Unlock()
|
f.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *flowControlManager) UpdateTransportParameters(params *handshake.TransportParameters) {
|
||||||
|
f.mutex.Lock()
|
||||||
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
|
f.connFlowController.UpdateSendWindow(params.ConnectionFlowControlWindow)
|
||||||
|
f.initialStreamSendWindow = params.StreamFlowControlWindow
|
||||||
|
for _, fc := range f.streamFlowController {
|
||||||
|
fc.UpdateSendWindow(params.StreamFlowControlWindow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ResetStream should be called when receiving a RstStreamFrame
|
// ResetStream should be called when receiving a RstStreamFrame
|
||||||
// it updates the byte offset to the value in the RstStreamFrame
|
// it updates the byte offset to the value in the RstStreamFrame
|
||||||
// streamID must not be 0 here
|
// streamID must not be 0 here
|
||||||
|
@ -233,7 +243,6 @@ func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset pro
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fc.UpdateSendWindow(offset), nil
|
return fc.UpdateSendWindow(offset), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,9 @@ package flowcontrol
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -15,13 +16,18 @@ var _ = Describe("Flow Control Manager", func() {
|
||||||
var fcm *flowControlManager
|
var fcm *flowControlManager
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
mockPn := mocks.NewMockParamsNegotiator(mockCtrl)
|
fcm = NewFlowControlManager(
|
||||||
fcm = NewFlowControlManager(mockPn, protocol.MaxByteCount, protocol.MaxByteCount, &congestion.RTTStats{}).(*flowControlManager)
|
0x2000, // maxReceiveStreamWindow
|
||||||
|
0x4000, // maxReceiveConnectionWindow
|
||||||
|
&congestion.RTTStats{},
|
||||||
|
).(*flowControlManager)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates a connection level flow controller", func() {
|
It("creates a connection level flow controller", func() {
|
||||||
Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(0)))
|
Expect(fcm.streamFlowController).To(BeEmpty())
|
||||||
Expect(fcm.connFlowController.ContributesToConnection()).To(BeFalse())
|
Expect(fcm.connFlowController.ContributesToConnection()).To(BeFalse())
|
||||||
|
Expect(fcm.connFlowController.sendWindow).To(BeZero())
|
||||||
|
Expect(fcm.connFlowController.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x4000)))
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("creating new streams", func() {
|
Context("creating new streams", func() {
|
||||||
|
@ -31,6 +37,19 @@ var _ = Describe("Flow Control Manager", func() {
|
||||||
fc := fcm.streamFlowController[5]
|
fc := fcm.streamFlowController[5]
|
||||||
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
||||||
Expect(fc.ContributesToConnection()).To(BeFalse())
|
Expect(fc.ContributesToConnection()).To(BeFalse())
|
||||||
|
// the transport parameters have not yet been received. Start with a window of size 0
|
||||||
|
Expect(fc.sendWindow).To(BeZero())
|
||||||
|
Expect(fc.maxReceiveWindowIncrement).To(Equal(protocol.ByteCount(0x2000)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("creates a new stream after it has received transport parameters", func() {
|
||||||
|
fcm.UpdateTransportParameters(&handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: 0x3000,
|
||||||
|
})
|
||||||
|
fcm.NewStream(5, false)
|
||||||
|
Expect(fcm.streamFlowController).To(HaveKey(protocol.StreamID(5)))
|
||||||
|
fc := fcm.streamFlowController[5]
|
||||||
|
Expect(fc.sendWindow).To(Equal(protocol.ByteCount(0x3000)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't create a new flow controller if called for an existing stream", func() {
|
It("doesn't create a new flow controller if called for an existing stream", func() {
|
||||||
|
@ -51,6 +70,16 @@ var _ = Describe("Flow Control Manager", func() {
|
||||||
Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(5)))
|
Expect(fcm.streamFlowController).ToNot(HaveKey(protocol.StreamID(5)))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("updates the send windows for existing streams when receiveing the transport parameters", func() {
|
||||||
|
fcm.NewStream(5, false)
|
||||||
|
fcm.UpdateTransportParameters(&handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: 0x3000,
|
||||||
|
ConnectionFlowControlWindow: 0x6000,
|
||||||
|
})
|
||||||
|
Expect(fcm.connFlowController.sendWindow).To(Equal(protocol.ByteCount(0x6000)))
|
||||||
|
Expect(fcm.streamFlowController[5].sendWindow).To(Equal(protocol.ByteCount(0x3000)))
|
||||||
|
})
|
||||||
|
|
||||||
Context("receiving data", func() {
|
Context("receiving data", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
fcm.NewStream(1, false)
|
fcm.NewStream(1, false)
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
@ -14,8 +13,7 @@ type flowController struct {
|
||||||
streamID protocol.StreamID
|
streamID protocol.StreamID
|
||||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||||
|
|
||||||
connParams handshake.ParamsNegotiator
|
rttStats *congestion.RTTStats
|
||||||
rttStats *congestion.RTTStats
|
|
||||||
|
|
||||||
bytesSent protocol.ByteCount
|
bytesSent protocol.ByteCount
|
||||||
sendWindow protocol.ByteCount
|
sendWindow protocol.ByteCount
|
||||||
|
@ -36,19 +34,19 @@ var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
|
||||||
func newFlowController(
|
func newFlowController(
|
||||||
streamID protocol.StreamID,
|
streamID protocol.StreamID,
|
||||||
contributesToConnection bool,
|
contributesToConnection bool,
|
||||||
connParams handshake.ParamsNegotiator,
|
|
||||||
receiveWindow protocol.ByteCount,
|
receiveWindow protocol.ByteCount,
|
||||||
maxReceiveWindow protocol.ByteCount,
|
maxReceiveWindow protocol.ByteCount,
|
||||||
|
initialSendWindow protocol.ByteCount,
|
||||||
rttStats *congestion.RTTStats,
|
rttStats *congestion.RTTStats,
|
||||||
) *flowController {
|
) *flowController {
|
||||||
return &flowController{
|
return &flowController{
|
||||||
streamID: streamID,
|
streamID: streamID,
|
||||||
contributesToConnection: contributesToConnection,
|
contributesToConnection: contributesToConnection,
|
||||||
connParams: connParams,
|
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
receiveWindow: receiveWindow,
|
receiveWindow: receiveWindow,
|
||||||
receiveWindowIncrement: receiveWindow,
|
receiveWindowIncrement: receiveWindow,
|
||||||
maxReceiveWindowIncrement: maxReceiveWindow,
|
maxReceiveWindowIncrement: maxReceiveWindow,
|
||||||
|
sendWindow: initialSendWindow,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,16 +54,6 @@ func (c *flowController) ContributesToConnection() bool {
|
||||||
return c.contributesToConnection
|
return c.contributesToConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *flowController) getSendWindow() protocol.ByteCount {
|
|
||||||
if c.sendWindow == 0 {
|
|
||||||
if c.streamID == 0 {
|
|
||||||
return c.connParams.GetSendConnectionFlowControlWindow()
|
|
||||||
}
|
|
||||||
return c.connParams.GetSendStreamFlowControlWindow()
|
|
||||||
}
|
|
||||||
return c.sendWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
|
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
c.bytesSent += n
|
c.bytesSent += n
|
||||||
}
|
}
|
||||||
|
@ -81,16 +69,11 @@ func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *flowController) SendWindowSize() protocol.ByteCount {
|
func (c *flowController) SendWindowSize() protocol.ByteCount {
|
||||||
sendWindow := c.getSendWindow()
|
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
|
||||||
|
if c.bytesSent > c.sendWindow {
|
||||||
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return sendWindow - c.bytesSent
|
return c.sendWindow - c.bytesSent
|
||||||
}
|
|
||||||
|
|
||||||
func (c *flowController) SendWindowOffset() protocol.ByteCount {
|
|
||||||
return c.getSendWindow()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -19,61 +18,28 @@ var _ = Describe("Flow controller", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Constructor", func() {
|
Context("Constructor", func() {
|
||||||
var rttStats *congestion.RTTStats
|
rttStats := &congestion.RTTStats{}
|
||||||
var mockPn *mocks.MockParamsNegotiator
|
|
||||||
|
|
||||||
receiveStreamWindow := protocol.ByteCount(2000)
|
It("sets the send and receive windows", func() {
|
||||||
receiveConnectionWindow := protocol.ByteCount(4000)
|
receiveWindow := protocol.ByteCount(2000)
|
||||||
maxReceiveStreamWindow := protocol.ByteCount(8000)
|
maxReceiveWindow := protocol.ByteCount(3000)
|
||||||
maxReceiveConnectionWindow := protocol.ByteCount(9000)
|
sendWindow := protocol.ByteCount(4000)
|
||||||
|
fc := newFlowController(5, true, receiveWindow, maxReceiveWindow, sendWindow, rttStats)
|
||||||
BeforeEach(func() {
|
|
||||||
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
|
|
||||||
mockPn.EXPECT().GetSendStreamFlowControlWindow().AnyTimes().Return(protocol.ByteCount(1000))
|
|
||||||
mockPn.EXPECT().GetSendConnectionFlowControlWindow().AnyTimes().Return(protocol.ByteCount(3000))
|
|
||||||
rttStats = &congestion.RTTStats{}
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads the stream send and receive windows when acting as stream-level flow controller", func() {
|
|
||||||
fc := newFlowController(5, true, mockPn, receiveStreamWindow, maxReceiveStreamWindow, rttStats)
|
|
||||||
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
||||||
Expect(fc.receiveWindow).To(Equal(receiveStreamWindow))
|
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
||||||
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveStreamWindow))
|
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveWindow))
|
||||||
})
|
Expect(fc.sendWindow).To(Equal(sendWindow))
|
||||||
|
|
||||||
It("reads the stream send and receive windows when acting as connection-level flow controller", func() {
|
|
||||||
fc := newFlowController(0, false, mockPn, receiveConnectionWindow, maxReceiveConnectionWindow, rttStats)
|
|
||||||
Expect(fc.streamID).To(Equal(protocol.StreamID(0)))
|
|
||||||
Expect(fc.receiveWindow).To(Equal(receiveConnectionWindow))
|
|
||||||
Expect(fc.maxReceiveWindowIncrement).To(Equal(maxReceiveConnectionWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not set the stream flow control windows for sending", func() {
|
|
||||||
fc := newFlowController(5, true, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
|
|
||||||
Expect(fc.sendWindow).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not set the connection flow control windows for sending", func() {
|
|
||||||
fc := newFlowController(0, false, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
|
|
||||||
Expect(fc.sendWindow).To(BeZero())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("says if it contributes to connection-level flow control", func() {
|
It("says if it contributes to connection-level flow control", func() {
|
||||||
fc := newFlowController(1, false, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
|
fc := newFlowController(1, false, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
|
||||||
Expect(fc.ContributesToConnection()).To(BeFalse())
|
Expect(fc.ContributesToConnection()).To(BeFalse())
|
||||||
fc = newFlowController(5, true, mockPn, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
|
fc = newFlowController(5, true, protocol.MaxByteCount, protocol.MaxByteCount, protocol.MaxByteCount, rttStats)
|
||||||
Expect(fc.ContributesToConnection()).To(BeTrue())
|
Expect(fc.ContributesToConnection()).To(BeTrue())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("send flow control", func() {
|
Context("send flow control", func() {
|
||||||
var mockPn *mocks.MockParamsNegotiator
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
|
|
||||||
controller.connParams = mockPn
|
|
||||||
})
|
|
||||||
|
|
||||||
It("adds bytes sent", func() {
|
It("adds bytes sent", func() {
|
||||||
controller.bytesSent = 5
|
controller.bytesSent = 5
|
||||||
controller.AddBytesSent(6)
|
controller.AddBytesSent(6)
|
||||||
|
@ -89,14 +55,14 @@ var _ = Describe("Flow controller", func() {
|
||||||
It("gets the offset of the flow control window", func() {
|
It("gets the offset of the flow control window", func() {
|
||||||
controller.bytesSent = 5
|
controller.bytesSent = 5
|
||||||
controller.sendWindow = 12
|
controller.sendWindow = 12
|
||||||
Expect(controller.SendWindowOffset()).To(Equal(protocol.ByteCount(12)))
|
Expect(controller.sendWindow).To(Equal(protocol.ByteCount(12)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("updates the size of the flow control window", func() {
|
It("updates the size of the flow control window", func() {
|
||||||
controller.bytesSent = 5
|
controller.bytesSent = 5
|
||||||
updateSuccessful := controller.UpdateSendWindow(15)
|
updateSuccessful := controller.UpdateSendWindow(15)
|
||||||
Expect(updateSuccessful).To(BeTrue())
|
Expect(updateSuccessful).To(BeTrue())
|
||||||
Expect(controller.SendWindowOffset()).To(Equal(protocol.ByteCount(15)))
|
Expect(controller.sendWindow).To(Equal(protocol.ByteCount(15)))
|
||||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(15 - 5)))
|
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(15 - 5)))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -108,36 +74,6 @@ var _ = Describe("Flow controller", func() {
|
||||||
Expect(updateSuccessful).To(BeFalse())
|
Expect(updateSuccessful).To(BeFalse())
|
||||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20)))
|
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(20)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("asks the ConnectionParametersManager for the stream flow control window size", func() {
|
|
||||||
controller.streamID = 5
|
|
||||||
mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(1000))
|
|
||||||
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(1000)))
|
|
||||||
// make sure the value is not cached
|
|
||||||
mockPn.EXPECT().GetSendStreamFlowControlWindow().Return(protocol.ByteCount(2000))
|
|
||||||
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(2000)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops asking the ConnectionParametersManager for the flow control stream window size once a window update has arrived", func() {
|
|
||||||
controller.streamID = 5
|
|
||||||
Expect(controller.UpdateSendWindow(8000))
|
|
||||||
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(8000)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("asks the ConnectionParametersManager for the connection flow control window size", func() {
|
|
||||||
controller.streamID = 0
|
|
||||||
mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(3000))
|
|
||||||
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(3000)))
|
|
||||||
// make sure the value is not cached
|
|
||||||
mockPn.EXPECT().GetSendConnectionFlowControlWindow().Return(protocol.ByteCount(5000))
|
|
||||||
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(5000)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("stops asking the ConnectionParametersManager for the connection flow control window size once a window update has arrived", func() {
|
|
||||||
controller.streamID = 0
|
|
||||||
Expect(controller.UpdateSendWindow(7000))
|
|
||||||
Expect(controller.getSendWindow()).To(Equal(protocol.ByteCount(7000)))
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("receive flow control", func() {
|
Context("receive flow control", func() {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package flowcontrol
|
package flowcontrol
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
import "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
|
||||||
// WindowUpdate provides the data for WindowUpdateFrames.
|
// WindowUpdate provides the data for WindowUpdateFrames.
|
||||||
type WindowUpdate struct {
|
type WindowUpdate struct {
|
||||||
|
@ -12,6 +13,7 @@ type WindowUpdate struct {
|
||||||
type FlowControlManager interface {
|
type FlowControlManager interface {
|
||||||
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
|
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
|
||||||
RemoveStream(streamID protocol.StreamID)
|
RemoveStream(streamID protocol.StreamID)
|
||||||
|
UpdateTransportParameters(*handshake.TransportParameters)
|
||||||
// methods needed for receiving data
|
// methods needed for receiving data
|
||||||
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
||||||
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
||||||
|
|
|
@ -49,10 +49,11 @@ type cryptoSetupClient struct {
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
secureAEAD crypto.AEAD
|
secureAEAD crypto.AEAD
|
||||||
forwardSecureAEAD crypto.AEAD
|
forwardSecureAEAD crypto.AEAD
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
|
||||||
|
|
||||||
requestConnIDOmission bool
|
paramsChan chan<- TransportParameters
|
||||||
params *paramsNegotiatorGQUIC
|
aeadChanged chan<- protocol.EncryptionLevel
|
||||||
|
|
||||||
|
params *TransportParameters
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ CryptoSetup = &cryptoSetupClient{}
|
var _ CryptoSetup = &cryptoSetupClient{}
|
||||||
|
@ -70,24 +71,24 @@ func NewCryptoSetupClient(
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
negotiatedVersions []protocol.VersionNumber,
|
||||||
) (CryptoSetup, ParamsNegotiator, error) {
|
) (CryptoSetup, error) {
|
||||||
pn := newParamsNegotiatorGQUIC(protocol.PerspectiveClient, version, params)
|
|
||||||
return &cryptoSetupClient{
|
return &cryptoSetupClient{
|
||||||
hostname: hostname,
|
hostname: hostname,
|
||||||
connID: connID,
|
connID: connID,
|
||||||
version: version,
|
version: version,
|
||||||
certManager: crypto.NewCertManager(tlsConfig),
|
certManager: crypto.NewCertManager(tlsConfig),
|
||||||
params: pn,
|
params: params,
|
||||||
requestConnIDOmission: params.RequestConnectionIDOmission,
|
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
keyExchange: getEphermalKEX,
|
||||||
keyExchange: getEphermalKEX,
|
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
paramsChan: paramsChan,
|
||||||
aeadChanged: aeadChanged,
|
aeadChanged: aeadChanged,
|
||||||
negotiatedVersions: negotiatedVersions,
|
negotiatedVersions: negotiatedVersions,
|
||||||
divNonceChan: make(chan []byte),
|
divNonceChan: make(chan []byte),
|
||||||
}, pn, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
|
func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
|
||||||
|
@ -141,15 +142,21 @@ func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
|
||||||
utils.Debugf("Got %s", message)
|
utils.Debugf("Got %s", message)
|
||||||
switch message.Tag {
|
switch message.Tag {
|
||||||
case TagREJ:
|
case TagREJ:
|
||||||
err = h.handleREJMessage(message.Data)
|
if err := h.handleREJMessage(message.Data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case TagSHLO:
|
case TagSHLO:
|
||||||
err = h.handleSHLOMessage(message.Data)
|
params, err := h.handleSHLOMessage(message.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// blocks until the session has received the parameters
|
||||||
|
h.paramsChan <- *params
|
||||||
|
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||||
|
close(h.aeadChanged)
|
||||||
default:
|
default:
|
||||||
return qerr.InvalidCryptoMessageType
|
return qerr.InvalidCryptoMessageType
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -215,12 +222,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
defer h.mutex.Unlock()
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
if !h.receivedSecurePacket {
|
if !h.receivedSecurePacket {
|
||||||
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
||||||
}
|
}
|
||||||
|
|
||||||
if sno, ok := cryptoData[TagSNO]; ok {
|
if sno, ok := cryptoData[TagSNO]; ok {
|
||||||
|
@ -229,22 +236,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
||||||
|
|
||||||
serverPubs, ok := cryptoData[TagPUBS]
|
serverPubs, ok := cryptoData[TagPUBS]
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||||
}
|
}
|
||||||
|
|
||||||
verTag, ok := cryptoData[TagVER]
|
verTag, ok := cryptoData[TagVER]
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
||||||
}
|
}
|
||||||
if !h.validateVersionList(verTag) {
|
if !h.validateVersionList(verTag) {
|
||||||
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := append(h.nonc, h.sno...)
|
nonce := append(h.nonc, h.sno...)
|
||||||
|
|
||||||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
leafCert := h.certManager.GetLeafCert()
|
leafCert := h.certManager.GetLeafCert()
|
||||||
|
@ -261,18 +268,14 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
||||||
protocol.PerspectiveClient,
|
protocol.PerspectiveClient,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.params.SetFromMap(cryptoData)
|
params, err := readHelloMap(cryptoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return qerr.InvalidCryptoMessageParameter
|
return nil, qerr.InvalidCryptoMessageParameter
|
||||||
}
|
}
|
||||||
|
return params, nil
|
||||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
|
||||||
close(h.aeadChanged)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
||||||
|
@ -405,10 +408,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
||||||
tags, err := h.params.GetHelloMap()
|
tags := h.params.getHelloMap()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tags[TagSNI] = []byte(h.hostname)
|
tags[TagSNI] = []byte(h.hostname)
|
||||||
tags[TagPDMD] = []byte("X509")
|
tags[TagPDMD] = []byte("X509")
|
||||||
|
|
||||||
|
@ -421,9 +421,6 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
||||||
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
|
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
|
||||||
tags[TagVER] = versionTag
|
tags[TagVER] = versionTag
|
||||||
|
|
||||||
if h.requestConnIDOmission {
|
|
||||||
tags[TagTCID] = []byte{0, 0, 0, 0}
|
|
||||||
}
|
|
||||||
if len(h.stk) > 0 {
|
if len(h.stk) > 0 {
|
||||||
tags[TagSTK] = h.stk
|
tags[TagSTK] = h.stk
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,6 +79,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
keyDerivationCalledWith *keyDerivationValues
|
keyDerivationCalledWith *keyDerivationValues
|
||||||
shloMap map[Tag][]byte
|
shloMap map[Tag][]byte
|
||||||
aeadChanged chan protocol.EncryptionLevel
|
aeadChanged chan protocol.EncryptionLevel
|
||||||
|
paramsChan chan TransportParameters
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -108,13 +109,16 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
stream = newMockStream()
|
stream = newMockStream()
|
||||||
certManager = &mockCertManager{}
|
certManager = &mockCertManager{}
|
||||||
version := protocol.Version37
|
version := protocol.Version37
|
||||||
|
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
|
||||||
|
paramsChan = make(chan TransportParameters, 1)
|
||||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||||
csInt, _, err := NewCryptoSetupClient(
|
csInt, err := NewCryptoSetupClient(
|
||||||
"hostname",
|
"hostname",
|
||||||
0,
|
0,
|
||||||
version,
|
version,
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
@ -222,7 +226,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
It("returns the right error when detecting a downgrade attack", func() {
|
It("returns the right error when detecting a downgrade attack", func() {
|
||||||
cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever}
|
cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever}
|
||||||
cs.receivedSecurePacket = true
|
cs.receivedSecurePacket = true
|
||||||
err := cs.handleSHLOMessage(map[Tag][]byte{
|
_, err := cs.handleSHLOMessage(map[Tag][]byte{
|
||||||
TagPUBS: []byte{0},
|
TagPUBS: []byte{0},
|
||||||
TagVER: []byte{0, 1},
|
TagVER: []byte{0, 1},
|
||||||
})
|
})
|
||||||
|
@ -385,7 +389,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
|
|
||||||
It("rejects unencrypted SHLOs", func() {
|
It("rejects unencrypted SHLOs", func() {
|
||||||
cs.receivedSecurePacket = false
|
cs.receivedSecurePacket = false
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
|
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
|
||||||
Expect(aeadChanged).ToNot(Receive())
|
Expect(aeadChanged).ToNot(Receive())
|
||||||
Expect(aeadChanged).ToNot(BeClosed())
|
Expect(aeadChanged).ToNot(BeClosed())
|
||||||
|
@ -393,14 +397,14 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
|
|
||||||
It("rejects SHLOs without a PUBS", func() {
|
It("rejects SHLOs without a PUBS", func() {
|
||||||
delete(shloMap, TagPUBS)
|
delete(shloMap, TagPUBS)
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
|
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
|
||||||
Expect(aeadChanged).ToNot(BeClosed())
|
Expect(aeadChanged).ToNot(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("rejects SHLOs without a version list", func() {
|
It("rejects SHLOs without a version list", func() {
|
||||||
delete(shloMap, TagVER)
|
delete(shloMap, TagVER)
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
|
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
|
||||||
Expect(aeadChanged).ToNot(BeClosed())
|
Expect(aeadChanged).ToNot(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -412,36 +416,58 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(ver))
|
utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(ver))
|
||||||
shloMap[TagVER] = b.Bytes()
|
shloMap[TagVER] = b.Bytes()
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reads the server nonce, if set", func() {
|
It("reads the server nonce, if set", func() {
|
||||||
shloMap[TagSNO] = []byte("server nonce")
|
shloMap[TagSNO] = []byte("server nonce")
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cs.sno).To(Equal(shloMap[TagSNO]))
|
Expect(cs.sno).To(Equal(shloMap[TagSNO]))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates a forwardSecureAEAD", func() {
|
It("creates a forwardSecureAEAD", func() {
|
||||||
shloMap[TagSNO] = []byte("server nonce")
|
shloMap[TagSNO] = []byte("server nonce")
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cs.forwardSecureAEAD).ToNot(BeNil())
|
Expect(cs.forwardSecureAEAD).ToNot(BeNil())
|
||||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
|
||||||
Expect(aeadChanged).To(BeClosed())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reads the connection paramaters", func() {
|
It("reads the connection paramaters", func() {
|
||||||
shloMap[TagICSL] = []byte{13, 0, 0, 0} // 13 seconds
|
shloMap[TagICSL] = []byte{13, 0, 0, 0} // 13 seconds
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
params, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cs.params.GetRemoteIdleTimeout()).To(Equal(13 * time.Second))
|
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes the aeadChanged when receiving an SHLO", func() {
|
||||||
|
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
err := cs.HandleCryptoStream(stream)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}()
|
||||||
|
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||||
|
Eventually(aeadChanged).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("passes the transport parameters on the channel", func() {
|
||||||
|
shloMap[TagSFCW] = []byte{0x0d, 0x00, 0xdf, 0xba}
|
||||||
|
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
err := cs.HandleCryptoStream(stream)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}()
|
||||||
|
var params TransportParameters
|
||||||
|
Eventually(paramsChan).Should(Receive(¶ms))
|
||||||
|
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xbadf000d)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if it can't read a connection parameter", func() {
|
It("errors if it can't read a connection parameter", func() {
|
||||||
shloMap[TagICSL] = []byte{3, 0, 0} // 1 byte too short
|
shloMap[TagICSL] = []byte{3, 0, 0} // 1 byte too short
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).To(MatchError(qerr.InvalidCryptoMessageParameter))
|
Expect(err).To(MatchError(qerr.InvalidCryptoMessageParameter))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -488,15 +514,14 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("requests to omit the connection ID", func() {
|
It("requests to omit the connection ID", func() {
|
||||||
cs.requestConnIDOmission = true
|
cs.params.OmitConnectionID = true
|
||||||
tags, err := cs.getTags()
|
tags, err := cs.getTags()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
|
Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("adds the tags returned from the connectionParametersManager to the CHLO", func() {
|
It("adds the tags returned from the connectionParametersManager to the CHLO", func() {
|
||||||
pnTags, err := cs.params.GetHelloMap()
|
pnTags := cs.params.getHelloMap()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pnTags).ToNot(BeEmpty())
|
Expect(pnTags).ToNot(BeEmpty())
|
||||||
tags, err := cs.getTags()
|
tags, err := cs.getTags()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -588,7 +613,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
||||||
|
|
||||||
doSHLO := func() {
|
doSHLO := func() {
|
||||||
cs.receivedSecurePacket = true
|
cs.receivedSecurePacket = true
|
||||||
err := cs.handleSHLOMessage(shloMap)
|
_, err := cs.handleSHLOMessage(shloMap)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,14 +40,17 @@ type cryptoSetupServer struct {
|
||||||
receivedForwardSecurePacket bool
|
receivedForwardSecurePacket bool
|
||||||
receivedSecurePacket bool
|
receivedSecurePacket bool
|
||||||
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
|
||||||
|
receivedParams bool
|
||||||
|
paramsChan chan<- TransportParameters
|
||||||
|
aeadChanged chan<- protocol.EncryptionLevel
|
||||||
|
|
||||||
keyDerivation QuicCryptoKeyDerivationFunction
|
keyDerivation QuicCryptoKeyDerivationFunction
|
||||||
keyExchange KeyExchangeFunction
|
keyExchange KeyExchangeFunction
|
||||||
|
|
||||||
cryptoStream io.ReadWriter
|
cryptoStream io.ReadWriter
|
||||||
|
|
||||||
params *paramsNegotiatorGQUIC
|
params *TransportParameters
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
@ -72,14 +75,14 @@ func NewCryptoSetup(
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
acceptSTK func(net.Addr, *Cookie) bool,
|
acceptSTK func(net.Addr, *Cookie) bool,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
) (CryptoSetup, ParamsNegotiator, error) {
|
) (CryptoSetup, error) {
|
||||||
stkGenerator, err := NewCookieGenerator()
|
stkGenerator, err := NewCookieGenerator()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pn := newParamsNegotiatorGQUIC(protocol.PerspectiveServer, version, params)
|
|
||||||
return &cryptoSetupServer{
|
return &cryptoSetupServer{
|
||||||
connID: connID,
|
connID: connID,
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
|
@ -90,11 +93,12 @@ func NewCryptoSetup(
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||||
keyExchange: getEphermalKEX,
|
keyExchange: getEphermalKEX,
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||||
params: pn,
|
params: params,
|
||||||
acceptSTKCallback: acceptSTK,
|
acceptSTKCallback: acceptSTK,
|
||||||
sentSHLO: make(chan struct{}),
|
sentSHLO: make(chan struct{}),
|
||||||
|
paramsChan: paramsChan,
|
||||||
aeadChanged: aeadChanged,
|
aeadChanged: aeadChanged,
|
||||||
}, pn, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleCryptoStream reads and writes messages on the crypto stream
|
// HandleCryptoStream reads and writes messages on the crypto stream
|
||||||
|
@ -163,6 +167,16 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params, err := readHelloMap(cryptoData)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
// blocks until the session has received the parameters
|
||||||
|
if !h.receivedParams {
|
||||||
|
h.receivedParams = true
|
||||||
|
h.paramsChan <- *params
|
||||||
|
}
|
||||||
|
|
||||||
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
|
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
|
||||||
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
|
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
|
||||||
reply, err = h.handleCHLO(sni, chloData, cryptoData)
|
reply, err = h.handleCHLO(sni, chloData, cryptoData)
|
||||||
|
@ -418,14 +432,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.params.SetFromMap(cryptoData); err != nil {
|
replyMap := h.params.getHelloMap()
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
replyMap, err := h.params.GetHelloMap()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// add crypto parameters
|
// add crypto parameters
|
||||||
verTag := &bytes.Buffer{}
|
verTag := &bytes.Buffer{}
|
||||||
for _, v := range h.supportedVersions {
|
for _, v := range h.supportedVersions {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
@ -167,6 +168,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
scfg *ServerConfig
|
scfg *ServerConfig
|
||||||
cs *cryptoSetupServer
|
cs *cryptoSetupServer
|
||||||
stream *mockStream
|
stream *mockStream
|
||||||
|
paramsChan chan TransportParameters
|
||||||
aeadChanged chan protocol.EncryptionLevel
|
aeadChanged chan protocol.EncryptionLevel
|
||||||
nonce32 []byte
|
nonce32 []byte
|
||||||
versionTag []byte
|
versionTag []byte
|
||||||
|
@ -183,6 +185,8 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||||
expectedInitialNonceLen = 32
|
expectedInitialNonceLen = 32
|
||||||
expectedFSNonceLen = 64
|
expectedFSNonceLen = 64
|
||||||
|
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
||||||
|
paramsChan = make(chan TransportParameters, 1)
|
||||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||||
stream = newMockStream()
|
stream = newMockStream()
|
||||||
kex = &mockKEX{}
|
kex = &mockKEX{}
|
||||||
|
@ -197,7 +201,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
||||||
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
||||||
csInt, _, err := NewCryptoSetup(
|
csInt, err := NewCryptoSetup(
|
||||||
protocol.ConnectionID(42),
|
protocol.ConnectionID(42),
|
||||||
remoteAddr,
|
remoteAddr,
|
||||||
version,
|
version,
|
||||||
|
@ -205,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
||||||
supportedVersions,
|
supportedVersions,
|
||||||
nil,
|
nil,
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
@ -285,6 +290,16 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||||
Expect(err).To(MatchError(ErrNSTPExperiment))
|
Expect(err).To(MatchError(ErrNSTPExperiment))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("reads the transport parameters sent by the client", func() {
|
||||||
|
sourceAddrValid = true
|
||||||
|
fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0}
|
||||||
|
_, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), fullCHLO)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
var params TransportParameters
|
||||||
|
Expect(paramsChan).To(Receive(¶ms))
|
||||||
|
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
||||||
|
})
|
||||||
|
|
||||||
It("generates REJ messages", func() {
|
It("generates REJ messages", func() {
|
||||||
sourceAddrValid = false
|
sourceAddrValid = false
|
||||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), nil)
|
||||||
|
|
|
@ -38,52 +38,52 @@ var newMintController = func(conn *mint.Conn) crypto.MintController {
|
||||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||||
func NewCryptoSetupTLSServer(
|
func NewCryptoSetupTLSServer(
|
||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
transportParams *TransportParameters,
|
params *TransportParameters,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetup, ParamsNegotiator, error) {
|
) (CryptoSetup, error) {
|
||||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
|
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := newParamsNegotiator(protocol.PerspectiveServer, version, transportParams)
|
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
mintConf: mintConf,
|
mintConf: mintConf,
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
aeadChanged: aeadChanged,
|
aeadChanged: aeadChanged,
|
||||||
extensionHandler: newExtensionHandlerServer(params, supportedVersions, version),
|
extensionHandler: newExtensionHandlerServer(params, paramsChan, supportedVersions, version),
|
||||||
}, params, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||||
func NewCryptoSetupTLSClient(
|
func NewCryptoSetupTLSClient(
|
||||||
hostname string, // only needed for the client
|
hostname string, // only needed for the client
|
||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
transportParams *TransportParameters,
|
params *TransportParameters,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetup, ParamsNegotiator, error) {
|
) (CryptoSetup, error) {
|
||||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
|
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mintConf.ServerName = hostname
|
mintConf.ServerName = hostname
|
||||||
|
|
||||||
params := newParamsNegotiator(protocol.PerspectiveClient, version, transportParams)
|
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
mintConf: mintConf,
|
mintConf: mintConf,
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
aeadChanged: aeadChanged,
|
aeadChanged: aeadChanged,
|
||||||
extensionHandler: newExtensionHandlerClient(params, initialVersion, supportedVersions, version),
|
extensionHandler: newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version),
|
||||||
}, params, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error {
|
func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error {
|
||||||
|
|
|
@ -33,16 +33,19 @@ func mockKeyDerivation(crypto.MintController, protocol.Perspective) (crypto.AEAD
|
||||||
var _ = Describe("TLS Crypto Setup", func() {
|
var _ = Describe("TLS Crypto Setup", func() {
|
||||||
var (
|
var (
|
||||||
cs *cryptoSetupTLS
|
cs *cryptoSetupTLS
|
||||||
|
paramsChan chan TransportParameters
|
||||||
aeadChanged chan protocol.EncryptionLevel
|
aeadChanged chan protocol.EncryptionLevel
|
||||||
|
|
||||||
mintControllerConstructor = newMintController
|
mintControllerConstructor = newMintController
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
paramsChan = make(chan TransportParameters)
|
||||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||||
csInt, _, err := NewCryptoSetupTLSServer(
|
csInt, err := NewCryptoSetupTLSServer(
|
||||||
testdata.GetTLSConfig(),
|
testdata.GetTLSConfig(),
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
nil,
|
nil,
|
||||||
protocol.VersionTLS,
|
protocol.VersionTLS,
|
||||||
|
|
|
@ -2,7 +2,6 @@ package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
@ -25,9 +24,3 @@ type CryptoSetup interface {
|
||||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransportParameters are parameters sent to the peer during the handshake
|
|
||||||
type TransportParameters struct {
|
|
||||||
RequestConnectionIDOmission bool
|
|
||||||
IdleTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,111 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type paramsNegotiator struct {
|
|
||||||
paramsNegotiatorBase
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ParamsNegotiator = ¶msNegotiator{}
|
|
||||||
|
|
||||||
// newParamsNegotiator creates a new connection parameters manager
|
|
||||||
func newParamsNegotiator(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiator {
|
|
||||||
h := ¶msNegotiator{}
|
|
||||||
h.perspective = pers
|
|
||||||
h.version = v
|
|
||||||
h.init(params)
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiator) SetFromTransportParameters(params []transportParameter) error {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
var foundInitialMaxStreamData bool
|
|
||||||
var foundInitialMaxData bool
|
|
||||||
var foundInitialMaxStreamID bool
|
|
||||||
var foundIdleTimeout bool
|
|
||||||
|
|
||||||
for _, p := range params {
|
|
||||||
switch p.Parameter {
|
|
||||||
case initialMaxStreamDataParameterID:
|
|
||||||
foundInitialMaxStreamData = true
|
|
||||||
if len(p.Value) != 4 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
|
|
||||||
}
|
|
||||||
h.sendStreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
|
|
||||||
utils.Debugf("h.sendStreamFlowControlWindow: %#x", h.sendStreamFlowControlWindow)
|
|
||||||
case initialMaxDataParameterID:
|
|
||||||
foundInitialMaxData = true
|
|
||||||
if len(p.Value) != 4 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
|
|
||||||
}
|
|
||||||
h.sendConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
|
|
||||||
utils.Debugf("h.sendConnectionFlowControlWindow: %#x", h.sendConnectionFlowControlWindow)
|
|
||||||
case initialMaxStreamIDParameterID:
|
|
||||||
foundInitialMaxStreamID = true
|
|
||||||
if len(p.Value) != 4 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value))
|
|
||||||
}
|
|
||||||
// TODO: handle this value
|
|
||||||
case idleTimeoutParameterID:
|
|
||||||
foundIdleTimeout = true
|
|
||||||
if len(p.Value) != 2 {
|
|
||||||
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
|
|
||||||
}
|
|
||||||
h.setRemoteIdleTimeout(time.Duration(binary.BigEndian.Uint16(p.Value)) * time.Second)
|
|
||||||
case omitConnectionIDParameterID:
|
|
||||||
if len(p.Value) != 0 {
|
|
||||||
return fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
|
|
||||||
}
|
|
||||||
h.omitConnectionID = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) {
|
|
||||||
return errors.New("missing parameter")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiator) GetTransportParameters() []transportParameter {
|
|
||||||
initialMaxStreamData := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(protocol.ReceiveStreamFlowControlWindow))
|
|
||||||
initialMaxData := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(initialMaxData, uint32(protocol.ReceiveConnectionFlowControlWindow))
|
|
||||||
initialMaxStreamID := make([]byte, 4)
|
|
||||||
// TODO: use a reasonable value here
|
|
||||||
binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32)
|
|
||||||
idleTimeout := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(idleTimeout, uint16(h.idleTimeout))
|
|
||||||
maxPacketSize := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
|
|
||||||
params := []transportParameter{
|
|
||||||
{initialMaxStreamDataParameterID, initialMaxStreamData},
|
|
||||||
{initialMaxDataParameterID, initialMaxData},
|
|
||||||
{initialMaxStreamIDParameterID, initialMaxStreamID},
|
|
||||||
{idleTimeoutParameterID, idleTimeout},
|
|
||||||
{maxPacketSizeParameterID, maxPacketSize},
|
|
||||||
}
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
if h.omitConnectionID {
|
|
||||||
params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}})
|
|
||||||
}
|
|
||||||
return params
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiator) OmitConnectionID() bool {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.omitConnectionID
|
|
||||||
}
|
|
|
@ -1,85 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The ParamsNegotiator negotiates and stores the connection parameters.
|
|
||||||
// It can be used for a server as well as a client.
|
|
||||||
type ParamsNegotiator interface {
|
|
||||||
GetSendStreamFlowControlWindow() protocol.ByteCount
|
|
||||||
GetSendConnectionFlowControlWindow() protocol.ByteCount
|
|
||||||
GetMaxOutgoingStreams() uint32
|
|
||||||
// get the idle timeout that was sent by the peer
|
|
||||||
GetRemoteIdleTimeout() time.Duration
|
|
||||||
// determines if the client requests omission of connection IDs.
|
|
||||||
OmitConnectionID() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the server:
|
|
||||||
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
|
|
||||||
// 2. call GetHelloMap to get the values to send in the SHLO
|
|
||||||
// For the client:
|
|
||||||
// 1. call GetHelloMap to get the values to send in a CHLO
|
|
||||||
// 2. call SetFromMap with the values received in the SHLO
|
|
||||||
type paramsNegotiatorBase struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
perspective protocol.Perspective
|
|
||||||
|
|
||||||
flowControlNegotiated bool
|
|
||||||
|
|
||||||
omitConnectionID bool
|
|
||||||
requestConnectionIDOmission bool
|
|
||||||
|
|
||||||
maxOutgoingStreams uint32
|
|
||||||
idleTimeout time.Duration
|
|
||||||
remoteIdleTimeout time.Duration
|
|
||||||
sendStreamFlowControlWindow protocol.ByteCount
|
|
||||||
sendConnectionFlowControlWindow protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiatorBase) init(params *TransportParameters) {
|
|
||||||
h.sendStreamFlowControlWindow = protocol.InitialStreamFlowControlWindow // can only be changed by the client
|
|
||||||
h.sendConnectionFlowControlWindow = protocol.InitialConnectionFlowControlWindow // can only be changed by the client
|
|
||||||
h.requestConnectionIDOmission = params.RequestConnectionIDOmission
|
|
||||||
|
|
||||||
h.idleTimeout = params.IdleTimeout
|
|
||||||
// use this as a default value. As soon as the client sends its value, this gets updated
|
|
||||||
h.maxOutgoingStreams = protocol.MaxIncomingStreams
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|
||||||
func (h *paramsNegotiatorBase) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.sendStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
|
|
||||||
func (h *paramsNegotiatorBase) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.sendConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiatorBase) GetMaxOutgoingStreams() uint32 {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.maxOutgoingStreams
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiatorBase) setRemoteIdleTimeout(t time.Duration) {
|
|
||||||
h.remoteIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiatorBase) GetRemoteIdleTimeout() time.Duration {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.remoteIdleTimeout
|
|
||||||
}
|
|
|
@ -1,116 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// errMalformedTag is returned when the tag value cannot be read
|
|
||||||
var (
|
|
||||||
errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
|
||||||
errFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
|
|
||||||
)
|
|
||||||
|
|
||||||
type paramsNegotiatorGQUIC struct {
|
|
||||||
paramsNegotiatorBase
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ParamsNegotiator = ¶msNegotiatorGQUIC{}
|
|
||||||
|
|
||||||
// newParamsNegotiatorGQUIC creates a new connection parameters manager
|
|
||||||
func newParamsNegotiatorGQUIC(pers protocol.Perspective, v protocol.VersionNumber, params *TransportParameters) *paramsNegotiatorGQUIC {
|
|
||||||
h := ¶msNegotiatorGQUIC{}
|
|
||||||
h.perspective = pers
|
|
||||||
h.version = v
|
|
||||||
h.init(params)
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFromMap reads all params.
|
|
||||||
func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
|
|
||||||
clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return errMalformedTag
|
|
||||||
}
|
|
||||||
h.omitConnectionID = (clientValue == 0)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagMIDS]; ok {
|
|
||||||
clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return errMalformedTag
|
|
||||||
}
|
|
||||||
h.maxOutgoingStreams = clientValue
|
|
||||||
}
|
|
||||||
if value, ok := params[TagICSL]; ok {
|
|
||||||
clientValue, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return errMalformedTag
|
|
||||||
}
|
|
||||||
h.setRemoteIdleTimeout(time.Duration(clientValue) * time.Second)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagSFCW]; ok {
|
|
||||||
if h.flowControlNegotiated {
|
|
||||||
return errFlowControlRenegotiationNotSupported
|
|
||||||
}
|
|
||||||
sendStreamFlowControlWindow, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return errMalformedTag
|
|
||||||
}
|
|
||||||
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
|
|
||||||
}
|
|
||||||
if value, ok := params[TagCFCW]; ok {
|
|
||||||
if h.flowControlNegotiated {
|
|
||||||
return errFlowControlRenegotiationNotSupported
|
|
||||||
}
|
|
||||||
sendConnectionFlowControlWindow, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return errMalformedTag
|
|
||||||
}
|
|
||||||
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, containsSFCW := params[TagSFCW]
|
|
||||||
_, containsCFCW := params[TagCFCW]
|
|
||||||
if containsCFCW || containsSFCW {
|
|
||||||
h.flowControlNegotiated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHelloMap gets all parameters needed for the Hello message.
|
|
||||||
func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) {
|
|
||||||
sfcw := bytes.NewBuffer([]byte{})
|
|
||||||
utils.LittleEndian.WriteUint32(sfcw, uint32(protocol.ReceiveStreamFlowControlWindow))
|
|
||||||
cfcw := bytes.NewBuffer([]byte{})
|
|
||||||
utils.LittleEndian.WriteUint32(cfcw, uint32(protocol.ReceiveConnectionFlowControlWindow))
|
|
||||||
mids := bytes.NewBuffer([]byte{})
|
|
||||||
utils.LittleEndian.WriteUint32(mids, protocol.MaxIncomingStreams)
|
|
||||||
icsl := bytes.NewBuffer([]byte{})
|
|
||||||
utils.LittleEndian.WriteUint32(icsl, uint32(h.idleTimeout/time.Second))
|
|
||||||
|
|
||||||
return map[Tag][]byte{
|
|
||||||
TagICSL: icsl.Bytes(),
|
|
||||||
TagMIDS: mids.Bytes(),
|
|
||||||
TagCFCW: cfcw.Bytes(),
|
|
||||||
TagSFCW: sfcw.Bytes(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *paramsNegotiatorGQUIC) OmitConnectionID() bool {
|
|
||||||
if h.perspective == protocol.PerspectiveClient {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.omitConnectionID
|
|
||||||
}
|
|
|
@ -1,231 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Params Negotiator (for gQUIC)", func() {
|
|
||||||
var pn *paramsNegotiatorGQUIC // a connectionParametersManager for a server
|
|
||||||
var pnClient *paramsNegotiatorGQUIC
|
|
||||||
idleTimeout := 42 * time.Second
|
|
||||||
BeforeEach(func() {
|
|
||||||
pn = newParamsNegotiatorGQUIC(
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
protocol.VersionWhatever,
|
|
||||||
&TransportParameters{
|
|
||||||
IdleTimeout: idleTimeout,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pnClient = newParamsNegotiatorGQUIC(
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
protocol.VersionWhatever,
|
|
||||||
&TransportParameters{
|
|
||||||
IdleTimeout: idleTimeout,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("SHLO", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
// these tests should only use the server connectionParametersManager. Make them panic if they don't
|
|
||||||
pnClient = nil
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns all parameters necessary for the SHLO", func() {
|
|
||||||
entryMap, err := pn.GetHelloMap()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(entryMap).To(HaveKey(TagICSL))
|
|
||||||
Expect(entryMap).To(HaveKey(TagMIDS))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the stream-level flow control windows in SHLO", func() {
|
|
||||||
entryMap, err := pn.GetHelloMap()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expected := make([]byte, 4)
|
|
||||||
binary.LittleEndian.PutUint32(expected, uint32(protocol.ReceiveStreamFlowControlWindow))
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the connection-level flow control windows in SHLO", func() {
|
|
||||||
entryMap, err := pn.GetHelloMap()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
expected := make([]byte, 4)
|
|
||||||
binary.LittleEndian.PutUint32(expected, uint32(protocol.ReceiveConnectionFlowControlWindow))
|
|
||||||
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, expected))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets the connection-level flow control windows in SHLO", func() {
|
|
||||||
pn.idleTimeout = 0xdecafbad * time.Second
|
|
||||||
entryMap, err := pn.GetHelloMap()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(entryMap).To(HaveKey(TagICSL))
|
|
||||||
Expect(entryMap[TagICSL]).To(Equal([]byte{0xad, 0xfb, 0xca, 0xde}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("always sends its own value for the maximum incoming dynamic streams in the SHLO", func() {
|
|
||||||
err := pn.SetFromMap(map[Tag][]byte{TagMIDS: []byte{5, 0, 0, 0}})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
entryMap, err := pn.GetHelloMap()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(entryMap[TagMIDS]).To(Equal([]byte{byte(protocol.MaxIncomingStreams), 0, 0, 0}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("CHLO", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
// these tests should only use the client connectionParametersManager. Make them panic if they don't
|
|
||||||
pn = nil
|
|
||||||
})
|
|
||||||
|
|
||||||
It("has the right values", func() {
|
|
||||||
entryMap, err := pnClient.GetHelloMap()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(entryMap).To(HaveKey(TagICSL))
|
|
||||||
Expect(binary.LittleEndian.Uint32(entryMap[TagICSL])).To(BeEquivalentTo(idleTimeout / time.Second))
|
|
||||||
Expect(entryMap).To(HaveKey(TagMIDS))
|
|
||||||
Expect(binary.LittleEndian.Uint32(entryMap[TagMIDS])).To(BeEquivalentTo(protocol.MaxIncomingStreams))
|
|
||||||
Expect(entryMap).To(HaveKey(TagSFCW))
|
|
||||||
Expect(binary.LittleEndian.Uint32(entryMap[TagSFCW])).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow))
|
|
||||||
Expect(entryMap).To(HaveKey(TagCFCW))
|
|
||||||
Expect(binary.LittleEndian.Uint32(entryMap[TagCFCW])).To(BeEquivalentTo(protocol.ReceiveConnectionFlowControlWindow))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Omitted connection IDs", func() {
|
|
||||||
It("does not send omitted connection IDs if the TCID tag is missing", func() {
|
|
||||||
Expect(pn.OmitConnectionID()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("reads the tag for omitted connection IDs", func() {
|
|
||||||
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
|
|
||||||
pn.SetFromMap(values)
|
|
||||||
Expect(pn.OmitConnectionID()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores the TCID tag, as a client", func() {
|
|
||||||
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
|
|
||||||
pnClient.SetFromMap(values)
|
|
||||||
Expect(pnClient.OmitConnectionID()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid value", func() {
|
|
||||||
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("flow control", func() {
|
|
||||||
It("has the correct default flow control windows for sending", func() {
|
|
||||||
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow))
|
|
||||||
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow))
|
|
||||||
Expect(pnClient.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow))
|
|
||||||
Expect(pnClient.GetSendConnectionFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets a new stream-level flow control window for sending", func() {
|
|
||||||
values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF}}
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not change the stream-level flow control window when given an invalid value", func() {
|
|
||||||
values := map[Tag][]byte{TagSFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialStreamFlowControlWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sets a new connection-level flow control window for sending", func() {
|
|
||||||
values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF}}
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not change the connection-level flow control window when given an invalid value", func() {
|
|
||||||
values := map[Tag][]byte{TagCFCW: {0xDE, 0xAD, 0xBE}} // 1 byte too short
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.InitialConnectionFlowControlWindow))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("does not allow renegotiation of flow control parameters", func() {
|
|
||||||
values := map[Tag][]byte{
|
|
||||||
TagCFCW: {0xDE, 0xAD, 0xBE, 0xEF},
|
|
||||||
TagSFCW: {0xDE, 0xAD, 0xBE, 0xEF},
|
|
||||||
}
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
values = map[Tag][]byte{
|
|
||||||
TagCFCW: {0x13, 0x37, 0x13, 0x37},
|
|
||||||
TagSFCW: {0x13, 0x37, 0x13, 0x37},
|
|
||||||
}
|
|
||||||
err = pn.SetFromMap(values)
|
|
||||||
Expect(err).To(MatchError(errFlowControlRenegotiationNotSupported))
|
|
||||||
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
|
|
||||||
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0xEFBEADDE)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("idle timeout", func() {
|
|
||||||
It("sets the remote idle timeout", func() {
|
|
||||||
values := map[Tag][]byte{
|
|
||||||
TagICSL: {10, 0, 0, 0},
|
|
||||||
}
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetRemoteIdleTimeout()).To(Equal(10 * time.Second))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't allow values below the minimum remote idle timeout", func() {
|
|
||||||
t := 2 * time.Second
|
|
||||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
|
||||||
values := map[Tag][]byte{
|
|
||||||
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
|
|
||||||
}
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetRemoteIdleTimeout()).To(Equal(protocol.MinRemoteIdleTimeout))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("errors when given an invalid value", func() {
|
|
||||||
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("max streams per connection", func() {
|
|
||||||
It("errors when given an invalid max dynamic incoming streams per connection value", func() {
|
|
||||||
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
|
|
||||||
err := pn.SetFromMap(values)
|
|
||||||
Expect(err).To(MatchError(errMalformedTag))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("outgoing connections", func() {
|
|
||||||
It("sets the negotiated max streams per connection value", func() {
|
|
||||||
// this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection
|
|
||||||
err := pn.SetFromMap(map[Tag][]byte{
|
|
||||||
TagMIDS: {2, 0, 0, 0},
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(2)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the the MSPC value, if no MIDS is given", func() {
|
|
||||||
err := pn.SetFromMap(map[Tag][]byte{
|
|
||||||
TagMIDS: {3, 0, 0, 0},
|
|
||||||
})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetMaxOutgoingStreams()).To(Equal(uint32(3)))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -1,154 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Params Negotiator (for TLS)", func() {
|
|
||||||
var params map[transportParameterID][]byte
|
|
||||||
var pn *paramsNegotiator
|
|
||||||
|
|
||||||
paramsMapToList := func(p map[transportParameterID][]byte) []transportParameter {
|
|
||||||
var list []transportParameter
|
|
||||||
for id, val := range p {
|
|
||||||
list = append(list, transportParameter{id, val})
|
|
||||||
}
|
|
||||||
return list
|
|
||||||
}
|
|
||||||
|
|
||||||
paramsListToMap := func(l []transportParameter) map[transportParameterID][]byte {
|
|
||||||
p := make(map[transportParameterID][]byte)
|
|
||||||
for _, v := range l {
|
|
||||||
p[v.Parameter] = v.Value
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
pn = newParamsNegotiator(
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
protocol.VersionWhatever,
|
|
||||||
&TransportParameters{},
|
|
||||||
)
|
|
||||||
params = map[transportParameterID][]byte{
|
|
||||||
initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44},
|
|
||||||
initialMaxDataParameterID: []byte{0x22, 0x33, 0x44, 0x55},
|
|
||||||
initialMaxStreamIDParameterID: []byte{0x33, 0x44, 0x55, 0x66},
|
|
||||||
idleTimeoutParameterID: []byte{0x13, 0x37},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("getting", func() {
|
|
||||||
It("creates the parameters list", func() {
|
|
||||||
pn.idleTimeout = 0xcafe
|
|
||||||
buf := make([]byte, 4)
|
|
||||||
values := paramsListToMap(pn.GetTransportParameters())
|
|
||||||
Expect(values).To(HaveLen(5))
|
|
||||||
binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveStreamFlowControlWindow))
|
|
||||||
Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, buf))
|
|
||||||
binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveConnectionFlowControlWindow))
|
|
||||||
Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, buf))
|
|
||||||
Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff}))
|
|
||||||
Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe}))
|
|
||||||
Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac
|
|
||||||
})
|
|
||||||
|
|
||||||
It("request ommision of the connection ID", func() {
|
|
||||||
pn.omitConnectionID = true
|
|
||||||
values := paramsListToMap(pn.GetTransportParameters())
|
|
||||||
Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{}))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("setting", func() {
|
|
||||||
It("reads parameters", func() {
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetSendStreamFlowControlWindow()).To(Equal(protocol.ByteCount(0x11223344)))
|
|
||||||
Expect(pn.GetSendConnectionFlowControlWindow()).To(Equal(protocol.ByteCount(0x22334455)))
|
|
||||||
Expect(pn.GetRemoteIdleTimeout()).To(Equal(0x1337 * time.Second))
|
|
||||||
Expect(pn.OmitConnectionID()).To(BeFalse())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("saves if it should omit the connection ID", func() {
|
|
||||||
params[omitConnectionIDParameterID] = []byte{}
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.OmitConnectionID()).To(BeTrue())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_data is missing", func() {
|
|
||||||
delete(params, initialMaxStreamDataParameterID)
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("missing parameter"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_data is missing", func() {
|
|
||||||
delete(params, initialMaxDataParameterID)
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("missing parameter"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_id is missing", func() {
|
|
||||||
delete(params, initialMaxStreamIDParameterID)
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("missing parameter"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the idle_timeout is missing", func() {
|
|
||||||
delete(params, idleTimeoutParameterID)
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("missing parameter"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("doesn't allow values below the minimum remote idle timeout", func() {
|
|
||||||
t := 2 * time.Second
|
|
||||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
|
||||||
params[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(pn.GetRemoteIdleTimeout()).To(Equal(protocol.MinRemoteIdleTimeout))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
|
|
||||||
params[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_data has the wrong length", func() {
|
|
||||||
params[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_max_stream_id has the wrong length", func() {
|
|
||||||
params[initialMaxStreamIDParameterID] = []byte{0x11, 0x22, 0x33, 0x44, 0x55} // should be 4 bytes
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id: 5 (expected 4)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
|
|
||||||
params[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("rejects the parameters if omit_connection_id is non-empty", func() {
|
|
||||||
params[omitConnectionIDParameterID] = []byte{0} // should be empty
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("ignores unknown parameters", func() {
|
|
||||||
params[1337] = []byte{42}
|
|
||||||
err := pn.SetFromTransportParameters(paramsMapToList(params))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -3,6 +3,7 @@ package handshake
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
|
||||||
|
@ -12,7 +13,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type extensionHandlerClient struct {
|
type extensionHandlerClient struct {
|
||||||
params *paramsNegotiator
|
params *TransportParameters
|
||||||
|
paramsChan chan<- TransportParameters
|
||||||
|
|
||||||
initialVersion protocol.VersionNumber
|
initialVersion protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
|
@ -21,9 +23,16 @@ type extensionHandlerClient struct {
|
||||||
|
|
||||||
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
||||||
|
|
||||||
func newExtensionHandlerClient(params *paramsNegotiator, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerClient {
|
func newExtensionHandlerClient(
|
||||||
|
params *TransportParameters,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
|
initialVersion protocol.VersionNumber,
|
||||||
|
supportedVersions []protocol.VersionNumber,
|
||||||
|
version protocol.VersionNumber,
|
||||||
|
) *extensionHandlerClient {
|
||||||
return &extensionHandlerClient{
|
return &extensionHandlerClient{
|
||||||
params: params,
|
params: params,
|
||||||
|
paramsChan: paramsChan,
|
||||||
initialVersion: initialVersion,
|
initialVersion: initialVersion,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
version: version,
|
version: version,
|
||||||
|
@ -38,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
|
||||||
data, err := syntax.Marshal(clientHelloTransportParameters{
|
data, err := syntax.Marshal(clientHelloTransportParameters{
|
||||||
NegotiatedVersion: uint32(h.version),
|
NegotiatedVersion: uint32(h.version),
|
||||||
InitialVersion: uint32(h.initialVersion),
|
InitialVersion: uint32(h.initialVersion),
|
||||||
Parameters: h.params.GetTransportParameters(),
|
Parameters: h.params.getTransportParameters(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -99,5 +108,12 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
|
||||||
// TODO: return the right error here
|
// TODO: return the right error here
|
||||||
return errors.New("server didn't sent stateless_reset_token")
|
return errors.New("server didn't sent stateless_reset_token")
|
||||||
}
|
}
|
||||||
return h.params.SetFromTransportParameters(eetp.Parameters)
|
params, err := readTransportParamters(eetp.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO(#878): remove this when implementing the MAX_STREAM_ID frame
|
||||||
|
params.MaxStreams = math.MaxUint32
|
||||||
|
h.paramsChan <- *params
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,12 +12,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("TLS Extension Handler, for the client", func() {
|
var _ = Describe("TLS Extension Handler, for the client", func() {
|
||||||
var handler *extensionHandlerClient
|
var (
|
||||||
var el mint.ExtensionList
|
handler *extensionHandlerClient
|
||||||
|
el mint.ExtensionList
|
||||||
|
paramsChan chan TransportParameters
|
||||||
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
pn := ¶msNegotiator{}
|
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
|
||||||
handler = newExtensionHandlerClient(pn, protocol.VersionWhatever, nil, protocol.VersionWhatever)
|
paramsChan = make(chan TransportParameters, 1)
|
||||||
|
handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever)
|
||||||
el = make(mint.ExtensionList, 0)
|
el = make(mint.ExtensionList, 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -78,7 +82,9 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
||||||
addEncryptedExtensionsWithParameters(parameters)
|
addEncryptedExtensionsWithParameters(parameters)
|
||||||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344))
|
var params TransportParameters
|
||||||
|
Expect(paramsChan).To(Receive(¶ms))
|
||||||
|
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() {
|
It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
|
||||||
|
@ -13,7 +14,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type extensionHandlerServer struct {
|
type extensionHandlerServer struct {
|
||||||
params *paramsNegotiator
|
params *TransportParameters
|
||||||
|
paramsChan chan<- TransportParameters
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
|
@ -21,9 +23,15 @@ type extensionHandlerServer struct {
|
||||||
|
|
||||||
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
||||||
|
|
||||||
func newExtensionHandlerServer(params *paramsNegotiator, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerServer {
|
func newExtensionHandlerServer(
|
||||||
|
params *TransportParameters,
|
||||||
|
paramsChan chan<- TransportParameters,
|
||||||
|
supportedVersions []protocol.VersionNumber,
|
||||||
|
version protocol.VersionNumber,
|
||||||
|
) *extensionHandlerServer {
|
||||||
return &extensionHandlerServer{
|
return &extensionHandlerServer{
|
||||||
params: params,
|
params: params,
|
||||||
|
paramsChan: paramsChan,
|
||||||
version: version,
|
version: version,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
}
|
}
|
||||||
|
@ -35,7 +43,8 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
|
||||||
}
|
}
|
||||||
|
|
||||||
transportParams := append(
|
transportParams := append(
|
||||||
h.params.GetTransportParameters(),
|
h.params.getTransportParameters(),
|
||||||
|
// TODO(#855): generate a real token
|
||||||
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
|
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
|
||||||
)
|
)
|
||||||
supportedVersions := make([]uint32, len(h.supportedVersions))
|
supportedVersions := make([]uint32, len(h.supportedVersions))
|
||||||
|
@ -89,5 +98,12 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
|
||||||
return errors.New("client sent a stateless reset token")
|
return errors.New("client sent a stateless reset token")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return h.params.SetFromTransportParameters(chtp.Parameters)
|
params, err := readTransportParamters(chtp.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO(#878): remove this when implementing the MAX_STREAM_ID frame
|
||||||
|
params.MaxStreams = math.MaxUint32
|
||||||
|
h.paramsChan <- *params
|
||||||
|
return nil
|
||||||
}
|
}
|
|
@ -19,12 +19,16 @@ func parameterMapToList(paramMap map[transportParameterID][]byte) []transportPar
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = Describe("TLS Extension Handler, for the server", func() {
|
var _ = Describe("TLS Extension Handler, for the server", func() {
|
||||||
var handler *extensionHandlerServer
|
var (
|
||||||
var el mint.ExtensionList
|
handler *extensionHandlerServer
|
||||||
|
el mint.ExtensionList
|
||||||
|
paramsChan chan TransportParameters
|
||||||
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
pn := ¶msNegotiator{}
|
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
|
||||||
handler = newExtensionHandlerServer(pn, nil, protocol.VersionWhatever)
|
paramsChan = make(chan TransportParameters, 1)
|
||||||
|
handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever)
|
||||||
el = make(mint.ExtensionList, 0)
|
el = make(mint.ExtensionList, 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -79,7 +83,9 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
|
||||||
addClientHelloWithParameters(parameters)
|
addClientHelloWithParameters(parameters)
|
||||||
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344))
|
var params TransportParameters
|
||||||
|
Expect(paramsChan).To(Receive(¶ms))
|
||||||
|
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if the ClientHello doesn't contain TransportParameters", func() {
|
It("errors if the ClientHello doesn't contain TransportParameters", func() {
|
||||||
|
|
|
@ -6,25 +6,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("TLS extension body", func() {
|
var _ = Describe("TLS extension body", func() {
|
||||||
// var server, client mint.AppExtensionHandler
|
|
||||||
// var el mint.ExtensionList
|
|
||||||
|
|
||||||
// BeforeEach(func() {
|
|
||||||
// server = &extensionHandler{perspective: protocol.PerspectiveServer}
|
|
||||||
// client = &extensionHandler{perspective: protocol.PerspectiveClient}
|
|
||||||
// // el = make(mint.ExtensionList, 0)
|
|
||||||
// // TODO: initialize el with some dummy extensions
|
|
||||||
// })
|
|
||||||
|
|
||||||
// It("writes and reads a ClientHello", func() {
|
|
||||||
// err := client.Send(mint.HandshakeTypeClientHello, &el)
|
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
// ch := &tlsExtensionBody{}
|
|
||||||
// found := el.Find(ch)
|
|
||||||
// Expect(found).To(BeTrue())
|
|
||||||
// err = server.Receive(mint.HandshakeTypeClientHello, &el)
|
|
||||||
// Expect(err).ToNot(HaveOccurred())
|
|
||||||
// })
|
|
||||||
var extBody *tlsExtensionBody
|
var extBody *tlsExtensionBody
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
|
246
internal/handshake/transport_parameter_test.go
Normal file
246
internal/handshake/transport_parameter_test.go
Normal file
|
@ -0,0 +1,246 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Transport Parameters", func() {
|
||||||
|
Context("for gQUIC", func() {
|
||||||
|
Context("parsing", func() {
|
||||||
|
It("sets all values", func() {
|
||||||
|
values := map[Tag][]byte{
|
||||||
|
TagSFCW: {0xad, 0xfb, 0xca, 0xde},
|
||||||
|
TagCFCW: {0xef, 0xbe, 0xad, 0xde},
|
||||||
|
TagICSL: {0x0d, 0xf0, 0xad, 0xba},
|
||||||
|
TagMIDS: {0xff, 0x10, 0x00, 0xc0},
|
||||||
|
}
|
||||||
|
params, err := readHelloMap(values)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad)))
|
||||||
|
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef)))
|
||||||
|
Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second))
|
||||||
|
Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff)))
|
||||||
|
Expect(params.OmitConnectionID).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads if the connection ID should be omitted", func() {
|
||||||
|
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
|
||||||
|
params, err := readHelloMap(values)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(params.OmitConnectionID).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't allow idle timeouts below the minimum remote idle timeout", func() {
|
||||||
|
t := 2 * time.Second
|
||||||
|
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
||||||
|
values := map[Tag][]byte{
|
||||||
|
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
|
||||||
|
}
|
||||||
|
params, err := readHelloMap(values)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when given an invalid SFCW value", func() {
|
||||||
|
values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short
|
||||||
|
_, err := readHelloMap(values)
|
||||||
|
Expect(err).To(MatchError(errMalformedTag))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when given an invalid CFCW value", func() {
|
||||||
|
values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short
|
||||||
|
_, err := readHelloMap(values)
|
||||||
|
Expect(err).To(MatchError(errMalformedTag))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when given an invalid TCID value", func() {
|
||||||
|
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
|
||||||
|
_, err := readHelloMap(values)
|
||||||
|
Expect(err).To(MatchError(errMalformedTag))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when given an invalid ICSL value", func() {
|
||||||
|
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
|
||||||
|
_, err := readHelloMap(values)
|
||||||
|
Expect(err).To(MatchError(errMalformedTag))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when given an invalid MIDS value", func() {
|
||||||
|
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
|
||||||
|
_, err := readHelloMap(values)
|
||||||
|
Expect(err).To(MatchError(errMalformedTag))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("writing", func() {
|
||||||
|
It("returns all necessary parameters ", func() {
|
||||||
|
params := &TransportParameters{
|
||||||
|
StreamFlowControlWindow: 0xdeadbeef,
|
||||||
|
ConnectionFlowControlWindow: 0xdecafbad,
|
||||||
|
IdleTimeout: 0xbaaaaaad * time.Second,
|
||||||
|
MaxStreams: 0x1337,
|
||||||
|
}
|
||||||
|
entryMap := params.getHelloMap()
|
||||||
|
Expect(entryMap).To(HaveLen(4))
|
||||||
|
Expect(entryMap).ToNot(HaveKey(TagTCID))
|
||||||
|
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde}))
|
||||||
|
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde}))
|
||||||
|
Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba}))
|
||||||
|
Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0}))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("requests omission of the connection ID", func() {
|
||||||
|
params := &TransportParameters{OmitConnectionID: true}
|
||||||
|
entryMap := params.getHelloMap()
|
||||||
|
Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("for TLS", func() {
|
||||||
|
paramsMapToList := func(p map[transportParameterID][]byte) []transportParameter {
|
||||||
|
var list []transportParameter
|
||||||
|
for id, val := range p {
|
||||||
|
list = append(list, transportParameter{id, val})
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
Context("parsing", func() {
|
||||||
|
var parameters map[transportParameterID][]byte
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
parameters = map[transportParameterID][]byte{
|
||||||
|
initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44},
|
||||||
|
initialMaxDataParameterID: []byte{0x22, 0x33, 0x44, 0x55},
|
||||||
|
initialMaxStreamIDParameterID: []byte{0x33, 0x44, 0x55, 0x66},
|
||||||
|
idleTimeoutParameterID: []byte{0x13, 0x37},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
It("reads parameters", func() {
|
||||||
|
params, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
|
||||||
|
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
|
||||||
|
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
||||||
|
Expect(params.OmitConnectionID).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("saves if it should omit the connection ID", func() {
|
||||||
|
parameters[omitConnectionIDParameterID] = []byte{}
|
||||||
|
params, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(params.OmitConnectionID).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_max_stream_data is missing", func() {
|
||||||
|
delete(parameters, initialMaxStreamDataParameterID)
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("missing parameter"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_max_data is missing", func() {
|
||||||
|
delete(parameters, initialMaxDataParameterID)
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("missing parameter"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_max_stream_id is missing", func() {
|
||||||
|
delete(parameters, initialMaxStreamIDParameterID)
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("missing parameter"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the idle_timeout is missing", func() {
|
||||||
|
delete(parameters, idleTimeoutParameterID)
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("missing parameter"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't allow values below the minimum remote idle timeout", func() {
|
||||||
|
t := 2 * time.Second
|
||||||
|
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
||||||
|
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
|
||||||
|
params, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
|
||||||
|
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_max_data has the wrong length", func() {
|
||||||
|
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_max_stream_id has the wrong length", func() {
|
||||||
|
parameters[initialMaxStreamIDParameterID] = []byte{0x11, 0x22, 0x33, 0x44, 0x55} // should be 4 bytes
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("wrong length for initial_max_stream_id: 5 (expected 4)"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
|
||||||
|
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects the parameters if omit_connection_id is non-empty", func() {
|
||||||
|
parameters[omitConnectionIDParameterID] = []byte{0} // should be empty
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).To(MatchError("wrong length for omit_connection_id: 1 (expected empty)"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores unknown parameters", func() {
|
||||||
|
parameters[1337] = []byte{42}
|
||||||
|
_, err := readTransportParamters(paramsMapToList(parameters))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("writing", func() {
|
||||||
|
var params *TransportParameters
|
||||||
|
|
||||||
|
paramsListToMap := func(l []transportParameter) map[transportParameterID][]byte {
|
||||||
|
p := make(map[transportParameterID][]byte)
|
||||||
|
for _, v := range l {
|
||||||
|
p[v.Parameter] = v.Value
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
params = &TransportParameters{
|
||||||
|
StreamFlowControlWindow: 0xdeadbeef,
|
||||||
|
ConnectionFlowControlWindow: 0xdecafbad,
|
||||||
|
IdleTimeout: 0xcafe,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("creates the parameters list", func() {
|
||||||
|
values := paramsListToMap(params.getTransportParameters())
|
||||||
|
Expect(values).To(HaveLen(5))
|
||||||
|
Expect(values).To(HaveKeyWithValue(initialMaxStreamDataParameterID, []byte{0xde, 0xad, 0xbe, 0xef}))
|
||||||
|
Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, []byte{0xde, 0xca, 0xfb, 0xad}))
|
||||||
|
Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff}))
|
||||||
|
Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe}))
|
||||||
|
Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac
|
||||||
|
})
|
||||||
|
|
||||||
|
It("request ommision of the connection ID", func() {
|
||||||
|
params.OmitConnectionID = true
|
||||||
|
values := paramsListToMap(params.getTransportParameters())
|
||||||
|
Expect(values).To(HaveKeyWithValue(omitConnectionIDParameterID, []byte{}))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
167
internal/handshake/transport_parameters.go
Normal file
167
internal/handshake/transport_parameters.go
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// errMalformedTag is returned when the tag value cannot be read
|
||||||
|
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
||||||
|
|
||||||
|
// TransportParameters are parameters sent to the peer during the handshake
|
||||||
|
type TransportParameters struct {
|
||||||
|
StreamFlowControlWindow protocol.ByteCount
|
||||||
|
ConnectionFlowControlWindow protocol.ByteCount
|
||||||
|
|
||||||
|
MaxStreams uint32
|
||||||
|
|
||||||
|
OmitConnectionID bool
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
|
||||||
|
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
|
||||||
|
params := &TransportParameters{}
|
||||||
|
if value, ok := tags[TagTCID]; ok {
|
||||||
|
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errMalformedTag
|
||||||
|
}
|
||||||
|
params.OmitConnectionID = (v == 0)
|
||||||
|
}
|
||||||
|
if value, ok := tags[TagMIDS]; ok {
|
||||||
|
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errMalformedTag
|
||||||
|
}
|
||||||
|
params.MaxStreams = v
|
||||||
|
}
|
||||||
|
if value, ok := tags[TagICSL]; ok {
|
||||||
|
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errMalformedTag
|
||||||
|
}
|
||||||
|
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
|
||||||
|
}
|
||||||
|
if value, ok := tags[TagSFCW]; ok {
|
||||||
|
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errMalformedTag
|
||||||
|
}
|
||||||
|
params.StreamFlowControlWindow = protocol.ByteCount(v)
|
||||||
|
}
|
||||||
|
if value, ok := tags[TagCFCW]; ok {
|
||||||
|
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errMalformedTag
|
||||||
|
}
|
||||||
|
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
|
||||||
|
}
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
|
||||||
|
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
|
||||||
|
sfcw := bytes.NewBuffer([]byte{})
|
||||||
|
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
|
||||||
|
cfcw := bytes.NewBuffer([]byte{})
|
||||||
|
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
|
||||||
|
mids := bytes.NewBuffer([]byte{})
|
||||||
|
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
|
||||||
|
icsl := bytes.NewBuffer([]byte{})
|
||||||
|
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
|
||||||
|
|
||||||
|
tags := map[Tag][]byte{
|
||||||
|
TagICSL: icsl.Bytes(),
|
||||||
|
TagMIDS: mids.Bytes(),
|
||||||
|
TagCFCW: cfcw.Bytes(),
|
||||||
|
TagSFCW: sfcw.Bytes(),
|
||||||
|
}
|
||||||
|
if p.OmitConnectionID {
|
||||||
|
tags[TagTCID] = []byte{0, 0, 0, 0}
|
||||||
|
}
|
||||||
|
return tags
|
||||||
|
}
|
||||||
|
|
||||||
|
// readTransportParameters reads the transport parameters sent in the QUIC TLS extension
|
||||||
|
func readTransportParamters(paramsList []transportParameter) (*TransportParameters, error) {
|
||||||
|
params := &TransportParameters{}
|
||||||
|
|
||||||
|
var foundInitialMaxStreamData bool
|
||||||
|
var foundInitialMaxData bool
|
||||||
|
var foundInitialMaxStreamID bool
|
||||||
|
var foundIdleTimeout bool
|
||||||
|
|
||||||
|
for _, p := range paramsList {
|
||||||
|
switch p.Parameter {
|
||||||
|
case initialMaxStreamDataParameterID:
|
||||||
|
foundInitialMaxStreamData = true
|
||||||
|
if len(p.Value) != 4 {
|
||||||
|
return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
|
||||||
|
}
|
||||||
|
params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
|
||||||
|
case initialMaxDataParameterID:
|
||||||
|
foundInitialMaxData = true
|
||||||
|
if len(p.Value) != 4 {
|
||||||
|
return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
|
||||||
|
}
|
||||||
|
params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
|
||||||
|
case initialMaxStreamIDParameterID:
|
||||||
|
foundInitialMaxStreamID = true
|
||||||
|
if len(p.Value) != 4 {
|
||||||
|
return nil, fmt.Errorf("wrong length for initial_max_stream_id: %d (expected 4)", len(p.Value))
|
||||||
|
}
|
||||||
|
// TODO: handle this value
|
||||||
|
case idleTimeoutParameterID:
|
||||||
|
foundIdleTimeout = true
|
||||||
|
if len(p.Value) != 2 {
|
||||||
|
return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
|
||||||
|
}
|
||||||
|
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second)
|
||||||
|
case omitConnectionIDParameterID:
|
||||||
|
if len(p.Value) != 0 {
|
||||||
|
return nil, fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value))
|
||||||
|
}
|
||||||
|
params.OmitConnectionID = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(foundInitialMaxStreamData && foundInitialMaxData && foundInitialMaxStreamID && foundIdleTimeout) {
|
||||||
|
return nil, errors.New("missing parameter")
|
||||||
|
}
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTransportParameters gets the parameters needed for the TLS handshake.
|
||||||
|
func (p *TransportParameters) getTransportParameters() []transportParameter {
|
||||||
|
initialMaxStreamData := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow))
|
||||||
|
initialMaxData := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow))
|
||||||
|
initialMaxStreamID := make([]byte, 4)
|
||||||
|
// TODO: use a reasonable value here
|
||||||
|
binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32)
|
||||||
|
idleTimeout := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout))
|
||||||
|
maxPacketSize := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
|
||||||
|
params := []transportParameter{
|
||||||
|
{initialMaxStreamDataParameterID, initialMaxStreamData},
|
||||||
|
{initialMaxDataParameterID, initialMaxData},
|
||||||
|
{initialMaxStreamIDParameterID, initialMaxStreamID},
|
||||||
|
{idleTimeoutParameterID, idleTimeout},
|
||||||
|
{maxPacketSizeParameterID, maxPacketSize},
|
||||||
|
}
|
||||||
|
if p.OmitConnectionID {
|
||||||
|
params = append(params, transportParameter{omitConnectionIDParameterID, []byte{}})
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
177
internal/mocks/flow_control_manager.go
Normal file
177
internal/mocks/flow_control_manager.go
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: ../flowcontrol/interface.go
|
||||||
|
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
|
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockFlowControlManager is a mock of FlowControlManager interface
|
||||||
|
type MockFlowControlManager struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockFlowControlManagerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager
|
||||||
|
type MockFlowControlManagerMockRecorder struct {
|
||||||
|
mock *MockFlowControlManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockFlowControlManager creates a new mock instance
|
||||||
|
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
|
||||||
|
mock := &MockFlowControlManager{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockFlowControlManagerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (_m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder {
|
||||||
|
return _m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStream mocks base method
|
||||||
|
func (_m *MockFlowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) {
|
||||||
|
_m.ctrl.Call(_m, "NewStream", streamID, contributesToConnectionFlow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStream indicates an expected call of NewStream
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveStream mocks base method
|
||||||
|
func (_m *MockFlowControlManager) RemoveStream(streamID protocol.StreamID) {
|
||||||
|
_m.ctrl.Call(_m, "RemoveStream", streamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveStream indicates an expected call of RemoveStream
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) RemoveStream(arg0 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTransportParameters mocks base method
|
||||||
|
func (_m *MockFlowControlManager) UpdateTransportParameters(_param0 *handshake.TransportParameters) {
|
||||||
|
_m.ctrl.Call(_m, "UpdateTransportParameters", _param0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTransportParameters indicates an expected call of UpdateTransportParameters
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) UpdateTransportParameters(arg0 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateTransportParameters", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateTransportParameters), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetStream mocks base method
|
||||||
|
func (_m *MockFlowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "ResetStream", streamID, byteOffset)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetStream indicates an expected call of ResetStream
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHighestReceived mocks base method
|
||||||
|
func (_m *MockFlowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "UpdateHighestReceived", streamID, byteOffset)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHighestReceived indicates an expected call of UpdateHighestReceived
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBytesRead mocks base method
|
||||||
|
func (_m *MockFlowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "AddBytesRead", streamID, n)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBytesRead indicates an expected call of AddBytesRead
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWindowUpdates mocks base method
|
||||||
|
func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
|
||||||
|
ret := _m.ctrl.Call(_m, "GetWindowUpdates")
|
||||||
|
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWindowUpdates indicates an expected call of GetWindowUpdates
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReceiveWindow mocks base method
|
||||||
|
func (_m *MockFlowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
||||||
|
ret := _m.ctrl.Call(_m, "GetReceiveWindow", streamID)
|
||||||
|
ret0, _ := ret[0].(protocol.ByteCount)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReceiveWindow indicates an expected call of GetReceiveWindow
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBytesSent mocks base method
|
||||||
|
func (_m *MockFlowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "AddBytesSent", streamID, n)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBytesSent indicates an expected call of AddBytesSent
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendWindowSize mocks base method
|
||||||
|
func (_m *MockFlowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
||||||
|
ret := _m.ctrl.Call(_m, "SendWindowSize", streamID)
|
||||||
|
ret0, _ := ret[0].(protocol.ByteCount)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendWindowSize indicates an expected call of SendWindowSize
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) SendWindowSize(arg0 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemainingConnectionWindowSize mocks base method
|
||||||
|
func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
|
||||||
|
ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize")
|
||||||
|
ret0, _ := ret[0].(protocol.ByteCount)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWindow mocks base method
|
||||||
|
func (_m *MockFlowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
|
||||||
|
ret := _m.ctrl.Call(_m, "UpdateWindow", streamID, offset)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWindow indicates an expected call of UpdateWindow
|
||||||
|
func (_mr *MockFlowControlManagerMockRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "UpdateWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateWindow), arg0, arg1)
|
||||||
|
}
|
|
@ -3,6 +3,5 @@ package mocks
|
||||||
// mockgen source mode doesn't properly recognize structs defined in the same package
|
// mockgen source mode doesn't properly recognize structs defined in the same package
|
||||||
// so we have to use sed to correct for that
|
// so we have to use sed to correct for that
|
||||||
|
|
||||||
//go:generate sh -c "mockgen -package mocks_fc -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > mocks_fc/flow_control_manager.go"
|
//go:generate sh -c "mockgen -package mocks -source ../flowcontrol/interface.go | sed \"s/\\[\\]WindowUpdate/[]flowcontrol.WindowUpdate/g\" > flow_control_manager.go"
|
||||||
//go:generate sh -c "mockgen -package mocks -source ../handshake/params_negotiator_base.go > params_negotiator.go"
|
|
||||||
//go:generate sh -c "goimports -w ."
|
//go:generate sh -c "goimports -w ."
|
||||||
|
|
|
@ -1,167 +0,0 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: ../flowcontrol/interface.go
|
|
||||||
|
|
||||||
// Package mocks_fc is a generated GoMock package.
|
|
||||||
package mocks_fc
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
|
||||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockFlowControlManager is a mock of FlowControlManager interface
|
|
||||||
type MockFlowControlManager struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockFlowControlManagerMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockFlowControlManagerMockRecorder is the mock recorder for MockFlowControlManager
|
|
||||||
type MockFlowControlManagerMockRecorder struct {
|
|
||||||
mock *MockFlowControlManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockFlowControlManager creates a new mock instance
|
|
||||||
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
|
|
||||||
mock := &MockFlowControlManager{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockFlowControlManagerMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
|
||||||
func (m *MockFlowControlManager) EXPECT() *MockFlowControlManagerMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStream mocks base method
|
|
||||||
func (m *MockFlowControlManager) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) {
|
|
||||||
m.ctrl.Call(m, "NewStream", streamID, contributesToConnectionFlow)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStream indicates an expected call of NewStream
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) NewStream(streamID, contributesToConnectionFlow interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewStream", reflect.TypeOf((*MockFlowControlManager)(nil).NewStream), streamID, contributesToConnectionFlow)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveStream mocks base method
|
|
||||||
func (m *MockFlowControlManager) RemoveStream(streamID protocol.StreamID) {
|
|
||||||
m.ctrl.Call(m, "RemoveStream", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveStream indicates an expected call of RemoveStream
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) RemoveStream(streamID interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveStream", reflect.TypeOf((*MockFlowControlManager)(nil).RemoveStream), streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetStream mocks base method
|
|
||||||
func (m *MockFlowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
|
||||||
ret := m.ctrl.Call(m, "ResetStream", streamID, byteOffset)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetStream indicates an expected call of ResetStream
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) ResetStream(streamID, byteOffset interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetStream", reflect.TypeOf((*MockFlowControlManager)(nil).ResetStream), streamID, byteOffset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateHighestReceived mocks base method
|
|
||||||
func (m *MockFlowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
|
||||||
ret := m.ctrl.Call(m, "UpdateHighestReceived", streamID, byteOffset)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateHighestReceived indicates an expected call of UpdateHighestReceived
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) UpdateHighestReceived(streamID, byteOffset interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateHighestReceived), streamID, byteOffset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddBytesRead mocks base method
|
|
||||||
func (m *MockFlowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
|
|
||||||
ret := m.ctrl.Call(m, "AddBytesRead", streamID, n)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddBytesRead indicates an expected call of AddBytesRead
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) AddBytesRead(streamID, n interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesRead), streamID, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetWindowUpdates mocks base method
|
|
||||||
func (m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
|
|
||||||
ret := m.ctrl.Call(m, "GetWindowUpdates")
|
|
||||||
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetWindowUpdates indicates an expected call of GetWindowUpdates
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) GetWindowUpdates() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdates", reflect.TypeOf((*MockFlowControlManager)(nil).GetWindowUpdates))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetReceiveWindow mocks base method
|
|
||||||
func (m *MockFlowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
|
||||||
ret := m.ctrl.Call(m, "GetReceiveWindow", streamID)
|
|
||||||
ret0, _ := ret[0].(protocol.ByteCount)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetReceiveWindow indicates an expected call of GetReceiveWindow
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) GetReceiveWindow(streamID interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReceiveWindow", reflect.TypeOf((*MockFlowControlManager)(nil).GetReceiveWindow), streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddBytesSent mocks base method
|
|
||||||
func (m *MockFlowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
|
|
||||||
ret := m.ctrl.Call(m, "AddBytesSent", streamID, n)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddBytesSent indicates an expected call of AddBytesSent
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) AddBytesSent(streamID, n interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockFlowControlManager)(nil).AddBytesSent), streamID, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendWindowSize mocks base method
|
|
||||||
func (m *MockFlowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
|
||||||
ret := m.ctrl.Call(m, "SendWindowSize", streamID)
|
|
||||||
ret0, _ := ret[0].(protocol.ByteCount)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendWindowSize indicates an expected call of SendWindowSize
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) SendWindowSize(streamID interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).SendWindowSize), streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemainingConnectionWindowSize mocks base method
|
|
||||||
func (m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
|
|
||||||
ret := m.ctrl.Call(m, "RemainingConnectionWindowSize")
|
|
||||||
ret0, _ := ret[0].(protocol.ByteCount)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemainingConnectionWindowSize indicates an expected call of RemainingConnectionWindowSize
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) RemainingConnectionWindowSize() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemainingConnectionWindowSize", reflect.TypeOf((*MockFlowControlManager)(nil).RemainingConnectionWindowSize))
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateWindow mocks base method
|
|
||||||
func (m *MockFlowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
|
|
||||||
ret := m.ctrl.Call(m, "UpdateWindow", streamID, offset)
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateWindow indicates an expected call of UpdateWindow
|
|
||||||
func (mr *MockFlowControlManagerMockRecorder) UpdateWindow(streamID, offset interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWindow", reflect.TypeOf((*MockFlowControlManager)(nil).UpdateWindow), streamID, offset)
|
|
||||||
}
|
|
|
@ -1,96 +0,0 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: ../handshake/params_negotiator_base.go
|
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
|
||||||
package mocks
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
time "time"
|
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockParamsNegotiator is a mock of ParamsNegotiator interface
|
|
||||||
type MockParamsNegotiator struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockParamsNegotiatorMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockParamsNegotiatorMockRecorder is the mock recorder for MockParamsNegotiator
|
|
||||||
type MockParamsNegotiatorMockRecorder struct {
|
|
||||||
mock *MockParamsNegotiator
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockParamsNegotiator creates a new mock instance
|
|
||||||
func NewMockParamsNegotiator(ctrl *gomock.Controller) *MockParamsNegotiator {
|
|
||||||
mock := &MockParamsNegotiator{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockParamsNegotiatorMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
|
||||||
func (m *MockParamsNegotiator) EXPECT() *MockParamsNegotiatorMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendStreamFlowControlWindow mocks base method
|
|
||||||
func (m *MockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
ret := m.ctrl.Call(m, "GetSendStreamFlowControlWindow")
|
|
||||||
ret0, _ := ret[0].(protocol.ByteCount)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendStreamFlowControlWindow indicates an expected call of GetSendStreamFlowControlWindow
|
|
||||||
func (mr *MockParamsNegotiatorMockRecorder) GetSendStreamFlowControlWindow() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSendStreamFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendStreamFlowControlWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendConnectionFlowControlWindow mocks base method
|
|
||||||
func (m *MockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
ret := m.ctrl.Call(m, "GetSendConnectionFlowControlWindow")
|
|
||||||
ret0, _ := ret[0].(protocol.ByteCount)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSendConnectionFlowControlWindow indicates an expected call of GetSendConnectionFlowControlWindow
|
|
||||||
func (mr *MockParamsNegotiatorMockRecorder) GetSendConnectionFlowControlWindow() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSendConnectionFlowControlWindow", reflect.TypeOf((*MockParamsNegotiator)(nil).GetSendConnectionFlowControlWindow))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxOutgoingStreams mocks base method
|
|
||||||
func (m *MockParamsNegotiator) GetMaxOutgoingStreams() uint32 {
|
|
||||||
ret := m.ctrl.Call(m, "GetMaxOutgoingStreams")
|
|
||||||
ret0, _ := ret[0].(uint32)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMaxOutgoingStreams indicates an expected call of GetMaxOutgoingStreams
|
|
||||||
func (mr *MockParamsNegotiatorMockRecorder) GetMaxOutgoingStreams() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxOutgoingStreams", reflect.TypeOf((*MockParamsNegotiator)(nil).GetMaxOutgoingStreams))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRemoteIdleTimeout mocks base method
|
|
||||||
func (m *MockParamsNegotiator) GetRemoteIdleTimeout() time.Duration {
|
|
||||||
ret := m.ctrl.Call(m, "GetRemoteIdleTimeout")
|
|
||||||
ret0, _ := ret[0].(time.Duration)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRemoteIdleTimeout indicates an expected call of GetRemoteIdleTimeout
|
|
||||||
func (mr *MockParamsNegotiatorMockRecorder) GetRemoteIdleTimeout() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRemoteIdleTimeout", reflect.TypeOf((*MockParamsNegotiator)(nil).GetRemoteIdleTimeout))
|
|
||||||
}
|
|
||||||
|
|
||||||
// OmitConnectionID mocks base method
|
|
||||||
func (m *MockParamsNegotiator) OmitConnectionID() bool {
|
|
||||||
ret := m.ctrl.Call(m, "OmitConnectionID")
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// OmitConnectionID indicates an expected call of OmitConnectionID
|
|
||||||
func (mr *MockParamsNegotiatorMockRecorder) OmitConnectionID() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OmitConnectionID", reflect.TypeOf((*MockParamsNegotiator)(nil).OmitConnectionID))
|
|
||||||
}
|
|
|
@ -43,12 +43,6 @@ const MaxReceivePacketSize ByteCount = 1452
|
||||||
// Used in QUIC for congestion window computations in bytes.
|
// Used in QUIC for congestion window computations in bytes.
|
||||||
const DefaultTCPMSS ByteCount = 1460
|
const DefaultTCPMSS ByteCount = 1460
|
||||||
|
|
||||||
// InitialStreamFlowControlWindow is the initial stream-level flow control window for sending
|
|
||||||
const InitialStreamFlowControlWindow ByteCount = (1 << 14) // 16 kB
|
|
||||||
|
|
||||||
// InitialConnectionFlowControlWindow is the initial connection-level flow control window for sending
|
|
||||||
const InitialConnectionFlowControlWindow ByteCount = (1 << 14) // 16 kB
|
|
||||||
|
|
||||||
// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have.
|
// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have.
|
||||||
const ClientHelloMinimumSize = 1024
|
const ClientHelloMinimumSize = 1024
|
||||||
|
|
||||||
|
|
|
@ -25,18 +25,17 @@ type packetPacker struct {
|
||||||
cryptoSetup handshake.CryptoSetup
|
cryptoSetup handshake.CryptoSetup
|
||||||
|
|
||||||
packetNumberGenerator *packetNumberGenerator
|
packetNumberGenerator *packetNumberGenerator
|
||||||
connParams handshake.ParamsNegotiator
|
|
||||||
streamFramer *streamFramer
|
streamFramer *streamFramer
|
||||||
|
|
||||||
controlFrames []wire.Frame
|
controlFrames []wire.Frame
|
||||||
stopWaiting *wire.StopWaitingFrame
|
stopWaiting *wire.StopWaitingFrame
|
||||||
ackFrame *wire.AckFrame
|
ackFrame *wire.AckFrame
|
||||||
leastUnacked protocol.PacketNumber
|
leastUnacked protocol.PacketNumber
|
||||||
|
omitConnectionID bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPacketPacker(connectionID protocol.ConnectionID,
|
func newPacketPacker(connectionID protocol.ConnectionID,
|
||||||
cryptoSetup handshake.CryptoSetup,
|
cryptoSetup handshake.CryptoSetup,
|
||||||
connParams handshake.ParamsNegotiator,
|
|
||||||
streamFramer *streamFramer,
|
streamFramer *streamFramer,
|
||||||
perspective protocol.Perspective,
|
perspective protocol.Perspective,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
|
@ -44,7 +43,6 @@ func newPacketPacker(connectionID protocol.ConnectionID,
|
||||||
return &packetPacker{
|
return &packetPacker{
|
||||||
cryptoSetup: cryptoSetup,
|
cryptoSetup: cryptoSetup,
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
connParams: connParams,
|
|
||||||
perspective: perspective,
|
perspective: perspective,
|
||||||
version: version,
|
version: version,
|
||||||
streamFramer: streamFramer,
|
streamFramer: streamFramer,
|
||||||
|
@ -268,12 +266,14 @@ func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *wire.
|
||||||
pnum := p.packetNumberGenerator.Peek()
|
pnum := p.packetNumberGenerator.Peek()
|
||||||
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked)
|
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked)
|
||||||
publicHeader := &wire.PublicHeader{
|
publicHeader := &wire.PublicHeader{
|
||||||
ConnectionID: p.connectionID,
|
ConnectionID: p.connectionID,
|
||||||
PacketNumber: pnum,
|
PacketNumber: pnum,
|
||||||
PacketNumberLen: packetNumberLen,
|
PacketNumberLen: packetNumberLen,
|
||||||
OmitConnectionID: p.connParams.OmitConnectionID(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
|
||||||
|
publicHeader.OmitConnectionID = true
|
||||||
|
}
|
||||||
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
|
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
|
||||||
publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
|
publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
|
||||||
}
|
}
|
||||||
|
@ -329,3 +329,7 @@ func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
|
||||||
func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
|
func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
|
||||||
p.leastUnacked = leastUnacked
|
p.leastUnacked = leastUnacked
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *packetPacker) SetOmitConnectionID() {
|
||||||
|
p.omitConnectionID = true
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -61,19 +60,15 @@ var _ = Describe("Packet packer", func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
mockPn := mocks.NewMockParamsNegotiator(mockCtrl)
|
|
||||||
mockPn.EXPECT().OmitConnectionID().Return(false).AnyTimes()
|
|
||||||
|
|
||||||
cryptoStream = &stream{}
|
cryptoStream = &stream{}
|
||||||
|
|
||||||
streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer, nil)
|
streamsMap := newStreamsMap(nil, nil, protocol.PerspectiveServer)
|
||||||
streamsMap.streams[1] = cryptoStream
|
streamsMap.streams[1] = cryptoStream
|
||||||
streamsMap.openStreams = []protocol.StreamID{1}
|
streamsMap.openStreams = []protocol.StreamID{1}
|
||||||
streamFramer = newStreamFramer(streamsMap, nil)
|
streamFramer = newStreamFramer(streamsMap, nil)
|
||||||
|
|
||||||
packer = &packetPacker{
|
packer = &packetPacker{
|
||||||
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
||||||
connParams: mockPn,
|
|
||||||
connectionID: 0x1337,
|
connectionID: 0x1337,
|
||||||
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
|
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
|
||||||
streamFramer: streamFramer,
|
streamFramer: streamFramer,
|
||||||
|
@ -234,6 +229,20 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(p).ToNot(BeNil())
|
Expect(p).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("it omits the connection ID for forward-secure packets", func() {
|
||||||
|
ph := packer.getPublicHeader(protocol.EncryptionForwardSecure)
|
||||||
|
Expect(ph.OmitConnectionID).To(BeFalse())
|
||||||
|
packer.SetOmitConnectionID()
|
||||||
|
ph = packer.getPublicHeader(protocol.EncryptionForwardSecure)
|
||||||
|
Expect(ph.OmitConnectionID).To(BeTrue())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("doesn't omit the connection ID for non-forware-secure packets", func() {
|
||||||
|
packer.SetOmitConnectionID()
|
||||||
|
ph := packer.getPublicHeader(protocol.EncryptionSecure)
|
||||||
|
Expect(ph.OmitConnectionID).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
It("adds the version flag to the public header before the crypto handshake is finished", func() {
|
It("adds the version flag to the public header before the crypto handshake is finished", func() {
|
||||||
packer.perspective = protocol.PerspectiveClient
|
packer.perspective = protocol.PerspectiveClient
|
||||||
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
|
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure
|
||||||
|
|
45
session.go
45
session.go
|
@ -88,6 +88,8 @@ type session struct {
|
||||||
undecryptablePackets []*receivedPacket
|
undecryptablePackets []*receivedPacket
|
||||||
receivedTooManyUndecrytablePacketsTime time.Time
|
receivedTooManyUndecrytablePacketsTime time.Time
|
||||||
|
|
||||||
|
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
|
||||||
|
paramsChan <-chan handshake.TransportParameters
|
||||||
// this channel is passed to the CryptoSetup and receives the current encryption level
|
// this channel is passed to the CryptoSetup and receives the current encryption level
|
||||||
// it is closed as soon as the handshake is complete
|
// it is closed as soon as the handshake is complete
|
||||||
aeadChanged <-chan protocol.EncryptionLevel
|
aeadChanged <-chan protocol.EncryptionLevel
|
||||||
|
@ -100,8 +102,6 @@ type session struct {
|
||||||
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
|
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
|
||||||
handshakeChan chan<- handshakeEvent
|
handshakeChan chan<- handshakeEvent
|
||||||
|
|
||||||
connParams handshake.ParamsNegotiator
|
|
||||||
|
|
||||||
lastRcvdPacketNumber protocol.PacketNumber
|
lastRcvdPacketNumber protocol.PacketNumber
|
||||||
// Used to calculate the next packet number from the truncated wire
|
// Used to calculate the next packet number from the truncated wire
|
||||||
// representation, and sent back in public reset packets
|
// representation, and sent back in public reset packets
|
||||||
|
@ -109,6 +109,7 @@ type session struct {
|
||||||
|
|
||||||
sessionCreationTime time.Time
|
sessionCreationTime time.Time
|
||||||
lastNetworkActivityTime time.Time
|
lastNetworkActivityTime time.Time
|
||||||
|
remoteIdleTimeout time.Duration
|
||||||
|
|
||||||
timer *utils.Timer
|
timer *utils.Timer
|
||||||
// keepAlivePingSent stores whether a Ping frame was sent to the peer or not
|
// keepAlivePingSent stores whether a Ping frame was sent to the peer or not
|
||||||
|
@ -166,7 +167,9 @@ func (s *session) setup(
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
negotiatedVersions []protocol.VersionNumber,
|
||||||
) (packetHandler, <-chan handshakeEvent, error) {
|
) (packetHandler, <-chan handshakeEvent, error) {
|
||||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||||
|
paramsChan := make(chan handshake.TransportParameters)
|
||||||
s.aeadChanged = aeadChanged
|
s.aeadChanged = aeadChanged
|
||||||
|
s.paramsChan = paramsChan
|
||||||
handshakeChan := make(chan handshakeEvent, 3)
|
handshakeChan := make(chan handshakeEvent, 3)
|
||||||
s.handshakeChan = handshakeChan
|
s.handshakeChan = handshakeChan
|
||||||
s.handshakeCompleteChan = make(chan error, 1)
|
s.handshakeCompleteChan = make(chan error, 1)
|
||||||
|
@ -183,7 +186,10 @@ func (s *session) setup(
|
||||||
|
|
||||||
s.rttStats = &congestion.RTTStats{}
|
s.rttStats = &congestion.RTTStats{}
|
||||||
transportParams := &handshake.TransportParameters{
|
transportParams := &handshake.TransportParameters{
|
||||||
IdleTimeout: s.config.IdleTimeout,
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
MaxStreams: protocol.MaxIncomingStreams,
|
||||||
|
IdleTimeout: s.config.IdleTimeout,
|
||||||
}
|
}
|
||||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
||||||
|
@ -194,15 +200,16 @@ func (s *session) setup(
|
||||||
return s.config.AcceptCookie(clientAddr, cookie)
|
return s.config.AcceptCookie(clientAddr, cookie)
|
||||||
}
|
}
|
||||||
if s.version.UsesTLS() {
|
if s.version.UsesTLS() {
|
||||||
s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSServer(
|
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
|
||||||
tlsConf,
|
tlsConf,
|
||||||
transportParams,
|
transportParams,
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
s.config.Versions,
|
s.config.Versions,
|
||||||
s.version,
|
s.version,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
s.cryptoSetup, s.connParams, err = newCryptoSetup(
|
s.cryptoSetup, err = newCryptoSetup(
|
||||||
s.connectionID,
|
s.connectionID,
|
||||||
s.conn.RemoteAddr(),
|
s.conn.RemoteAddr(),
|
||||||
s.version,
|
s.version,
|
||||||
|
@ -210,28 +217,31 @@ func (s *session) setup(
|
||||||
transportParams,
|
transportParams,
|
||||||
s.config.Versions,
|
s.config.Versions,
|
||||||
verifySourceAddr,
|
verifySourceAddr,
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
|
||||||
if s.version.UsesTLS() {
|
if s.version.UsesTLS() {
|
||||||
s.cryptoSetup, s.connParams, err = handshake.NewCryptoSetupTLSClient(
|
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
|
||||||
hostname,
|
hostname,
|
||||||
tlsConf,
|
tlsConf,
|
||||||
transportParams,
|
transportParams,
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
initialVersion,
|
initialVersion,
|
||||||
s.config.Versions,
|
s.config.Versions,
|
||||||
s.version,
|
s.version,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
transportParams.RequestConnectionIDOmission = s.config.RequestConnectionIDOmission
|
s.cryptoSetup, err = newCryptoSetupClient(
|
||||||
s.cryptoSetup, s.connParams, err = newCryptoSetupClient(
|
|
||||||
hostname,
|
hostname,
|
||||||
s.connectionID,
|
s.connectionID,
|
||||||
s.version,
|
s.version,
|
||||||
tlsConf,
|
tlsConf,
|
||||||
transportParams,
|
transportParams,
|
||||||
|
paramsChan,
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
negotiatedVersions,
|
negotiatedVersions,
|
||||||
)
|
)
|
||||||
|
@ -242,16 +252,14 @@ func (s *session) setup(
|
||||||
}
|
}
|
||||||
|
|
||||||
s.flowControlManager = flowcontrol.NewFlowControlManager(
|
s.flowControlManager = flowcontrol.NewFlowControlManager(
|
||||||
s.connParams,
|
|
||||||
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
|
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
|
||||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||||
s.rttStats,
|
s.rttStats,
|
||||||
)
|
)
|
||||||
s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective, s.connParams)
|
s.streamsMap = newStreamsMap(s.newStream, s.flowControlManager.RemoveStream, s.perspective)
|
||||||
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
|
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
|
||||||
s.packer = newPacketPacker(s.connectionID,
|
s.packer = newPacketPacker(s.connectionID,
|
||||||
s.cryptoSetup,
|
s.cryptoSetup,
|
||||||
s.connParams,
|
|
||||||
s.streamFramer,
|
s.streamFramer,
|
||||||
s.perspective,
|
s.perspective,
|
||||||
s.version,
|
s.version,
|
||||||
|
@ -318,6 +326,8 @@ runLoop:
|
||||||
// This is a bit unclean, but works properly, since the packet always
|
// This is a bit unclean, but works properly, since the packet always
|
||||||
// begins with the public header and we never copy it.
|
// begins with the public header and we never copy it.
|
||||||
putPacketBuffer(p.publicHeader.Raw)
|
putPacketBuffer(p.publicHeader.Raw)
|
||||||
|
case p := <-s.paramsChan:
|
||||||
|
s.processTransportParameters(&p)
|
||||||
case l, ok := <-aeadChanged:
|
case l, ok := <-aeadChanged:
|
||||||
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
|
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
|
||||||
s.handshakeComplete = true
|
s.handshakeComplete = true
|
||||||
|
@ -338,7 +348,7 @@ runLoop:
|
||||||
s.sentPacketHandler.OnAlarm()
|
s.sentPacketHandler.OnAlarm()
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.connParams.GetRemoteIdleTimeout()/2 {
|
if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.remoteIdleTimeout/2 {
|
||||||
// send the PING frame since there is no activity in the session
|
// send the PING frame since there is no activity in the session
|
||||||
s.packer.QueueControlFrame(&wire.PingFrame{})
|
s.packer.QueueControlFrame(&wire.PingFrame{})
|
||||||
s.keepAlivePingSent = true
|
s.keepAlivePingSent = true
|
||||||
|
@ -379,7 +389,7 @@ func (s *session) Context() context.Context {
|
||||||
func (s *session) maybeResetTimer() {
|
func (s *session) maybeResetTimer() {
|
||||||
var deadline time.Time
|
var deadline time.Time
|
||||||
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
|
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
|
||||||
deadline = s.lastNetworkActivityTime.Add(s.connParams.GetRemoteIdleTimeout() / 2)
|
deadline = s.lastNetworkActivityTime.Add(s.remoteIdleTimeout / 2)
|
||||||
} else {
|
} else {
|
||||||
deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout)
|
deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout)
|
||||||
}
|
}
|
||||||
|
@ -613,6 +623,15 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
||||||
return s.sendConnectionClose(quicErr)
|
return s.sendConnectionClose(quicErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) processTransportParameters(params *handshake.TransportParameters) {
|
||||||
|
s.remoteIdleTimeout = params.IdleTimeout
|
||||||
|
s.flowControlManager.UpdateTransportParameters(params)
|
||||||
|
s.streamsMap.UpdateMaxStreamLimit(params.MaxStreams)
|
||||||
|
if params.OmitConnectionID {
|
||||||
|
s.packer.SetOmitConnectionID()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *session) sendPacket() error {
|
func (s *session) sendPacket() error {
|
||||||
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
|
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
@ -142,20 +141,6 @@ func areSessionsRunning() bool {
|
||||||
return strings.Contains(b.String(), "quic-go.(*session).run")
|
return strings.Contains(b.String(), "quic-go.(*session).run")
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockParamsNegotiator struct{}
|
|
||||||
|
|
||||||
var _ handshake.ParamsNegotiator = &mockParamsNegotiator{}
|
|
||||||
|
|
||||||
func (m *mockParamsNegotiator) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
|
||||||
return protocol.InitialStreamFlowControlWindow
|
|
||||||
}
|
|
||||||
func (m *mockParamsNegotiator) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
|
||||||
return protocol.InitialConnectionFlowControlWindow
|
|
||||||
}
|
|
||||||
func (m *mockParamsNegotiator) GetMaxOutgoingStreams() uint32 { return 100 }
|
|
||||||
func (m *mockParamsNegotiator) GetRemoteIdleTimeout() time.Duration { return time.Hour }
|
|
||||||
func (m *mockParamsNegotiator) OmitConnectionID() bool { return false }
|
|
||||||
|
|
||||||
var _ = Describe("Session", func() {
|
var _ = Describe("Session", func() {
|
||||||
var (
|
var (
|
||||||
sess *session
|
sess *session
|
||||||
|
@ -178,10 +163,11 @@ var _ = Describe("Session", func() {
|
||||||
_ *handshake.TransportParameters,
|
_ *handshake.TransportParameters,
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
_ func(net.Addr, *Cookie) bool,
|
_ func(net.Addr, *Cookie) bool,
|
||||||
|
_ chan<- handshake.TransportParameters,
|
||||||
aeadChangedP chan<- protocol.EncryptionLevel,
|
aeadChangedP chan<- protocol.EncryptionLevel,
|
||||||
) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
aeadChanged = aeadChangedP
|
aeadChanged = aeadChangedP
|
||||||
return cryptoSetup, &mockParamsNegotiator{}, nil
|
return cryptoSetup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mconn = newMockConnection()
|
mconn = newMockConnection()
|
||||||
|
@ -202,8 +188,6 @@ var _ = Describe("Session", func() {
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
sess = pSess.(*session)
|
sess = pSess.(*session)
|
||||||
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // 1 stream: the crypto stream
|
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // 1 stream: the crypto stream
|
||||||
|
|
||||||
sess.connParams = &mockParamsNegotiator{}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
@ -228,10 +212,11 @@ var _ = Describe("Session", func() {
|
||||||
_ *handshake.TransportParameters,
|
_ *handshake.TransportParameters,
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
cookieFunc func(net.Addr, *Cookie) bool,
|
cookieFunc func(net.Addr, *Cookie) bool,
|
||||||
|
_ chan<- handshake.TransportParameters,
|
||||||
_ chan<- protocol.EncryptionLevel,
|
_ chan<- protocol.EncryptionLevel,
|
||||||
) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
cookieVerify = cookieFunc
|
cookieVerify = cookieFunc
|
||||||
return cryptoSetup, &mockParamsNegotiator{}, nil
|
return cryptoSetup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
conf := populateServerConfig(&Config{})
|
conf := populateServerConfig(&Config{})
|
||||||
|
@ -270,6 +255,10 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("when handling stream frames", func() {
|
Context("when handling stream frames", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
sess.streamsMap.UpdateMaxStreamLimit(100)
|
||||||
|
})
|
||||||
|
|
||||||
It("makes new streams", func() {
|
It("makes new streams", func() {
|
||||||
sess.handleStreamFrame(&wire.StreamFrame{
|
sess.handleStreamFrame(&wire.StreamFrame{
|
||||||
StreamID: 5,
|
StreamID: 5,
|
||||||
|
@ -464,7 +453,7 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
It("passes the byte offset to the flow controller", func() {
|
It("passes the byte offset to the flow controller", func() {
|
||||||
sess.streamsMap.GetOrOpenStream(5)
|
sess.streamsMap.GetOrOpenStream(5)
|
||||||
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
|
fcm := mocks.NewMockFlowControlManager(mockCtrl)
|
||||||
sess.flowControlManager = fcm
|
sess.flowControlManager = fcm
|
||||||
fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337))
|
fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337))
|
||||||
err := sess.handleRstStreamFrame(&wire.RstStreamFrame{
|
err := sess.handleRstStreamFrame(&wire.RstStreamFrame{
|
||||||
|
@ -477,7 +466,7 @@ var _ = Describe("Session", func() {
|
||||||
It("returns errors from the flow controller", func() {
|
It("returns errors from the flow controller", func() {
|
||||||
testErr := errors.New("flow control violation")
|
testErr := errors.New("flow control violation")
|
||||||
sess.streamsMap.GetOrOpenStream(5)
|
sess.streamsMap.GetOrOpenStream(5)
|
||||||
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
|
fcm := mocks.NewMockFlowControlManager(mockCtrl)
|
||||||
sess.flowControlManager = fcm
|
sess.flowControlManager = fcm
|
||||||
fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)).Return(testErr)
|
fcm.EXPECT().ResetStream(protocol.StreamID(5), protocol.ByteCount(0x1337)).Return(testErr)
|
||||||
err := sess.handleRstStreamFrame(&wire.RstStreamFrame{
|
err := sess.handleRstStreamFrame(&wire.RstStreamFrame{
|
||||||
|
@ -525,6 +514,10 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("handling WINDOW_UPDATE frames", func() {
|
Context("handling WINDOW_UPDATE frames", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
sess.flowControlManager.UpdateTransportParameters(&handshake.TransportParameters{ConnectionFlowControlWindow: 0x1000})
|
||||||
|
})
|
||||||
|
|
||||||
It("updates the Flow Control Window of a stream", func() {
|
It("updates the Flow Control Window of a stream", func() {
|
||||||
_, err := sess.GetOrOpenStream(5)
|
_, err := sess.GetOrOpenStream(5)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -1093,7 +1086,7 @@ var _ = Describe("Session", func() {
|
||||||
It("retransmits a WindowUpdate if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() {
|
It("retransmits a WindowUpdate if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() {
|
||||||
_, err := sess.GetOrOpenStream(5)
|
_, err := sess.GetOrOpenStream(5)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
|
fcm := mocks.NewMockFlowControlManager(mockCtrl)
|
||||||
sess.flowControlManager = fcm
|
sess.flowControlManager = fcm
|
||||||
fcm.EXPECT().GetWindowUpdates()
|
fcm.EXPECT().GetWindowUpdates()
|
||||||
fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x1000), nil)
|
fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x1000), nil)
|
||||||
|
@ -1114,7 +1107,7 @@ var _ = Describe("Session", func() {
|
||||||
It("doesn't retransmit WindowUpdates if it already sent a WindowUpdate with a higher ByteOffset", func() {
|
It("doesn't retransmit WindowUpdates if it already sent a WindowUpdate with a higher ByteOffset", func() {
|
||||||
_, err := sess.GetOrOpenStream(5)
|
_, err := sess.GetOrOpenStream(5)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
fcm := mocks_fc.NewMockFlowControlManager(mockCtrl)
|
fcm := mocks.NewMockFlowControlManager(mockCtrl)
|
||||||
sess.flowControlManager = fcm
|
sess.flowControlManager = fcm
|
||||||
fcm.EXPECT().GetWindowUpdates()
|
fcm.EXPECT().GetWindowUpdates()
|
||||||
fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x2000), nil)
|
fcm.EXPECT().GetReceiveWindow(protocol.StreamID(5)).Return(protocol.ByteCount(0x2000), nil)
|
||||||
|
@ -1140,7 +1133,7 @@ var _ = Describe("Session", func() {
|
||||||
err = sess.streamsMap.DeleteClosedStreams()
|
err = sess.streamsMap.DeleteClosedStreams()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = sess.flowControlManager.SendWindowSize(5)
|
_, err = sess.flowControlManager.SendWindowSize(5)
|
||||||
Expect(err).To(MatchError("Error accessing the flowController map."))
|
Expect(err).To(MatchError("Error accessing the flowController map"))
|
||||||
sph.retransmissionQueue = []*ackhandler.Packet{{
|
sph.retransmissionQueue = []*ackhandler.Packet{{
|
||||||
Frames: []wire.Frame{&wire.WindowUpdateFrame{
|
Frames: []wire.Frame{&wire.WindowUpdateFrame{
|
||||||
StreamID: 5,
|
StreamID: 5,
|
||||||
|
@ -1183,6 +1176,11 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
Context("scheduling sending", func() {
|
Context("scheduling sending", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
sess.processTransportParameters(&handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.MaxByteCount,
|
||||||
|
ConnectionFlowControlWindow: protocol.MaxByteCount,
|
||||||
|
MaxStreams: 1000,
|
||||||
|
})
|
||||||
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -1420,16 +1418,33 @@ var _ = Describe("Session", func() {
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("process transport parameters received from the peer", func() {
|
||||||
|
paramsChan := make(chan handshake.TransportParameters)
|
||||||
|
sess.paramsChan = paramsChan
|
||||||
|
_, err := sess.GetOrOpenStream(5)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
go sess.run()
|
||||||
|
paramsChan <- handshake.TransportParameters{
|
||||||
|
MaxStreams: 123,
|
||||||
|
IdleTimeout: 90 * time.Second,
|
||||||
|
StreamFlowControlWindow: 0x5000,
|
||||||
|
ConnectionFlowControlWindow: 0x5000,
|
||||||
|
OmitConnectionID: true,
|
||||||
|
}
|
||||||
|
Eventually(func() time.Duration { return sess.remoteIdleTimeout }).Should(Equal(90 * time.Second))
|
||||||
|
Eventually(func() uint32 { return sess.streamsMap.maxOutgoingStreams }).Should(Equal(uint32(123)))
|
||||||
|
Eventually(func() (protocol.ByteCount, error) { return sess.flowControlManager.SendWindowSize(5) }).Should(Equal(protocol.ByteCount(0x5000)))
|
||||||
|
Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue())
|
||||||
|
Expect(sess.Close(nil)).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
Context("keep-alives", func() {
|
Context("keep-alives", func() {
|
||||||
var mockPn *mocks.MockParamsNegotiator
|
|
||||||
// should be shorter than the local timeout for these tests
|
// should be shorter than the local timeout for these tests
|
||||||
// otherwise we'd send a CONNECTION_CLOSE in the tests where we're testing that no PING is sent
|
// otherwise we'd send a CONNECTION_CLOSE in the tests where we're testing that no PING is sent
|
||||||
remoteIdleTimeout := 20 * time.Second
|
remoteIdleTimeout := 20 * time.Second
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
|
sess.remoteIdleTimeout = remoteIdleTimeout
|
||||||
mockPn.EXPECT().GetRemoteIdleTimeout().Return(remoteIdleTimeout).AnyTimes()
|
|
||||||
sess.connParams = mockPn
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends a PING", func() {
|
It("sends a PING", func() {
|
||||||
|
@ -1523,6 +1538,10 @@ var _ = Describe("Session", func() {
|
||||||
}, 0.5)
|
}, 0.5)
|
||||||
|
|
||||||
Context("getting streams", func() {
|
Context("getting streams", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
sess.processTransportParameters(&handshake.TransportParameters{MaxStreams: 1000})
|
||||||
|
})
|
||||||
|
|
||||||
It("returns a new stream", func() {
|
It("returns a new stream", func() {
|
||||||
str, err := sess.GetOrOpenStream(11)
|
str, err := sess.GetOrOpenStream(11)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -1653,11 +1672,12 @@ var _ = Describe("Client Session", func() {
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ *tls.Config,
|
_ *tls.Config,
|
||||||
_ *handshake.TransportParameters,
|
_ *handshake.TransportParameters,
|
||||||
|
_ chan<- handshake.TransportParameters,
|
||||||
aeadChangedP chan<- protocol.EncryptionLevel,
|
aeadChangedP chan<- protocol.EncryptionLevel,
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
) (handshake.CryptoSetup, handshake.ParamsNegotiator, error) {
|
) (handshake.CryptoSetup, error) {
|
||||||
aeadChanged = aeadChangedP
|
aeadChanged = aeadChangedP
|
||||||
return cryptoSetup, &mockParamsNegotiator{}, nil
|
return cryptoSetup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
mconn = newMockConnection()
|
mconn = newMockConnection()
|
||||||
|
|
|
@ -3,7 +3,7 @@ package quic
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc"
|
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -21,7 +21,7 @@ var _ = Describe("Stream Framer", func() {
|
||||||
framer *streamFramer
|
framer *streamFramer
|
||||||
streamsMap *streamsMap
|
streamsMap *streamsMap
|
||||||
stream1, stream2 *stream
|
stream1, stream2 *stream
|
||||||
mockFcm *mocks_fc.MockFlowControlManager
|
mockFcm *mocks.MockFlowControlManager
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -37,11 +37,11 @@ var _ = Describe("Stream Framer", func() {
|
||||||
stream1 = &stream{streamID: id1}
|
stream1 = &stream{streamID: id1}
|
||||||
stream2 = &stream{streamID: id2}
|
stream2 = &stream{streamID: id2}
|
||||||
|
|
||||||
streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer, nil)
|
streamsMap = newStreamsMap(nil, nil, protocol.PerspectiveServer)
|
||||||
streamsMap.putStream(stream1)
|
streamsMap.putStream(stream1)
|
||||||
streamsMap.putStream(stream2)
|
streamsMap.putStream(stream2)
|
||||||
|
|
||||||
mockFcm = mocks_fc.NewMockFlowControlManager(mockCtrl)
|
mockFcm = mocks.NewMockFlowControlManager(mockCtrl)
|
||||||
framer = newStreamFramer(streamsMap, mockFcm)
|
framer = newStreamFramer(streamsMap, mockFcm)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc"
|
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ var _ = Describe("Stream", func() {
|
||||||
resetCalledForStream protocol.StreamID
|
resetCalledForStream protocol.StreamID
|
||||||
resetCalledAtOffset protocol.ByteCount
|
resetCalledAtOffset protocol.ByteCount
|
||||||
|
|
||||||
mockFcm *mocks_fc.MockFlowControlManager
|
mockFcm *mocks.MockFlowControlManager
|
||||||
)
|
)
|
||||||
|
|
||||||
// in the tests for the stream deadlines we set a deadline
|
// in the tests for the stream deadlines we set a deadline
|
||||||
|
@ -58,7 +58,7 @@ var _ = Describe("Stream", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
onDataCalled = false
|
onDataCalled = false
|
||||||
resetCalled = false
|
resetCalled = false
|
||||||
mockFcm = mocks_fc.NewMockFlowControlManager(mockCtrl)
|
mockFcm = mocks.NewMockFlowControlManager(mockCtrl)
|
||||||
str = newStream(streamID, onData, onReset, mockFcm)
|
str = newStream(streamID, onData, onReset, mockFcm)
|
||||||
|
|
||||||
timeout := scaleDuration(250 * time.Millisecond)
|
timeout := scaleDuration(250 * time.Millisecond)
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
@ -14,7 +13,6 @@ import (
|
||||||
type streamsMap struct {
|
type streamsMap struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
connParams handshake.ParamsNegotiator
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
|
|
||||||
streams map[protocol.StreamID]*stream
|
streams map[protocol.StreamID]*stream
|
||||||
|
@ -36,6 +34,7 @@ type streamsMap struct {
|
||||||
numOutgoingStreams uint32
|
numOutgoingStreams uint32
|
||||||
numIncomingStreams uint32
|
numIncomingStreams uint32
|
||||||
maxIncomingStreams uint32
|
maxIncomingStreams uint32
|
||||||
|
maxOutgoingStreams uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamLambda func(*stream) (bool, error)
|
type streamLambda func(*stream) (bool, error)
|
||||||
|
@ -44,7 +43,7 @@ type newStreamLambda func(protocol.StreamID) *stream
|
||||||
|
|
||||||
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||||
|
|
||||||
func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective, connParams handshake.ParamsNegotiator) *streamsMap {
|
func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamCallback, pers protocol.Perspective) *streamsMap {
|
||||||
// add some tolerance to the maximum incoming streams value
|
// add some tolerance to the maximum incoming streams value
|
||||||
maxStreams := uint32(protocol.MaxIncomingStreams)
|
maxStreams := uint32(protocol.MaxIncomingStreams)
|
||||||
maxIncomingStreams := utils.MaxUint32(
|
maxIncomingStreams := utils.MaxUint32(
|
||||||
|
@ -57,7 +56,6 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC
|
||||||
openStreams: make([]protocol.StreamID, 0),
|
openStreams: make([]protocol.StreamID, 0),
|
||||||
newStream: newStream,
|
newStream: newStream,
|
||||||
removeStreamCallback: removeStreamCallback,
|
removeStreamCallback: removeStreamCallback,
|
||||||
connParams: connParams,
|
|
||||||
maxIncomingStreams: maxIncomingStreams,
|
maxIncomingStreams: maxIncomingStreams,
|
||||||
}
|
}
|
||||||
sm.nextStreamOrErrCond.L = &sm.mutex
|
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||||
|
@ -66,6 +64,8 @@ func newStreamsMap(newStream newStreamLambda, removeStreamCallback removeStreamC
|
||||||
if pers == protocol.PerspectiveClient {
|
if pers == protocol.PerspectiveClient {
|
||||||
sm.nextStream = 1
|
sm.nextStream = 1
|
||||||
sm.nextStreamToAccept = 2
|
sm.nextStreamToAccept = 2
|
||||||
|
// TODO: find a better solution for opening the crypto stream
|
||||||
|
sm.maxOutgoingStreams = 1 // allow the crypto stream
|
||||||
} else {
|
} else {
|
||||||
sm.nextStream = 2
|
sm.nextStream = 2
|
||||||
sm.nextStreamToAccept = 1
|
sm.nextStreamToAccept = 1
|
||||||
|
@ -159,7 +159,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
|
||||||
|
|
||||||
func (m *streamsMap) openStreamImpl() (*stream, error) {
|
func (m *streamsMap) openStreamImpl() (*stream, error) {
|
||||||
id := m.nextStream
|
id := m.nextStream
|
||||||
if m.numOutgoingStreams >= m.connParams.GetMaxOutgoingStreams() {
|
if m.numOutgoingStreams >= m.maxOutgoingStreams {
|
||||||
return nil, qerr.TooManyOpenStreams
|
return nil, qerr.TooManyOpenStreams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -340,3 +340,9 @@ func (m *streamsMap) CloseWithError(err error) {
|
||||||
m.streams[s].Cancel(err)
|
m.streams[s].Cancel(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
m.maxOutgoingStreams = limit
|
||||||
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package quic
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -11,22 +10,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Streams Map", func() {
|
var _ = Describe("Streams Map", func() {
|
||||||
const maxOutgoingStreams = 60
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
m *streamsMap
|
m *streamsMap
|
||||||
mockPn *mocks.MockParamsNegotiator
|
|
||||||
)
|
)
|
||||||
|
|
||||||
setNewStreamsMap := func(p protocol.Perspective) {
|
setNewStreamsMap := func(p protocol.Perspective) {
|
||||||
mockPn = mocks.NewMockParamsNegotiator(mockCtrl)
|
|
||||||
mockPn.EXPECT().GetMaxOutgoingStreams().AnyTimes().Return(uint32(maxOutgoingStreams))
|
|
||||||
|
|
||||||
newStream := func(id protocol.StreamID) *stream {
|
newStream := func(id protocol.StreamID) *stream {
|
||||||
return newStream(id, func() {}, nil, nil)
|
return newStream(id, func() {}, nil, nil)
|
||||||
}
|
}
|
||||||
removeStreamCallback := func(protocol.StreamID) {}
|
removeStreamCallback := func(protocol.StreamID) {}
|
||||||
m = newStreamsMap(newStream, removeStreamCallback, p, mockPn)
|
m = newStreamsMap(newStream, removeStreamCallback, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
@ -132,7 +125,13 @@ var _ = Describe("Streams Map", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("server-side streams", func() {
|
Context("server-side streams", func() {
|
||||||
|
It("doesn't allow opening streams before receiving the transport parameters", func() {
|
||||||
|
_, err := m.OpenStream()
|
||||||
|
Expect(err).To(MatchError(qerr.TooManyOpenStreams))
|
||||||
|
})
|
||||||
|
|
||||||
It("opens a stream 2 first", func() {
|
It("opens a stream 2 first", func() {
|
||||||
|
m.UpdateMaxStreamLimit(100)
|
||||||
s, err := m.OpenStream()
|
s, err := m.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(s).ToNot(BeNil())
|
Expect(s).ToNot(BeNil())
|
||||||
|
@ -149,6 +148,7 @@ var _ = Describe("Streams Map", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't reopen an already closed stream", func() {
|
It("doesn't reopen an already closed stream", func() {
|
||||||
|
m.UpdateMaxStreamLimit(100)
|
||||||
str, err := m.OpenStream()
|
str, err := m.OpenStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
|
Expect(str.StreamID()).To(Equal(protocol.StreamID(2)))
|
||||||
|
@ -160,6 +160,12 @@ var _ = Describe("Streams Map", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("counting streams", func() {
|
Context("counting streams", func() {
|
||||||
|
const maxOutgoingStreams = 50
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
m.UpdateMaxStreamLimit(maxOutgoingStreams)
|
||||||
|
})
|
||||||
|
|
||||||
It("errors when too many streams are opened", func() {
|
It("errors when too many streams are opened", func() {
|
||||||
for i := 1; i <= maxOutgoingStreams; i++ {
|
for i := 1; i <= maxOutgoingStreams; i++ {
|
||||||
_, err := m.OpenStream()
|
_, err := m.OpenStream()
|
||||||
|
@ -190,6 +196,12 @@ var _ = Describe("Streams Map", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("opening streams synchronously", func() {
|
Context("opening streams synchronously", func() {
|
||||||
|
const maxOutgoingStreams = 10
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
m.UpdateMaxStreamLimit(maxOutgoingStreams)
|
||||||
|
})
|
||||||
|
|
||||||
openMaxNumStreams := func() {
|
openMaxNumStreams := func() {
|
||||||
for i := 1; i <= maxOutgoingStreams; i++ {
|
for i := 1; i <= maxOutgoingStreams; i++ {
|
||||||
_, err := m.OpenStream()
|
_, err := m.OpenStream()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue