move the unencrypted STREAM frame check from the unpacker to the session

This commit is contained in:
Marten Seemann 2018-03-25 18:33:46 +02:00
parent 6f12844094
commit 2fbc994d29
6 changed files with 38 additions and 60 deletions

View file

@ -6,6 +6,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"github.com/bifurcation/mint"
@ -123,6 +124,9 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio
if frame == nil {
return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
}
if frame.StreamID != version.CryptoStreamID() {
return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID)
}
// We don't need a check for the stream ID here.
// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
if frame.Offset != 0 {

View file

@ -125,7 +125,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() {
}
p := packPacket([]wire.Frame{f})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42"))
Expect(err).To(MatchError("Received STREAM_FRAME for wrong stream (Stream ID 42)"))
})
It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() {

View file

@ -2,7 +2,6 @@ package quic
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -49,11 +48,6 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by
if frame == nil {
break
}
if sf, ok := frame.(*wire.StreamFrame); ok {
if sf.StreamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted {
return nil, qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", sf.StreamID))
}
}
fs = append(fs, frame)
}

View file

@ -71,51 +71,4 @@ var _ = Describe("Packet unpacker", func() {
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionSecure))
})
Context("unpacking STREAM frames", func() {
BeforeEach(func() {
unpacker.version = versionGQUICFrames
})
It("unpacks unencrypted STREAM frames on the crypto stream", func() {
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted
f := &wire.StreamFrame{
StreamID: versionGQUICFrames.CryptoStreamID(),
Data: []byte("foobar"),
}
err := f.Write(buf, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("unpacks encrypted STREAM frames on the crypto stream", func() {
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionSecure
f := &wire.StreamFrame{
StreamID: versionGQUICFrames.CryptoStreamID(),
Data: []byte("foobar"),
}
err := f.Write(buf, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
packet, err := unpacker.Unpack(hdrBin, hdr, data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
It("does not unpack unencrypted STREAM frames on higher streams", func() {
unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted
f := &wire.StreamFrame{
StreamID: 3,
Data: []byte("foobar"),
}
err := f.Write(buf, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())
setData(buf.Bytes())
_, err = unpacker.Unpack(hdrBin, hdr, data)
Expect(err).To(MatchError(qerr.Error(qerr.UnencryptedStreamData, "received unencrypted stream data on stream 3")))
})
})
})

View file

@ -572,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
wire.LogFrame(ff, false)
switch frame := ff.(type) {
case *wire.StreamFrame:
err = s.handleStreamFrame(frame)
err = s.handleStreamFrame(frame, encLevel)
case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel)
case *wire.ConnectionCloseFrame:
@ -615,12 +615,14 @@ func (s *session) handlePacket(p *receivedPacket) {
}
}
func (s *session) handleStreamFrame(frame *wire.StreamFrame) error {
func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {
if frame.StreamID == s.version.CryptoStreamID() {
if frame.FinBit {
return errors.New("Received STREAM frame with FIN bit for the crypto stream")
}
return s.cryptoStream.handleStreamFrame(frame)
} else if encLevel <= protocol.EncryptionUnencrypted {
return qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", frame.StreamID))
}
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {

View file

@ -203,7 +203,7 @@ var _ = Describe("Session", func() {
str := NewMockReceiveStreamI(mockCtrl)
str.EXPECT().handleStreamFrame(f)
streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil)
err := sess.handleStreamFrame(f)
err := sess.handleStreamFrame(f, protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred())
})
@ -216,7 +216,7 @@ var _ = Describe("Session", func() {
str := NewMockReceiveStreamI(mockCtrl)
str.EXPECT().handleStreamFrame(f).Return(testErr)
streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil)
err := sess.handleStreamFrame(f)
err := sess.handleStreamFrame(f, protocol.EncryptionForwardSecure)
Expect(err).To(MatchError(testErr))
})
@ -225,7 +225,7 @@ var _ = Describe("Session", func() {
err := sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
})
}, protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred())
})
@ -234,9 +234,34 @@ var _ = Describe("Session", func() {
StreamID: sess.version.CryptoStreamID(),
Offset: 0x1337,
FinBit: true,
})
}, protocol.EncryptionForwardSecure)
Expect(err).To(MatchError("Received STREAM frame with FIN bit for the crypto stream"))
})
It("accepts unencrypted STREAM frames on the crypto stream", func() {
f := &wire.StreamFrame{
StreamID: versionGQUICFrames.CryptoStreamID(),
Data: []byte("foobar"),
}
err := sess.handleStreamFrame(f, protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred())
})
It("unpacks encrypted STREAM frames on the crypto stream", func() {
err := sess.handleStreamFrame(&wire.StreamFrame{
StreamID: versionGQUICFrames.CryptoStreamID(),
Data: []byte("foobar"),
}, protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred())
})
It("does not unpack unencrypted STREAM frames on higher streams", func() {
err := sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 3,
Data: []byte("foobar"),
}, protocol.EncryptionUnencrypted)
Expect(err).To(MatchError(qerr.Error(qerr.UnencryptedStreamData, "received unencrypted stream data on stream 3")))
})
})
Context("handling ACK frames", func() {