reject unknown versions in the quic.Config

This commit is contained in:
Marten Seemann 2018-01-05 13:38:27 +07:00
parent cd01b55090
commit 624ac61412
6 changed files with 76 additions and 33 deletions

View file

@ -85,6 +85,14 @@ func Dial(
} }
} }
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},

View file

@ -21,7 +21,6 @@ import (
var _ = Describe("Client", func() { var _ = Describe("Client", func() {
var ( var (
cl *client cl *client
config *Config
sess *mockSession sess *mockSession
packetConn *mockPacketConn packetConn *mockPacketConn
addr net.Addr addr net.Addr
@ -50,11 +49,7 @@ var _ = Describe("Client", func() {
packetConn = newMockPacketConn() packetConn = newMockPacketConn()
packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
packetConn.dataReadFrom = addr packetConn.dataReadFrom = addr
config = &Config{
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
}
cl = &client{ cl = &client{
config: config,
connectionID: 0x1337, connectionID: 0x1337,
session: sess, session: sess,
version: protocol.SupportedVersions[0], version: protocol.SupportedVersions[0],
@ -106,7 +101,7 @@ var _ = Describe("Client", func() {
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil()) Expect(s).ToNot(BeNil())
close(dialed) close(dialed)
@ -179,7 +174,7 @@ var _ = Describe("Client", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
@ -193,7 +188,7 @@ var _ = Describe("Client", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
close(done) close(done)
}() }()
@ -217,6 +212,12 @@ var _ = Describe("Client", func() {
Expect(c.MaxIncomingUniStreams).To(Equal(4321)) Expect(c.MaxIncomingUniStreams).To(Equal(4321))
}) })
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Dial(nil, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
})
It("disables bidirectional streams", func() { It("disables bidirectional streams", func() {
config := &Config{ config := &Config{
MaxIncomingStreams: -1, MaxIncomingStreams: -1,
@ -248,7 +249,7 @@ var _ = Describe("Client", func() {
It("errors when receiving an error from the connection", func() { It("errors when receiving an error from the connection", func() {
testErr := errors.New("connection error") testErr := errors.New("connection error")
packetConn.readErr = testErr packetConn.readErr = testErr
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -266,11 +267,22 @@ var _ = Describe("Client", func() {
) (packetHandler, error) { ) (packetHandler, error) {
return nil, testErr return nil, testErr
} }
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
Context("version negotiation", func() { Context("version negotiation", func() {
var origSupportedVersions []protocol.VersionNumber
BeforeEach(func() {
origSupportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...)
})
AfterEach(func() {
protocol.SupportedVersions = origSupportedVersions
})
It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() {
ph := wire.Header{ ph := wire.Header{
PacketNumber: 1, PacketNumber: 1,
@ -290,7 +302,7 @@ var _ = Describe("Client", func() {
var negotiatedVersions []protocol.VersionNumber var negotiatedVersions []protocol.VersionNumber
newVersion := protocol.VersionNumber(77) newVersion := protocol.VersionNumber(77)
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion)) cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}}
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
handshakeChan := make(chan error) handshakeChan := make(chan error)
newClientSession = func( newClientSession = func(
@ -367,36 +379,34 @@ var _ = Describe("Client", func() {
} }
go cl.dial() go cl.dial()
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
newVersion := protocol.VersionNumber(77) cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}}
Expect(newVersion).ToNot(Equal(cl.version)) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{77}))
Expect(config.Versions).To(ContainElement(newVersion))
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
newVersion = protocol.VersionNumber(78) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{78}))
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
}) })
It("errors if no matching version is found", func() { It("errors if no matching version is found", func() {
cl.config = &Config{Versions: protocol.SupportedVersions}
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1})) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closed).To(BeTrue())
Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion))
}) })
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
v := protocol.VersionNumber(111) v := protocol.VersionNumber(1234)
Expect(v).ToNot(Equal(cl.version)) Expect(v).ToNot(Equal(cl.version))
Expect(config.Versions).ToNot(ContainElement(v)) cl.config = &Config{Versions: protocol.SupportedVersions}
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{v})) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{v}))
Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closed).To(BeTrue())
Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion))
}) })
It("changes to the version preferred by the quic.Config", func() { It("changes to the version preferred by the quic.Config", func() {
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) config := &Config{Versions: []protocol.VersionNumber{1234, 4321}}
Expect(cl.version).To(Equal(config.Versions[1])) cl.config = config
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{4321, 1234}))
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
}) })
It("ignores delayed version negotiation packets", func() { It("ignores delayed version negotiation packets", func() {
@ -423,7 +433,7 @@ var _ = Describe("Client", func() {
}) })
It("ignores packets without connection id, if it didn't request connection id trunctation", func() { It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
cl.config.RequestConnectionIDOmission = false cl.config = &Config{RequestConnectionIDOmission: false}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
(&wire.Header{ (&wire.Header{
OmitConnectionID: true, OmitConnectionID: true,
@ -448,6 +458,7 @@ var _ = Describe("Client", func() {
}) })
It("creates new GQUIC sessions with the right parameters", func() { It("creates new GQUIC sessions with the right parameters", func() {
config := &Config{Versions: protocol.SupportedVersions}
closeErr := errors.New("peer doesn't reply") closeErr := errors.New("peer doesn't reply")
c := make(chan struct{}) c := make(chan struct{})
var cconn connection var cconn connection
@ -488,7 +499,7 @@ var _ = Describe("Client", func() {
}) })
It("creates new TLS sessions with the right parameters", func() { It("creates new TLS sessions with the right parameters", func() {
config.Versions = []protocol.VersionNumber{protocol.VersionTLS} config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
c := make(chan struct{}) c := make(chan struct{})
var cconn connection var cconn connection
var hostname string var hostname string
@ -527,7 +538,8 @@ var _ = Describe("Client", func() {
}) })
It("creates a new session when the server performs a retry", func() { It("creates a new session when the server performs a retry", func() {
config.Versions = []protocol.VersionNumber{protocol.VersionTLS} config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
newTLSClientSession = func( newTLSClientSession = func(
connP connection, connP connection,

View file

@ -30,6 +30,11 @@ var SupportedVersions = []VersionNumber{
Version39, Version39,
} }
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
}
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake // UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
func (vn VersionNumber) UsesTLS() bool { func (vn VersionNumber) UsesTLS() bool {
return vn == VersionTLS return vn == VersionTLS

View file

@ -15,6 +15,14 @@ var _ = Describe("Version", func() {
Expect(Version39).To(BeEquivalentTo(0x51303339)) Expect(Version39).To(BeEquivalentTo(0x51303339))
}) })
It("says if a version is valid", func() {
Expect(IsValidVersion(Version39)).To(BeTrue())
Expect(IsValidVersion(VersionTLS)).To(BeTrue())
Expect(IsValidVersion(VersionWhatever)).To(BeFalse())
Expect(IsValidVersion(VersionUnknown)).To(BeFalse())
Expect(IsValidVersion(1234)).To(BeFalse())
})
It("says if a version supports TLS", func() { It("says if a version supports TLS", func() {
Expect(Version39.UsesTLS()).To(BeFalse()) Expect(Version39.UsesTLS()).To(BeFalse())
Expect(VersionTLS.UsesTLS()).To(BeTrue()) Expect(VersionTLS.UsesTLS()).To(BeTrue())

View file

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
@ -85,9 +86,12 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
} }
config = populateServerConfig(config) config = populateServerConfig(config)
// check if any of the supported versions supports TLS
var supportsTLS bool var supportsTLS bool
for _, v := range config.Versions { for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
// check if any of the supported versions supports TLS
if v.UsesTLS() { if v.UsesTLS() {
supportsTLS = true supportsTLS = true
break break

View file

@ -446,7 +446,7 @@ var _ = Describe("Server", func() {
}) })
It("setups with the right values", func() { It("setups with the right values", func() {
supportedVersions := []protocol.VersionNumber{1, 3, 5} supportedVersions := []protocol.VersionNumber{protocol.VersionTLS, protocol.Version39}
acceptCookie := func(_ net.Addr, _ *Cookie) bool { return true } acceptCookie := func(_ net.Addr, _ *Cookie) bool { return true }
config := Config{ config := Config{
Versions: supportedVersions, Versions: supportedVersions,
@ -468,6 +468,12 @@ var _ = Describe("Server", func() {
Expect(server.config.KeepAlive).To(BeTrue()) Expect(server.config.KeepAlive).To(BeTrue())
}) })
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Listen(conn, &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
})
It("fills in default values if options are not set in the Config", func() { It("fills in default values if options are not set in the Config", func() {
ln, err := Listen(conn, &tls.Config{}, &Config{}) ln, err := Listen(conn, &tls.Config{}, &Config{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -500,7 +506,6 @@ var _ = Describe("Server", func() {
}) })
It("sends a gQUIC Version Negotaion Packet, if the client sent a gQUIC Public Header", func() { It("sends a gQUIC Version Negotaion Packet, if the client sent a gQUIC Public Header", func() {
config.Versions = []protocol.VersionNumber{99}
b := &bytes.Buffer{} b := &bytes.Buffer{}
hdr := wire.Header{ hdr := wire.Header{
VersionFlag: true, VersionFlag: true,
@ -537,7 +542,7 @@ var _ = Describe("Server", func() {
}) })
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS} config.Versions = append(config.Versions, protocol.VersionTLS)
b := &bytes.Buffer{} b := &bytes.Buffer{}
hdr := wire.Header{ hdr := wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -556,6 +561,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover()
ln.Accept() ln.Accept()
close(done) close(done)
}() }()
@ -569,12 +575,12 @@ var _ = Describe("Server", func() {
Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337)))
Expect(r.Len()).To(BeZero()) Expect(r.Len()).To(BeZero())
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
// make the go routine return
ln.Close()
Eventually(done).Should(BeClosed())
}) })
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() { It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
version := protocol.VersionNumber(99)
Expect(version.UsesTLS()).To(BeFalse())
config.Versions = []protocol.VersionNumber{version}
b := &bytes.Buffer{} b := &bytes.Buffer{}
hdr := wire.Header{ hdr := wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,