move TLS message header parsing logic to the crypto stream

This commit is contained in:
Marten Seemann 2018-10-20 10:11:25 +09:00
parent d1f49ad2d0
commit 19e5feef57
8 changed files with 117 additions and 91 deletions

View file

@ -20,7 +20,8 @@ type cryptoStream interface {
} }
type cryptoStreamImpl struct { type cryptoStreamImpl struct {
queue *frameSorter queue *frameSorter
msgBuf []byte
writeOffset protocol.ByteCount writeOffset protocol.ByteCount
writeBuf []byte writeBuf []byte
@ -36,13 +37,31 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset {
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
} }
return s.queue.Push(f.Data, f.Offset, false) if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
return err
}
for {
data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
} }
// GetCryptoData retrieves data that was received in CRYPTO frames // GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte { func (s *cryptoStreamImpl) GetCryptoData() []byte {
data, _ := s.queue.Pop() if len(s.msgBuf) < 4 {
return data return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
} }
// Writes writes data that should be sent out in CRYPTO frames // Writes writes data that should be sent out in CRYPTO frames

View file

@ -8,7 +8,7 @@ import (
) )
type cryptoDataHandler interface { type cryptoDataHandler interface {
HandleData([]byte, protocol.EncryptionLevel) HandleMessage([]byte, protocol.EncryptionLevel)
} }
type cryptoStreamManager struct { type cryptoStreamManager struct {
@ -48,6 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
if data == nil { if data == nil {
return nil return nil
} }
m.cryptoHandler.HandleData(data, encLevel) m.cryptoHandler.HandleMessage(data, encLevel)
} }
} }

View file

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -22,34 +21,35 @@ var _ = Describe("Crypto Stream Manager", func() {
csm = newCryptoStreamManager(cs, initialStream, handshakeStream) csm = newCryptoStreamManager(cs, initialStream, handshakeStream)
}) })
It("handles in in-order crypto frame", func() { It("passes messages to the right stream", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")} initialMsg := createHandshakeMessage(10)
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionInitial) handshakeMsg := createHandshakeMessage(20)
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionInitial)).To(Succeed())
// only pass in a part of the message, to make sure they get assembled in the right crypto stream
Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{
Data: initialMsg[:5],
}, protocol.EncryptionInitial)).To(Succeed())
Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{
Data: handshakeMsg[:5],
}, protocol.EncryptionHandshake)).To(Succeed())
// now pass in the rest of the initial message
cs.EXPECT().HandleMessage(initialMsg, protocol.EncryptionInitial)
Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{
Data: initialMsg[5:],
Offset: 5,
}, protocol.EncryptionInitial)).To(Succeed())
// now pass in the rest of the handshake message
cs.EXPECT().HandleMessage(handshakeMsg, protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{
Data: handshakeMsg[5:],
Offset: 5,
}, protocol.EncryptionHandshake)).To(Succeed())
}) })
It("errors for unknown encryption levels", func() { It("errors for unknown encryption levels", func() {
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT"))
}) })
It("handles out-of-order crypto frames", func() {
f1 := &wire.CryptoFrame{Data: []byte("foo")}
f2 := &wire.CryptoFrame{
Offset: 3,
Data: []byte("bar"),
}
gomock.InOrder(
cs.EXPECT().HandleData([]byte("foo"), protocol.EncryptionInitial),
cs.EXPECT().HandleData([]byte("bar"), protocol.EncryptionInitial),
)
Expect(csm.HandleCryptoFrame(f1, protocol.EncryptionInitial)).To(Succeed())
Expect(csm.HandleCryptoFrame(f2, protocol.EncryptionInitial)).To(Succeed())
})
It("handles handshake data", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")}
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
})
}) })

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"crypto/rand"
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -10,6 +11,16 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
func createHandshakeMessage(len int) []byte {
msg := make([]byte, 4+len)
rand.Read(msg[:1]) // random message type
msg[1] = uint8(len >> 16)
msg[2] = uint8(len >> 8)
msg[3] = uint8(len)
rand.Read(msg[4:])
return msg
}
var _ = Describe("Crypto Stream", func() { var _ = Describe("Crypto Stream", func() {
var ( var (
str cryptoStream str cryptoStream
@ -21,11 +32,21 @@ var _ = Describe("Crypto Stream", func() {
Context("handling incoming data", func() { Context("handling incoming data", func() {
It("handles in-order CRYPTO frames", func() { It("handles in-order CRYPTO frames", func() {
err := str.HandleCryptoFrame(&wire.CryptoFrame{ msg := createHandshakeMessage(6)
Data: []byte("foobar"), err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal([]byte("foobar"))) Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil())
})
It("handles multiple messages in one CRYPTO frame", func() {
msg1 := createHandshakeMessage(6)
msg2 := createHandshakeMessage(10)
msg := append(append([]byte{}, msg1...), msg2...)
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg1))
Expect(str.GetCryptoData()).To(Equal(msg2))
Expect(str.GetCryptoData()).To(BeNil()) Expect(str.GetCryptoData()).To(BeNil())
}) })
@ -37,19 +58,35 @@ var _ = Describe("Crypto Stream", func() {
Expect(err).To(MatchError(fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset))) Expect(err).To(MatchError(fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset)))
}) })
It("handles out-of-order CRYPTO frames", func() { It("handles messages split over multiple CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{ err := str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: 3, Data: msg[:4],
Data: []byte("bar"),
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(BeNil()) Expect(str.GetCryptoData()).To(BeNil())
err = str.HandleCryptoFrame(&wire.CryptoFrame{ err = str.HandleCryptoFrame(&wire.CryptoFrame{
Data: []byte("foo"), Offset: 4,
Data: msg[4:],
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal([]byte("foo"))) Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(Equal([]byte("bar"))) Expect(str.GetCryptoData()).To(BeNil())
})
It("handles out-of-order CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: 4,
Data: msg[4:],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(BeNil())
err = str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg[:4],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil()) Expect(str.GetCryptoData()).To(BeNil())
}) })
}) })

View file

@ -1,7 +1,6 @@
package handshake package handshake
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -74,14 +73,12 @@ type cryptoSetupTLS struct {
clientHelloWritten bool clientHelloWritten bool
clientHelloWrittenChan chan struct{} clientHelloWrittenChan chan struct{}
initialReadBuf bytes.Buffer initialStream io.Writer
initialStream io.Writer initialAEAD crypto.AEAD
initialAEAD crypto.AEAD
handshakeReadBuf bytes.Buffer handshakeStream io.Writer
handshakeStream io.Writer handshakeOpener Opener
handshakeOpener Opener handshakeSealer Sealer
handshakeSealer Sealer
opener Opener opener Opener
sealer Sealer sealer Sealer
@ -272,40 +269,14 @@ func (h *cryptoSetupTLS) RunHandshake() error {
} }
} }
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) {
var buf *bytes.Buffer
switch encLevel {
case protocol.EncryptionInitial:
buf = &h.initialReadBuf
case protocol.EncryptionHandshake:
buf = &h.handshakeReadBuf
default:
h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
return
}
buf.Write(data)
for buf.Len() >= 4 {
b := buf.Bytes()
// read the TLS message length
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if buf.Len() < 4+length { // message not yet complete
return
}
msg := make([]byte, length+4)
buf.Read(msg)
if err := h.handleMessage(msg, encLevel); err != nil {
h.messageErrChan <- err
}
}
}
// handleMessage handles a TLS handshake message. // handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available. // It is called by the crypto streams when a new message is available.
func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) {
msgType := messageType(data[0]) msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
return err h.messageErrChan <- err
return
} }
h.messageChan <- data h.messageChan <- data
switch h.perspective { switch h.perspective {
@ -316,7 +287,6 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption
default: default:
panic("") panic("")
} }
return nil
} }
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {

View file

@ -83,7 +83,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}() }()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionInitial) server.HandleMessage(fakeCH, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -114,7 +114,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}() }()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -150,9 +150,9 @@ var _ = Describe("Crypto Setup TLS", func() {
for { for {
select { select {
case c := <-cChunkChan: case c := <-cChunkChan:
server.HandleData(c.data, c.encLevel) server.HandleMessage(c.data, c.encLevel)
case c := <-sChunkChan: case c := <-sChunkChan:
client.HandleData(c.data, c.encLevel) client.HandleMessage(c.data, c.encLevel)
case <-done: // handshake complete case <-done: // handshake complete
} }
} }
@ -264,7 +264,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(len(ch.data) - 4).To(Equal(length)) Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return // make the go routine return
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })

View file

@ -44,7 +44,7 @@ type CryptoSetup interface {
type CryptoSetupTLS interface { type CryptoSetupTLS interface {
baseCryptoSetup baseCryptoSetup
HandleData([]byte, protocol.EncryptionLevel) HandleMessage([]byte, protocol.EncryptionLevel)
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

View file

@ -34,12 +34,12 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
return m.recorder return m.recorder
} }
// HandleData mocks base method // HandleMessage mocks base method
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) { func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) {
m.ctrl.Call(m, "HandleData", arg0, arg1) m.ctrl.Call(m, "HandleMessage", arg0, arg1)
} }
// HandleData indicates an expected call of HandleData // HandleMessage indicates an expected call of HandleMessage
func (mr *MockCryptoDataHandlerMockRecorder) HandleData(arg0, arg1 interface{}) *gomock.Call { func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleData", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleData), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1)
} }