mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
set the Long Header packet type based on the state of the handshake
This commit is contained in:
parent
a65929f6cf
commit
3f62ea8673
14 changed files with 205 additions and 22 deletions
|
@ -381,6 +381,10 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
|||
h.divNonceChan <- data
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
|
|
|
@ -458,6 +458,10 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
|||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||
if len(nonce) != 32 {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
||||
|
|
|
@ -19,13 +19,14 @@ type cryptoSetupTLS struct {
|
|||
|
||||
perspective protocol.Perspective
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
|
||||
tls mintTLS
|
||||
conn *fakeConn
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
nextPacketType protocol.PacketType
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
}
|
||||
|
@ -98,12 +99,13 @@ func NewCryptoSetupTLSClient(
|
|||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: &mintController{mintConn},
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: &mintController{mintConn},
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
nextPacketType: protocol.PacketTypeClientInitial,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -114,7 +116,10 @@ handshakeLoop:
|
|||
case mint.AlertNoAlert: // handshake complete
|
||||
break handshakeLoop
|
||||
case mint.AlertWouldBlock:
|
||||
h.conn.UnblockRead()
|
||||
h.determineNextPacketType()
|
||||
if err := h.conn.Continue(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
|
@ -184,6 +189,35 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
|
|||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) determineNextPacketType() error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
state := h.tls.State().HandshakeState
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
switch state {
|
||||
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
|
||||
h.nextPacketType = protocol.PacketTypeServerStatelessRetry
|
||||
case "ServerStateWaitFinished":
|
||||
h.nextPacketType = protocol.PacketTypeServerCleartext
|
||||
default:
|
||||
// TODO: accept 0-RTT data
|
||||
return fmt.Errorf("Unexpected handshake state: %s", state)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// client
|
||||
if state != "ClientStateWaitSH" {
|
||||
h.nextPacketType = protocol.PacketTypeClientCleartext
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.nextPacketType
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -53,8 +54,10 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
})
|
||||
|
||||
It("continues shaking hands when mint says that it would block", func() {
|
||||
cs.conn.stream = &bytes.Buffer{}
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{})
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
|
@ -71,6 +74,60 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
Expect(aeadChanged).To(BeClosed())
|
||||
})
|
||||
|
||||
Context("determining the packet type", func() {
|
||||
Context("for the client", func() {
|
||||
var csClient *cryptoSetupTLS
|
||||
|
||||
BeforeEach(func() {
|
||||
csInt, err := NewCryptoSetupTLSClient(
|
||||
nil,
|
||||
1,
|
||||
"quic.clemente.io",
|
||||
testdata.GetTLSConfig(),
|
||||
&TransportParameters{},
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
csClient = csInt.(*cryptoSetupTLS)
|
||||
csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
})
|
||||
|
||||
It("sends a Client Initial first", func() {
|
||||
Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeClientInitial))
|
||||
})
|
||||
|
||||
It("sends a Client Cleartext after the server sent a Server Hello", func() {
|
||||
csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"})
|
||||
err := csClient.determineNextPacketType()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("for the server", func() {
|
||||
BeforeEach(func() {
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
})
|
||||
|
||||
It("sends a Stateless Retry packet", func() {
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"})
|
||||
err := cs.determineNextPacketType()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeServerStatelessRetry))
|
||||
})
|
||||
|
||||
It("sends a Server Cleartext packet", func() {
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"})
|
||||
err := cs.determineNextPacketType()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeServerCleartext))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("escalating crypto", func() {
|
||||
doHandshake := func() {
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
|
|
|
@ -15,8 +15,9 @@ type CryptoSetup interface {
|
|||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
HandleCryptoStream() error
|
||||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
gocrypto "crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
@ -50,6 +51,7 @@ type mintTLS interface {
|
|||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.ConnectionState
|
||||
}
|
||||
|
||||
var _ crypto.TLSExporter = (mintTLS)(nil)
|
||||
|
@ -72,13 +74,18 @@ func (mc *mintController) Handshake() mint.Alert {
|
|||
return mc.conn.Handshake()
|
||||
}
|
||||
|
||||
func (mc *mintController) State() mint.ConnectionState {
|
||||
return mc.conn.State()
|
||||
}
|
||||
|
||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
||||
// so we wrap a stream such that implements a net.Conn
|
||||
type fakeConn struct {
|
||||
stream io.ReadWriter
|
||||
pers protocol.Perspective
|
||||
|
||||
blockRead bool
|
||||
blockRead bool
|
||||
writeBuffer bytes.Buffer
|
||||
}
|
||||
|
||||
var _ net.Conn = &fakeConn{}
|
||||
|
@ -92,11 +99,23 @@ func (c *fakeConn) Read(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
func (c *fakeConn) Write(p []byte) (int, error) {
|
||||
return c.stream.Write(p)
|
||||
if c.pers == protocol.PerspectiveClient {
|
||||
return c.stream.Write(p)
|
||||
}
|
||||
// Buffer all writes by the server.
|
||||
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
|
||||
return c.writeBuffer.Write(p)
|
||||
}
|
||||
|
||||
func (c *fakeConn) UnblockRead() {
|
||||
func (c *fakeConn) Continue() error {
|
||||
c.blockRead = false
|
||||
if c.pers == protocol.PerspectiveClient {
|
||||
return nil
|
||||
}
|
||||
// write all contents of the write buffer to the stream.
|
||||
_, err := c.stream.Write(c.writeBuffer.Bytes())
|
||||
c.writeBuffer.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
|
|
|
@ -3,6 +3,7 @@ package handshake
|
|||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -35,10 +36,30 @@ var _ = Describe("Fake Conn", func() {
|
|||
b := make([]byte, 3)
|
||||
_, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c.UnblockRead()
|
||||
err = c.Continue()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("bar")))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Writing", func() {
|
||||
It("writes directly when acting as a client", func() {
|
||||
c.pers = protocol.PerspectiveClient
|
||||
_, err := c.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("only writes after flushing when acting as a server", func() {
|
||||
c.pers = protocol.PerspectiveServer
|
||||
_, err := c.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Bytes()).To(BeEmpty())
|
||||
err = c.Continue()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -69,3 +69,15 @@ func (_m *MockmintTLS) Handshake() mint.Alert {
|
|||
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
|
||||
}
|
||||
|
||||
// State mocks base method
|
||||
func (_m *MockmintTLS) State() mint.ConnectionState {
|
||||
ret := _m.ctrl.Call(_m, "State")
|
||||
ret0, _ := ret[0].(mint.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// State indicates an expected call of State
|
||||
func (_mr *MockmintTLSMockRecorder) State() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State))
|
||||
}
|
||||
|
|
|
@ -21,6 +21,24 @@ const (
|
|||
PacketNumberLen6 PacketNumberLen = 6
|
||||
)
|
||||
|
||||
// The PacketType is the Long Header Type (only used for the IETF draft header format)
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
// PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet
|
||||
PacketTypeVersionNegotiation PacketType = 1
|
||||
// PacketTypeClientInitial is the packet type of a Client Initial packet
|
||||
PacketTypeClientInitial PacketType = 2
|
||||
// PacketTypeServerStatelessRetry is the packet type of a Server Stateless Retry packet
|
||||
PacketTypeServerStatelessRetry PacketType = 3
|
||||
// PacketTypeServerCleartext is the packet type of a Server Cleartext packet
|
||||
PacketTypeServerCleartext PacketType = 4
|
||||
// PacketTypeClientCleartext is the packet type of a Client Cleartext packet
|
||||
PacketTypeClientCleartext PacketType = 5
|
||||
// PacketType0RTT is the packet type of a 0-RTT packet
|
||||
PacketType0RTT PacketType = 6
|
||||
)
|
||||
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID uint64
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ type Header struct {
|
|||
DiversificationNonce []byte
|
||||
|
||||
// only needed for the IETF Header
|
||||
Type uint8
|
||||
Type protocol.PacketType
|
||||
IsLongHeader bool
|
||||
KeyPhase int
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ func parseLongHeader(b *bytes.Reader, packetSentBy protocol.Perspective, typeByt
|
|||
return nil, err
|
||||
}
|
||||
h := &Header{
|
||||
Type: typeByte & 0x7f,
|
||||
Type: protocol.PacketType(typeByte & 0x7f),
|
||||
IsLongHeader: true,
|
||||
ConnectionID: protocol.ConnectionID(connID),
|
||||
PacketNumber: protocol.PacketNumber(pn),
|
||||
|
|
|
@ -32,7 +32,7 @@ var _ = Describe("IETF draft Header", func() {
|
|||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(h.Type).To(BeEquivalentTo(3))
|
||||
Expect(h.Type).To(Equal(protocol.PacketType(3)))
|
||||
Expect(h.IsLongHeader).To(BeTrue())
|
||||
Expect(h.OmitConnectionID).To(BeFalse())
|
||||
Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337)))
|
||||
|
@ -62,6 +62,7 @@ var _ = Describe("IETF draft Header", func() {
|
|||
b := bytes.NewReader(data)
|
||||
h, err := parseHeader(b, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(h.Type).To(Equal(protocol.PacketTypeVersionNegotiation))
|
||||
Expect(h.SupportedVersions).To(Equal([]protocol.VersionNumber{
|
||||
0x22334455,
|
||||
0x33445566,
|
||||
|
|
|
@ -291,8 +291,11 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
|||
header.VersionFlag = true
|
||||
header.Version = p.version
|
||||
}
|
||||
} else if encLevel != protocol.EncryptionForwardSecure {
|
||||
header.Version = p.version
|
||||
} else {
|
||||
header.Type = p.cryptoSetup.GetNextPacketType()
|
||||
if encLevel != protocol.EncryptionForwardSecure {
|
||||
header.Version = p.version
|
||||
}
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type mockCryptoSetup struct {
|
|||
divNonce []byte
|
||||
encLevelSeal protocol.EncryptionLevel
|
||||
encLevelSealCrypto protocol.EncryptionLevel
|
||||
nextPacketType protocol.PacketType
|
||||
}
|
||||
|
||||
var _ handshake.CryptoSetup = &mockCryptoSetup{}
|
||||
|
@ -49,6 +50,7 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
|
|||
}
|
||||
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
|
||||
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
|
||||
func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType }
|
||||
|
||||
var _ = Describe("Packet packer", func() {
|
||||
var (
|
||||
|
@ -189,6 +191,13 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(h.Version).To(Equal(versionIETFHeader))
|
||||
})
|
||||
|
||||
It("sets the packet type based on the state of the handshake", func() {
|
||||
packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5
|
||||
h := packer.getHeader(protocol.EncryptionSecure)
|
||||
Expect(h.IsLongHeader).To(BeTrue())
|
||||
Expect(h.Type).To(Equal(protocol.PacketType(5)))
|
||||
})
|
||||
|
||||
It("uses the Short Header format for forward-secure packets", func() {
|
||||
h := packer.getHeader(protocol.EncryptionForwardSecure)
|
||||
Expect(h.IsLongHeader).To(BeFalse())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue