From 624ac614128cd7eb888229e13ffe9a69b4be8cab Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 5 Jan 2018 13:38:27 +0700 Subject: [PATCH] reject unknown versions in the quic.Config --- client.go | 8 ++++ client_test.go | 64 ++++++++++++++++++------------- internal/protocol/version.go | 5 +++ internal/protocol/version_test.go | 8 ++++ server.go | 6 ++- server_test.go | 18 ++++++--- 6 files changed, 76 insertions(+), 33 deletions(-) diff --git a/client.go b/client.go index 610cfee2..460ddd21 100644 --- a/client.go +++ b/client.go @@ -85,6 +85,14 @@ func Dial( } } + // check that all versions are actually supported + if config != nil { + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + } + } clientConfig := populateClientConfig(config) c := &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, diff --git a/client_test.go b/client_test.go index 2c2f88bd..787bb0e3 100644 --- a/client_test.go +++ b/client_test.go @@ -21,7 +21,6 @@ import ( var _ = Describe("Client", func() { var ( cl *client - config *Config sess *mockSession packetConn *mockPacketConn addr net.Addr @@ -50,11 +49,7 @@ var _ = Describe("Client", func() { packetConn = newMockPacketConn() packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} packetConn.dataReadFrom = addr - config = &Config{ - Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, - } cl = &client{ - config: config, connectionID: 0x1337, session: sess, version: protocol.SupportedVersions[0], @@ -106,7 +101,7 @@ var _ = Describe("Client", func() { dialed := make(chan struct{}) go func() { defer GinkgoRecover() - s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) close(dialed) @@ -179,7 +174,7 @@ var _ = Describe("Client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).To(MatchError(testErr)) close(done) }() @@ -193,7 +188,7 @@ var _ = Describe("Client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).To(MatchError(testErr)) close(done) }() @@ -217,6 +212,12 @@ var _ = Describe("Client", func() { Expect(c.MaxIncomingUniStreams).To(Equal(4321)) }) + It("errors when the Config contains an invalid version", func() { + version := protocol.VersionNumber(0x1234) + _, err := Dial(nil, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) + Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + }) + It("disables bidirectional streams", func() { config := &Config{ MaxIncomingStreams: -1, @@ -248,7 +249,7 @@ var _ = Describe("Client", func() { It("errors when receiving an error from the connection", func() { testErr := errors.New("connection error") packetConn.readErr = testErr - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).To(MatchError(testErr)) }) @@ -266,11 +267,22 @@ var _ = Describe("Client", func() { ) (packetHandler, error) { return nil, testErr } - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).To(MatchError(testErr)) }) Context("version negotiation", func() { + var origSupportedVersions []protocol.VersionNumber + + BeforeEach(func() { + origSupportedVersions = protocol.SupportedVersions + protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...) + }) + + AfterEach(func() { + protocol.SupportedVersions = origSupportedVersions + }) + It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { ph := wire.Header{ PacketNumber: 1, @@ -290,7 +302,7 @@ var _ = Describe("Client", func() { var negotiatedVersions []protocol.VersionNumber newVersion := protocol.VersionNumber(77) Expect(newVersion).ToNot(Equal(cl.version)) - Expect(config.Versions).To(ContainElement(newVersion)) + cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}} sessionChan := make(chan *mockSession) handshakeChan := make(chan error) newClientSession = func( @@ -367,36 +379,34 @@ var _ = Describe("Client", func() { } go cl.dial() Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) - newVersion := protocol.VersionNumber(77) - Expect(newVersion).ToNot(Equal(cl.version)) - Expect(config.Versions).To(ContainElement(newVersion)) - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}} + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{77})) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) - newVersion = protocol.VersionNumber(78) - Expect(newVersion).ToNot(Equal(cl.version)) - Expect(config.Versions).To(ContainElement(newVersion)) - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{78})) Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) }) It("errors if no matching version is found", func() { + cl.config = &Config{Versions: protocol.SupportedVersions} cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1})) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { - v := protocol.VersionNumber(111) + v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) - Expect(config.Versions).ToNot(ContainElement(v)) + cl.config = &Config{Versions: protocol.SupportedVersions} cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{v})) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("changes to the version preferred by the quic.Config", func() { - cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) - Expect(cl.version).To(Equal(config.Versions[1])) + config := &Config{Versions: []protocol.VersionNumber{1234, 4321}} + cl.config = config + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{4321, 1234})) + Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) }) It("ignores delayed version negotiation packets", func() { @@ -423,7 +433,7 @@ var _ = Describe("Client", func() { }) It("ignores packets without connection id, if it didn't request connection id trunctation", func() { - cl.config.RequestConnectionIDOmission = false + cl.config = &Config{RequestConnectionIDOmission: false} buf := &bytes.Buffer{} (&wire.Header{ OmitConnectionID: true, @@ -448,6 +458,7 @@ var _ = Describe("Client", func() { }) It("creates new GQUIC sessions with the right parameters", func() { + config := &Config{Versions: protocol.SupportedVersions} closeErr := errors.New("peer doesn't reply") c := make(chan struct{}) var cconn connection @@ -488,7 +499,7 @@ var _ = Describe("Client", func() { }) It("creates new TLS sessions with the right parameters", func() { - config.Versions = []protocol.VersionNumber{protocol.VersionTLS} + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} c := make(chan struct{}) var cconn connection var hostname string @@ -527,7 +538,8 @@ var _ = Describe("Client", func() { }) It("creates a new session when the server performs a retry", func() { - config.Versions = []protocol.VersionNumber{protocol.VersionTLS} + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} + cl.config = config sessionChan := make(chan *mockSession) newTLSClientSession = func( connP connection, diff --git a/internal/protocol/version.go b/internal/protocol/version.go index 90641165..6d965eba 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -30,6 +30,11 @@ var SupportedVersions = []VersionNumber{ Version39, } +// IsValidVersion says if the version is known to quic-go +func IsValidVersion(v VersionNumber) bool { + return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) +} + // UsesTLS says if this QUIC version uses TLS 1.3 for the handshake func (vn VersionNumber) UsesTLS() bool { return vn == VersionTLS diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 01de14e3..d47bff19 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -15,6 +15,14 @@ var _ = Describe("Version", func() { Expect(Version39).To(BeEquivalentTo(0x51303339)) }) + It("says if a version is valid", func() { + Expect(IsValidVersion(Version39)).To(BeTrue()) + Expect(IsValidVersion(VersionTLS)).To(BeTrue()) + Expect(IsValidVersion(VersionWhatever)).To(BeFalse()) + Expect(IsValidVersion(VersionUnknown)).To(BeFalse()) + Expect(IsValidVersion(1234)).To(BeFalse()) + }) + It("says if a version supports TLS", func() { Expect(Version39.UsesTLS()).To(BeFalse()) Expect(VersionTLS.UsesTLS()).To(BeTrue()) diff --git a/server.go b/server.go index 67154dd2..52241eb6 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "errors" + "fmt" "net" "sync" "time" @@ -85,9 +86,12 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, } config = populateServerConfig(config) - // check if any of the supported versions supports TLS var supportsTLS bool for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + // check if any of the supported versions supports TLS if v.UsesTLS() { supportsTLS = true break diff --git a/server_test.go b/server_test.go index 69e7f2e3..78330811 100644 --- a/server_test.go +++ b/server_test.go @@ -446,7 +446,7 @@ var _ = Describe("Server", func() { }) It("setups with the right values", func() { - supportedVersions := []protocol.VersionNumber{1, 3, 5} + supportedVersions := []protocol.VersionNumber{protocol.VersionTLS, protocol.Version39} acceptCookie := func(_ net.Addr, _ *Cookie) bool { return true } config := Config{ Versions: supportedVersions, @@ -468,6 +468,12 @@ var _ = Describe("Server", func() { Expect(server.config.KeepAlive).To(BeTrue()) }) + It("errors when the Config contains an invalid version", func() { + version := protocol.VersionNumber(0x1234) + _, err := Listen(conn, &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) + Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + }) + It("fills in default values if options are not set in the Config", func() { ln, err := Listen(conn, &tls.Config{}, &Config{}) Expect(err).ToNot(HaveOccurred()) @@ -500,7 +506,6 @@ var _ = Describe("Server", func() { }) It("sends a gQUIC Version Negotaion Packet, if the client sent a gQUIC Public Header", func() { - config.Versions = []protocol.VersionNumber{99} b := &bytes.Buffer{} hdr := wire.Header{ VersionFlag: true, @@ -537,7 +542,7 @@ var _ = Describe("Server", func() { }) It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { - config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS} + config.Versions = append(config.Versions, protocol.VersionTLS) b := &bytes.Buffer{} hdr := wire.Header{ Type: protocol.PacketTypeInitial, @@ -556,6 +561,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { + defer GinkgoRecover() ln.Accept() close(done) }() @@ -569,12 +575,12 @@ var _ = Describe("Server", func() { Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) Expect(r.Len()).To(BeZero()) Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + ln.Close() + Eventually(done).Should(BeClosed()) }) It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() { - version := protocol.VersionNumber(99) - Expect(version.UsesTLS()).To(BeFalse()) - config.Versions = []protocol.VersionNumber{version} b := &bytes.Buffer{} hdr := wire.Header{ Type: protocol.PacketTypeInitial,