qtls: only attempt 0-RTT resumption for 0-RTT enabled session tickets (#4183)

This commit is contained in:
Marten Seemann 2023-12-09 19:47:47 +05:30 committed by GitHub
parent 38eafe4ad8
commit d234d62d52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 135 additions and 25 deletions

View file

@ -641,6 +641,49 @@ var _ = Describe("0-RTT", func() {
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("doesn't use 0-RTT, if the server didn't enable it", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn2, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
conn2.CloseWithError(0, "")
})
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
counter, tracer := newPacketTracer()

View file

@ -692,6 +692,49 @@ var _ = Describe("0-RTT", func() {
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("doesn't use 0-RTT, if the server didn't enable it", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn2, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse())
conn2.CloseWithError(0, "")
})
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
counter, tracer := newPacketTracer()

View file

@ -25,7 +25,7 @@ type quicVersionContextKey struct{}
var QUICVersionContextKey = &quicVersionContextKey{}
const clientSessionStateRevision = 3
const clientSessionStateRevision = 4
type cryptoSetup struct {
tlsConf *tls.Config
@ -313,19 +313,24 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) error {
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState() []byte {
func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte {
b := make([]byte, 0, 256)
b = quicvarint.Append(b, clientSessionStateRevision)
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
return h.peerParams.MarshalForSessionTicket(b)
if earlyData {
// only save the transport parameters for 0-RTT enabled session tickets
return h.peerParams.MarshalForSessionTicket(b)
}
return b
}
func (h *cryptoSetup) handleDataFromSessionState(data []byte) (allowEarlyData bool) {
tp, err := h.handleDataFromSessionStateImpl(data)
func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) {
rtt, tp, err := decodeDataFromSessionState(data, earlyData)
if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
}
h.rttStats.SetInitialRTT(rtt)
// The session ticket might have been saved from a connection that allowed 0-RTT,
// and therefore contain transport parameters.
// Only use them if 0-RTT is actually used on the new connection.
@ -336,25 +341,28 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte) (allowEarlyData bo
return false
}
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
if err != nil {
return nil, err
return 0, nil, err
}
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rtt, err := quicvarint.Read(r)
rttEncoded, err := quicvarint.Read(r)
if err != nil {
return nil, err
return 0, nil, err
}
rtt := time.Duration(rttEncoded) * time.Microsecond
if !earlyData {
return rtt, nil, nil
}
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return nil, err
return 0, nil, err
}
return &tp, nil
return rtt, &tp, nil
}
func (h *cryptoSetup) getDataForSessionTicket() []byte {

View file

@ -7,8 +7,8 @@ import (
)
type clientSessionCache struct {
getData func() []byte
setData func([]byte) (allowEarlyData bool)
getData func(earlyData bool) []byte
setData func(data []byte, earlyData bool) (allowEarlyData bool)
wrapped tls.ClientSessionCache
}
@ -24,7 +24,7 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.wrapped.Put(key, cs)
return
}
state.Extra = append(state.Extra, addExtraPrefix(c.getData()))
state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData)))
newCS, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error. Just save the original state.
@ -46,12 +46,13 @@ func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
c.wrapped.Put(key, nil)
return nil, false
}
var earlyData bool
// restore QUIC transport parameters and RTT stored in state.Extra
if extra := findExtraData(state.Extra); extra != nil {
earlyData = c.setData(extra)
earlyData := c.setData(extra, state.EarlyData)
if state.EarlyData {
state.EarlyData = earlyData
}
}
state.EarlyData = earlyData
session, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error.

View file

@ -40,8 +40,9 @@ var _ = Describe("Client Session Cache", func() {
RootCAs: testdata.GetRootCA(),
ClientSessionCache: &clientSessionCache{
wrapped: tls.NewLRUClientSessionCache(10),
getData: func() []byte { return []byte("session") },
setData: func(data []byte) bool {
getData: func(bool) []byte { return []byte("session") },
setData: func(data []byte, earlyData bool) bool {
Expect(earlyData).To(BeFalse()) // running on top of TCP, we can only test non-0-RTT here
restored <- data
return true
},

View file

@ -52,10 +52,20 @@ func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTi
}
}
func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte) bool) {
func SetupConfigForClient(
conf *QUICConfig,
getDataForSessionState func(earlyData bool) []byte,
setDataFromSessionState func(data []byte, earlyData bool) (allowEarlyData bool),
) {
conf.ExtraConfig = &qtls.ExtraConfig{
GetAppDataForSessionState: getDataForSessionState,
SetAppDataFromSessionState: setDataFromSessionState,
GetAppDataForSessionState: func() []byte {
// qtls only calls the GetAppDataForSessionState when doing 0-RTT
return getDataForSessionState(true)
},
SetAppDataFromSessionState: func(data []byte) (allowEarlyData bool) {
// qtls only calls the SetAppDataFromSessionState for 0-RTT enabled tickets
return setDataFromSessionState(data, true)
},
}
}

View file

@ -93,7 +93,11 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, hand
}
}
func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte) bool) {
func SetupConfigForClient(
qconf *QUICConfig,
getData func(earlyData bool) []byte,
setData func(data []byte, earlyData bool) (allowEarlyData bool),
) {
conf := qconf.TLSConfig
if conf.ClientSessionCache != nil {
origCache := conf.ClientSessionCache