move the QUIC version check to the config validation

This commit is contained in:
Marten Seemann 2023-03-31 19:35:03 +09:00
parent 5400587610
commit ae5a8bd35c
5 changed files with 9 additions and 20 deletions

View file

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
@ -136,15 +135,6 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
tlsConf = tlsConf.Clone()
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err

View file

@ -366,12 +366,9 @@ var _ = Describe("Client", func() {
})
It("errors when the Config contains an invalid version", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
version := protocol.VersionNumber(0x1234)
_, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
})
It("disables bidirectional streams", func() {

View file

@ -2,6 +2,7 @@ package quic
import (
"errors"
"fmt"
"net"
"time"
@ -29,6 +30,12 @@ func validateConfig(config *Config) error {
if config.MaxIncomingUniStreams > 1<<60 {
return errors.New("invalid value for Config.MaxIncomingUniStreams")
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return fmt.Errorf("invalid QUIC version: %s", v)
}
}
return nil
}

View file

@ -234,11 +234,6 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
return nil, err
}
config = populateServerConfig(config)
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {

View file

@ -114,7 +114,7 @@ var _ = Describe("Server", func() {
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
})
It("fills in default values if options are not set in the Config", func() {