when the encryption level changes, reject data on that crypto stream

There are two checks that need to be performed:
1. the crypto stream must not have any more data queued for reading
2. when receiving CRYPTO frames for that crypto stream afterwards, they
must not exceed the highest offset received on that stream
This commit is contained in:
Marten Seemann 2018-10-20 11:22:05 +09:00
parent fe442e4d19
commit 387c28d707
10 changed files with 156 additions and 26 deletions

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
@ -13,6 +14,7 @@ type cryptoStream interface {
// for receiving data // for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte GetCryptoData() []byte
Finish() error
// for sending data // for sending data
io.Writer io.Writer
HasData() bool HasData() bool
@ -23,6 +25,9 @@ type cryptoStreamImpl struct {
queue *frameSorter queue *frameSorter
msgBuf []byte msgBuf []byte
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount writeOffset protocol.ByteCount
writeBuf []byte writeBuf []byte
} }
@ -34,9 +39,20 @@ func newCryptoStream() cryptoStream {
} }
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
} }
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return errors.New("received crypto data after change of encryption level")
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, false); err != nil { if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
return err return err
} }
@ -64,6 +80,14 @@ func (s *cryptoStreamImpl) GetCryptoData() []byte {
return msg return msg
} }
func (s *cryptoStreamImpl) Finish() error {
if s.queue.HasMoreData() {
return errors.New("encryption level changed, but crypto stream has more data to read")
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames // Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStreamImpl) Write(p []byte) (int, error) { func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...) s.writeBuf = append(s.writeBuf, p...)

View file

@ -8,7 +8,7 @@ import (
) )
type cryptoDataHandler interface { type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) HandleMessage([]byte, protocol.EncryptionLevel) bool
} }
type cryptoStreamManager struct { type cryptoStreamManager struct {
@ -48,6 +48,8 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
if data == nil { if data == nil {
return nil return nil
} }
m.cryptoHandler.HandleMessage(data, encLevel) if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return str.Finish()
}
} }
} }

View file

@ -1,6 +1,9 @@
package quic package quic
import ( import (
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -61,6 +64,29 @@ var _ = Describe("Crypto Stream Manager", func() {
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
}) })
It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
gomock.InOrder(
handshakeStream.EXPECT().HandleCryptoFrame(cf),
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish(),
)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed())
})
It("returns errors that occur when finishing a stream", func() {
testErr := errors.New("test error")
cf := &wire.CryptoFrame{Data: []byte("foobar")}
gomock.InOrder(
handshakeStream.EXPECT().HandleCryptoFrame(cf),
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish().Return(testErr),
)
Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr))
})
It("errors for unknown encryption levels", func() { It("errors for unknown encryption levels", func() {
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT"))

View file

@ -89,6 +89,52 @@ var _ = Describe("Crypto Stream", func() {
Expect(str.GetCryptoData()).To(Equal(msg)) Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil()) Expect(str.GetCryptoData()).To(BeNil())
}) })
Context("finishing", func() {
It("errors if there's still data to read after finishing", func() {
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: createHandshakeMessage(5),
Offset: 10,
})).To(Succeed())
err := str.Finish()
Expect(err).To(MatchError("encryption level changed, but crypto stream has more data to read"))
})
It("works with reordered data", func() {
f1 := &wire.CryptoFrame{
Data: []byte("foo"),
}
f2 := &wire.CryptoFrame{
Offset: 3,
Data: []byte("bar"),
}
Expect(str.HandleCryptoFrame(f2)).To(Succeed())
Expect(str.HandleCryptoFrame(f1)).To(Succeed())
Expect(str.Finish()).To(Succeed())
Expect(str.HandleCryptoFrame(f2)).To(Succeed())
})
It("rejects new crypto data after finishing", func() {
Expect(str.Finish()).To(Succeed())
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Data: createHandshakeMessage(5),
})
Expect(err).To(MatchError("received crypto data after change of encryption level"))
})
It("ignores crypto data below the maximum offset received before finishing", func() {
msg := createHandshakeMessage(15)
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg,
})).To(Succeed())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.Finish()).To(Succeed())
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: protocol.ByteCount(len(msg) - 6),
Data: []byte("foobar"),
})).To(Succeed())
})
})
}) })
Context("writing data", func() { Context("writing data", func() {

View file

@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
s.readPos += protocol.ByteCount(len(data)) s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset return data, s.readPos >= s.finalOffset
} }
// HasMoreData says if there is any more data queued at *any* offset.
func (s *frameSorter) HasMoreData() bool {
return len(s.queue) > 0
}

View file

@ -55,6 +55,15 @@ var _ = Describe("STREAM frame sorter", func() {
Expect(s.Pop()).To(BeNil()) Expect(s.Pop()).To(BeNil())
}) })
It("says if has more data", func() {
Expect(s.HasMoreData()).To(BeFalse())
Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
Expect(s.HasMoreData()).To(BeTrue())
data, _ := s.Pop()
Expect(data).To(Equal([]byte("foo")))
Expect(s.HasMoreData()).To(BeFalse())
})
Context("FIN handling", func() { Context("FIN handling", func() {
It("saves a FIN at offset 0", func() { It("saves a FIN at offset 0", func() {
Expect(s.Push(nil, 0, true)).To(Succeed()) Expect(s.Push(nil, 0, true)).To(Succeed())

View file

@ -271,19 +271,20 @@ func (h *cryptoSetupTLS) RunHandshake() error {
// handleMessage handles a TLS handshake message. // handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available. // It is called by the crypto streams when a new message is available.
func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) { // It returns if it is done with messages on the same encryption level.
func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0]) msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.messageErrChan <- err h.messageErrChan <- err
return return false
} }
h.messageChan <- data h.messageChan <- data
switch h.perspective { switch h.perspective {
case protocol.PerspectiveClient: case protocol.PerspectiveClient:
h.handleMessageForClient(msgType) return h.handleMessageForClient(msgType)
case protocol.PerspectiveServer: case protocol.PerspectiveServer:
h.handleMessageForServer(msgType) return h.handleMessageForServer(msgType)
default: default:
panic("") panic("")
} }
@ -310,78 +311,81 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
return nil return nil
} }
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) { func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
switch msgType { switch msgType {
case typeClientHello: case typeClientHello:
select { select {
case params := <-h.receivedTransportParams: case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params) h.handleParamsCallback(&params)
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
// get the handshake write key // get the handshake write key
select { select {
case <-h.receivedWriteKey: case <-h.receivedWriteKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
// get the 1-RTT write key // get the 1-RTT write key
select { select {
case <-h.receivedWriteKey: case <-h.receivedWriteKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
// get the handshake read key // get the handshake read key
// TODO: check that the initial stream doesn't have any more data // TODO: check that the initial stream doesn't have any more data
select { select {
case <-h.receivedReadKey: case <-h.receivedReadKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
return true
case typeCertificate, typeCertificateVerify: case typeCertificate, typeCertificateVerify:
// nothing to do // nothing to do
return false
case typeFinished: case typeFinished:
// get the 1-RTT read key // get the 1-RTT read key
// TODO: check that the handshake stream doesn't have any more data
select { select {
case <-h.receivedReadKey: case <-h.receivedReadKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
return true
default: default:
panic("unexpected handshake message") panic("unexpected handshake message")
} }
} }
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
switch msgType { switch msgType {
case typeServerHello: case typeServerHello:
// get the handshake read key // get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select { select {
case <-h.receivedReadKey: case <-h.receivedReadKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
return true
case typeEncryptedExtensions: case typeEncryptedExtensions:
select { select {
case params := <-h.receivedTransportParams: case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params) h.handleParamsCallback(&params)
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
return false
case typeCertificateRequest, typeCertificate, typeCertificateVerify: case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do // nothing to do
return false
case typeFinished: case typeFinished:
// get the handshake write key // get the handshake write key
// TODO: check that the initial stream doesn't have any more data
select { select {
case <-h.receivedWriteKey: case <-h.receivedWriteKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
// While the order of these two is not defined by the TLS spec, // While the order of these two is not defined by the TLS spec,
// we have to do it on the same order as our TLS library does it. // we have to do it on the same order as our TLS library does it.
@ -389,16 +393,16 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
select { select {
case <-h.receivedWriteKey: case <-h.receivedWriteKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
// get the 1-RTT read key // get the 1-RTT read key
select { select {
case <-h.receivedReadKey: case <-h.receivedReadKey:
case <-h.handshakeErrChan: case <-h.handshakeErrChan:
return return false
} }
// TODO: check that the handshake stream doesn't have any more data
h.handshakeEvent <- struct{}{} h.handshakeEvent <- struct{}{}
return true
default: default:
panic("unexpected handshake message: ") panic("unexpected handshake message: ")
} }

View file

@ -44,7 +44,7 @@ type CryptoSetup interface {
type CryptoSetupTLS interface { type CryptoSetupTLS interface {
baseCryptoSetup baseCryptoSetup
HandleMessage([]byte, protocol.EncryptionLevel) HandleMessage([]byte, protocol.EncryptionLevel) bool
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

View file

@ -35,8 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
} }
// HandleMessage mocks base method // HandleMessage mocks base method
func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) { func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
m.ctrl.Call(m, "HandleMessage", arg0, arg1) ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
} }
// HandleMessage indicates an expected call of HandleMessage // HandleMessage indicates an expected call of HandleMessage

View file

@ -35,6 +35,18 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder {
return m.recorder return m.recorder
} }
// Finish mocks base method
func (m *MockCryptoStream) Finish() error {
ret := m.ctrl.Call(m, "Finish")
ret0, _ := ret[0].(error)
return ret0
}
// Finish indicates an expected call of Finish
func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish))
}
// GetCryptoData mocks base method // GetCryptoData mocks base method
func (m *MockCryptoStream) GetCryptoData() []byte { func (m *MockCryptoStream) GetCryptoData() []byte {
ret := m.ctrl.Call(m, "GetCryptoData") ret := m.ctrl.Call(m, "GetCryptoData")