make it possible to configure the QUIC versions for the client

This commit is contained in:
Marten Seemann 2017-04-29 22:43:04 +07:00
parent 1b70bd42d9
commit 16ca3012e9
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
5 changed files with 69 additions and 42 deletions

View file

@ -47,12 +47,13 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return nil, err return nil, err
} }
clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID, connectionID: connID,
hostname: hostname, hostname: hostname,
config: config, config: clientConfig,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default version: clientConfig.Versions[0],
} }
c.connStateChangeOrErrCond.L = &c.mutex c.connStateChangeOrErrCond.L = &c.mutex
@ -67,6 +68,19 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return c.establishConnection() return c.establishConnection()
} }
func populateClientConfig(config *Config) *Config {
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
return &Config{
TLSConfig: config.TLSConfig,
ConnState: config.ConnState,
Versions: versions,
}
}
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) { func DialAddr(addr string, config *Config) (Session, error) {
@ -191,7 +205,7 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
} }
} }
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions) ok, highestSupportedVersion := protocol.HighestSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
return qerr.InvalidVersion return qerr.InvalidVersion
} }

View file

@ -34,6 +34,7 @@ var _ = Describe("Client", func() {
versionNegotiateConnStateCalled = true versionNegotiateConnStateCalled = true
} }
}, },
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
} }
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
sess = &mockSession{connectionID: 0x1337} sess = &mockSession{connectionID: 0x1337}
@ -41,7 +42,7 @@ var _ = Describe("Client", func() {
config: config, config: config,
connectionID: 0x1337, connectionID: 0x1337,
session: sess, session: sess,
version: protocol.Version36, version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr}, conn: &conn{pconn: packetConn, currentAddr: addr},
} }
}) })
@ -56,7 +57,6 @@ var _ = Describe("Client", func() {
Context("Dialing", func() { Context("Dialing", func() {
It("creates a new client", func() { It("creates a new client", func() {
packetConn.dataToRead = []byte{0x0, 0x1, 0x0} packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
var err error
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
@ -64,6 +64,11 @@ var _ = Describe("Client", func() {
sess.Close(nil) sess.Close(nil)
}) })
It("uses all supported versions, if none are specified in the quic.Config", func() {
c := populateClientConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
})
It("errors when receiving an invalid first packet from the server", func() { It("errors when receiving an invalid first packet from the server", func() {
packetConn.dataToRead = []byte{0xff} packetConn.dataToRead = []byte{0xff}
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
@ -213,11 +218,13 @@ var _ = Describe("Client", func() {
}) })
It("changes the version after receiving a version negotiation packet", func() { It("changes the version after receiving a version negotiation packet", func() {
newVersion := protocol.Version35 newVersion := protocol.VersionNumber(77)
Expect(config.Versions).To(ContainElement(newVersion))
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
cl.connectionID = 0x1337 cl.connectionID = 0x1337
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(newVersion)) Expect(cl.version).To(Equal(newVersion))
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue()) Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
@ -229,7 +236,7 @@ var _ = Describe("Client", func() {
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
// if the version negotiation packet was passed to the new session, it would end up as an undecryptable packet there // if the version negotiation packet was passed to the new session, it would end up as an undecryptable packet there
Expect(cl.session.(*session).undecryptablePackets).To(BeEmpty()) Expect(cl.session.(*session).undecryptablePackets).To(BeEmpty())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{newVersion}))
}) })
It("errors if no matching version is found", func() { It("errors if no matching version is found", func() {
@ -237,6 +244,20 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError(qerr.InvalidVersion)) Expect(err).To(MatchError(qerr.InvalidVersion))
}) })
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
v := protocol.SupportedVersions[1]
Expect(v).ToNot(Equal(cl.version))
Expect(config.Versions).ToNot(ContainElement(v))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("changes to the version preferred by the quic.Config", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(config.Versions[1]))
})
It("ignores delayed version negotiation packets", func() { It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
cl.connState = ConnStateVersionNegotiated cl.connState = ConnStateVersionNegotiated

View file

@ -8,8 +8,8 @@ const (
Version35 VersionNumber = 35 + iota Version35 VersionNumber = 35 + iota
Version36 Version36
Version37 Version37
VersionWhatever = 0 // for when the version doesn't matter VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnsupported = -1 VersionUnsupported VersionNumber = -1
) )
// SupportedVersions lists the versions that the server supports // SupportedVersions lists the versions that the server supports
@ -40,16 +40,14 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
} }
// HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions // HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions
// the versions in other do not need to be ordered
// it returns true and the version number, if there is one, otherwise false // it returns true and the version number, if there is one, otherwise false
func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) { func HighestSupportedVersion(ours, theirs []VersionNumber) (bool, VersionNumber) {
for _, v := range SupportedVersions { for _, ourVer := range ours {
for _, ver := range other { for _, theirVer := range theirs {
if ver == v { if ourVer == theirVer {
return true, ver return true, ourVer
} }
} }
} }
return false, VersionUnsupported
return false, 0
} }

View file

@ -27,39 +27,33 @@ var _ = Describe("Version", func() {
}) })
Context("highest supported version", func() { Context("highest supported version", func() {
var initialSupportedVersions []VersionNumber
BeforeEach(func() {
initialSupportedVersions = make([]VersionNumber, len(SupportedVersions))
copy(initialSupportedVersions, SupportedVersions)
})
AfterEach(func() {
SupportedVersions = initialSupportedVersions
})
It("finds the supported version", func() { It("finds the supported version", func() {
SupportedVersions = []VersionNumber{3, 2, 1} supportedVersions := []VersionNumber{1, 2, 3}
other := []VersionNumber{3, 4, 5, 6} other := []VersionNumber{6, 5, 4, 3}
found, ver := HighestSupportedVersion(other) found, ver := HighestSupportedVersion(supportedVersions, other)
Expect(found).To(BeTrue()) Expect(found).To(BeTrue())
Expect(ver).To(Equal(VersionNumber(3))) Expect(ver).To(Equal(VersionNumber(3)))
}) })
It("picks the highest supported version", func() { It("picks the preferred version", func() {
SupportedVersions = []VersionNumber{7, 6, 3, 2, 1} supportedVersions := []VersionNumber{2, 1, 3}
other := []VersionNumber{3, 6, 1, 8, 2, 10} other := []VersionNumber{3, 6, 1, 8, 2, 10}
found, ver := HighestSupportedVersion(other) found, ver := HighestSupportedVersion(supportedVersions, other)
Expect(found).To(BeTrue()) Expect(found).To(BeTrue())
Expect(ver).To(Equal(VersionNumber(6))) Expect(ver).To(Equal(VersionNumber(2)))
}) })
It("handles empty inputs", func() { It("handles empty inputs", func() {
SupportedVersions = []VersionNumber{102, 101} supportedVersions := []VersionNumber{102, 101}
Expect(HighestSupportedVersion([]VersionNumber{})).To(BeFalse()) found, _ := HighestSupportedVersion(supportedVersions, nil)
SupportedVersions = []VersionNumber{} Expect(found).To(BeFalse())
Expect(HighestSupportedVersion([]VersionNumber{1, 2})).To(BeFalse()) found, _ = HighestSupportedVersion(supportedVersions, []VersionNumber{})
Expect(HighestSupportedVersion([]VersionNumber{})).To(BeFalse()) Expect(found).To(BeFalse())
supportedVersions = []VersionNumber{}
found, _ = HighestSupportedVersion(supportedVersions, []VersionNumber{1, 2})
Expect(found).To(BeFalse())
found, _ = HighestSupportedVersion(supportedVersions, []VersionNumber{})
Expect(found).To(BeFalse())
}) })
}) })
}) })

View file

@ -68,7 +68,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
return &server{ return &server{
conn: conn, conn: conn,
config: populateConfig(config), config: populateServerConfig(config),
certChain: certChain, certChain: certChain,
scfg: scfg, scfg: scfg,
sessions: map[protocol.ConnectionID]packetHandler{}, sessions: map[protocol.ConnectionID]packetHandler{},
@ -77,7 +77,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
}, nil }, nil
} }
func populateConfig(config *Config) *Config { func populateServerConfig(config *Config) *Config {
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions