remove the tls.Config from the quic.Config

The tls.Config now is a separate parameter to all Listen and Dial
functions in the quic package.
This commit is contained in:
Marten Seemann 2017-06-30 17:33:35 +02:00
parent 890b801a60
commit a851aaacda
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
14 changed files with 126 additions and 86 deletions

View file

@ -7,6 +7,7 @@
- Add a `quic.Config` option to request truncation of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Various bugfixes

View file

@ -35,7 +35,7 @@ var _ = Describe("Benchmarks", func() {
go func() {
defer GinkgoRecover()
var err error
ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()})
ln, err = ListenAddr("localhost:0", testdata.GetTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr()
sess, err := ln.Accept()
@ -49,11 +49,8 @@ var _ = Describe("Benchmarks", func() {
}()
// start the client
conf := &Config{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
addr := <-serverAddr
sess, err := DialAddr(addr.String(), conf)
sess, err := DialAddr(addr.String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())

View file

@ -2,6 +2,7 @@ package quic
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net"
@ -24,6 +25,7 @@ type client struct {
errorChan chan struct{}
handshakeChan <-chan handshakeEvent
tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet
@ -39,7 +41,7 @@ var (
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) {
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) {
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, config)
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -62,20 +68,26 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, config)
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) {
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}
var hostname string
if config.TLSConfig != nil {
hostname = config.TLSConfig.ServerName
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
@ -90,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
@ -107,8 +120,14 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config)
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil {
return nil, err
}
@ -119,7 +138,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return sess, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
@ -140,7 +164,6 @@ func populateClientConfig(config *Config) *Config {
}
return &Config{
TLSConfig: config.TLSConfig,
Versions: versions,
HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
@ -270,6 +293,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.hostname,
c.version,
c.connectionID,
c.tlsConf,
c.config,
negotiatedVersions,
)

View file

@ -22,13 +22,13 @@ var _ = Describe("Client", func() {
packetConn *mockPacketConn
addr net.Addr
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
)
BeforeEach(func() {
originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse())
msess, _, _ := newMockSession(nil, 0, 0, nil, nil)
msess, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
sess = msess.(*mockSession)
packetConn = &mockPacketConn{addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}}
config = &Config{
@ -63,6 +63,7 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
@ -75,7 +76,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
Consistently(func() Session { return dialedSess }).Should(BeNil())
@ -89,7 +90,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = DialAddrNonFWSecure("localhost:18901", config)
dialedSess, err = DialAddrNonFWSecure("localhost:18901", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
Consistently(func() Session { return dialedSess }).Should(BeNil())
@ -103,7 +104,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", config)
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
@ -120,13 +121,14 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
cconn = conn
return sess, nil, nil
}
go DialAddr("localhost:17890", &Config{})
go DialAddr("localhost:17890", nil, &Config{})
Eventually(func() connection { return cconn }).ShouldNot(BeNil())
Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890"))
close(done)
@ -139,13 +141,14 @@ var _ = Describe("Client", func() {
h string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
hostname = h
return sess, nil, nil
}
go DialAddr("localhost:17890", &Config{TLSConfig: &tls.Config{ServerName: "foobar"}})
go DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil)
Eventually(func() string { return hostname }).Should(Equal("foobar"))
close(done)
})
@ -154,7 +157,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("early handshake error")
var dialErr error
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
}()
sess.handshakeChan <- handshakeEvent{err: testErr}
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
@ -165,7 +168,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("late handshake error")
var dialErr error
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr
@ -192,7 +195,7 @@ var _ = Describe("Client", func() {
It("errors when receiving an invalid first packet from the server", func(done Done) {
packetConn.dataToRead = []byte{0xff}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(HaveOccurred())
close(done)
})
@ -200,7 +203,7 @@ var _ = Describe("Client", func() {
It("errors when receiving an error from the connection", func(done Done) {
testErr := errors.New("connection error")
packetConn.readErr = testErr
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
close(done)
})
@ -212,12 +215,13 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
return nil, nil, testErr
}
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
})
@ -243,6 +247,7 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
_ *tls.Config,
_ *Config,
negotiatedVersionsP []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
@ -324,6 +329,7 @@ var _ = Describe("Client", func() {
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
configP *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
@ -335,7 +341,7 @@ var _ = Describe("Client", func() {
return sess, nil, nil
}
go func() {
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
<-c

View file

@ -31,10 +31,7 @@ func main() {
// Start a server that echos all data on the first stream opened by the client
func echoServer() error {
cfgServer := &quic.Config{
TLSConfig: generateTLSConfig(),
}
listener, err := quic.ListenAddr(addr, cfgServer)
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
if err != nil {
return err
}
@ -52,10 +49,7 @@ func echoServer() error {
}
func clientMain() error {
cfgClient := &quic.Config{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
session, err := quic.DialAddr(addr, cfgClient)
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
if err != nil {
return err
}

View file

@ -28,7 +28,8 @@ type roundTripperOpts struct {
type client struct {
mutex sync.RWMutex
dialAddr func(hostname string, config *quic.Config) (quic.Session, error)
dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error)
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
@ -55,8 +56,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: &quic.Config{
TLSConfig: tlsConfig,
RequestConnectionIDTruncation: true,
},
opts: opts,
@ -67,7 +68,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
// dial dials the connection
func (c *client) dial() error {
var err error
c.session, err = c.dialAddr(c.hostname, c.config)
c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config)
if err != nil {
return err
}

View file

@ -45,7 +45,7 @@ var _ = Describe("Client", func() {
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient(tlsConf, "", &roundTripperOpts{})
Expect(client.config.TLSConfig).To(Equal(tlsConf))
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("adds the port to the hostname, if none is given", func() {
@ -56,7 +56,7 @@ var _ = Describe("Client", func() {
It("dials", func(done Done) {
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
@ -68,7 +68,7 @@ var _ = Describe("Client", func() {
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
@ -78,7 +78,7 @@ var _ = Describe("Client", func() {
It("errors if the header stream has the wrong stream ID", func() {
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
@ -89,7 +89,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("you shall not pass")
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
session.streamOpenErr = testErr
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
@ -98,7 +98,7 @@ var _ = Describe("Client", func() {
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
@ -140,7 +140,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
var err error
client.encryptionLevel = protocol.EncryptionForwardSecure
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)

View file

@ -84,16 +84,15 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
}
config := quic.Config{
TLSConfig: tlsConfig,
Versions: protocol.SupportedVersions,
Versions: protocol.SupportedVersions,
}
var ln quic.Listener
var err error
if conn == nil {
ln, err = quic.ListenAddr(s.Addr, &config)
ln, err = quic.ListenAddr(s.Addr, tlsConfig, &config)
} else {
ln, err = quic.Listen(conn, &config)
ln, err = quic.Listen(conn, tlsConfig, &config)
}
if err != nil {
s.listenerMutex.Unlock()

View file

@ -11,8 +11,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/testdata"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@ -24,11 +24,10 @@ var _ = Describe("Handshake integration tets", func() {
serverConfig *quic.Config
testStartedAt time.Time
)
rtt := 350 * time.Millisecond
BeforeEach(func() {
serverConfig = &quic.Config{TLSConfig: testdata.GetTLSConfig()}
serverConfig = &quic.Config{}
})
AfterEach(func() {
@ -39,7 +38,7 @@ var _ = Describe("Handshake integration tets", func() {
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", serverConfig)
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
@ -73,7 +72,7 @@ var _ = Describe("Handshake integration tets", func() {
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), clientConfig)
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
@ -84,7 +83,7 @@ var _ = Describe("Handshake integration tets", func() {
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
@ -95,7 +94,7 @@ var _ = Describe("Handshake integration tets", func() {
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
@ -106,7 +105,7 @@ var _ = Describe("Handshake integration tets", func() {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
@ -116,7 +115,7 @@ var _ = Describe("Handshake integration tets", func() {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
@ -124,7 +123,7 @@ var _ = Describe("Handshake integration tets", func() {
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout

View file

@ -1,7 +1,6 @@
package quic
import (
"crypto/tls"
"io"
"net"
"time"
@ -64,7 +63,6 @@ type STK struct {
// Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct {
TLSConfig *tls.Config
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.

View file

@ -2,6 +2,7 @@ package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"sync"
@ -23,7 +24,8 @@ type packetHandler interface {
// A Listener of QUIC
type server struct {
config *Config
tlsConf *tls.Config
config *Config
conn net.PacketConn
@ -38,14 +40,15 @@ type server struct {
sessionQueue chan Session
errorChan chan struct{}
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, config *Config) (packetHandler, <-chan handshakeEvent, error)
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error)
}
var _ Listener = &server{}
// ListenAddr creates a QUIC server listening on a given address.
// The listener is not active until Serve() is called.
func ListenAddr(addr string, config *Config) (Listener, error) {
// The tls.Config must not be nil, the quic.Config may be nil.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
@ -54,13 +57,14 @@ func ListenAddr(addr string, config *Config) (Listener, error) {
if err != nil {
return nil, err
}
return Listen(conn, config)
return Listen(conn, tlsConf, config)
}
// Listen listens for QUIC connections on a given net.PacketConn.
// The listener is not active until Serve() is called.
func Listen(conn net.PacketConn, config *Config) (Listener, error) {
certChain := crypto.NewCertChain(config.TLSConfig)
// The tls.Config must not be nil, the quic.Config may be nil.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
certChain := crypto.NewCertChain(tlsConf)
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
@ -72,6 +76,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
s := &server{
conn: conn,
tlsConf: tlsConf,
config: populateServerConfig(config),
certChain: certChain,
scfg: scfg,
@ -101,7 +106,12 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
return sourceAddr == stk.remoteAddr
}
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
@ -127,7 +137,6 @@ func populateServerConfig(config *Config) *Config {
}
return &Config{
TLSConfig: config.TLSConfig,
Versions: versions,
HandshakeTimeout: handshakeTimeout,
AcceptSTK: vsa,
@ -256,6 +265,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
version,
hdr.ConnectionID,
s.scfg,
s.tlsConf,
s.config,
)
if err != nil {

View file

@ -75,6 +75,7 @@ func newMockSession(
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
_ *handshake.ServerConfig,
_ *tls.Config,
_ *Config,
) (packetHandler, <-chan handshakeEvent, error) {
s := mockSession{
@ -95,10 +96,7 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
conn = &mockPacketConn{}
config = &Config{
TLSConfig: &tls.Config{},
Versions: protocol.SupportedVersions,
}
config = &Config{Versions: protocol.SupportedVersions}
})
Context("with mock session", func() {
@ -225,7 +223,7 @@ var _ = Describe("Server", func() {
})
It("closes sessions and the connection when Close is called", func() {
session, _, _ := newMockSession(nil, 0, 0, nil, nil)
session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
serv.sessions[1] = session
err := serv.Close()
Expect(err).NotTo(HaveOccurred())
@ -241,8 +239,15 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[connID]).To(BeNil())
})
It("works if no quic.Config is given", func(done Done) {
ln, err := ListenAddr("127.0.0.1:0", nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(ln.Close()).To(Succeed())
close(done)
}, 1)
It("closes properly", func() {
ln, err := ListenAddr("127.0.0.1:0", config)
ln, err := ListenAddr("127.0.0.1:0", nil, config)
Expect(err).ToNot(HaveOccurred())
var returned bool
@ -268,7 +273,7 @@ var _ = Describe("Server", func() {
}, 0.5)
It("closes all sessions when encountering a connection error", func() {
session, _, _ := newMockSession(nil, 0, 0, nil, nil)
session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
serv.sessions[0x12345] = session
Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse())
testErr := errors.New("connection error")
@ -348,12 +353,11 @@ var _ = Describe("Server", func() {
supportedVersions := []protocol.VersionNumber{1, 3, 5}
acceptSTK := func(_ net.Addr, _ *STK) bool { return true }
config := Config{
TLSConfig: &tls.Config{},
Versions: supportedVersions,
AcceptSTK: acceptSTK,
HandshakeTimeout: 1337 * time.Hour,
}
ln, err := Listen(conn, &config)
ln, err := Listen(conn, &tls.Config{}, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
@ -365,8 +369,7 @@ var _ = Describe("Server", func() {
})
It("fills in default values if options are not set in the Config", func() {
config := Config{TLSConfig: &tls.Config{}}
ln, err := Listen(conn, &config)
ln, err := Listen(conn, &tls.Config{}, &Config{})
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
@ -376,7 +379,7 @@ var _ = Describe("Server", func() {
It("listens on a given address", func() {
addr := "127.0.0.1:13579"
ln, err := ListenAddr(addr, config)
ln, err := ListenAddr(addr, nil, config)
Expect(err).ToNot(HaveOccurred())
serv := ln.(*server)
Expect(serv.Addr().String()).To(Equal(addr))
@ -384,13 +387,13 @@ var _ = Describe("Server", func() {
It("errors if given an invalid address", func() {
addr := "127.0.0.1"
_, err := ListenAddr(addr, config)
_, err := ListenAddr(addr, nil, config)
Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
})
It("errors if given an invalid address", func() {
addr := "1.1.1.1:1111"
_, err := ListenAddr(addr, config)
_, err := ListenAddr(addr, nil, config)
Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
})
@ -407,7 +410,7 @@ var _ = Describe("Server", func() {
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead = b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, config)
ln, err := Listen(conn, nil, config)
Expect(err).ToNot(HaveOccurred())
var returned bool
@ -431,7 +434,7 @@ var _ = Describe("Server", func() {
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
conn.dataReadFrom = udpAddr
conn.dataToRead = []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}
ln, err := Listen(conn, config)
ln, err := Listen(conn, nil, config)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()

View file

@ -1,6 +1,7 @@
package quic
import (
"crypto/tls"
"errors"
"fmt"
"net"
@ -53,6 +54,7 @@ type session struct {
connectionID protocol.ConnectionID
perspective protocol.Perspective
version protocol.VersionNumber
tlsConf *tls.Config
config *Config
conn connection
@ -119,6 +121,7 @@ func newSession(
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
sCfg *handshake.ServerConfig,
tlsConf *tls.Config,
config *Config,
) (packetHandler, <-chan handshakeEvent, error) {
s := &session{
@ -137,6 +140,7 @@ var newClientSession = func(
hostname string,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
tlsConf *tls.Config,
config *Config,
negotiatedVersions []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
@ -145,6 +149,7 @@ var newClientSession = func(
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
tlsConf: tlsConf,
config: config,
}
return s.setup(nil, hostname, negotiatedVersions)
@ -209,7 +214,7 @@ func (s *session) setup(
s.connectionID,
s.version,
cryptoStream,
s.config.TLSConfig,
s.tlsConf,
s.connectionParameters,
aeadChanged,
&handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation},

View file

@ -169,6 +169,7 @@ var _ = Describe("Session", func() {
protocol.Version35,
0,
scfg,
nil,
populateServerConfig(&Config{}),
)
Expect(err).NotTo(HaveOccurred())
@ -220,6 +221,7 @@ var _ = Describe("Session", func() {
protocol.Version35,
0,
scfg,
nil,
conf,
)
Expect(err).NotTo(HaveOccurred())
@ -1635,6 +1637,7 @@ var _ = Describe("Client Session", func() {
"hostname",
protocol.Version35,
0,
nil,
populateClientConfig(&Config{}),
nil,
)