Merge pull request #3636 from lucas-clemente/early-conn

make ConnectionState usable during the handshake
This commit is contained in:
Marten Seemann 2023-01-17 22:29:08 -08:00 committed by GitHub
commit 4d9ab7b604
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 139 additions and 63 deletions

View file

@ -218,6 +218,9 @@ type connection struct {
datagramQueue *datagramQueue datagramQueue *datagramQueue
connStateMutex sync.Mutex
connState ConnectionState
logID string logID string
tracer logging.ConnectionTracer tracer logging.ConnectionTracer
logger utils.Logger logger utils.Logger
@ -545,6 +548,7 @@ func (s *connection) preSetup() {
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame)
s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger)
s.connState.Version = s.version
} }
// run the connection main loop // run the connection main loop
@ -738,11 +742,10 @@ func (s *connection) supportsDatagrams() bool {
} }
func (s *connection) ConnectionState() ConnectionState { func (s *connection) ConnectionState() ConnectionState {
return ConnectionState{ s.connStateMutex.Lock()
TLS: s.cryptoStreamHandler.ConnectionState(), defer s.connStateMutex.Unlock()
SupportsDatagrams: s.supportsDatagrams(), s.connState.TLS = s.cryptoStreamHandler.ConnectionState()
Version: s.version, return s.connState
}
} }
// Time when the next keep-alive packet should be sent. // Time when the next keep-alive packet should be sent.
@ -1678,6 +1681,9 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters
s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit)
s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.connFlowController.UpdateSendWindow(params.InitialMaxData)
s.streamsMap.UpdateLimits(params) s.streamsMap.UpdateLimits(params)
s.connStateMutex.Lock()
s.connState.SupportsDatagrams = s.supportsDatagrams()
s.connStateMutex.Unlock()
} }
func (s *connection) handleTransportParameters(params *wire.TransportParameters) { func (s *connection) handleTransportParameters(params *wire.TransportParameters) {
@ -1696,6 +1702,10 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters)
// the client's transport parameters. // the client's transport parameters.
close(s.earlyConnReadyChan) close(s.earlyConnReadyChan)
} }
s.connStateMutex.Lock()
s.connState.SupportsDatagrams = s.supportsDatagrams()
s.connStateMutex.Unlock()
} }
func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { func (s *connection) checkTransportParameters(params *wire.TransportParameters) error {

4
go.mod
View file

@ -6,8 +6,8 @@ require (
github.com/francoispqt/gojay v1.2.13 github.com/francoispqt/gojay v1.2.13
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/marten-seemann/qpack v0.3.0 github.com/marten-seemann/qpack v0.3.0
github.com/marten-seemann/qtls-go1-18 v0.1.3 github.com/marten-seemann/qtls-go1-18 v0.1.4
github.com/marten-seemann/qtls-go1-19 v0.1.1 github.com/marten-seemann/qtls-go1-19 v0.1.2
github.com/onsi/ginkgo/v2 v2.2.0 github.com/onsi/ginkgo/v2 v2.2.0
github.com/onsi/gomega v1.20.1 github.com/onsi/gomega v1.20.1
golang.org/x/crypto v0.4.0 golang.org/x/crypto v0.4.0

8
go.sum
View file

@ -70,10 +70,10 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE= github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE=
github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g= github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g=
github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.4 h1:ogomB+lWV3Vmwiu6RTwDVTMGx+9j7SEi98e8QB35Its=
github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= github.com/marten-seemann/qtls-go1-18 v0.1.4/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4=
github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= github.com/marten-seemann/qtls-go1-19 v0.1.2 h1:ZevAEqKXH0bZmoOBPiqX2h5rhQ7cbZi+X+rlq2JUbCE=
github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/marten-seemann/qtls-go1-19 v0.1.2/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View file

@ -1,7 +1,6 @@
package http3 package http3
import ( import (
"crypto/tls"
"errors" "errors"
"net/http" "net/http"
"net/url" "net/url"
@ -101,7 +100,6 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) {
ContentLength: contentLength, ContentLength: contentLength,
Host: authority, Host: authority,
RequestURI: requestURI, RequestURI: requestURI,
TLS: &tls.ConnectionState{},
}, nil }, nil
} }

View file

@ -30,7 +30,6 @@ var _ = Describe("Request", func() {
Expect(req.Body).To(BeNil()) Expect(req.Body).To(BeNil())
Expect(req.Host).To(Equal("quic.clemente.io")) Expect(req.Host).To(Equal("quic.clemente.io"))
Expect(req.RequestURI).To(Equal("/foo")) Expect(req.RequestURI).To(Equal("/foo"))
Expect(req.TLS).ToNot(BeNil())
}) })
It("parses path with leading double slashes", func() { It("parses path with leading double slashes", func() {

View file

@ -272,7 +272,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
baseConf := ConfigureTLSConfig(tlsConf) baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig quicConf := s.QuicConfig
if quicConf == nil { if quicConf == nil {
quicConf = &quic.Config{} quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }}
} else { } else {
quicConf = s.QuicConfig.Clone() quicConf = s.QuicConfig.Clone()
} }
@ -570,6 +570,8 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q
return newStreamError(errorGeneralProtocolError, err) return newStreamError(errorGeneralProtocolError, err)
} }
connState := conn.ConnectionState().TLS.ConnectionState
req.TLS = &connState
req.RemoteAddr = conn.RemoteAddr().String() req.RemoteAddr = conn.RemoteAddr().String()
body := newRequestBody(newStream(str, onFrameError)) body := newRequestBody(newStream(str, onFrameError))
req.Body = body req.Body = body

View file

@ -163,6 +163,7 @@ var _ = Describe("Server", func() {
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes() conn.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
}) })
It("calls the HTTP handler function", func() { It("calls the HTTP handler function", func() {
@ -632,6 +633,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes() conn.EXPECT().LocalAddr().AnyTimes()
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes()
}) })
AfterEach(func() { testDone <- struct{}{} }) AfterEach(func() { testDone <- struct{}{} })

View file

@ -81,10 +81,10 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE= github.com/marten-seemann/qpack v0.3.0 h1:UiWstOgT8+znlkDPOg2+3rIuYXJ2CnGDkGUXN6ki6hE=
github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g= github.com/marten-seemann/qpack v0.3.0/go.mod h1:cGfKPBiP4a9EQdxCwEwI/GEeWAsjSekBvx/X8mh58+g=
github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.4 h1:ogomB+lWV3Vmwiu6RTwDVTMGx+9j7SEi98e8QB35Its=
github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= github.com/marten-seemann/qtls-go1-18 v0.1.4/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4=
github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= github.com/marten-seemann/qtls-go1-19 v0.1.2 h1:ZevAEqKXH0bZmoOBPiqX2h5rhQ7cbZi+X+rlq2JUbCE=
github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/marten-seemann/qtls-go1-19 v0.1.2/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"net" "net"
"time" "time"
@ -18,54 +19,33 @@ import (
var _ = Describe("Handshake RTT tests", func() { var _ = Describe("Handshake RTT tests", func() {
var ( var (
proxy *quicproxy.QuicProxy proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config serverConfig *quic.Config
serverTLSConfig *tls.Config serverTLSConfig *tls.Config
testStartedAt time.Time
acceptStopped chan struct{}
) )
rtt := 400 * time.Millisecond const rtt = 400 * time.Millisecond
BeforeEach(func() { BeforeEach(func() {
acceptStopped = make(chan struct{})
serverConfig = getQuicConfig(nil) serverConfig = getQuicConfig(nil)
serverTLSConfig = getTLSConfig() serverTLSConfig = getTLSConfig()
}) })
AfterEach(func() { AfterEach(func() {
Expect(proxy.Close()).To(Succeed()) Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
<-acceptStopped
}) })
runServerAndProxy := func() { runProxy := func(serverAddr net.Addr) {
var err error var err error
// start the server
server, err = quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy // start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: server.Addr().String(), RemoteAddr: serverAddr.String(),
DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 },
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
if _, err := server.Accept(context.Background()); err != nil {
return
}
}
}()
} }
expectDurationInRTTs := func(num int) { expectDurationInRTTs := func(startTime time.Time, num int) {
testDuration := time.Since(testStartedAt) testDuration := time.Since(startTime)
rtts := float32(testDuration) / float32(rtt) rtts := float32(testDuration) / float32(rtt)
Expect(rtts).To(SatisfyAll( Expect(rtts).To(SatisfyAll(
BeNumerically(">=", num), BeNumerically(">=", num),
@ -78,15 +58,19 @@ var _ = Describe("Handshake RTT tests", func() {
Skip("Test requires at least 2 supported versions.") Skip("Test requires at least 2 supported versions.")
} }
serverConfig.Versions = protocol.SupportedVersions[:1] serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy() ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
clientConfig := getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}) Expect(err).ToNot(HaveOccurred())
_, err := quic.DialAddr( defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
proxy.LocalAddr().String(), proxy.LocalAddr().String(),
getTLSClientConfig(), getTLSClientConfig(),
clientConfig, getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
) )
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
expectDurationInRTTs(1) expectDurationInRTTs(startTime, 1)
}) })
var clientConfig *quic.Config var clientConfig *quic.Config
@ -102,36 +86,114 @@ var _ = Describe("Handshake RTT tests", func() {
// 1 RTT for the TLS handshake // 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() { It("is forward-secure after 2 RTTs", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
runServerAndProxy() ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
_, err := quic.DialAddr( Expect(err).ToNot(HaveOccurred())
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
clientConfig, clientConfig,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2) expectDurationInRTTs(startTime, 2)
}) })
It("establishes a connection in 1 RTT when the server doesn't require a token", func() { It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
runServerAndProxy() ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
_, err := quic.DialAddr( Expect(err).ToNot(HaveOccurred())
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
clientConfig, clientConfig,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(1) expectDurationInRTTs(startTime, 1)
}) })
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
runServerAndProxy() ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
_, err := quic.DialAddr( Expect(err).ToNot(HaveOccurred())
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(), getTLSClientConfig(),
clientConfig, clientConfig,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2) expectDurationInRTTs(startTime, 2)
})
It("receives the first message from the server after 2 RTTs, when the server uses ListenAddr", func() {
ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
expectDurationInRTTs(startTime, 2)
})
It("receives the first message from the server after 1 RTT, when the server uses ListenAddrEarly", func() {
ln, err := quic.ListenAddrEarly("localhost:0", serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
// Check the ALPN now. This is probably what an application would do.
// It makes sure that ConnectionState does not block until the handshake completes.
Expect(conn.ConnectionState().TLS.NegotiatedProtocol).To(Equal(alpn))
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
defer ln.Close()
runProxy(ln.Addr())
startTime := time.Now()
conn, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
expectDurationInRTTs(startTime, 1)
}) })
}) })

View file

@ -110,19 +110,20 @@ var _ = Describe("0-RTT", func() {
clientConf *quic.Config, clientConf *quic.Config,
testdata []byte, // data to transfer testdata []byte, // data to transfer
) { ) {
// now dial the second connection, and use 0-RTT to send some data // accept the second connection, and receive the data sent in 0-RTT
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
conn, err := ln.Accept(context.Background()) conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background()) str, err := conn.AcceptStream(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str) data, err := io.ReadAll(str)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testdata)) Expect(data).To(Equal(testdata))
Expect(str.Close()).To(Succeed())
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
Expect(conn.CloseWithError(0, "")).To(Succeed()) <-conn.Context().Done()
close(done) close(done)
}() }()
@ -136,13 +137,15 @@ var _ = Describe("0-RTT", func() {
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "") defer conn.CloseWithError(0, "")
str, err := conn.OpenUniStream() str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = str.Write(testdata) _, err = str.Write(testdata)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed()) Expect(str.Close()).To(Succeed())
<-conn.HandshakeComplete().Done() <-conn.HandshakeComplete().Done()
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn
conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed())
} }