remove the host parameter from all dial functions

This commit is contained in:
Marten Seemann 2023-04-01 18:40:26 +09:00
parent ea721c9c75
commit d683b841c4
14 changed files with 43 additions and 237 deletions

114
client.go
View file

@ -46,23 +46,13 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address.
func DialAddr(
addr string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address.
func DialAddrEarly(
addr string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
func DialAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
}
@ -84,22 +74,11 @@ func DialAddrEarlyContext(
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// See DialAddr for details.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
func DialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialAddrContext(ctx, addr, tlsConf, config, false)
}
func dialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
use0RTT bool,
) (quicConn, error) {
func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -108,78 +87,42 @@ func dialAddrContext(
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets. The same PacketConn can be used for multiple calls to Dial and
// Listen, QUIC connection IDs are used for demultiplexing the different
// connections. The host parameter is used for SNI. The tls.Config must define
// an application protocol (using NextProtos).
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
// packets.
// The same PacketConn can be used for multiple calls to Dial and Listen.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must define an application protocol (using NextProtos).
func Dial(pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialContext(context.Background(), pconn, addr, tlsConf, config, false, false)
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The host parameter is used for SNI.
// The tls.Config must define an application protocol (using NextProtos).
func DialEarly(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
func DialEarly(pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
return DialEarlyContext(context.Background(), pconn, addr, tlsConf, config)
}
// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
// See DialEarly for details.
func DialEarlyContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
func DialEarlyContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
return dialContext(ctx, pconn, addr, tlsConf, config, true, false)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// See Dial for details.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
func DialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialContext(ctx, pconn, addr, tlsConf, config, false, false)
}
func dialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
use0RTT bool,
createdPacketConn bool,
) (quicConn, error) {
func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
@ -191,7 +134,7 @@ func dialContext(
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn)
c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn)
if err != nil {
return nil, err
}
@ -214,29 +157,12 @@ func dialContext(
return c.conn, nil
}
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
use0RTT bool,
createdPacketConn bool,
) (*client, error) {
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
if tlsConf.ServerName == "" {
sni, _, err := net.SplitHostPort(host)
if err != nil {
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
sni = host
}
tlsConf.ServerName = sni
}
// check that all versions are actually supported
if config != nil {

View file

@ -133,79 +133,6 @@ var _ = Describe("Client", func() {
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
})
It("uses the tls.Config.ServerName as the hostname, if present", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
hostnameChan := make(chan string, 1)
newClientConnection = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
tlsConf *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
hostnameChan <- tlsConf.ServerName
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
tlsConf.ServerName = "foobar"
_, err := DialAddr("localhost:17890", tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
Eventually(hostnameChan).Should(Receive(Equal("foobar")))
})
It("allows passing host without port as server name", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
hostnameChan := make(chan string, 1)
newClientConnection = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
tlsConf *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
hostnameChan <- tlsConf.ServerName
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run()
return conn
}
tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"test.com",
tlsConf,
config,
)
Expect(err).ToNot(HaveOccurred())
Eventually(hostnameChan).Should(Receive(Equal("test.com")))
})
It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
@ -236,13 +163,7 @@ var _ = Describe("Client", func() {
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
s, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
config,
)
s, err := Dial(packetConn, addr, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Eventually(run).Should(BeClosed())
@ -282,13 +203,7 @@ var _ = Describe("Client", func() {
defer GinkgoRecover()
defer close(done)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
s, err := DialEarly(
packetConn,
addr,
"localhost:1337",
tlsConf,
config,
)
s, err := DialEarly(packetConn, addr, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
}()
@ -324,13 +239,7 @@ var _ = Describe("Client", func() {
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
config,
)
_, err := Dial(packetConn, addr, tlsConf, config)
Expect(err).To(MatchError(testErr))
})
@ -368,14 +277,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialContext(
ctx,
packetConn,
addr,
"localhost:1337",
tlsConf,
config,
)
_, err := DialContext(ctx, packetConn, addr, tlsConf, config)
Expect(err).To(MatchError(context.Canceled))
close(dialed)
}()
@ -468,7 +370,7 @@ var _ = Describe("Client", func() {
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
version := protocol.VersionNumber(0x1234)
_, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
_, err := Dial(packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
})
@ -540,7 +442,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
_, err := Dial(packetConn, addr, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))

View file

@ -273,6 +273,6 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
if err != nil {
return nil, err
}
return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg)
return quicDialer(ctx, r.udpConn, udpAddr, tlsCfg, cfg)
}
}

View file

@ -322,7 +322,7 @@ var _ = Describe("RoundTripper", func() {
})
Context("reusing udpconn", func() {
var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, host string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error)
var originalDialer func(ctx context.Context, pconn net.PacketConn, remoteAddr net.Addr, tlsConf *tls.Config, config *quic.Config) (quic.EarlyConnection, error)
var req1, req2 *http.Request
BeforeEach(func() {
@ -356,7 +356,7 @@ var _ = Describe("RoundTripper", func() {
It("reuses udpconn in different hosts", func() {
Expect(rt.udpConn).To(BeNil())
quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
quicDialer = func(_ context.Context, pconn net.PacketConn, _ net.Addr, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().LocalAddr().Return(pconn.LocalAddr())
return conn, nil

View file

@ -36,7 +36,7 @@ func BenchmarkHandshake(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c, err := quic.Dial(conn, ln.Addr(), "localhost", tlsClientConfig, nil)
c, err := quic.Dial(conn, ln.Addr(), tlsClientConfig, nil)
if err != nil {
b.Fatal(err)
}

View file

@ -107,7 +107,6 @@ var _ = Describe("Datagram test", func() {
conn, err := quic.Dial(
clientConn,
raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
@ -143,7 +142,6 @@ var _ = Describe("Datagram test", func() {
conn, err := quic.Dial(
clientConn,
raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
@ -161,7 +159,6 @@ var _ = Describe("Datagram test", func() {
conn, err := quic.Dial(
clientConn,
raddr,
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)

View file

@ -152,11 +152,12 @@ var _ = Describe("Handshake tests", func() {
runServer(getTLSConfig())
conn, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
conf := getTLSClientConfig()
conf.ServerName = "foo.bar"
_, err = quic.Dial(
conn,
server.Addr(),
"foo.bar",
getTLSClientConfig(),
conf,
getQuicConfig(nil),
)
Expect(err).To(HaveOccurred())
@ -222,13 +223,7 @@ var _ = Describe("Handshake tests", func() {
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
Expect(err).ToNot(HaveOccurred())
return quic.Dial(
pconn,
raddr,
remoteAddr,
getTLSClientConfig(),
nil,
)
return quic.Dial(pconn, raddr, getTLSClientConfig(), nil)
}
BeforeEach(func() {

View file

@ -2,7 +2,6 @@ package self_test
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
@ -12,7 +11,6 @@ import (
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/internal/testdata"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -84,17 +82,14 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
server1 = &http3.Server{
Handler: mux1,
TLSConfig: testdata.GetTLSConfig(),
QuicConfig: getQuicConfig(nil),
}
server2 = &http3.Server{
Handler: mux2,
TLSConfig: testdata.GetTLSConfig(),
QuicConfig: getQuicConfig(nil),
}
tlsConf := http3.ConfigureTLSConfig(testdata.GetTLSConfig())
tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
ln = &listenerWrapper{EarlyListener: quicln}
Expect(err).NotTo(HaveOccurred())
@ -108,9 +103,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
BeforeEach(func() {
client = &http.Client{
Transport: &http3.RoundTripper{
TLSClientConfig: &tls.Config{
RootCAs: testdata.GetRootCA(),
},
TLSClientConfig: getTLSClientConfig(),
DisableCompression: true,
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
},

View file

@ -5,7 +5,6 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@ -15,11 +14,11 @@ import (
"strconv"
"time"
"golang.org/x/sync/errgroup"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/testdata"
"golang.org/x/sync/errgroup"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -80,7 +79,7 @@ var _ = Describe("HTTP tests", func() {
server = &http3.Server{
Handler: mux,
TLSConfig: testdata.GetTLSConfig(),
TLSConfig: getTLSConfig(),
QuicConfig: getQuicConfig(nil),
}
@ -107,9 +106,7 @@ var _ = Describe("HTTP tests", func() {
BeforeEach(func() {
client = &http.Client{
Transport: &http3.RoundTripper{
TLSClientConfig: &tls.Config{
RootCAs: testdata.GetRootCA(),
},
TLSClientConfig: getTLSClientConfig(),
DisableCompression: true,
QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
},
@ -381,7 +378,7 @@ var _ = Describe("HTTP tests", func() {
if version == protocol.VersionDraft29 {
Skip("This test only works on RFC versions")
}
tlsConf := testdata.GetTLSConfig()
tlsConf := getTLSConfig()
tlsConf.NextProtos = []string{"h3"}
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
Expect(err).ToNot(HaveOccurred())

View file

@ -149,7 +149,6 @@ var _ = Describe("MITM test", func() {
conn, err := quic.Dial(
clientUDPConn,
raddr,
fmt.Sprintf("localhost:%d", proxyPort),
getTLSClientConfig(),
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
)
@ -193,7 +192,6 @@ var _ = Describe("MITM test", func() {
conn, err := quic.Dial(
clientUDPConn,
raddr,
fmt.Sprintf("localhost:%d", proxyPort),
getTLSClientConfig(),
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
)
@ -308,7 +306,6 @@ var _ = Describe("MITM test", func() {
_, err = quic.Dial(
clientUDPConn,
raddr,
fmt.Sprintf("localhost:%d", proxyPort),
getTLSClientConfig(),
getQuicConfig(&quic.Config{
ConnectionIDLength: connIDLen,

View file

@ -2,7 +2,6 @@ package self_test
import (
"context"
"fmt"
"io"
"net"
"runtime"
@ -39,7 +38,6 @@ var _ = Describe("Multiplexing", func() {
conn, err := quic.Dial(
pconn,
addr,
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)

View file

@ -127,6 +127,7 @@ func init() {
root := x509.NewCertPool()
root.AddCert(ca)
tlsClientConfig = &tls.Config{
ServerName: "localhost",
RootCAs: root,
NextProtos: []string{alpn},
}

View file

@ -486,7 +486,6 @@ var _ = Describe("Timeout tests", func() {
conn, err := quic.Dial(
&faultyConn{PacketConn: conn, MaxPackets: maxPackets},
ln.Addr(),
"localhost",
getTLSClientConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
)

View file

@ -44,6 +44,7 @@ func init() {
root := x509.NewCertPool()
root.AddCert(ca)
tlsClientConfig = &tls.Config{
ServerName: "localhost",
RootCAs: root,
NextProtos: []string{tools.ALPN},
}