set the Long Header packet type based on the state of the handshake

This commit is contained in:
Marten Seemann 2017-10-27 08:39:06 +07:00
parent a65929f6cf
commit 3f62ea8673
14 changed files with 205 additions and 22 deletions

View file

@ -381,6 +381,10 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {

View file

@ -458,6 +458,10 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View file

@ -19,13 +19,14 @@ type cryptoSetupTLS struct {
perspective protocol.Perspective
keyDerivation KeyDerivationFunction
tls mintTLS
conn *fakeConn
nullAEAD crypto.AEAD
aead crypto.AEAD
nextPacketType protocol.PacketType
keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD
aead crypto.AEAD
aeadChanged chan<- protocol.EncryptionLevel
}
@ -98,12 +99,13 @@ func NewCryptoSetupTLSClient(
}
return &cryptoSetupTLS{
conn: conn,
perspective: protocol.PerspectiveClient,
tls: &mintController{mintConn},
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
conn: conn,
perspective: protocol.PerspectiveClient,
tls: &mintController{mintConn},
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
nextPacketType: protocol.PacketTypeClientInitial,
}, nil
}
@ -114,7 +116,10 @@ handshakeLoop:
case mint.AlertNoAlert: // handshake complete
break handshakeLoop
case mint.AlertWouldBlock:
h.conn.UnblockRead()
h.determineNextPacketType()
if err := h.conn.Continue(); err != nil {
return err
}
default:
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
@ -184,6 +189,35 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) determineNextPacketType() error {
h.mutex.Lock()
defer h.mutex.Unlock()
state := h.tls.State().HandshakeState
if h.perspective == protocol.PerspectiveServer {
switch state {
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
h.nextPacketType = protocol.PacketTypeServerStatelessRetry
case "ServerStateWaitFinished":
h.nextPacketType = protocol.PacketTypeServerCleartext
default:
// TODO: accept 0-RTT data
return fmt.Errorf("Unexpected handshake state: %s", state)
}
return nil
}
// client
if state != "ClientStateWaitSH" {
h.nextPacketType = protocol.PacketTypeClientCleartext
}
return nil
}
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.nextPacketType
}
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
panic("diversification nonce not needed for TLS")
}

View file

@ -1,6 +1,7 @@
package handshake
import (
"bytes"
"errors"
"fmt"
@ -53,8 +54,10 @@ var _ = Describe("TLS Crypto Setup", func() {
})
It("continues shaking hands when mint says that it would block", func() {
cs.conn.stream = &bytes.Buffer{}
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{})
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
@ -71,6 +74,60 @@ var _ = Describe("TLS Crypto Setup", func() {
Expect(aeadChanged).To(BeClosed())
})
Context("determining the packet type", func() {
Context("for the client", func() {
var csClient *cryptoSetupTLS
BeforeEach(func() {
csInt, err := NewCryptoSetupTLSClient(
nil,
1,
"quic.clemente.io",
testdata.GetTLSConfig(),
&TransportParameters{},
paramsChan,
aeadChanged,
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
csClient = csInt.(*cryptoSetupTLS)
csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl)
})
It("sends a Client Initial first", func() {
Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeClientInitial))
})
It("sends a Client Cleartext after the server sent a Server Hello", func() {
csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"})
err := csClient.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
})
})
Context("for the server", func() {
BeforeEach(func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
})
It("sends a Stateless Retry packet", func() {
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"})
err := cs.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeServerStatelessRetry))
})
It("sends a Server Cleartext packet", func() {
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"})
err := cs.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeServerCleartext))
})
})
})
Context("escalating crypto", func() {
doHandshake := func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)

View file

@ -15,8 +15,9 @@ type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)

View file

@ -1,6 +1,7 @@
package handshake
import (
"bytes"
gocrypto "crypto"
"crypto/tls"
"crypto/x509"
@ -50,6 +51,7 @@ type mintTLS interface {
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
// additional methods
Handshake() mint.Alert
State() mint.ConnectionState
}
var _ crypto.TLSExporter = (mintTLS)(nil)
@ -72,13 +74,18 @@ func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
func (mc *mintController) State() mint.ConnectionState {
return mc.conn.State()
}
// mint expects a net.Conn, but we're doing the handshake on a stream
// so we wrap a stream such that implements a net.Conn
type fakeConn struct {
stream io.ReadWriter
pers protocol.Perspective
blockRead bool
blockRead bool
writeBuffer bytes.Buffer
}
var _ net.Conn = &fakeConn{}
@ -92,11 +99,23 @@ func (c *fakeConn) Read(b []byte) (int, error) {
}
func (c *fakeConn) Write(p []byte) (int, error) {
return c.stream.Write(p)
if c.pers == protocol.PerspectiveClient {
return c.stream.Write(p)
}
// Buffer all writes by the server.
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
return c.writeBuffer.Write(p)
}
func (c *fakeConn) UnblockRead() {
func (c *fakeConn) Continue() error {
c.blockRead = false
if c.pers == protocol.PerspectiveClient {
return nil
}
// write all contents of the write buffer to the stream.
_, err := c.stream.Write(c.writeBuffer.Bytes())
c.writeBuffer.Reset()
return err
}
func (c *fakeConn) Close() error { return nil }

View file

@ -3,6 +3,7 @@ package handshake
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@ -35,10 +36,30 @@ var _ = Describe("Fake Conn", func() {
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
c.UnblockRead()
err = c.Continue()
Expect(err).ToNot(HaveOccurred())
_, err = c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("bar")))
})
})
Context("Writing", func() {
It("writes directly when acting as a client", func() {
c.pers = protocol.PerspectiveClient
_, err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
It("only writes after flushing when acting as a server", func() {
c.pers = protocol.PerspectiveServer
_, err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(BeEmpty())
err = c.Continue()
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
})
})

View file

@ -69,3 +69,15 @@ func (_m *MockmintTLS) Handshake() mint.Alert {
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
}
// State mocks base method
func (_m *MockmintTLS) State() mint.ConnectionState {
ret := _m.ctrl.Call(_m, "State")
ret0, _ := ret[0].(mint.ConnectionState)
return ret0
}
// State indicates an expected call of State
func (_mr *MockmintTLSMockRecorder) State() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State))
}

View file

@ -21,6 +21,24 @@ const (
PacketNumberLen6 PacketNumberLen = 6
)
// The PacketType is the Long Header Type (only used for the IETF draft header format)
type PacketType uint8
const (
// PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet
PacketTypeVersionNegotiation PacketType = 1
// PacketTypeClientInitial is the packet type of a Client Initial packet
PacketTypeClientInitial PacketType = 2
// PacketTypeServerStatelessRetry is the packet type of a Server Stateless Retry packet
PacketTypeServerStatelessRetry PacketType = 3
// PacketTypeServerCleartext is the packet type of a Server Cleartext packet
PacketTypeServerCleartext PacketType = 4
// PacketTypeClientCleartext is the packet type of a Client Cleartext packet
PacketTypeClientCleartext PacketType = 5
// PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT PacketType = 6
)
// A ConnectionID in QUIC
type ConnectionID uint64

View file

@ -24,7 +24,7 @@ type Header struct {
DiversificationNonce []byte
// only needed for the IETF Header
Type uint8
Type protocol.PacketType
IsLongHeader bool
KeyPhase int

View file

@ -35,7 +35,7 @@ func parseLongHeader(b *bytes.Reader, packetSentBy protocol.Perspective, typeByt
return nil, err
}
h := &Header{
Type: typeByte & 0x7f,
Type: protocol.PacketType(typeByte & 0x7f),
IsLongHeader: true,
ConnectionID: protocol.ConnectionID(connID),
PacketNumber: protocol.PacketNumber(pn),

View file

@ -32,7 +32,7 @@ var _ = Describe("IETF draft Header", func() {
b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
Expect(h.Type).To(BeEquivalentTo(3))
Expect(h.Type).To(Equal(protocol.PacketType(3)))
Expect(h.IsLongHeader).To(BeTrue())
Expect(h.OmitConnectionID).To(BeFalse())
Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337)))
@ -62,6 +62,7 @@ var _ = Describe("IETF draft Header", func() {
b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
Expect(h.Type).To(Equal(protocol.PacketTypeVersionNegotiation))
Expect(h.SupportedVersions).To(Equal([]protocol.VersionNumber{
0x22334455,
0x33445566,

View file

@ -291,8 +291,11 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
header.VersionFlag = true
header.Version = p.version
}
} else if encLevel != protocol.EncryptionForwardSecure {
header.Version = p.version
} else {
header.Type = p.cryptoSetup.GetNextPacketType()
if encLevel != protocol.EncryptionForwardSecure {
header.Version = p.version
}
}
return header
}

View file

@ -28,6 +28,7 @@ type mockCryptoSetup struct {
divNonce []byte
encLevelSeal protocol.EncryptionLevel
encLevelSealCrypto protocol.EncryptionLevel
nextPacketType protocol.PacketType
}
var _ handshake.CryptoSetup = &mockCryptoSetup{}
@ -49,6 +50,7 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
}
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType }
var _ = Describe("Packet packer", func() {
var (
@ -189,6 +191,13 @@ var _ = Describe("Packet packer", func() {
Expect(h.Version).To(Equal(versionIETFHeader))
})
It("sets the packet type based on the state of the handshake", func() {
packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5
h := packer.getHeader(protocol.EncryptionSecure)
Expect(h.IsLongHeader).To(BeTrue())
Expect(h.Type).To(Equal(protocol.PacketType(5)))
})
It("uses the Short Header format for forward-secure packets", func() {
h := packer.getHeader(protocol.EncryptionForwardSecure)
Expect(h.IsLongHeader).To(BeFalse())