mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
remove the tls.Config from the quic.Config
The tls.Config now is a separate parameter to all Listen and Dial functions in the quic package.
This commit is contained in:
parent
890b801a60
commit
a851aaacda
14 changed files with 126 additions and 86 deletions
|
@ -7,6 +7,7 @@
|
|||
- Add a `quic.Config` option to request truncation of the connection ID from a server
|
||||
- Add a `quic.Config` option to configure the source address validation
|
||||
- Add a `quic.Config` option to configure the handshake timeout
|
||||
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
|
||||
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
|
||||
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
|
||||
- Various bugfixes
|
||||
|
|
|
@ -35,7 +35,7 @@ var _ = Describe("Benchmarks", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()})
|
||||
ln, err = ListenAddr("localhost:0", testdata.GetTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr <- ln.Addr()
|
||||
sess, err := ln.Accept()
|
||||
|
@ -49,11 +49,8 @@ var _ = Describe("Benchmarks", func() {
|
|||
}()
|
||||
|
||||
// start the client
|
||||
conf := &Config{
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
addr := <-serverAddr
|
||||
sess, err := DialAddr(addr.String(), conf)
|
||||
sess, err := DialAddr(addr.String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
44
client.go
44
client.go
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -24,6 +25,7 @@ type client struct {
|
|||
errorChan chan struct{}
|
||||
handshakeChan <-chan handshakeEvent
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
|
||||
|
@ -39,7 +41,7 @@ var (
|
|||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddr(addr string, config *Config) (Session, error) {
|
||||
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Dial(udpConn, udpAddr, addr, config)
|
||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
|
||||
func DialAddrNonFWSecure(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -62,20 +68,26 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, config)
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) {
|
||||
func DialNonFWSecure(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
connID, err := utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if config.TLSConfig != nil {
|
||||
hostname = config.TLSConfig.ServerName
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
|
@ -90,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
|
|||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
|
@ -107,8 +120,14 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
|
|||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config)
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -119,7 +138,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
|||
return sess, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
|
@ -140,7 +164,6 @@ func populateClientConfig(config *Config) *Config {
|
|||
}
|
||||
|
||||
return &Config{
|
||||
TLSConfig: config.TLSConfig,
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
|
||||
|
@ -270,6 +293,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
|
|||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
negotiatedVersions,
|
||||
)
|
||||
|
|
|
@ -22,13 +22,13 @@ var _ = Describe("Client", func() {
|
|||
packetConn *mockPacketConn
|
||||
addr net.Addr
|
||||
|
||||
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
|
||||
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
originalClientSessConstructor = newClientSession
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
msess, _, _ := newMockSession(nil, 0, 0, nil, nil)
|
||||
msess, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
||||
sess = msess.(*mockSession)
|
||||
packetConn = &mockPacketConn{addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}}
|
||||
config = &Config{
|
||||
|
@ -63,6 +63,7 @@ var _ = Describe("Client", func() {
|
|||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
|
@ -75,7 +76,7 @@ var _ = Describe("Client", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
Consistently(func() Session { return dialedSess }).Should(BeNil())
|
||||
|
@ -89,7 +90,7 @@ var _ = Describe("Client", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
dialedSess, err = DialAddrNonFWSecure("localhost:18901", config)
|
||||
dialedSess, err = DialAddrNonFWSecure("localhost:18901", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
Consistently(func() Session { return dialedSess }).Should(BeNil())
|
||||
|
@ -103,7 +104,7 @@ var _ = Describe("Client", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
|
@ -120,13 +121,14 @@ var _ = Describe("Client", func() {
|
|||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
cconn = conn
|
||||
return sess, nil, nil
|
||||
}
|
||||
go DialAddr("localhost:17890", &Config{})
|
||||
go DialAddr("localhost:17890", nil, &Config{})
|
||||
Eventually(func() connection { return cconn }).ShouldNot(BeNil())
|
||||
Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890"))
|
||||
close(done)
|
||||
|
@ -139,13 +141,14 @@ var _ = Describe("Client", func() {
|
|||
h string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
hostname = h
|
||||
return sess, nil, nil
|
||||
}
|
||||
go DialAddr("localhost:17890", &Config{TLSConfig: &tls.Config{ServerName: "foobar"}})
|
||||
go DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil)
|
||||
Eventually(func() string { return hostname }).Should(Equal("foobar"))
|
||||
close(done)
|
||||
})
|
||||
|
@ -154,7 +157,7 @@ var _ = Describe("Client", func() {
|
|||
testErr := errors.New("early handshake error")
|
||||
var dialErr error
|
||||
go func() {
|
||||
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{err: testErr}
|
||||
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
|
||||
|
@ -165,7 +168,7 @@ var _ = Describe("Client", func() {
|
|||
testErr := errors.New("late handshake error")
|
||||
var dialErr error
|
||||
go func() {
|
||||
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
sess.handshakeComplete <- testErr
|
||||
|
@ -192,7 +195,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("errors when receiving an invalid first packet from the server", func(done Done) {
|
||||
packetConn.dataToRead = []byte{0xff}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(HaveOccurred())
|
||||
close(done)
|
||||
})
|
||||
|
@ -200,7 +203,7 @@ var _ = Describe("Client", func() {
|
|||
It("errors when receiving an error from the connection", func(done Done) {
|
||||
testErr := errors.New("connection error")
|
||||
packetConn.readErr = testErr
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
})
|
||||
|
@ -212,12 +215,13 @@ var _ = Describe("Client", func() {
|
|||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
return nil, nil, testErr
|
||||
}
|
||||
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -243,6 +247,7 @@ var _ = Describe("Client", func() {
|
|||
_ string,
|
||||
_ protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
negotiatedVersionsP []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
|
@ -324,6 +329,7 @@ var _ = Describe("Client", func() {
|
|||
hostnameP string,
|
||||
versionP protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
configP *Config,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
|
@ -335,7 +341,7 @@ var _ = Describe("Client", func() {
|
|||
return sess, nil, nil
|
||||
}
|
||||
go func() {
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
<-c
|
||||
|
|
|
@ -31,10 +31,7 @@ func main() {
|
|||
|
||||
// Start a server that echos all data on the first stream opened by the client
|
||||
func echoServer() error {
|
||||
cfgServer := &quic.Config{
|
||||
TLSConfig: generateTLSConfig(),
|
||||
}
|
||||
listener, err := quic.ListenAddr(addr, cfgServer)
|
||||
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -52,10 +49,7 @@ func echoServer() error {
|
|||
}
|
||||
|
||||
func clientMain() error {
|
||||
cfgClient := &quic.Config{
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
session, err := quic.DialAddr(addr, cfgClient)
|
||||
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -28,7 +28,8 @@ type roundTripperOpts struct {
|
|||
type client struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
dialAddr func(hostname string, config *quic.Config) (quic.Session, error)
|
||||
dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error)
|
||||
tlsConf *tls.Config
|
||||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
|
@ -55,8 +56,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
|
|||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
||||
tlsConf: tlsConfig,
|
||||
config: &quic.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
RequestConnectionIDTruncation: true,
|
||||
},
|
||||
opts: opts,
|
||||
|
@ -67,7 +68,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
|
|||
// dial dials the connection
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
c.session, err = c.dialAddr(c.hostname, c.config)
|
||||
c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ var _ = Describe("Client", func() {
|
|||
It("saves the TLS config", func() {
|
||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||
client = newClient(tlsConf, "", &roundTripperOpts{})
|
||||
Expect(client.config.TLSConfig).To(Equal(tlsConf))
|
||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
|
@ -56,7 +56,7 @@ var _ = Describe("Client", func() {
|
|||
It("dials", func(done Done) {
|
||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
close(headerStream.unblockRead)
|
||||
|
@ -68,7 +68,7 @@ var _ = Describe("Client", func() {
|
|||
It("errors when dialing fails", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
|
@ -78,7 +78,7 @@ var _ = Describe("Client", func() {
|
|||
It("errors if the header stream has the wrong stream ID", func() {
|
||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||
session.streamsToOpen = []quic.Stream{&mockStream{id: 2}}
|
||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
|
@ -89,7 +89,7 @@ var _ = Describe("Client", func() {
|
|||
testErr := errors.New("you shall not pass")
|
||||
client = newClient(nil, "localhost:1337", &roundTripperOpts{})
|
||||
session.streamOpenErr = testErr
|
||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
|
@ -98,7 +98,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("returns a request when dial fails", func() {
|
||||
testErr := errors.New("dial error")
|
||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
|
@ -140,7 +140,7 @@ var _ = Describe("Client", func() {
|
|||
BeforeEach(func() {
|
||||
var err error
|
||||
client.encryptionLevel = protocol.EncryptionForwardSecure
|
||||
client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) {
|
||||
client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
dataStream = newMockStream(5)
|
||||
|
|
|
@ -84,16 +84,15 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
|
|||
}
|
||||
|
||||
config := quic.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
Versions: protocol.SupportedVersions,
|
||||
Versions: protocol.SupportedVersions,
|
||||
}
|
||||
|
||||
var ln quic.Listener
|
||||
var err error
|
||||
if conn == nil {
|
||||
ln, err = quic.ListenAddr(s.Addr, &config)
|
||||
ln, err = quic.ListenAddr(s.Addr, tlsConfig, &config)
|
||||
} else {
|
||||
ln, err = quic.Listen(conn, &config)
|
||||
ln, err = quic.Listen(conn, tlsConfig, &config)
|
||||
}
|
||||
if err != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
|
|
|
@ -11,8 +11,8 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"github.com/lucas-clemente/quic-go/testdata"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/testdata"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -24,11 +24,10 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
serverConfig *quic.Config
|
||||
testStartedAt time.Time
|
||||
)
|
||||
|
||||
rtt := 350 * time.Millisecond
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig = &quic.Config{TLSConfig: testdata.GetTLSConfig()}
|
||||
serverConfig = &quic.Config{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -39,7 +38,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
runServerAndProxy := func() {
|
||||
var err error
|
||||
// start the server
|
||||
server, err = quic.ListenAddr("localhost:0", serverConfig)
|
||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// start the proxy
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
|
||||
|
@ -73,7 +72,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
clientConfig := &quic.Config{
|
||||
Versions: protocol.SupportedVersions[1:2],
|
||||
}
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), clientConfig)
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
|
||||
expectDurationInRTTs(1)
|
||||
|
@ -84,7 +83,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
// 1 RTT to become forward-secure
|
||||
It("is forward-secure after 3 RTTs", func() {
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
|
@ -95,7 +94,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
PIt("is secure after 2 RTTs", func() {
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
|
||||
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
fmt.Println("#### is non fw secure ###")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
|
@ -106,7 +105,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
return true
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
@ -116,7 +115,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
return false
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
|
||||
})
|
||||
|
@ -124,7 +123,7 @@ var _ = Describe("Handshake integration tets", func() {
|
|||
It("doesn't complete the handshake when the handshake timeout is too short", func() {
|
||||
serverConfig.HandshakeTimeout = 2 * rtt
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}})
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
// 2 RTTs during the timeout
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -64,7 +63,6 @@ type STK struct {
|
|||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
||||
type Config struct {
|
||||
TLSConfig *tls.Config
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
// Warning: This API should not be considered stable and will change soon.
|
||||
|
|
24
server.go
24
server.go
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -23,7 +24,8 @@ type packetHandler interface {
|
|||
|
||||
// A Listener of QUIC
|
||||
type server struct {
|
||||
config *Config
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
conn net.PacketConn
|
||||
|
||||
|
@ -38,14 +40,15 @@ type server struct {
|
|||
sessionQueue chan Session
|
||||
errorChan chan struct{}
|
||||
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, config *Config) (packetHandler, <-chan handshakeEvent, error)
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error)
|
||||
}
|
||||
|
||||
var _ Listener = &server{}
|
||||
|
||||
// ListenAddr creates a QUIC server listening on a given address.
|
||||
// The listener is not active until Serve() is called.
|
||||
func ListenAddr(addr string, config *Config) (Listener, error) {
|
||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
||||
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -54,13 +57,14 @@ func ListenAddr(addr string, config *Config) (Listener, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Listen(conn, config)
|
||||
return Listen(conn, tlsConf, config)
|
||||
}
|
||||
|
||||
// Listen listens for QUIC connections on a given net.PacketConn.
|
||||
// The listener is not active until Serve() is called.
|
||||
func Listen(conn net.PacketConn, config *Config) (Listener, error) {
|
||||
certChain := crypto.NewCertChain(config.TLSConfig)
|
||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
||||
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
|
||||
certChain := crypto.NewCertChain(tlsConf)
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -72,6 +76,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
|
|||
|
||||
s := &server{
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: populateServerConfig(config),
|
||||
certChain: certChain,
|
||||
scfg: scfg,
|
||||
|
@ -101,7 +106,12 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
|
|||
return sourceAddr == stk.remoteAddr
|
||||
}
|
||||
|
||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
|
@ -127,7 +137,6 @@ func populateServerConfig(config *Config) *Config {
|
|||
}
|
||||
|
||||
return &Config{
|
||||
TLSConfig: config.TLSConfig,
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
AcceptSTK: vsa,
|
||||
|
@ -256,6 +265,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
version,
|
||||
hdr.ConnectionID,
|
||||
s.scfg,
|
||||
s.tlsConf,
|
||||
s.config,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -75,6 +75,7 @@ func newMockSession(
|
|||
_ protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
_ *handshake.ServerConfig,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
s := mockSession{
|
||||
|
@ -95,10 +96,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
conn = &mockPacketConn{}
|
||||
config = &Config{
|
||||
TLSConfig: &tls.Config{},
|
||||
Versions: protocol.SupportedVersions,
|
||||
}
|
||||
config = &Config{Versions: protocol.SupportedVersions}
|
||||
})
|
||||
|
||||
Context("with mock session", func() {
|
||||
|
@ -225,7 +223,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("closes sessions and the connection when Close is called", func() {
|
||||
session, _, _ := newMockSession(nil, 0, 0, nil, nil)
|
||||
session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
||||
serv.sessions[1] = session
|
||||
err := serv.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -241,8 +239,15 @@ var _ = Describe("Server", func() {
|
|||
Expect(serv.sessions[connID]).To(BeNil())
|
||||
})
|
||||
|
||||
It("works if no quic.Config is given", func(done Done) {
|
||||
ln, err := ListenAddr("127.0.0.1:0", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
close(done)
|
||||
}, 1)
|
||||
|
||||
It("closes properly", func() {
|
||||
ln, err := ListenAddr("127.0.0.1:0", config)
|
||||
ln, err := ListenAddr("127.0.0.1:0", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var returned bool
|
||||
|
@ -268,7 +273,7 @@ var _ = Describe("Server", func() {
|
|||
}, 0.5)
|
||||
|
||||
It("closes all sessions when encountering a connection error", func() {
|
||||
session, _, _ := newMockSession(nil, 0, 0, nil, nil)
|
||||
session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
||||
serv.sessions[0x12345] = session
|
||||
Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse())
|
||||
testErr := errors.New("connection error")
|
||||
|
@ -348,12 +353,11 @@ var _ = Describe("Server", func() {
|
|||
supportedVersions := []protocol.VersionNumber{1, 3, 5}
|
||||
acceptSTK := func(_ net.Addr, _ *STK) bool { return true }
|
||||
config := Config{
|
||||
TLSConfig: &tls.Config{},
|
||||
Versions: supportedVersions,
|
||||
AcceptSTK: acceptSTK,
|
||||
HandshakeTimeout: 1337 * time.Hour,
|
||||
}
|
||||
ln, err := Listen(conn, &config)
|
||||
ln, err := Listen(conn, &tls.Config{}, &config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := ln.(*server)
|
||||
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
|
||||
|
@ -365,8 +369,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("fills in default values if options are not set in the Config", func() {
|
||||
config := Config{TLSConfig: &tls.Config{}}
|
||||
ln, err := Listen(conn, &config)
|
||||
ln, err := Listen(conn, &tls.Config{}, &Config{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server := ln.(*server)
|
||||
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
|
||||
|
@ -376,7 +379,7 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("listens on a given address", func() {
|
||||
addr := "127.0.0.1:13579"
|
||||
ln, err := ListenAddr(addr, config)
|
||||
ln, err := ListenAddr(addr, nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv := ln.(*server)
|
||||
Expect(serv.Addr().String()).To(Equal(addr))
|
||||
|
@ -384,13 +387,13 @@ var _ = Describe("Server", func() {
|
|||
|
||||
It("errors if given an invalid address", func() {
|
||||
addr := "127.0.0.1"
|
||||
_, err := ListenAddr(addr, config)
|
||||
_, err := ListenAddr(addr, nil, config)
|
||||
Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
|
||||
})
|
||||
|
||||
It("errors if given an invalid address", func() {
|
||||
addr := "1.1.1.1:1111"
|
||||
_, err := ListenAddr(addr, config)
|
||||
_, err := ListenAddr(addr, nil, config)
|
||||
Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
|
||||
})
|
||||
|
||||
|
@ -407,7 +410,7 @@ var _ = Describe("Server", func() {
|
|||
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||
conn.dataToRead = b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, config)
|
||||
ln, err := Listen(conn, nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var returned bool
|
||||
|
@ -431,7 +434,7 @@ var _ = Describe("Server", func() {
|
|||
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
||||
conn.dataReadFrom = udpAddr
|
||||
conn.dataToRead = []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}
|
||||
ln, err := Listen(conn, config)
|
||||
ln, err := Listen(conn, nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -53,6 +54,7 @@ type session struct {
|
|||
connectionID protocol.ConnectionID
|
||||
perspective protocol.Perspective
|
||||
version protocol.VersionNumber
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
conn connection
|
||||
|
@ -119,6 +121,7 @@ func newSession(
|
|||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
sCfg *handshake.ServerConfig,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
s := &session{
|
||||
|
@ -137,6 +140,7 @@ var newClientSession = func(
|
|||
hostname string,
|
||||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (packetHandler, <-chan handshakeEvent, error) {
|
||||
|
@ -145,6 +149,7 @@ var newClientSession = func(
|
|||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
}
|
||||
return s.setup(nil, hostname, negotiatedVersions)
|
||||
|
@ -209,7 +214,7 @@ func (s *session) setup(
|
|||
s.connectionID,
|
||||
s.version,
|
||||
cryptoStream,
|
||||
s.config.TLSConfig,
|
||||
s.tlsConf,
|
||||
s.connectionParameters,
|
||||
aeadChanged,
|
||||
&handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation},
|
||||
|
|
|
@ -169,6 +169,7 @@ var _ = Describe("Session", func() {
|
|||
protocol.Version35,
|
||||
0,
|
||||
scfg,
|
||||
nil,
|
||||
populateServerConfig(&Config{}),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -220,6 +221,7 @@ var _ = Describe("Session", func() {
|
|||
protocol.Version35,
|
||||
0,
|
||||
scfg,
|
||||
nil,
|
||||
conf,
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -1635,6 +1637,7 @@ var _ = Describe("Client Session", func() {
|
|||
"hostname",
|
||||
protocol.Version35,
|
||||
0,
|
||||
nil,
|
||||
populateClientConfig(&Config{}),
|
||||
nil,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue