mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
save the RTT in non-0-RTT session tickets (#4042)
* also send session ticket when 0-RTT is disabled for go1.21 * allow session ticket without transport parameters * do not include transport parameters for non-0RTT session ticket * remove the test assertion because it is not supported for go1.20 * Update internal/handshake/session_ticket.go Co-authored-by: Marten Seemann <martenseemann@gmail.com> * add a 0-RTT argument to unmarshaling session tickets * bump sessionTicketRevision to 4 * check if non-0-RTT session ticket has expected length * change parameter order * add test checks --------- Co-authored-by: Marten Seemann <martenseemann@gmail.com>
This commit is contained in:
parent
1f25153884
commit
d1f6ea997c
6 changed files with 95 additions and 42 deletions
|
@ -127,7 +127,7 @@ func NewCryptoSetupServer(
|
||||||
cs.allow0RTT = allow0RTT
|
cs.allow0RTT = allow0RTT
|
||||||
|
|
||||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
||||||
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
|
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
|
||||||
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
|
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
|
||||||
|
|
||||||
cs.tlsConf = quicConf.TLSConfig
|
cs.tlsConf = quicConf.TLSConfig
|
||||||
|
@ -347,10 +347,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
||||||
return (&sessionTicket{
|
ticket := &sessionTicket{
|
||||||
Parameters: h.ourParams,
|
RTT: h.rttStats.SmoothedRTT(),
|
||||||
RTT: h.rttStats.SmoothedRTT(),
|
}
|
||||||
}).Marshal()
|
if h.allow0RTT {
|
||||||
|
ticket.Parameters = h.ourParams
|
||||||
|
}
|
||||||
|
return ticket.Marshal()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionTicket generates a new session ticket.
|
// GetSessionTicket generates a new session ticket.
|
||||||
|
@ -379,12 +382,16 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||||
return ticket, nil
|
return ticket, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// accept0RTT is called for the server when receiving the client's session ticket.
|
// handleSessionTicket is called for the server when receiving the client's session ticket.
|
||||||
// It decides whether to accept 0-RTT.
|
// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT.
|
||||||
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
|
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
|
||||||
var t sessionTicket
|
var t sessionTicket
|
||||||
if err := t.Unmarshal(sessionTicketData); err != nil {
|
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
|
||||||
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error())
|
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h.rttStats.SetInitialRTT(t.RTT)
|
||||||
|
if !using0RTT {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
valid := h.ourParams.ValidFor0RTT(t.Parameters)
|
valid := h.ourParams.ValidFor0RTT(t.Parameters)
|
||||||
|
@ -397,7 +404,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
|
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
|
||||||
h.rttStats.SetInitialRTT(t.RTT)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
|
mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
|
||||||
|
@ -417,11 +419,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
close(receivedSessionTicket)
|
close(receivedSessionTicket)
|
||||||
})
|
})
|
||||||
clientConf.ClientSessionCache = csc
|
clientConf.ClientSessionCache = csc
|
||||||
|
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
|
||||||
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
|
||||||
|
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
|
||||||
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
|
||||||
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
|
||||||
clientConf, serverConf,
|
clientConf, serverConf,
|
||||||
clientOrigRTTStats, &utils.RTTStats{},
|
clientOrigRTTStats, serverOrigRTTStats,
|
||||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
|
@ -434,9 +438,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
||||||
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
clientRTTStats := &utils.RTTStats{}
|
clientRTTStats := &utils.RTTStats{}
|
||||||
|
serverRTTStats := &utils.RTTStats{}
|
||||||
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
|
client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
|
||||||
clientConf, serverConf,
|
clientConf, serverConf,
|
||||||
clientRTTStats, &utils.RTTStats{},
|
clientRTTStats, serverRTTStats,
|
||||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
|
@ -446,6 +451,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
||||||
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
||||||
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
|
||||||
|
if !strings.Contains(runtime.Version(), "go1.20") {
|
||||||
|
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't use session resumption if the server disabled it", func() {
|
It("doesn't use session resumption if the server disabled it", func() {
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/quic-go/quic-go/quicvarint"
|
"github.com/quic-go/quic-go/quicvarint"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sessionTicketRevision = 3
|
const sessionTicketRevision = 4
|
||||||
|
|
||||||
type sessionTicket struct {
|
type sessionTicket struct {
|
||||||
Parameters *wire.TransportParameters
|
Parameters *wire.TransportParameters
|
||||||
|
@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte {
|
||||||
b := make([]byte, 0, 256)
|
b := make([]byte, 0, 256)
|
||||||
b = quicvarint.Append(b, sessionTicketRevision)
|
b = quicvarint.Append(b, sessionTicketRevision)
|
||||||
b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
|
b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
|
||||||
|
if t.Parameters == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
return t.Parameters.MarshalForSessionTicket(b)
|
return t.Parameters.MarshalForSessionTicket(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *sessionTicket) Unmarshal(b []byte) error {
|
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
|
||||||
r := bytes.NewReader(b)
|
r := bytes.NewReader(b)
|
||||||
rev, err := quicvarint.Read(r)
|
rev, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("failed to read RTT")
|
return errors.New("failed to read RTT")
|
||||||
}
|
}
|
||||||
var tp wire.TransportParameters
|
if using0RTT {
|
||||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
var tp wire.TransportParameters
|
||||||
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
|
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||||
|
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
|
||||||
|
}
|
||||||
|
t.Parameters = &tp
|
||||||
|
} else if r.Len() > 0 {
|
||||||
|
return fmt.Errorf("the session ticket has more bytes than expected")
|
||||||
}
|
}
|
||||||
t.Parameters = &tp
|
|
||||||
t.RTT = time.Duration(rtt) * time.Microsecond
|
t.RTT = time.Duration(rtt) * time.Microsecond
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Session Ticket", func() {
|
var _ = Describe("Session Ticket", func() {
|
||||||
It("marshals and unmarshals a session ticket", func() {
|
It("marshals and unmarshals a 0-RTT session ticket", func() {
|
||||||
ticket := &sessionTicket{
|
ticket := &sessionTicket{
|
||||||
Parameters: &wire.TransportParameters{
|
Parameters: &wire.TransportParameters{
|
||||||
InitialMaxStreamDataBidiLocal: 1,
|
InitialMaxStreamDataBidiLocal: 1,
|
||||||
|
@ -22,33 +22,65 @@ var _ = Describe("Session Ticket", func() {
|
||||||
RTT: 1337 * time.Microsecond,
|
RTT: 1337 * time.Microsecond,
|
||||||
}
|
}
|
||||||
var t sessionTicket
|
var t sessionTicket
|
||||||
Expect(t.Unmarshal(ticket.Marshal())).To(Succeed())
|
Expect(t.Unmarshal(ticket.Marshal(), true)).To(Succeed())
|
||||||
Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1))
|
Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1))
|
||||||
Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2))
|
Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2))
|
||||||
Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10))
|
Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10))
|
||||||
Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20))
|
Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20))
|
||||||
Expect(t.RTT).To(Equal(1337 * time.Microsecond))
|
Expect(t.RTT).To(Equal(1337 * time.Microsecond))
|
||||||
|
// fails to unmarshal the ticket as a non-0-RTT ticket
|
||||||
|
Expect(t.Unmarshal(ticket.Marshal(), false)).To(MatchError("the session ticket has more bytes than expected"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("marshals and unmarshals a non-0-RTT session ticket", func() {
|
||||||
|
ticket := &sessionTicket{
|
||||||
|
RTT: 1337 * time.Microsecond,
|
||||||
|
}
|
||||||
|
var t sessionTicket
|
||||||
|
Expect(t.Unmarshal(ticket.Marshal(), false)).To(Succeed())
|
||||||
|
Expect(t.Parameters).To(BeNil())
|
||||||
|
Expect(t.RTT).To(Equal(1337 * time.Microsecond))
|
||||||
|
// fails to unmarshal the ticket as a 0-RTT ticket
|
||||||
|
Expect(t.Unmarshal(ticket.Marshal(), true)).To(MatchError(ContainSubstring("unmarshaling transport parameters from session ticket failed")))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("refuses to unmarshal if the ticket is too short for the revision", func() {
|
It("refuses to unmarshal if the ticket is too short for the revision", func() {
|
||||||
Expect((&sessionTicket{}).Unmarshal([]byte{})).To(MatchError("failed to read session ticket revision"))
|
Expect((&sessionTicket{}).Unmarshal([]byte{}, true)).To(MatchError("failed to read session ticket revision"))
|
||||||
|
Expect((&sessionTicket{}).Unmarshal([]byte{}, false)).To(MatchError("failed to read session ticket revision"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("refuses to unmarshal if the revision doesn't match", func() {
|
It("refuses to unmarshal if the revision doesn't match", func() {
|
||||||
b := quicvarint.Append(nil, 1337)
|
b := quicvarint.Append(nil, 1337)
|
||||||
Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("unknown session ticket revision: 1337"))
|
Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("unknown session ticket revision: 1337"))
|
||||||
|
Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("unknown session ticket revision: 1337"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("refuses to unmarshal if the RTT cannot be read", func() {
|
It("refuses to unmarshal if the RTT cannot be read", func() {
|
||||||
b := quicvarint.Append(nil, sessionTicketRevision)
|
b := quicvarint.Append(nil, sessionTicketRevision)
|
||||||
Expect((&sessionTicket{}).Unmarshal(b)).To(MatchError("failed to read RTT"))
|
Expect((&sessionTicket{}).Unmarshal(b, true)).To(MatchError("failed to read RTT"))
|
||||||
|
Expect((&sessionTicket{}).Unmarshal(b, false)).To(MatchError("failed to read RTT"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("refuses to unmarshal if unmarshaling the transport parameters fails", func() {
|
It("refuses to unmarshal a 0-RTT session ticket if unmarshaling the transport parameters fails", func() {
|
||||||
b := quicvarint.Append(nil, sessionTicketRevision)
|
b := quicvarint.Append(nil, sessionTicketRevision)
|
||||||
b = append(b, []byte("foobar")...)
|
b = append(b, []byte("foobar")...)
|
||||||
err := (&sessionTicket{}).Unmarshal(b)
|
err := (&sessionTicket{}).Unmarshal(b, true)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed"))
|
Expect(err.Error()).To(ContainSubstring("unmarshaling transport parameters from session ticket failed"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("refuses to unmarshal if the non-0-RTT session ticket has more bytes than expected", func() {
|
||||||
|
ticket := &sessionTicket{
|
||||||
|
Parameters: &wire.TransportParameters{
|
||||||
|
InitialMaxStreamDataBidiLocal: 1,
|
||||||
|
InitialMaxStreamDataBidiRemote: 2,
|
||||||
|
ActiveConnectionIDLimit: 10,
|
||||||
|
MaxDatagramFrameSize: 20,
|
||||||
|
},
|
||||||
|
RTT: 1234 * time.Microsecond,
|
||||||
|
}
|
||||||
|
err := (&sessionTicket{}).Unmarshal(ticket.Marshal(), false)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("the session ticket has more bytes than expected"))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -39,13 +39,15 @@ const (
|
||||||
QUICHandshakeDone = qtls.QUICHandshakeDone
|
QUICHandshakeDone = qtls.QUICHandshakeDone
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) {
|
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
||||||
qtls.InitSessionTicketKeys(conf.TLSConfig)
|
qtls.InitSessionTicketKeys(conf.TLSConfig)
|
||||||
conf.TLSConfig = conf.TLSConfig.Clone()
|
conf.TLSConfig = conf.TLSConfig.Clone()
|
||||||
conf.TLSConfig.MinVersion = tls.VersionTLS13
|
conf.TLSConfig.MinVersion = tls.VersionTLS13
|
||||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
conf.ExtraConfig = &qtls.ExtraConfig{
|
||||||
Enable0RTT: enable0RTT,
|
Enable0RTT: enable0RTT,
|
||||||
Accept0RTT: accept0RTT,
|
Accept0RTT: func(data []byte) bool {
|
||||||
|
return handleSessionTicket(data, true)
|
||||||
|
},
|
||||||
GetAppDataForSessionTicket: getDataForSessionTicket,
|
GetAppDataForSessionTicket: getDataForSessionTicket,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ const (
|
||||||
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
|
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
|
||||||
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
|
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
|
||||||
|
|
||||||
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) {
|
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
||||||
conf := qconf.TLSConfig
|
conf := qconf.TLSConfig
|
||||||
|
|
||||||
// Workaround for https://github.com/golang/go/issues/60506.
|
// Workaround for https://github.com/golang/go/issues/60506.
|
||||||
|
@ -55,11 +55,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
|
||||||
// add callbacks to save transport parameters into the session ticket
|
// add callbacks to save transport parameters into the session ticket
|
||||||
origWrapSession := conf.WrapSession
|
origWrapSession := conf.WrapSession
|
||||||
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
|
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
|
||||||
// Add QUIC transport parameters if this is a 0-RTT packet.
|
// Add QUIC session ticket
|
||||||
// TODO(#3853): also save the RTT for non-0-RTT tickets
|
state.Extra = append(state.Extra, addExtraPrefix(getData()))
|
||||||
if state.EarlyData {
|
|
||||||
state.Extra = append(state.Extra, addExtraPrefix(getData()))
|
|
||||||
}
|
|
||||||
if origWrapSession != nil {
|
if origWrapSession != nil {
|
||||||
return origWrapSession(cs, state)
|
return origWrapSession(cs, state)
|
||||||
}
|
}
|
||||||
|
@ -83,14 +81,14 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
|
||||||
if err != nil || state == nil {
|
if err != nil || state == nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if state.EarlyData {
|
|
||||||
extra := findExtraData(state.Extra)
|
extra := findExtraData(state.Extra)
|
||||||
if unwrapCount == 1 && extra != nil { // first session ticket
|
if extra != nil {
|
||||||
state.EarlyData = accept0RTT(extra)
|
state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
|
||||||
} else { // subsequent session ticket, can't be used for 0-RTT
|
} else {
|
||||||
state.EarlyData = false
|
state.EarlyData = false
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue