From aee7706d5ddb035aedc5ab9489e9940124f385bd Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 6 Apr 2023 18:32:46 +0800 Subject: [PATCH] clean up dial functions * add a context to all Dial functions * remove the explicit Dial{*}Context functions --- client.go | 42 ++++--------------- client_test.go | 18 ++++---- example/echo/echo.go | 2 +- http3/client.go | 2 +- http3/roundtrip.go | 2 +- http3/server_test.go | 10 ++--- integrationtests/self/benchmark_test.go | 4 +- integrationtests/self/cancelation_test.go | 9 ++++ integrationtests/self/close_test.go | 1 + integrationtests/self/conn_id_test.go | 1 + integrationtests/self/datagram_test.go | 3 ++ integrationtests/self/deadline_test.go | 1 + integrationtests/self/drop_test.go | 1 + integrationtests/self/early_data_test.go | 1 + integrationtests/self/handshake_drop_test.go | 3 ++ integrationtests/self/handshake_rtt_test.go | 5 +++ integrationtests/self/handshake_test.go | 14 ++++++- integrationtests/self/key_update_test.go | 1 + integrationtests/self/mitm_test.go | 3 ++ integrationtests/self/multiplex_test.go | 1 + integrationtests/self/packetization_test.go | 1 + integrationtests/self/resumption_test.go | 4 ++ integrationtests/self/rtt_test.go | 2 + integrationtests/self/stateless_reset_test.go | 1 + integrationtests/self/stream_test.go | 3 ++ integrationtests/self/timeout_test.go | 11 ++++- integrationtests/self/tracer_test.go | 1 + integrationtests/self/uni_stream_test.go | 3 ++ integrationtests/self/zero_rtt_test.go | 7 ++++ .../versionnegotiation/handshake_test.go | 4 +- .../versionnegotiation/rtt_test.go | 2 + interop/http09/client.go | 2 +- 32 files changed, 106 insertions(+), 59 deletions(-) diff --git a/client.go b/client.go index 287b3a20..62cad580 100644 --- a/client.go +++ b/client.go @@ -46,24 +46,14 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. // It uses a new UDP connection and closes this connection when the QUIC connection is closed. -func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Connection, error) { - return DialAddrContext(context.Background(), addr, tlsConf, config) +// The hostname for SNI is taken from the given address. +func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) { + return dialAddrContext(ctx, addr, tlsConf, config, false) } // DialAddrEarly establishes a new 0-RTT QUIC connection to a server. // It uses a new UDP connection and closes this connection when the QUIC connection is closed. -func DialAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { - return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context. -// See DialAddrEarly for details -func DialAddrEarlyContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { +func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) if err != nil { return nil, err @@ -72,12 +62,6 @@ func DialAddrEarlyContext( return conn, nil } -// DialAddrContext establishes a new QUIC connection to a server using the provided context. -// See DialAddr for details. -func DialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialAddrContext(ctx, addr, tlsConf, config, false) -} - func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { @@ -98,30 +82,18 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf // The same PacketConn can be used for multiple calls to Dial and Listen. // QUIC connection IDs are used for demultiplexing the different connections. // The tls.Config must define an application protocol (using NextProtos). -func Dial(pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialContext(context.Background(), pconn, addr, tlsConf, config, false, false) +func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) { + return dialContext(ctx, pconn, addr, tlsConf, config, false, false) } // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. // The same PacketConn can be used for multiple calls to Dial and Listen, // QUIC connection IDs are used for demultiplexing the different connections. // The tls.Config must define an application protocol (using NextProtos). -func DialEarly(pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { - return DialEarlyContext(context.Background(), pconn, addr, tlsConf, config) -} - -// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. -// See DialEarly for details. -func DialEarlyContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { +func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) { return dialContext(ctx, pconn, addr, tlsConf, config, true, false) } -// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. -// See Dial for details. -func DialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) { - return dialContext(ctx, pconn, addr, tlsConf, config, false, false) -} - func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") diff --git a/client_test.go b/client_test.go index d5a10f37..23122e62 100644 --- a/client_test.go +++ b/client_test.go @@ -128,7 +128,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } - _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) + _, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) Expect(err).ToNot(HaveOccurred()) Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) }) @@ -163,7 +163,7 @@ var _ = Describe("Client", func() { return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := Dial(packetConn, addr, tlsConf, config) + s, err := Dial(context.Background(), packetConn, addr, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) Eventually(run).Should(BeClosed()) @@ -203,7 +203,7 @@ var _ = Describe("Client", func() { defer GinkgoRecover() defer close(done) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - s, err := DialEarly(packetConn, addr, tlsConf, config) + s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) }() @@ -239,7 +239,7 @@ var _ = Describe("Client", func() { return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := Dial(packetConn, addr, tlsConf, config) + _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) Expect(err).To(MatchError(testErr)) }) @@ -277,7 +277,7 @@ var _ = Describe("Client", func() { go func() { defer GinkgoRecover() tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := DialContext(ctx, packetConn, addr, tlsConf, config) + _, err := Dial(ctx, packetConn, addr, tlsConf, config) Expect(err).To(MatchError(context.Canceled)) close(dialed) }() @@ -323,7 +323,7 @@ var _ = Describe("Client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := DialAddr("localhost:1337", tlsConf, nil) + _, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil) Expect(err).ToNot(HaveOccurred()) close(done) }() @@ -370,7 +370,7 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) - _, err := Dial(packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) + _, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) }) @@ -442,7 +442,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) return conn } - _, err := Dial(packetConn, addr, tlsConf, config) + _, err := Dial(context.Background(), packetConn, addr, tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn)) @@ -492,7 +492,7 @@ var _ = Describe("Client", func() { config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - _, err := DialAddr("localhost:7890", tlsConf, config) + _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) Expect(counter).To(Equal(2)) }) diff --git a/example/echo/echo.go b/example/echo/echo.go index 9d2de700..011c70a3 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -54,7 +54,7 @@ func clientMain() error { InsecureSkipVerify: true, NextProtos: []string{"quic-echo-example"}, } - conn, err := quic.DialAddr(addr, tlsConf, nil) + conn, err := quic.DialAddr(context.Background(), addr, tlsConf, nil) if err != nil { return err } diff --git a/http3/client.go b/http3/client.go index 82af3aff..9c94a31e 100644 --- a/http3/client.go +++ b/http3/client.go @@ -37,7 +37,7 @@ var defaultQuicConfig = &quic.Config{ type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) -var dialAddr = quic.DialAddrEarlyContext +var dialAddr dialFunc = quic.DialAddrEarly type roundTripperOpts struct { DisableCompression bool diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 0e8b3a6c..53885e1e 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -17,7 +17,7 @@ import ( ) // declare this as a variable, such that we can it mock it in the tests -var quicDialer = quic.DialEarlyContext +var quicDialer = quic.DialEarly type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) diff --git a/http3/server_test.go b/http3/server_test.go index 3f288f9f..bbf8c9cc 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -932,7 +932,7 @@ var _ = Describe("Server", func() { ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() - c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) @@ -943,7 +943,7 @@ var _ = Describe("Server", func() { ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.VersionDraft29}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() - c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3Draft29}}, nil) + c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3Draft29}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3Draft29)) @@ -971,7 +971,7 @@ var _ = Describe("Server", func() { ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() - c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) @@ -984,7 +984,7 @@ var _ = Describe("Server", func() { ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() - c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) @@ -1002,7 +1002,7 @@ var _ = Describe("Server", func() { ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() - c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) diff --git a/integrationtests/self/benchmark_test.go b/integrationtests/self/benchmark_test.go index 682a6625..983d8b7d 100644 --- a/integrationtests/self/benchmark_test.go +++ b/integrationtests/self/benchmark_test.go @@ -36,7 +36,7 @@ func BenchmarkHandshake(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - c, err := quic.Dial(conn, ln.Addr(), tlsClientConfig, nil) + c, err := quic.Dial(context.Background(), conn, ln.Addr(), tlsClientConfig, nil) if err != nil { b.Fatal(err) } @@ -71,7 +71,7 @@ func BenchmarkStreamChurn(b *testing.B) { } }() - c, err := quic.DialAddr(fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), tlsClientConfig, nil) + c, err := quic.DialAddr(context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), tlsClientConfig, nil) if err != nil { b.Fatal(err) } diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 7b98aba4..f4de3574 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -73,6 +73,7 @@ var _ = Describe("Stream Cancellations", func() { It("downloads when the client immediately cancels most streams", func() { serverCanceledCounterChan := runServer(PRData) conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -124,6 +125,7 @@ var _ = Describe("Stream Cancellations", func() { serverCanceledCounterChan := runServer(PRData) conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -174,6 +176,7 @@ var _ = Describe("Stream Cancellations", func() { serverCanceledCounterChan := runServer(make([]byte, 100)) // make sure the FIN is sent with the STREAM frame conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -221,6 +224,7 @@ var _ = Describe("Stream Cancellations", func() { Context("canceling the write side", func() { runClient := func(server quic.Listener) int32 /* number of canceled streams */ { conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -366,6 +370,7 @@ var _ = Describe("Stream Cancellations", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -451,6 +456,7 @@ var _ = Describe("Stream Cancellations", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -529,6 +535,7 @@ var _ = Describe("Stream Cancellations", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 3}), @@ -614,6 +621,7 @@ var _ = Describe("Stream Cancellations", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: maxIncomingStreams}), @@ -707,6 +715,7 @@ var _ = Describe("Stream Cancellations", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{}), diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index c2f1439a..31905e30 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -44,6 +44,7 @@ var _ = Describe("Connection ID lengths tests", func() { defer proxy.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 21b413e6..ddf3cfc5 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -61,6 +61,7 @@ var _ = Describe("Connection ID lengths tests", func() { runClient := func(addr net.Addr, conf *quic.Config) { GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength))) cl, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), getTLSClientConfig(), conf, diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index d23f887d..552488f4 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -105,6 +105,7 @@ var _ = Describe("Datagram test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( + context.Background(), clientConn, raddr, getTLSClientConfig(), @@ -140,6 +141,7 @@ var _ = Describe("Datagram test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( + context.Background(), clientConn, raddr, getTLSClientConfig(), @@ -157,6 +159,7 @@ var _ = Describe("Datagram test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( + context.Background(), clientConn, raddr, getTLSClientConfig(), diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index ae7a08cf..2dcb4b3b 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -30,6 +30,7 @@ var _ = Describe("Stream deadline tests", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index 52eea516..f3265f1c 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -97,6 +97,7 @@ var _ = Describe("Drop Tests", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index 5b96b2cb..0ce09926 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -50,6 +50,7 @@ var _ = Describe("early data", func() { defer proxy.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index ebb8ee14..e4d6d6b3 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -81,6 +81,7 @@ var _ = Describe("Handshake drop tests", func() { serverConnChan <- conn }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -118,6 +119,7 @@ var _ = Describe("Handshake drop tests", func() { serverConnChan <- conn }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -150,6 +152,7 @@ var _ = Describe("Handshake drop tests", func() { serverConnChan <- conn }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 5fcc5546..e1dff321 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -63,6 +63,7 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() _, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -79,6 +80,7 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() _, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -96,6 +98,7 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() _, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -122,6 +125,7 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -156,6 +160,7 @@ var _ = Describe("Handshake RTT tests", func() { runProxy(ln.Addr()) startTime := time.Now() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 42251ecc..7a5d000f 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -111,6 +111,7 @@ var _ = Describe("Handshake tests", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), nil, @@ -131,6 +132,7 @@ var _ = Describe("Handshake tests", func() { It("accepts the certificate", func() { runServer(getTLSConfig()) _, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -141,6 +143,7 @@ var _ = Describe("Handshake tests", func() { It("works with a long certificate chain", func() { runServer(getTLSConfigWithLongCertChain()) _, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -155,6 +158,7 @@ var _ = Describe("Handshake tests", func() { conf := getTLSClientConfig() conf.ServerName = "foo.bar" _, err = quic.Dial( + context.Background(), conn, server.Addr(), conf, @@ -173,6 +177,7 @@ var _ = Describe("Handshake tests", func() { runServer(tlsConf) conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), @@ -201,6 +206,7 @@ var _ = Describe("Handshake tests", func() { tlsConf := getTLSClientConfig() tlsConf.ServerName = "foo.bar" _, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), @@ -223,7 +229,7 @@ var _ = Describe("Handshake tests", func() { remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) raddr, err := net.ResolveUDPAddr("udp", remoteAddr) Expect(err).ToNot(HaveOccurred()) - return quic.Dial(pconn, raddr, getTLSClientConfig(), nil) + return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil) } BeforeEach(func() { @@ -324,6 +330,7 @@ var _ = Describe("Handshake tests", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), nil, @@ -342,6 +349,7 @@ var _ = Describe("Handshake tests", func() { tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{"foobar"} _, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, @@ -371,6 +379,7 @@ var _ = Describe("Handshake tests", func() { tokenStore := newTokenStore(gets, puts) quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore}) conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), quicConf, @@ -390,6 +399,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) }() conn, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), quicConf, @@ -410,6 +420,7 @@ var _ = Describe("Handshake tests", func() { defer server.Close() _, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), nil, @@ -441,6 +452,7 @@ var _ = Describe("Handshake tests", func() { tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{""} _, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port), tlsConf, nil, diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 21b8aeef..51ae6be0 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -72,6 +72,7 @@ var _ = Describe("Key Update tests", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}), diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index b9b9800f..133c6f23 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -147,6 +147,7 @@ var _ = Describe("MITM test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( + context.Background(), clientUDPConn, raddr, getTLSClientConfig(), @@ -190,6 +191,7 @@ var _ = Describe("MITM test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( + context.Background(), clientUDPConn, raddr, getTLSClientConfig(), @@ -304,6 +306,7 @@ var _ = Describe("MITM test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) _, err = quic.Dial( + context.Background(), clientUDPConn, raddr, getTLSClientConfig(), diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 77fea357..825ccb9f 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -36,6 +36,7 @@ var _ = Describe("Multiplexing", func() { dial := func(pconn net.PacketConn, addr net.Addr) { conn, err := quic.Dial( + context.Background(), pconn, addr, getTLSClientConfig(), diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index f85f7f41..1338b30c 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -45,6 +45,7 @@ var _ = Describe("Packetization", func() { clientTracer := newPacketTracer() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index 7732a2b0..f2c6ec89 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -58,6 +58,7 @@ var _ = Describe("TLS session resumption", func() { tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, @@ -72,6 +73,7 @@ var _ = Describe("TLS session resumption", func() { Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) conn, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, @@ -98,6 +100,7 @@ var _ = Describe("TLS session resumption", func() { tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, @@ -111,6 +114,7 @@ var _ = Describe("TLS session resumption", func() { Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) conn, err = quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index 223177fd..7482a9c9 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -37,6 +37,7 @@ var _ = Describe("non-zero RTT", func() { downloadFile := func(port int) { conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(nil), @@ -72,6 +73,7 @@ var _ = Describe("non-zero RTT", func() { defer proxy.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 581e6d75..98b21b2a 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -56,6 +56,7 @@ var _ = Describe("Stateless Resets", func() { defer proxy.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 622fad29..59484deb 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -87,6 +87,7 @@ var _ = Describe("Bidirectional streams", func() { }() client, err := quic.DialAddr( + context.Background(), serverAddr, getTLSClientConfig(), getQuicConfig(nil), @@ -105,6 +106,7 @@ var _ = Describe("Bidirectional streams", func() { }() client, err := quic.DialAddr( + context.Background(), serverAddr, getTLSClientConfig(), getQuicConfig(nil), @@ -132,6 +134,7 @@ var _ = Describe("Bidirectional streams", func() { }() client, err := quic.DialAddr( + context.Background(), serverAddr, getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 3ff4175c..4fb2a733 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -54,6 +54,7 @@ var _ = Describe("Timeout tests", func() { errChan := make(chan error) go func() { _, err := quic.DialAddr( + context.Background(), "localhost:12345", getTLSClientConfig(), getQuicConfig(&quic.Config{HandshakeIdleTimeout: 10 * time.Millisecond}), @@ -70,7 +71,7 @@ var _ = Describe("Timeout tests", func() { defer cancel() errChan := make(chan error) go func() { - _, err := quic.DialAddrContext( + _, err := quic.DialAddr( ctx, "localhost:12345", getTLSClientConfig(), @@ -89,7 +90,7 @@ var _ = Describe("Timeout tests", func() { defer cancel() errChan := make(chan error) go func() { - _, err := quic.DialAddrEarlyContext( + _, err := quic.DialAddrEarly( ctx, "localhost:12345", getTLSClientConfig(), @@ -135,6 +136,7 @@ var _ = Describe("Timeout tests", func() { defer proxy.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}), @@ -194,6 +196,7 @@ var _ = Describe("Timeout tests", func() { tr := newPacketTracer() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -273,6 +276,7 @@ var _ = Describe("Timeout tests", func() { }() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}), @@ -341,6 +345,7 @@ var _ = Describe("Timeout tests", func() { defer proxy.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -423,6 +428,7 @@ var _ = Describe("Timeout tests", func() { go func() { defer GinkgoRecover() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -484,6 +490,7 @@ var _ = Describe("Timeout tests", func() { go func() { defer GinkgoRecover() conn, err := quic.Dial( + context.Background(), &faultyConn{PacketConn: conn, MaxPackets: maxPackets}, ln.Addr(), getTLSClientConfig(), diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 6fe25e38..eb062de3 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -80,6 +80,7 @@ var _ = Describe("Handshake tests", func() { defer ln.Close() conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), quicClientConf, diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 0f73c1f0..a809d9d3 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -77,6 +77,7 @@ var _ = Describe("Unidirectional Streams", func() { }() client, err := quic.DialAddr( + context.Background(), serverAddr, getTLSClientConfig(), getQuicConfig(nil), @@ -95,6 +96,7 @@ var _ = Describe("Unidirectional Streams", func() { }() client, err := quic.DialAddr( + context.Background(), serverAddr, getTLSClientConfig(), getQuicConfig(nil), @@ -121,6 +123,7 @@ var _ = Describe("Unidirectional Streams", func() { }() client, err := quic.DialAddr( + context.Background(), serverAddr, getTLSClientConfig(), getQuicConfig(nil), diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index cc7b4264..5aee0361 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -86,6 +86,7 @@ var _ = Describe("0-RTT", func() { puts := make(chan string, 100) clientConf.ClientSessionCache = newClientSessionCache(gets, puts) conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(nil), @@ -126,6 +127,7 @@ var _ = Describe("0-RTT", func() { clientConf = getQuicConfig(nil) } conn, err := quic.DialAddrEarly( + context.Background(), fmt.Sprintf("localhost:%d", proxyPort), clientTLSConf, clientConf, @@ -151,6 +153,7 @@ var _ = Describe("0-RTT", func() { clientConf *tls.Config, ) { conn, err := quic.DialAddrEarly( + context.Background(), fmt.Sprintf("localhost:%d", proxyPort), clientConf, getQuicConfig(nil), @@ -280,6 +283,7 @@ var _ = Describe("0-RTT", func() { defer proxy.Close() conn, err := quic.DialAddrEarly( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(nil), @@ -479,6 +483,7 @@ var _ = Describe("0-RTT", func() { defer proxy.Close() conn, err := quic.DialAddrEarly( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(nil), @@ -609,6 +614,7 @@ var _ = Describe("0-RTT", func() { defer proxy.Close() conn, err := quic.DialAddrEarly( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(nil), @@ -680,6 +686,7 @@ var _ = Describe("0-RTT", func() { defer proxy.Close() conn, err := quic.DialAddrEarly( + context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(nil), diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index b2cd1269..a38edf30 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -69,7 +69,7 @@ var _ = Describe("Handshake tests", func() { var supportedVersions []protocol.VersionNumber BeforeEach(func() { - supportedVersions = protocol.SupportedVersions + supportedVersions = append([]quic.VersionNumber{}, protocol.SupportedVersions...) protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...) }) @@ -90,6 +90,7 @@ var _ = Describe("Handshake tests", func() { defer cl() clientTracer := &versionNegotiationTracer{} conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), @@ -119,6 +120,7 @@ var _ = Describe("Handshake tests", func() { clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} clientTracer := &versionNegotiationTracer{} conn, err := quic.DialAddr( + context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQlogTracer(&quic.Config{ diff --git a/integrationtests/versionnegotiation/rtt_test.go b/integrationtests/versionnegotiation/rtt_test.go index 7f8186ee..2c8868db 100644 --- a/integrationtests/versionnegotiation/rtt_test.go +++ b/integrationtests/versionnegotiation/rtt_test.go @@ -1,6 +1,7 @@ package versionnegotiation import ( + "context" "time" "github.com/quic-go/quic-go" @@ -43,6 +44,7 @@ var _ = Describe("Handshake RTT tests", func() { startTime := time.Now() _, err = quic.DialAddr( + context.Background(), proxy.LocalAddr().String(), getTLSClientConfig(), maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), diff --git a/interop/http09/client.go b/interop/http09/client.go index 6cb326bb..b6de79ef 100644 --- a/interop/http09/client.go +++ b/interop/http09/client.go @@ -90,7 +90,7 @@ type client struct { func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { c.once.Do(func() { - c.conn, c.dialErr = quic.DialAddrEarly(c.hostname, c.tlsConf, c.quicConf) + c.conn, c.dialErr = quic.DialAddrEarly(context.Background(), c.hostname, c.tlsConf, c.quicConf) }) if c.dialErr != nil { return nil, c.dialErr