save the RTT in non-0-RTT session tickets (#4042)

* also send session ticket when 0-RTT is disabled for go1.21

* allow session ticket without transport parameters

* do not include transport parameters for non-0RTT session ticket

* remove the test assertion because it is not supported for go1.20

* Update internal/handshake/session_ticket.go

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

* add a 0-RTT argument to unmarshaling session tickets

* bump sessionTicketRevision to 4

* check if non-0-RTT session ticket has expected length

* change parameter order

* add test checks

---------

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
Ameagari 2023-09-11 23:05:31 +08:00 committed by GitHub
parent 1f25153884
commit d1f6ea997c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 42 deletions

View file

@ -127,7 +127,7 @@ func NewCryptoSetupServer(
cs.allow0RTT = allow0RTT
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig
@ -347,10 +347,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
}
func (h *cryptoSetup) getDataForSessionTicket() []byte {
return (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
}).Marshal()
ticket := &sessionTicket{
RTT: h.rttStats.SmoothedRTT(),
}
if h.allow0RTT {
ticket.Parameters = h.ourParams
}
return ticket.Marshal()
}
// GetSessionTicket generates a new session ticket.
@ -379,12 +382,16 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
return ticket, nil
}
// accept0RTT is called for the server when receiving the client's session ticket.
// It decides whether to accept 0-RTT.
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
// handleSessionTicket is called for the server when receiving the client's session ticket.
// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT.
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
var t sessionTicket
if err := t.Unmarshal(sessionTicketData); err != nil {
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error())
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
h.rttStats.SetInitialRTT(t.RTT)
if !using0RTT {
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
@ -397,7 +404,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
return false
}
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
return true
}

View file

@ -8,6 +8,8 @@ import (
"crypto/x509/pkix"
"math/big"
"net"
"runtime"
"strings"
"time"
mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
@ -417,11 +419,13 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &utils.RTTStats{},
clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
@ -434,9 +438,10 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &utils.RTTStats{}
serverRTTStats := &utils.RTTStats{}
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &utils.RTTStats{},
clientRTTStats, serverRTTStats,
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
false,
)
@ -446,6 +451,9 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
if !strings.Contains(runtime.Version(), "go1.20") {
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
}
})
It("doesn't use session resumption if the server disabled it", func() {

View file

@ -10,7 +10,7 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
const sessionTicketRevision = 3
const sessionTicketRevision = 4
type sessionTicket struct {
Parameters *wire.TransportParameters
@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte {
b := make([]byte, 0, 256)
b = quicvarint.Append(b, sessionTicketRevision)
b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
if t.Parameters == nil {
return b
}
return t.Parameters.MarshalForSessionTicket(b)
}
func (t *sessionTicket) Unmarshal(b []byte) error {
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
if err != nil {
@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error {
if err != nil {
return errors.New("failed to read RTT")
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
if using0RTT {
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
}
t.Parameters = &tp
t.RTT = time.Duration(rtt) * time.Microsecond
return nil
}

View file

@ -11,7 +11,7 @@ import (
)
var _ = Describe("Session Ticket", func() {
It("marshals and unmarshals a session ticket", func() {
It("marshals and unmarshals a 0-RTT session ticket", func() {
ticket := &sessionTicket{
Parameters: &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: 1,
@ -22,33 +22,65 @@ var _ = Describe("Session Ticket", func() {
RTT: 1337 * time.Microsecond,
}
var t sessionTicket
Expect(t.Unmarshal(ticket.Marshal())).To(Succeed())
Expect(t.Unmarshal(ticket.Marshal(), true)).To(Succeed())
Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1))
Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2))
Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10))
Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20))
Expect(t.RTT).To(Equal(1337 * time.Microsecond))
// fails to unmarshal the ticket as a non-0-RTT ticket
Expect(t.Unmarshal(ticket.Marshal(), false)).To(MatchError("the session ticket has more bytes than expected"))
})
It("marshals and unmarshals a non-0-RTT session ticket", func() {
ticket := &sessionTicket{
RTT: 1337 * time.Microsecond,
}
var t sessionTicket
Expect(t.Unmarshal(ticket.Marshal(), false)).To(Succeed())
Expect(t.Parameters).To(BeNil())
Expect(t.RTT).To(Equal(1337 * time.Microsecond))
// fails to unmarshal the ticket as a 0-RTT ticket
Expect(t.Unmarshal(ticket.Marshal(), true)).To(MatchError(ContainSubstring("unmarshaling transport parameters from session ticket failed")))
})
It("refuses to unmarshal if the ticket is too short for the revision", func() {
Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision"))
Expect((&sessionTicket{}).Unmarshal([]byte{}, true)).To(MatchError("failed to read session ticket revision"))
Expect((&sessionTicket{}).Unmarshal([]byte{}, false)).To(MatchError("failed to read session ticket revision"))
})
It("refuses to unmarshal if the revision doesn't match", func() {
b := quicvarint.Append(nil, 1337)
Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("unknown session ticket revision: 1337"))
Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("unknown session ticket revision: 1337"))
Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("unknown session ticket revision: 1337"))
})
It("refuses to unmarshal if the RTT cannot be read", func() {
b := quicvarint.Append(nil, sessionTicketRevision)
Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("failed to read RTT"))
Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("failed to read RTT"))
Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("failed to read RTT"))
})
It("refuses to unmarshal if unmarshaling the transport parameters fails", func() {
It("refuses to unmarshal a 0-RTT session ticket if unmarshaling the transport parameters fails", func() {
b := quicvarint.Append(nil, sessionTicketRevision)
b = append(b, []byte("foobar")...)
err := (&sessionTicket{}).Unmarshal(b)
err := (&sessionTicket{}).Unmarshal(b, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed"))
})
It("refuses to unmarshal if the non-0-RTT session ticket has more bytes than expected", func() {
ticket := &sessionTicket{
Parameters: &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: 1,
InitialMaxStreamDataBidiRemote: 2,
ActiveConnectionIDLimit: 10,
MaxDatagramFrameSize: 20,
},
RTT: 1234 * time.Microsecond,
}
err := (&sessionTicket{}).Unmarshal(ticket.Marshal(), false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("the session ticket has more bytes than expected"))
})
})

View file

@ -39,13 +39,15 @@ const (
QUICHandshakeDone = qtls.QUICHandshakeDone
)
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) {
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) {
qtls.InitSessionTicketKeys(conf.TLSConfig)
conf.TLSConfig = conf.TLSConfig.Clone()
conf.TLSConfig.MinVersion = tls.VersionTLS13
conf.ExtraConfig = &qtls.ExtraConfig{
Enable0RTT: enable0RTT,
Accept0RTT: accept0RTT,
Enable0RTT: enable0RTT,
Accept0RTT: func(data []byte) bool {
return handleSessionTicket(data, true)
},
GetAppDataForSessionTicket: getDataForSessionTicket,
}
}

View file

@ -41,7 +41,7 @@ const (
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) {
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig
// Workaround for https://github.com/golang/go/issues/60506.
@ -55,11 +55,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC transport parameters if this is a 0-RTT packet.
// TODO(#3853): also save the RTT for non-0-RTT tickets
if state.EarlyData {
state.Extra = append(state.Extra, addExtraPrefix(getData()))
}
// Add QUIC session ticket
state.Extra = append(state.Extra, addExtraPrefix(getData()))
if origWrapSession != nil {
return origWrapSession(cs, state)
}
@ -83,14 +81,14 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
if err != nil || state == nil {
return nil, err
}
if state.EarlyData {
extra := findExtraData(state.Extra)
if unwrapCount == 1 && extra != nil { // first session ticket
state.EarlyData = accept0RTT(extra)
} else { // subsequent session ticket, can't be used for 0-RTT
state.EarlyData = false
}
extra := findExtraData(state.Extra)
if extra != nil {
state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
} else {
state.EarlyData = false
}
return state, nil
}
}