mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
use stream 0 for the crypto stream when using TLS
This commit is contained in:
parent
05f6e1cf8e
commit
f662822486
16 changed files with 93 additions and 59 deletions
|
@ -90,9 +90,6 @@ func (c *client) dial() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.headerStream.StreamID() != 3 {
|
||||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
|
|
|
@ -97,16 +97,6 @@ var _ = Describe("Client", func() {
|
|||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("errors if the header stream has the wrong stream ID", func() {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).To(MatchError("h2quic Client BUG: StreamID of Header Stream is not 3"))
|
||||
})
|
||||
|
||||
It("errors if it can't open a stream", func() {
|
||||
testErr := errors.New("you shall not pass")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
|
||||
|
|
|
@ -122,10 +122,6 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
||||
return
|
||||
}
|
||||
if stream.StreamID() != 3 {
|
||||
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
|
||||
return
|
||||
}
|
||||
|
||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||
h2framer := http2.NewFramer(nil, stream)
|
||||
|
|
|
@ -308,19 +308,6 @@ var _ = Describe("H2 server", func() {
|
|||
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
||||
})
|
||||
|
||||
It("errors if the accepted header stream has the wrong stream ID", func() {
|
||||
headerStream := &mockStream{id: 1}
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
session.streamToAccept = headerStream
|
||||
go s.handleHeaderStream(session)
|
||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
||||
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3")))
|
||||
})
|
||||
|
||||
It("supports closing after first request", func() {
|
||||
s.CloseAfterFirstRequest = true
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
|
|
@ -45,7 +45,7 @@ func (vn VersionNumber) String() string {
|
|||
case VersionTLS:
|
||||
return "TLS dev version (WIP)"
|
||||
default:
|
||||
if vn > gquicVersion0 && vn <= maxGquicVersion {
|
||||
if vn.isGQUIC() {
|
||||
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
|
||||
}
|
||||
return fmt.Sprintf("%d", vn)
|
||||
|
@ -54,12 +54,24 @@ func (vn VersionNumber) String() string {
|
|||
|
||||
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
|
||||
func (vn VersionNumber) ToAltSvc() string {
|
||||
if vn > gquicVersion0 && vn <= maxGquicVersion {
|
||||
if vn.isGQUIC() {
|
||||
return fmt.Sprintf("%d", vn.toGQUICVersion())
|
||||
}
|
||||
return fmt.Sprintf("%d", vn)
|
||||
}
|
||||
|
||||
// CryptoStreamID gets the Stream ID of the crypto stream
|
||||
func (vn VersionNumber) CryptoStreamID() StreamID {
|
||||
if vn.isGQUIC() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (vn VersionNumber) isGQUIC() bool {
|
||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||
}
|
||||
|
||||
func (vn VersionNumber) toGQUICVersion() int {
|
||||
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
|
||||
}
|
||||
|
|
|
@ -43,7 +43,13 @@ var _ = Describe("Version", func() {
|
|||
Expect(VersionNumber(0x51303133).ToAltSvc()).To(Equal("13"))
|
||||
Expect(VersionNumber(0x51303235).ToAltSvc()).To(Equal("25"))
|
||||
Expect(VersionNumber(0x51303438).ToAltSvc()).To(Equal("48"))
|
||||
})
|
||||
|
||||
It("tells the Stream ID of the crypto stream", func() {
|
||||
Expect(Version37.CryptoStreamID()).To(Equal(StreamID(1)))
|
||||
Expect(Version38.CryptoStreamID()).To(Equal(StreamID(1)))
|
||||
Expect(Version39.CryptoStreamID()).To(Equal(StreamID(1)))
|
||||
Expect(VersionTLS.CryptoStreamID()).To(Equal(StreamID(0)))
|
||||
})
|
||||
|
||||
It("recognizes supported versions", func() {
|
||||
|
|
|
@ -63,7 +63,7 @@ var _ = Describe("Packet packer", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)}
|
||||
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer)
|
||||
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
streamFramer = newStreamFramer(cryptoStream, streamsMap, nil)
|
||||
|
||||
packer = &packetPacker{
|
||||
|
@ -574,7 +574,10 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}))
|
||||
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
|
||||
StreamID: packer.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}))
|
||||
})
|
||||
|
||||
It("sends encrypted stream data on the crypto stream", func() {
|
||||
|
@ -584,7 +587,10 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
|
||||
Expect(p.frames).To(HaveLen(1))
|
||||
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}))
|
||||
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
|
||||
StreamID: packer.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}))
|
||||
})
|
||||
|
||||
It("does not pack stream frames if not allowed", func() {
|
||||
|
|
|
@ -55,7 +55,7 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by
|
|||
err = qerr.Error(qerr.InvalidStreamData, err.Error())
|
||||
} else {
|
||||
streamID := frame.(*wire.StreamFrame).StreamID
|
||||
if streamID != 1 && encryptionLevel <= protocol.EncryptionUnencrypted {
|
||||
if streamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted {
|
||||
err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -224,10 +224,10 @@ var _ = Describe("Packet unpacker", func() {
|
|||
})
|
||||
|
||||
Context("unpacking STREAM frames", func() {
|
||||
It("unpacks unencrypted STREAM frames on stream 1", func() {
|
||||
It("unpacks unencrypted STREAM frames on the crypto stream", func() {
|
||||
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 1,
|
||||
StreamID: unpacker.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := f.Write(buf, 0)
|
||||
|
@ -238,10 +238,10 @@ var _ = Describe("Packet unpacker", func() {
|
|||
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
|
||||
It("unpacks encrypted STREAM frames on stream 1", func() {
|
||||
It("unpacks encrypted STREAM frames on the crypto stream", func() {
|
||||
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionSecure
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 1,
|
||||
StreamID: unpacker.version.CryptoStreamID(),
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
err := f.Write(buf, 0)
|
||||
|
|
10
session.go
10
session.go
|
@ -200,8 +200,8 @@ func (s *session) setup(
|
|||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
s.rttStats,
|
||||
)
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
|
||||
s.cryptoStream = s.newStream(1)
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
|
||||
s.cryptoStream = s.newStream(s.version.CryptoStreamID())
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
|
||||
|
||||
var err error
|
||||
|
@ -527,7 +527,7 @@ func (s *session) handlePacket(p *receivedPacket) {
|
|||
}
|
||||
|
||||
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
|
||||
if frame.StreamID == 1 {
|
||||
if frame.StreamID == s.version.CryptoStreamID() {
|
||||
return s.cryptoStream.AddStreamFrame(frame)
|
||||
}
|
||||
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
|
||||
|
@ -820,7 +820,7 @@ func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.By
|
|||
func (s *session) newStream(id protocol.StreamID) streamI {
|
||||
// TODO: find a better solution for determining which streams contribute to connection level flow control
|
||||
var contributesToConnection bool
|
||||
if id != 1 && id != 3 {
|
||||
if id != 0 && id != 1 && id != 3 {
|
||||
contributesToConnection = true
|
||||
}
|
||||
var initialSendWindow protocol.ByteCount
|
||||
|
@ -836,7 +836,7 @@ func (s *session) newStream(id protocol.StreamID) streamI {
|
|||
initialSendWindow,
|
||||
s.rttStats,
|
||||
)
|
||||
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController)
|
||||
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version)
|
||||
}
|
||||
|
||||
func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
|
||||
|
|
|
@ -75,6 +75,7 @@ type stream struct {
|
|||
writeDeadline time.Time
|
||||
|
||||
flowController flowcontrol.StreamFlowController
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ Stream = &stream{}
|
||||
|
@ -93,6 +94,7 @@ func newStream(StreamID protocol.StreamID,
|
|||
onData func(),
|
||||
onReset func(protocol.StreamID, protocol.ByteCount),
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
version protocol.VersionNumber,
|
||||
) *stream {
|
||||
s := &stream{
|
||||
onData: onData,
|
||||
|
@ -102,6 +104,7 @@ func newStream(StreamID protocol.StreamID,
|
|||
frameQueue: newStreamFrameSorter(),
|
||||
readChan: make(chan struct{}, 1),
|
||||
writeChan: make(chan struct{}, 1),
|
||||
version: version,
|
||||
}
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
||||
return s
|
||||
|
@ -274,7 +277,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte {
|
|||
}
|
||||
|
||||
// TODO(#657): Flow control for the crypto stream
|
||||
if s.streamID != 1 {
|
||||
if s.streamID != s.version.CryptoStreamID() {
|
||||
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
|
||||
}
|
||||
if maxBytes == 0 {
|
||||
|
|
|
@ -60,7 +60,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str
|
|||
return nil
|
||||
}
|
||||
frame := &wire.StreamFrame{
|
||||
StreamID: 1,
|
||||
StreamID: f.cryptoStream.StreamID(),
|
||||
Offset: f.cryptoStream.GetWriteOffset(),
|
||||
}
|
||||
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error
|
||||
|
|
|
@ -41,7 +41,7 @@ var _ = Describe("Stream Framer", func() {
|
|||
stream2 = mocks.NewMockStreamI(mockCtrl)
|
||||
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
|
||||
|
||||
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer)
|
||||
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
streamsMap.putStream(stream1)
|
||||
streamsMap.putStream(stream2)
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ var _ = Describe("Stream", func() {
|
|||
onDataCalled = false
|
||||
resetCalled = false
|
||||
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str = newStream(streamID, onData, onReset, mockFC)
|
||||
str = newStream(streamID, onData, onReset, mockFC, protocol.VersionWhatever)
|
||||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = struct {
|
||||
|
|
|
@ -41,7 +41,7 @@ type newStreamLambda func(protocol.StreamID) streamI
|
|||
|
||||
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
|
||||
|
||||
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *streamsMap {
|
||||
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap {
|
||||
// add some tolerance to the maximum incoming streams value
|
||||
maxStreams := uint32(protocol.MaxIncomingStreams)
|
||||
maxIncomingStreams := utils.MaxUint32(
|
||||
|
@ -58,12 +58,16 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *stream
|
|||
sm.nextStreamOrErrCond.L = &sm.mutex
|
||||
sm.openStreamOrErrCond.L = &sm.mutex
|
||||
|
||||
nextOddStream := protocol.StreamID(1)
|
||||
if ver.CryptoStreamID() == protocol.StreamID(1) {
|
||||
nextOddStream = 3
|
||||
}
|
||||
if pers == protocol.PerspectiveClient {
|
||||
sm.nextStream = 3
|
||||
sm.nextStream = nextOddStream
|
||||
sm.nextStreamToAccept = 2
|
||||
} else {
|
||||
sm.nextStream = 2
|
||||
sm.nextStreamToAccept = 3
|
||||
sm.nextStreamToAccept = nextOddStream
|
||||
}
|
||||
|
||||
return &sm
|
||||
|
|
|
@ -14,6 +14,11 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Streams Map", func() {
|
||||
const (
|
||||
versionCryptoStream1 = protocol.Version39
|
||||
versionCryptoStream0 = protocol.VersionTLS
|
||||
)
|
||||
|
||||
var (
|
||||
m *streamsMap
|
||||
finishedStreams map[protocol.StreamID]*gomock.Call
|
||||
|
@ -27,11 +32,13 @@ var _ = Describe("Streams Map", func() {
|
|||
return str
|
||||
}
|
||||
|
||||
setNewStreamsMap := func(p protocol.Perspective) {
|
||||
m = newStreamsMap(newStream, p)
|
||||
setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) {
|
||||
m = newStreamsMap(newStream, p, v)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(versionCryptoStream0.CryptoStreamID()).To(Equal(protocol.StreamID(0)))
|
||||
Expect(versionCryptoStream1.CryptoStreamID()).To(Equal(protocol.StreamID(1)))
|
||||
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
|
||||
})
|
||||
|
||||
|
@ -50,7 +57,7 @@ var _ = Describe("Streams Map", func() {
|
|||
Context("getting and creating streams", func() {
|
||||
Context("as a server", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer)
|
||||
setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream1)
|
||||
})
|
||||
|
||||
Context("client-side streams", func() {
|
||||
|
@ -280,7 +287,22 @@ var _ = Describe("Streams Map", func() {
|
|||
Consistently(func() bool { return accepted }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("start with stream 3", func() {
|
||||
It("starts with stream 1, if the crypto stream is stream 0", func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream0)
|
||||
var str streamI
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
str, err = m.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
_, err := m.GetOrOpenStream(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(func() Stream { return str }).ShouldNot(BeNil())
|
||||
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
|
||||
})
|
||||
|
||||
It("starts with stream 3, if the crypto stream is stream 1", func() {
|
||||
var str streamI
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -399,7 +421,7 @@ var _ = Describe("Streams Map", func() {
|
|||
|
||||
Context("as a client", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveClient)
|
||||
setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream1)
|
||||
m.UpdateMaxStreamLimit(100)
|
||||
})
|
||||
|
||||
|
@ -445,7 +467,18 @@ var _ = Describe("Streams Map", func() {
|
|||
})
|
||||
|
||||
Context("server-side streams", func() {
|
||||
It("starts with stream 3", func() {
|
||||
It("starts with stream 1, if the crypto stream is stream 0", func() {
|
||||
setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream0)
|
||||
m.UpdateMaxStreamLimit(100)
|
||||
s, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s.StreamID()).To(BeEquivalentTo(1))
|
||||
Expect(m.numOutgoingStreams).To(BeZero())
|
||||
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
|
||||
})
|
||||
|
||||
It("starts with stream 3, if the crypto stream is stream 1", func() {
|
||||
s, err := m.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
|
@ -493,7 +526,7 @@ var _ = Describe("Streams Map", func() {
|
|||
|
||||
Context("DoS mitigation, iterating and deleting", func() {
|
||||
BeforeEach(func() {
|
||||
setNewStreamsMap(protocol.PerspectiveServer)
|
||||
setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream1)
|
||||
})
|
||||
|
||||
closeStream := func(id protocol.StreamID) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue