diff --git a/client.go b/client.go index c66a155c..b7d7fd67 100644 --- a/client.go +++ b/client.go @@ -47,12 +47,13 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config return nil, err } + clientConfig := populateClientConfig(config) c := &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, connectionID: connID, hostname: hostname, - config: config, - version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default + config: clientConfig, + version: clientConfig.Versions[0], } c.connStateChangeOrErrCond.L = &c.mutex @@ -67,6 +68,19 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config return c.establishConnection() } +func populateClientConfig(config *Config) *Config { + versions := config.Versions + if len(versions) == 0 { + versions = protocol.SupportedVersions + } + + return &Config{ + TLSConfig: config.TLSConfig, + ConnState: config.ConnState, + Versions: versions, + } +} + // DialAddr establishes a new QUIC connection to a server. // The hostname for SNI is taken from the given address. func DialAddr(addr string, config *Config) (Session, error) { @@ -191,7 +205,7 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { } } - ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions) + ok, highestSupportedVersion := protocol.HighestSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { return qerr.InvalidVersion } diff --git a/client_test.go b/client_test.go index 8571239e..fbbd27ab 100644 --- a/client_test.go +++ b/client_test.go @@ -34,6 +34,7 @@ var _ = Describe("Client", func() { versionNegotiateConnStateCalled = true } }, + Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, } addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} sess = &mockSession{connectionID: 0x1337} @@ -41,7 +42,7 @@ var _ = Describe("Client", func() { config: config, connectionID: 0x1337, session: sess, - version: protocol.Version36, + version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, } }) @@ -56,7 +57,6 @@ var _ = Describe("Client", func() { Context("Dialing", func() { It("creates a new client", func() { packetConn.dataToRead = []byte{0x0, 0x1, 0x0} - var err error sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) Expect(err).ToNot(HaveOccurred()) Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) @@ -64,6 +64,11 @@ var _ = Describe("Client", func() { sess.Close(nil) }) + It("uses all supported versions, if none are specified in the quic.Config", func() { + c := populateClientConfig(&Config{}) + Expect(c.Versions).To(Equal(protocol.SupportedVersions)) + }) + It("errors when receiving an invalid first packet from the server", func() { packetConn.dataToRead = []byte{0xff} sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) @@ -213,11 +218,13 @@ var _ = Describe("Client", func() { }) It("changes the version after receiving a version negotiation packet", func() { - newVersion := protocol.Version35 + newVersion := protocol.VersionNumber(77) + Expect(config.Versions).To(ContainElement(newVersion)) Expect(newVersion).ToNot(Equal(cl.version)) Expect(sess.packetCount).To(BeZero()) cl.connectionID = 0x1337 err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + Expect(err).ToNot(HaveOccurred()) Expect(cl.version).To(Equal(newVersion)) Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue()) @@ -229,7 +236,7 @@ var _ = Describe("Client", func() { Expect(sess.packetCount).To(BeZero()) // if the version negotiation packet was passed to the new session, it would end up as an undecryptable packet there Expect(cl.session.(*session).undecryptablePackets).To(BeEmpty()) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{newVersion})) }) It("errors if no matching version is found", func() { @@ -237,6 +244,20 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(qerr.InvalidVersion)) }) + It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { + v := protocol.SupportedVersions[1] + Expect(v).ToNot(Equal(cl.version)) + Expect(config.Versions).ToNot(ContainElement(v)) + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v})) + Expect(err).To(MatchError(qerr.InvalidVersion)) + }) + + It("changes to the version preferred by the quic.Config", func() { + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.version).To(Equal(config.Versions[1])) + }) + It("ignores delayed version negotiation packets", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test cl.connState = ConnStateVersionNegotiated diff --git a/protocol/version.go b/protocol/version.go index 606500db..cbac32d7 100644 --- a/protocol/version.go +++ b/protocol/version.go @@ -8,8 +8,8 @@ const ( Version35 VersionNumber = 35 + iota Version36 Version37 - VersionWhatever = 0 // for when the version doesn't matter - VersionUnsupported = -1 + VersionWhatever VersionNumber = 0 // for when the version doesn't matter + VersionUnsupported VersionNumber = -1 ) // SupportedVersions lists the versions that the server supports @@ -40,16 +40,14 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { } // HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions -// the versions in other do not need to be ordered // it returns true and the version number, if there is one, otherwise false -func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) { - for _, v := range SupportedVersions { - for _, ver := range other { - if ver == v { - return true, ver +func HighestSupportedVersion(ours, theirs []VersionNumber) (bool, VersionNumber) { + for _, ourVer := range ours { + for _, theirVer := range theirs { + if ourVer == theirVer { + return true, ourVer } } } - - return false, 0 + return false, VersionUnsupported } diff --git a/protocol/version_test.go b/protocol/version_test.go index d37784f0..b92a061f 100644 --- a/protocol/version_test.go +++ b/protocol/version_test.go @@ -27,39 +27,33 @@ var _ = Describe("Version", func() { }) Context("highest supported version", func() { - var initialSupportedVersions []VersionNumber - - BeforeEach(func() { - initialSupportedVersions = make([]VersionNumber, len(SupportedVersions)) - copy(initialSupportedVersions, SupportedVersions) - }) - - AfterEach(func() { - SupportedVersions = initialSupportedVersions - }) - It("finds the supported version", func() { - SupportedVersions = []VersionNumber{3, 2, 1} - other := []VersionNumber{3, 4, 5, 6} - found, ver := HighestSupportedVersion(other) + supportedVersions := []VersionNumber{1, 2, 3} + other := []VersionNumber{6, 5, 4, 3} + found, ver := HighestSupportedVersion(supportedVersions, other) Expect(found).To(BeTrue()) Expect(ver).To(Equal(VersionNumber(3))) }) - It("picks the highest supported version", func() { - SupportedVersions = []VersionNumber{7, 6, 3, 2, 1} + It("picks the preferred version", func() { + supportedVersions := []VersionNumber{2, 1, 3} other := []VersionNumber{3, 6, 1, 8, 2, 10} - found, ver := HighestSupportedVersion(other) + found, ver := HighestSupportedVersion(supportedVersions, other) Expect(found).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(6))) + Expect(ver).To(Equal(VersionNumber(2))) }) It("handles empty inputs", func() { - SupportedVersions = []VersionNumber{102, 101} - Expect(HighestSupportedVersion([]VersionNumber{})).To(BeFalse()) - SupportedVersions = []VersionNumber{} - Expect(HighestSupportedVersion([]VersionNumber{1, 2})).To(BeFalse()) - Expect(HighestSupportedVersion([]VersionNumber{})).To(BeFalse()) + supportedVersions := []VersionNumber{102, 101} + found, _ := HighestSupportedVersion(supportedVersions, nil) + Expect(found).To(BeFalse()) + found, _ = HighestSupportedVersion(supportedVersions, []VersionNumber{}) + Expect(found).To(BeFalse()) + supportedVersions = []VersionNumber{} + found, _ = HighestSupportedVersion(supportedVersions, []VersionNumber{1, 2}) + Expect(found).To(BeFalse()) + found, _ = HighestSupportedVersion(supportedVersions, []VersionNumber{}) + Expect(found).To(BeFalse()) }) }) }) diff --git a/server.go b/server.go index 7b60bd4d..ca5192ec 100644 --- a/server.go +++ b/server.go @@ -68,7 +68,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) { return &server{ conn: conn, - config: populateConfig(config), + config: populateServerConfig(config), certChain: certChain, scfg: scfg, sessions: map[protocol.ConnectionID]packetHandler{}, @@ -77,7 +77,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) { }, nil } -func populateConfig(config *Config) *Config { +func populateServerConfig(config *Config) *Config { versions := config.Versions if len(versions) == 0 { versions = protocol.SupportedVersions