diff --git a/crypto_stream.go b/crypto_stream.go index fbd41d7e..9007a2b0 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "fmt" "io" @@ -13,6 +14,7 @@ type cryptoStream interface { // for receiving data HandleCryptoFrame(*wire.CryptoFrame) error GetCryptoData() []byte + Finish() error // for sending data io.Writer HasData() bool @@ -23,6 +25,9 @@ type cryptoStreamImpl struct { queue *frameSorter msgBuf []byte + highestOffset protocol.ByteCount + finished bool + writeOffset protocol.ByteCount writeBuf []byte } @@ -34,9 +39,20 @@ func newCryptoStream() cryptoStream { } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { - if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { + highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) + if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) } + if s.finished { + if highestOffset > s.highestOffset { + // reject crypto data received after this stream was already finished + return errors.New("received crypto data after change of encryption level") + } + // ignore data with a smaller offset than the highest received + // could e.g. be a retransmission + return nil + } + s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) if err := s.queue.Push(f.Data, f.Offset, false); err != nil { return err } @@ -64,6 +80,14 @@ func (s *cryptoStreamImpl) GetCryptoData() []byte { return msg } +func (s *cryptoStreamImpl) Finish() error { + if s.queue.HasMoreData() { + return errors.New("encryption level changed, but crypto stream has more data to read") + } + s.finished = true + return nil +} + // Writes writes data that should be sent out in CRYPTO frames func (s *cryptoStreamImpl) Write(p []byte) (int, error) { s.writeBuf = append(s.writeBuf, p...) diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 764bc2f2..0498b516 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleMessage([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) bool } type cryptoStreamManager struct { @@ -48,6 +48,8 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve if data == nil { return nil } - m.cryptoHandler.HandleMessage(data, encLevel) + if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { + return str.Finish() + } } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index 21785f5e..aada3197 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,6 +1,9 @@ package quic import ( + "errors" + + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -61,6 +64,29 @@ var _ = Describe("Crypto Stream Manager", func() { Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) + It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish(), + ) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + }) + + It("returns errors that occur when finishing a stream", func() { + testErr := errors.New("test error") + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish().Return(testErr), + ) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr)) + }) + It("errors for unknown encryption levels", func() { err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 98f35619..ed8e769e 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -89,6 +89,52 @@ var _ = Describe("Crypto Stream", func() { Expect(str.GetCryptoData()).To(Equal(msg)) Expect(str.GetCryptoData()).To(BeNil()) }) + + Context("finishing", func() { + It("errors if there's still data to read after finishing", func() { + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + Offset: 10, + })).To(Succeed()) + err := str.Finish() + Expect(err).To(MatchError("encryption level changed, but crypto stream has more data to read")) + }) + + It("works with reordered data", func() { + f1 := &wire.CryptoFrame{ + Data: []byte("foo"), + } + f2 := &wire.CryptoFrame{ + Offset: 3, + Data: []byte("bar"), + } + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + Expect(str.HandleCryptoFrame(f1)).To(Succeed()) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + }) + + It("rejects new crypto data after finishing", func() { + Expect(str.Finish()).To(Succeed()) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + }) + Expect(err).To(MatchError("received crypto data after change of encryption level")) + }) + + It("ignores crypto data below the maximum offset received before finishing", func() { + msg := createHandshakeMessage(15) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg, + })).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(len(msg) - 6), + Data: []byte("foobar"), + })).To(Succeed()) + }) + }) }) Context("writing data", func() { diff --git a/frame_sorter.go b/frame_sorter.go index 47062c06..e07dad47 100644 --- a/frame_sorter.go +++ b/frame_sorter.go @@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) { s.readPos += protocol.ByteCount(len(data)) return data, s.readPos >= s.finalOffset } + +// HasMoreData says if there is any more data queued at *any* offset. +func (s *frameSorter) HasMoreData() bool { + return len(s.queue) > 0 +} diff --git a/frame_sorter_test.go b/frame_sorter_test.go index 9def2100..433b43a2 100644 --- a/frame_sorter_test.go +++ b/frame_sorter_test.go @@ -55,6 +55,15 @@ var _ = Describe("STREAM frame sorter", func() { Expect(s.Pop()).To(BeNil()) }) + It("says if has more data", func() { + Expect(s.HasMoreData()).To(BeFalse()) + Expect(s.Push([]byte("foo"), 0, false)).To(Succeed()) + Expect(s.HasMoreData()).To(BeTrue()) + data, _ := s.Pop() + Expect(data).To(Equal([]byte("foo"))) + Expect(s.HasMoreData()).To(BeFalse()) + }) + Context("FIN handling", func() { It("saves a FIN at offset 0", func() { Expect(s.Push(nil, 0, true)).To(Succeed()) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index efe0ba3f..e8a06ea8 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -271,19 +271,20 @@ func (h *cryptoSetupTLS) RunHandshake() error { // handleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. -func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) { +// It returns if it is done with messages on the same encryption level. +func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { msgType := messageType(data[0]) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { h.messageErrChan <- err - return + return false } h.messageChan <- data switch h.perspective { case protocol.PerspectiveClient: - h.handleMessageForClient(msgType) + return h.handleMessageForClient(msgType) case protocol.PerspectiveServer: - h.handleMessageForServer(msgType) + return h.handleMessageForServer(msgType) default: panic("") } @@ -310,78 +311,81 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot return nil } -func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) { +func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { switch msgType { case typeClientHello: select { case params := <-h.receivedTransportParams: h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: - return + return false } // get the handshake write key select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the 1-RTT write key select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the handshake read key // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } h.handshakeEvent <- struct{}{} + return true case typeCertificate, typeCertificateVerify: // nothing to do + return false case typeFinished: // get the 1-RTT read key - // TODO: check that the handshake stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } h.handshakeEvent <- struct{}{} + return true default: panic("unexpected handshake message") } } -func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { +func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { switch msgType { case typeServerHello: // get the handshake read key - // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } h.handshakeEvent <- struct{}{} + return true case typeEncryptedExtensions: select { case params := <-h.receivedTransportParams: h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: - return + return false } + return false case typeCertificateRequest, typeCertificate, typeCertificateVerify: // nothing to do + return false case typeFinished: // get the handshake write key - // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // While the order of these two is not defined by the TLS spec, // we have to do it on the same order as our TLS library does it. @@ -389,16 +393,16 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the 1-RTT read key select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } - // TODO: check that the handshake stream doesn't have any more data h.handshakeEvent <- struct{}{} + return true default: panic("unexpected handshake message: ") } diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 264e9a65..88122dc1 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -44,7 +44,7 @@ type CryptoSetup interface { type CryptoSetupTLS interface { baseCryptoSetup - HandleMessage([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) bool OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go index 8c740066..37a40800 100644 --- a/mock_crypto_data_handler.go +++ b/mock_crypto_data_handler.go @@ -35,8 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { } // HandleMessage mocks base method -func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) { - m.ctrl.Call(m, "HandleMessage", arg0, arg1) +func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 } // HandleMessage indicates an expected call of HandleMessage diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 66de8d0f..0465e7e2 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -35,6 +35,18 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { return m.recorder } +// Finish mocks base method +func (m *MockCryptoStream) Finish() error { + ret := m.ctrl.Call(m, "Finish") + ret0, _ := ret[0].(error) + return ret0 +} + +// Finish indicates an expected call of Finish +func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) +} + // GetCryptoData mocks base method func (m *MockCryptoStream) GetCryptoData() []byte { ret := m.ctrl.Call(m, "GetCryptoData")