use stream 0 for the crypto stream when using TLS

This commit is contained in:
Marten Seemann 2017-11-01 13:51:47 +07:00
parent 05f6e1cf8e
commit f662822486
16 changed files with 93 additions and 59 deletions

View file

@ -90,9 +90,6 @@ func (c *client) dial() error {
if err != nil {
return err
}
if c.headerStream.StreamID() != 3 {
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
}
c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream()
return nil

View file

@ -97,16 +97,6 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(testErr))
})
It("errors if the header stream has the wrong stream ID", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError("h2quic Client BUG: StreamID of Header Stream is not 3"))
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil)

View file

@ -122,10 +122,6 @@ func (s *Server) handleHeaderStream(session streamCreator) {
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
return
}
if stream.StreamID() != 3 {
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)

View file

@ -308,19 +308,6 @@ var _ = Describe("H2 server", func() {
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
})
It("errors if the accepted header stream has the wrong stream ID", func() {
headerStream := &mockStream{id: 1}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return session.closed }).Should(BeTrue())
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3")))
})
It("supports closing after first request", func() {
s.CloseAfterFirstRequest = true
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

View file

@ -45,7 +45,7 @@ func (vn VersionNumber) String() string {
case VersionTLS:
return "TLS dev version (WIP)"
default:
if vn > gquicVersion0 && vn <= maxGquicVersion {
if vn.isGQUIC() {
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
}
return fmt.Sprintf("%d", vn)
@ -54,12 +54,24 @@ func (vn VersionNumber) String() string {
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
func (vn VersionNumber) ToAltSvc() string {
if vn > gquicVersion0 && vn <= maxGquicVersion {
if vn.isGQUIC() {
return fmt.Sprintf("%d", vn.toGQUICVersion())
}
return fmt.Sprintf("%d", vn)
}
// CryptoStreamID gets the Stream ID of the crypto stream
func (vn VersionNumber) CryptoStreamID() StreamID {
if vn.isGQUIC() {
return 1
}
return 0
}
func (vn VersionNumber) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}
func (vn VersionNumber) toGQUICVersion() int {
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
}

View file

@ -43,7 +43,13 @@ var _ = Describe("Version", func() {
Expect(VersionNumber(0x51303133).ToAltSvc()).To(Equal("13"))
Expect(VersionNumber(0x51303235).ToAltSvc()).To(Equal("25"))
Expect(VersionNumber(0x51303438).ToAltSvc()).To(Equal("48"))
})
It("tells the Stream ID of the crypto stream", func() {
Expect(Version37.CryptoStreamID()).To(Equal(StreamID(1)))
Expect(Version38.CryptoStreamID()).To(Equal(StreamID(1)))
Expect(Version39.CryptoStreamID()).To(Equal(StreamID(1)))
Expect(VersionTLS.CryptoStreamID()).To(Equal(StreamID(0)))
})
It("recognizes supported versions", func() {

View file

@ -63,7 +63,7 @@ var _ = Describe("Packet packer", func() {
BeforeEach(func() {
cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)}
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer)
streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever)
streamFramer = newStreamFramer(cryptoStream, streamsMap, nil)
packer = &packetPacker{
@ -574,7 +574,10 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}))
})
It("sends encrypted stream data on the crypto stream", func() {
@ -584,7 +587,10 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}))
Expect(p.frames[0]).To(Equal(&wire.StreamFrame{
StreamID: packer.version.CryptoStreamID(),
Data: []byte("foobar"),
}))
})
It("does not pack stream frames if not allowed", func() {

View file

@ -55,7 +55,7 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by
err = qerr.Error(qerr.InvalidStreamData, err.Error())
} else {
streamID := frame.(*wire.StreamFrame).StreamID
if streamID != 1 && encryptionLevel <= protocol.EncryptionUnencrypted {
if streamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted {
err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID))
}
}

View file

@ -224,10 +224,10 @@ var _ = Describe("Packet unpacker", func() {
})
Context("unpacking STREAM frames", func() {
It("unpacks unencrypted STREAM frames on stream 1", func() {
It("unpacks unencrypted STREAM frames on the crypto stream", func() {
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted
f := &wire.StreamFrame{
StreamID: 1,
StreamID: unpacker.version.CryptoStreamID(),
Data: []byte("foobar"),
}
err := f.Write(buf, 0)
@ -238,10 +238,10 @@ var _ = Describe("Packet unpacker", func() {
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks encrypted STREAM frames on stream 1", func() {
It("unpacks encrypted STREAM frames on the crypto stream", func() {
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionSecure
f := &wire.StreamFrame{
StreamID: 1,
StreamID: unpacker.version.CryptoStreamID(),
Data: []byte("foobar"),
}
err := f.Write(buf, 0)

View file

@ -200,8 +200,8 @@ func (s *session) setup(
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
s.cryptoStream = s.newStream(1)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
s.cryptoStream = s.newStream(s.version.CryptoStreamID())
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
var err error
@ -527,7 +527,7 @@ func (s *session) handlePacket(p *receivedPacket) {
}
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
if frame.StreamID == 1 {
if frame.StreamID == s.version.CryptoStreamID() {
return s.cryptoStream.AddStreamFrame(frame)
}
str, err := s.streamsMap.GetOrOpenStream(frame.StreamID)
@ -820,7 +820,7 @@ func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.By
func (s *session) newStream(id protocol.StreamID) streamI {
// TODO: find a better solution for determining which streams contribute to connection level flow control
var contributesToConnection bool
if id != 1 && id != 3 {
if id != 0 && id != 1 && id != 3 {
contributesToConnection = true
}
var initialSendWindow protocol.ByteCount
@ -836,7 +836,7 @@ func (s *session) newStream(id protocol.StreamID) streamI {
initialSendWindow,
s.rttStats,
)
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController)
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version)
}
func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {

View file

@ -75,6 +75,7 @@ type stream struct {
writeDeadline time.Time
flowController flowcontrol.StreamFlowController
version protocol.VersionNumber
}
var _ Stream = &stream{}
@ -93,6 +94,7 @@ func newStream(StreamID protocol.StreamID,
onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
flowController flowcontrol.StreamFlowController,
version protocol.VersionNumber,
) *stream {
s := &stream{
onData: onData,
@ -102,6 +104,7 @@ func newStream(StreamID protocol.StreamID,
frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
version: version,
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
@ -274,7 +277,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte {
}
// TODO(#657): Flow control for the crypto stream
if s.streamID != 1 {
if s.streamID != s.version.CryptoStreamID() {
maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize())
}
if maxBytes == 0 {

View file

@ -60,7 +60,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str
return nil
}
frame := &wire.StreamFrame{
StreamID: 1,
StreamID: f.cryptoStream.StreamID(),
Offset: f.cryptoStream.GetWriteOffset(),
}
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error

View file

@ -41,7 +41,7 @@ var _ = Describe("Stream Framer", func() {
stream2 = mocks.NewMockStreamI(mockCtrl)
stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes()
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer)
streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever)
streamsMap.putStream(stream1)
streamsMap.putStream(stream2)

View file

@ -59,7 +59,7 @@ var _ = Describe("Stream", func() {
onDataCalled = false
resetCalled = false
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newStream(streamID, onData, onReset, mockFC)
str = newStream(streamID, onData, onReset, mockFC, protocol.VersionWhatever)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = struct {

View file

@ -41,7 +41,7 @@ type newStreamLambda func(protocol.StreamID) streamI
var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *streamsMap {
func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap {
// add some tolerance to the maximum incoming streams value
maxStreams := uint32(protocol.MaxIncomingStreams)
maxIncomingStreams := utils.MaxUint32(
@ -58,12 +58,16 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *stream
sm.nextStreamOrErrCond.L = &sm.mutex
sm.openStreamOrErrCond.L = &sm.mutex
nextOddStream := protocol.StreamID(1)
if ver.CryptoStreamID() == protocol.StreamID(1) {
nextOddStream = 3
}
if pers == protocol.PerspectiveClient {
sm.nextStream = 3
sm.nextStream = nextOddStream
sm.nextStreamToAccept = 2
} else {
sm.nextStream = 2
sm.nextStreamToAccept = 3
sm.nextStreamToAccept = nextOddStream
}
return &sm

View file

@ -14,6 +14,11 @@ import (
)
var _ = Describe("Streams Map", func() {
const (
versionCryptoStream1 = protocol.Version39
versionCryptoStream0 = protocol.VersionTLS
)
var (
m *streamsMap
finishedStreams map[protocol.StreamID]*gomock.Call
@ -27,11 +32,13 @@ var _ = Describe("Streams Map", func() {
return str
}
setNewStreamsMap := func(p protocol.Perspective) {
m = newStreamsMap(newStream, p)
setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) {
m = newStreamsMap(newStream, p, v)
}
BeforeEach(func() {
Expect(versionCryptoStream0.CryptoStreamID()).To(Equal(protocol.StreamID(0)))
Expect(versionCryptoStream1.CryptoStreamID()).To(Equal(protocol.StreamID(1)))
finishedStreams = make(map[protocol.StreamID]*gomock.Call)
})
@ -50,7 +57,7 @@ var _ = Describe("Streams Map", func() {
Context("getting and creating streams", func() {
Context("as a server", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer)
setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream1)
})
Context("client-side streams", func() {
@ -280,7 +287,22 @@ var _ = Describe("Streams Map", func() {
Consistently(func() bool { return accepted }).Should(BeFalse())
})
It("start with stream 3", func() {
It("starts with stream 1, if the crypto stream is stream 0", func() {
setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream0)
var str streamI
go func() {
defer GinkgoRecover()
var err error
str, err = m.AcceptStream()
Expect(err).ToNot(HaveOccurred())
}()
_, err := m.GetOrOpenStream(1)
Expect(err).ToNot(HaveOccurred())
Eventually(func() Stream { return str }).ShouldNot(BeNil())
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
})
It("starts with stream 3, if the crypto stream is stream 1", func() {
var str streamI
go func() {
defer GinkgoRecover()
@ -399,7 +421,7 @@ var _ = Describe("Streams Map", func() {
Context("as a client", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveClient)
setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream1)
m.UpdateMaxStreamLimit(100)
})
@ -445,7 +467,18 @@ var _ = Describe("Streams Map", func() {
})
Context("server-side streams", func() {
It("starts with stream 3", func() {
It("starts with stream 1, if the crypto stream is stream 0", func() {
setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream0)
m.UpdateMaxStreamLimit(100)
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Expect(s.StreamID()).To(BeEquivalentTo(1))
Expect(m.numOutgoingStreams).To(BeZero())
Expect(m.numIncomingStreams).To(BeEquivalentTo(1))
})
It("starts with stream 3, if the crypto stream is stream 1", func() {
s, err := m.OpenStream()
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
@ -493,7 +526,7 @@ var _ = Describe("Streams Map", func() {
Context("DoS mitigation, iterating and deleting", func() {
BeforeEach(func() {
setNewStreamsMap(protocol.PerspectiveServer)
setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream1)
})
closeStream := func(id protocol.StreamID) {