From ae5a8bd35ce35470b4609caf2dedc9b00dfc7cee Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 31 Mar 2023 19:35:03 +0900 Subject: [PATCH] move the QUIC version check to the config validation --- client.go | 10 ---------- client_test.go | 5 +---- config.go | 7 +++++++ server.go | 5 ----- server_test.go | 2 +- 5 files changed, 9 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 62cad580..ad80d4f2 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "github.com/quic-go/quic-go/internal/protocol" @@ -136,15 +135,6 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon tlsConf = tlsConf.Clone() } - // check that all versions are actually supported - if config != nil { - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - } - srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() if err != nil { return nil, err diff --git a/client_test.go b/client_test.go index 23122e62..dbd03cf3 100644 --- a/client_test.go +++ b/client_test.go @@ -366,12 +366,9 @@ var _ = Describe("Client", func() { }) It("errors when the Config contains an invalid version", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - version := protocol.VersionNumber(0x1234) _, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + Expect(err).To(MatchError("invalid QUIC version: 0x1234")) }) It("disables bidirectional streams", func() { diff --git a/config.go b/config.go index 3ead9b7a..b513a46b 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package quic import ( "errors" + "fmt" "net" "time" @@ -29,6 +30,12 @@ func validateConfig(config *Config) error { if config.MaxIncomingUniStreams > 1<<60 { return errors.New("invalid value for Config.MaxIncomingUniStreams") } + // check that all QUIC versions are actually supported + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return fmt.Errorf("invalid QUIC version: %s", v) + } + } return nil } diff --git a/server.go b/server.go index 3543e546..d5bb19e6 100644 --- a/server.go +++ b/server.go @@ -234,11 +234,6 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl return nil, err } config = populateServerConfig(config) - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) if err != nil { diff --git a/server_test.go b/server_test.go index e9f751e6..4108c698 100644 --- a/server_test.go +++ b/server_test.go @@ -114,7 +114,7 @@ var _ = Describe("Server", func() { It("errors when the Config contains an invalid version", func() { version := protocol.VersionNumber(0x1234) _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) - Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) + Expect(err).To(MatchError("invalid QUIC version: 0x1234")) }) It("fills in default values if options are not set in the Config", func() {