save the RTT in non-0-RTT session tickets (#4042)

* also send session ticket when 0-RTT is disabled for go1.21

* allow session ticket without transport parameters

* do not include transport parameters for non-0RTT session ticket

* remove the test assertion because it is not supported for go1.20

* Update internal/handshake/session_ticket.go

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

* add a 0-RTT argument to unmarshaling session tickets

* bump sessionTicketRevision to 4

* check if non-0-RTT session ticket has expected length

* change parameter order

* add test checks

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Ameagari 2023-09-11 23:05:31 +08:00 committed by GitHub
parent 1f25153884
commit d1f6ea997c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 42 deletions

View file

@ -127,7 +127,7 @@ func NewCryptoSetupServer(
cs.allow0RTT = allow0RTT cs.allow0RTT = allow0RTT
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig cs.tlsConf = quicConf.TLSConfig
@ -347,10 +347,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
} }
func (h *cryptoSetup) getDataForSessionTicket() []byte { func (h *cryptoSetup) getDataForSessionTicket() []byte {
return (&sessionTicket{ ticket := &sessionTicket{
Parameters: h.ourParams, RTT: h.rttStats.SmoothedRTT(),
RTT: h.rttStats.SmoothedRTT(), }
}).Marshal() if h.allow0RTT {
ticket.Parameters = h.ourParams
}
return ticket.Marshal()
} }
// GetSessionTicket generates a new session ticket. // GetSessionTicket generates a new session ticket.
@ -379,12 +382,16 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
return ticket, nil return ticket, nil
} }
// accept0RTT is called for the server when receiving the client's session ticket. // handleSessionTicket is called for the server when receiving the client's session ticket.
// It decides whether to accept 0-RTT. // It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT.
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
var t sessionTicket var t sessionTicket
if err := t.Unmarshal(sessionTicketData); err != nil { if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
h.rttStats.SetInitialRTT(t.RTT)
if !using0RTT {
return false return false
} }
valid := h.ourParams.ValidFor0RTT(t.Parameters) valid := h.ourParams.ValidFor0RTT(t.Parameters)
@ -397,7 +404,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
return false return false
} }
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
return true return true
} }

View file

@ -8,6 +8,8 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"math/big" "math/big"
"net" "net"
"runtime"
"strings"
"time" "time"
mocktls "github.com/quic-go/quic-go/internal/mocks/tls" mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
@ -417,11 +419,13 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket) close(receivedSessionTicket)
}) })
clientConf.ClientSessionCache = csc clientConf.ClientSessionCache = csc
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientOrigRTTStats, &utils.RTTStats{}, clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false, false,
) )
@ -434,9 +438,10 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{} clientRTTStats := &utils.RTTStats{}
serverRTTStats := &utils.RTTStats{}
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
clientConf, serverConf, clientConf, serverConf,
clientRTTStats, &utils.RTTStats{}, clientRTTStats, serverRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false, false,
) )
@ -446,6 +451,9 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
if !strings.Contains(runtime.Version(), "go1.20") {
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
}
}) })
It("doesn't use session resumption if the server disabled it", func() { It("doesn't use session resumption if the server disabled it", func() {

View file

@ -10,7 +10,7 @@ import (
"github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/quicvarint"
) )
const sessionTicketRevision = 3 const sessionTicketRevision = 4
type sessionTicket struct { type sessionTicket struct {
Parameters *wire.TransportParameters Parameters *wire.TransportParameters
@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte {
b := make([]byte, 0, 256) b := make([]byte, 0, 256)
b = quicvarint.Append(b, sessionTicketRevision) b = quicvarint.Append(b, sessionTicketRevision)
b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
if t.Parameters == nil {
return b
}
return t.Parameters.MarshalForSessionTicket(b) return t.Parameters.MarshalForSessionTicket(b)
} }
func (t *sessionTicket) Unmarshal(b []byte) error { func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b) r := bytes.NewReader(b)
rev, err := quicvarint.Read(r) rev, err := quicvarint.Read(r)
if err != nil { if err != nil {
@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error {
if err != nil { if err != nil {
return errors.New("failed to read RTT") return errors.New("failed to read RTT")
} }
var tp wire.TransportParameters if using0RTT {
if err := tp.UnmarshalFromSessionTicket(r); err != nil { var tp wire.TransportParameters
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
} }
t.Parameters = &tp
t.RTT = time.Duration(rtt) * time.Microsecond t.RTT = time.Duration(rtt) * time.Microsecond
return nil return nil
} }

View file

@ -11,7 +11,7 @@ import (
) )
var _ = Describe("Session Ticket", func() { var _ = Describe("Session Ticket", func() {
It("marshals and unmarshals a session ticket", func() { It("marshals and unmarshals a 0-RTT session ticket", func() {
ticket := &sessionTicket{ ticket := &sessionTicket{
Parameters: &wire.TransportParameters{ Parameters: &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiLocal: 1,
@ -22,33 +22,65 @@ var _ = Describe("Session Ticket", func() {
RTT: 1337 * time.Microsecond, RTT: 1337 * time.Microsecond,
} }
var t sessionTicket var t sessionTicket
Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) Expect(t.Unmarshal(ticket.Marshal(), true)).To(Succeed())
Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1))
Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2))
Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10)) Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10))
Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20)) Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20))
Expect(t.RTT).To(Equal(1337 * time.Microsecond)) Expect(t.RTT).To(Equal(1337 * time.Microsecond))
// fails to unmarshal the ticket as a non-0-RTT ticket
Expect(t.Unmarshal(ticket.Marshal(), false)).To(MatchError("the session ticket has more bytes than expected"))
})
It("marshals and unmarshals a non-0-RTT session ticket", func() {
ticket := &sessionTicket{
RTT: 1337 * time.Microsecond,
}
var t sessionTicket
Expect(t.Unmarshal(ticket.Marshal(), false)).To(Succeed())
Expect(t.Parameters).To(BeNil())
Expect(t.RTT).To(Equal(1337 * time.Microsecond))
// fails to unmarshal the ticket as a 0-RTT ticket
Expect(t.Unmarshal(ticket.Marshal(), true)).To(MatchError(ContainSubstring("unmarshaling transport parameters from session ticket failed")))
}) })
It("refuses to unmarshal if the ticket is too short for the revision", func() { It("refuses to unmarshal if the ticket is too short for the revision", func() {
Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision")) Expect((&sessionTicket{}).Unmarshal([]byte{}, true)).To(MatchError("failed to read session ticket revision"))
Expect((&sessionTicket{}).Unmarshal([]byte{}, false)).To(MatchError("failed to read session ticket revision"))
}) })
It("refuses to unmarshal if the revision doesn't match", func() { It("refuses to unmarshal if the revision doesn't match", func() {
b := quicvarint.Append(nil, 1337) b := quicvarint.Append(nil, 1337)
Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("unknown session ticket revision: 1337")) Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("unknown session ticket revision: 1337"))
Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("unknown session ticket revision: 1337"))
}) })
It("refuses to unmarshal if the RTT cannot be read", func() { It("refuses to unmarshal if the RTT cannot be read", func() {
b := quicvarint.Append(nil, sessionTicketRevision) b := quicvarint.Append(nil, sessionTicketRevision)
Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("failed to read RTT")) Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("failed to read RTT"))
Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("failed to read RTT"))
}) })
It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { It("refuses to unmarshal a 0-RTT session ticket if unmarshaling the transport parameters fails", func() {
b := quicvarint.Append(nil, sessionTicketRevision) b := quicvarint.Append(nil, sessionTicketRevision)
b = append(b, []byte("foobar")...) b = append(b, []byte("foobar")...)
err := (&sessionTicket{}).Unmarshal(b) err := (&sessionTicket{}).Unmarshal(b, true)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed")) Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed"))
}) })
It("refuses to unmarshal if the non-0-RTT session ticket has more bytes than expected", func() {
ticket := &sessionTicket{
Parameters: &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: 1,
InitialMaxStreamDataBidiRemote: 2,
ActiveConnectionIDLimit: 10,
MaxDatagramFrameSize: 20,
},
RTT: 1234 * time.Microsecond,
}
err := (&sessionTicket{}).Unmarshal(ticket.Marshal(), false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("the session ticket has more bytes than expected"))
})
}) })

View file

@ -39,13 +39,15 @@ const (
QUICHandshakeDone = qtls.QUICHandshakeDone QUICHandshakeDone = qtls.QUICHandshakeDone
) )
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) { func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) {
qtls.InitSessionTicketKeys(conf.TLSConfig) qtls.InitSessionTicketKeys(conf.TLSConfig)
conf.TLSConfig = conf.TLSConfig.Clone() conf.TLSConfig = conf.TLSConfig.Clone()
conf.TLSConfig.MinVersion = tls.VersionTLS13 conf.TLSConfig.MinVersion = tls.VersionTLS13
conf.ExtraConfig = &qtls.ExtraConfig{ conf.ExtraConfig = &qtls.ExtraConfig{
Enable0RTT: enable0RTT, Enable0RTT: enable0RTT,
Accept0RTT: accept0RTT, Accept0RTT: func(data []byte) bool {
return handleSessionTicket(data, true)
},
GetAppDataForSessionTicket: getDataForSessionTicket, GetAppDataForSessionTicket: getDataForSessionTicket,
} }
} }

View file

@ -41,7 +41,7 @@ const (
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) } func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) } func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig conf := qconf.TLSConfig
// Workaround for https://github.com/golang/go/issues/60506. // Workaround for https://github.com/golang/go/issues/60506.
@ -55,11 +55,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
// add callbacks to save transport parameters into the session ticket // add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC transport parameters if this is a 0-RTT packet. // Add QUIC session ticket
// TODO(#3853): also save the RTT for non-0-RTT tickets state.Extra = append(state.Extra, addExtraPrefix(getData()))
if state.EarlyData {
state.Extra = append(state.Extra, addExtraPrefix(getData()))
}
if origWrapSession != nil { if origWrapSession != nil {
return origWrapSession(cs, state) return origWrapSession(cs, state)
} }
@ -83,14 +81,14 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
if err != nil || state == nil { if err != nil || state == nil {
return nil, err return nil, err
} }
if state.EarlyData {
extra := findExtraData(state.Extra) extra := findExtraData(state.Extra)
if unwrapCount == 1 && extra != nil { // first session ticket if extra != nil {
state.EarlyData = accept0RTT(extra) state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
} else { // subsequent session ticket, can't be used for 0-RTT } else {
state.EarlyData = false state.EarlyData = false
}
} }
return state, nil return state, nil
} }
} }