return from Dial after conn is forward-secure, unless ConnState is given

This commit is contained in:
Marten Seemann 2017-02-22 16:55:30 +07:00
parent 6f27b7f70d
commit 8bfeb2ea8d
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
4 changed files with 54 additions and 35 deletions

View file

@ -2,7 +2,6 @@ package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"strings"
@ -21,14 +20,11 @@ type client struct {
conn connection
hostname string
config *Config
config *Config
connState ConnState
connectionID protocol.ConnectionID
version protocol.VersionNumber
versionNegotiated bool
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
connectionID protocol.ConnectionID
version protocol.VersionNumber
session packetHandler
}
@ -61,19 +57,6 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
c.connStateChangeCond.L = &c.mutex
c.cryptoChangeCallback = func(isForwardSecure bool) {
var state ConnState
if isForwardSecure {
state = ConnStateForwardSecure
} else {
state = ConnStateSecure
}
if c.config.ConnState != nil {
go config.ConnState(c.session, state)
}
}
err = c.createNewSession(nil)
if err != nil {
return nil, err
@ -84,7 +67,13 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
go c.listen()
c.mutex.Lock()
for !c.versionNegotiated {
for {
if c.config.ConnState != nil && c.connState >= ConnStateVersionNegotiated {
break
}
if c.config.ConnState == nil && c.connState == ConnStateForwardSecure {
break
}
c.connStateChangeCond.Wait()
}
c.mutex.Unlock()
@ -147,15 +136,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
hdr.Raw = packet[:len(packet)-r.Len()]
// ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag {
if c.connState >= ConnStateVersionNegotiated && hdr.VersionFlag {
return nil
}
// this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
if !hdr.VersionFlag && c.connState == ConnStateInitial {
c.mutex.Lock()
c.versionNegotiated = true
c.connState = ConnStateVersionNegotiated
c.connStateChangeCond.Signal()
c.mutex.Unlock()
if c.config.ConnState != nil {
@ -186,7 +175,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
// switch to negotiated version
c.version = highestSupportedVersion
c.versionNegotiated = true
c.connState = ConnStateVersionNegotiated
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
return err
@ -217,6 +206,19 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
return nil
}
func (c *client) cryptoChangeCallback(isForwardSecure bool) {
var state ConnState
if isForwardSecure {
state = ConnStateForwardSecure
} else {
state = ConnStateSecure
}
if c.config.ConnState != nil {
go c.config.ConnState(c.session, state)
}
}
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, err = newClientSession(

View file

@ -57,6 +57,21 @@ var _ = Describe("Client", func() {
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
})
// TODO: actually test this
// now we're only testing that Dial doesn't return directly after version negotiation
It("only returns once a forward-secure connection is established if no ConnState is defined", func() {
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
config.ConnState = nil
var dialReturned bool
go func() {
defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
dialReturned = true
}()
Consistently(func() bool { return dialReturned }).Should(BeFalse())
})
It("errors on invalid public header", func() {
err := cl.handlePacket(nil, nil)
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
@ -175,7 +190,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(nil, b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue())
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
})
@ -186,7 +201,7 @@ var _ = Describe("Client", func() {
cl.connectionID = 0x1337
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
Expect(cl.version).To(Equal(newVersion))
Expect(cl.versionNegotiated).To(BeTrue())
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
// it swapped the sessions
Expect(cl.session).ToNot(Equal(sess))
@ -204,11 +219,11 @@ var _ = Describe("Client", func() {
It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
cl.versionNegotiated = true
cl.connState = ConnStateVersionNegotiated
Expect(sess.packetCount).To(BeZero())
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.versionNegotiated).To(BeTrue())
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
Expect(sess.packetCount).To(BeZero())
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
})

View file

@ -36,8 +36,10 @@ type Session interface {
type ConnState int
const (
// ConnStateInitial is the initial state
ConnStateInitial ConnState = iota
// ConnStateVersionNegotiated means that version negotiation is complete
ConnStateVersionNegotiated ConnState = iota
ConnStateVersionNegotiated
// ConnStateSecure means that the connection is encrypted
ConnStateSecure
// ConnStateForwardSecure means that the connection is forward secure

View file

@ -35,9 +35,9 @@ var (
errSessionAlreadyClosed = errors.New("Cannot close session. It was already closed before.")
)
// CryptoChangeCallback is called every time the encryption level changes
// cryptoChangeCallback is called every time the encryption level changes
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type CryptoChangeCallback func(isForwardSecure bool)
type cryptoChangeCallback func(isForwardSecure bool)
// closeCallback is called when a session is closed
type closeCallback func(id protocol.ConnectionID)
@ -49,7 +49,7 @@ type session struct {
version protocol.VersionNumber
closeCallback closeCallback
cryptoChangeCallback CryptoChangeCallback
cryptoChangeCallback cryptoChangeCallback
conn connection
@ -132,7 +132,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return s, err
}
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
s := &session{
conn: conn,
connectionID: connectionID,