mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
return from Dial after conn is forward-secure, unless ConnState is given
This commit is contained in:
parent
6f27b7f70d
commit
8bfeb2ea8d
4 changed files with 54 additions and 35 deletions
54
client.go
54
client.go
|
@ -2,7 +2,6 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -21,14 +20,11 @@ type client struct {
|
||||||
conn connection
|
conn connection
|
||||||
hostname string
|
hostname string
|
||||||
|
|
||||||
config *Config
|
config *Config
|
||||||
|
connState ConnState
|
||||||
|
|
||||||
connectionID protocol.ConnectionID
|
connectionID protocol.ConnectionID
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
versionNegotiated bool
|
|
||||||
|
|
||||||
tlsConfig *tls.Config
|
|
||||||
cryptoChangeCallback CryptoChangeCallback
|
|
||||||
|
|
||||||
session packetHandler
|
session packetHandler
|
||||||
}
|
}
|
||||||
|
@ -61,19 +57,6 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
||||||
|
|
||||||
c.connStateChangeCond.L = &c.mutex
|
c.connStateChangeCond.L = &c.mutex
|
||||||
|
|
||||||
c.cryptoChangeCallback = func(isForwardSecure bool) {
|
|
||||||
var state ConnState
|
|
||||||
if isForwardSecure {
|
|
||||||
state = ConnStateForwardSecure
|
|
||||||
} else {
|
|
||||||
state = ConnStateSecure
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.ConnState != nil {
|
|
||||||
go config.ConnState(c.session, state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.createNewSession(nil)
|
err = c.createNewSession(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -84,7 +67,13 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
|
||||||
go c.listen()
|
go c.listen()
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
for !c.versionNegotiated {
|
for {
|
||||||
|
if c.config.ConnState != nil && c.connState >= ConnStateVersionNegotiated {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if c.config.ConnState == nil && c.connState == ConnStateForwardSecure {
|
||||||
|
break
|
||||||
|
}
|
||||||
c.connStateChangeCond.Wait()
|
c.connStateChangeCond.Wait()
|
||||||
}
|
}
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
|
@ -147,15 +136,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
|
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// ignore delayed / duplicated version negotiation packets
|
||||||
if c.versionNegotiated && hdr.VersionFlag {
|
if c.connState >= ConnStateVersionNegotiated && hdr.VersionFlag {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is the first packet after the client sent a packet with the VersionFlag set
|
// this is the first packet after the client sent a packet with the VersionFlag set
|
||||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
if !hdr.VersionFlag && c.connState == ConnStateInitial {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
c.versionNegotiated = true
|
c.connState = ConnStateVersionNegotiated
|
||||||
c.connStateChangeCond.Signal()
|
c.connStateChangeCond.Signal()
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
if c.config.ConnState != nil {
|
if c.config.ConnState != nil {
|
||||||
|
@ -186,7 +175,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||||
|
|
||||||
// switch to negotiated version
|
// switch to negotiated version
|
||||||
c.version = highestSupportedVersion
|
c.version = highestSupportedVersion
|
||||||
c.versionNegotiated = true
|
c.connState = ConnStateVersionNegotiated
|
||||||
c.connectionID, err = utils.GenerateConnectionID()
|
c.connectionID, err = utils.GenerateConnectionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -217,6 +206,19 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *client) cryptoChangeCallback(isForwardSecure bool) {
|
||||||
|
var state ConnState
|
||||||
|
if isForwardSecure {
|
||||||
|
state = ConnStateForwardSecure
|
||||||
|
} else {
|
||||||
|
state = ConnStateSecure
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.ConnState != nil {
|
||||||
|
go c.config.ConnState(c.session, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||||
var err error
|
var err error
|
||||||
c.session, err = newClientSession(
|
c.session, err = newClientSession(
|
||||||
|
|
|
@ -57,6 +57,21 @@ var _ = Describe("Client", func() {
|
||||||
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
|
Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// TODO: actually test this
|
||||||
|
// now we're only testing that Dial doesn't return directly after version negotiation
|
||||||
|
It("only returns once a forward-secure connection is established if no ConnState is defined", func() {
|
||||||
|
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
|
||||||
|
config.ConnState = nil
|
||||||
|
var dialReturned bool
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
dialReturned = true
|
||||||
|
}()
|
||||||
|
Consistently(func() bool { return dialReturned }).Should(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
It("errors on invalid public header", func() {
|
It("errors on invalid public header", func() {
|
||||||
err := cl.handlePacket(nil, nil)
|
err := cl.handlePacket(nil, nil)
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
|
||||||
|
@ -175,7 +190,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
err = cl.handlePacket(nil, b.Bytes())
|
err = cl.handlePacket(nil, b.Bytes())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.versionNegotiated).To(BeTrue())
|
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
|
||||||
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -186,7 +201,7 @@ var _ = Describe("Client", func() {
|
||||||
cl.connectionID = 0x1337
|
cl.connectionID = 0x1337
|
||||||
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
|
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
|
||||||
Expect(cl.version).To(Equal(newVersion))
|
Expect(cl.version).To(Equal(newVersion))
|
||||||
Expect(cl.versionNegotiated).To(BeTrue())
|
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
|
||||||
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
|
||||||
// it swapped the sessions
|
// it swapped the sessions
|
||||||
Expect(cl.session).ToNot(Equal(sess))
|
Expect(cl.session).ToNot(Equal(sess))
|
||||||
|
@ -204,11 +219,11 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("ignores delayed version negotiation packets", func() {
|
It("ignores delayed version negotiation packets", func() {
|
||||||
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
|
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
|
||||||
cl.versionNegotiated = true
|
cl.connState = ConnStateVersionNegotiated
|
||||||
Expect(sess.packetCount).To(BeZero())
|
Expect(sess.packetCount).To(BeZero())
|
||||||
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cl.versionNegotiated).To(BeTrue())
|
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
|
||||||
Expect(sess.packetCount).To(BeZero())
|
Expect(sess.packetCount).To(BeZero())
|
||||||
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
|
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
|
@ -36,8 +36,10 @@ type Session interface {
|
||||||
type ConnState int
|
type ConnState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// ConnStateInitial is the initial state
|
||||||
|
ConnStateInitial ConnState = iota
|
||||||
// ConnStateVersionNegotiated means that version negotiation is complete
|
// ConnStateVersionNegotiated means that version negotiation is complete
|
||||||
ConnStateVersionNegotiated ConnState = iota
|
ConnStateVersionNegotiated
|
||||||
// ConnStateSecure means that the connection is encrypted
|
// ConnStateSecure means that the connection is encrypted
|
||||||
ConnStateSecure
|
ConnStateSecure
|
||||||
// ConnStateForwardSecure means that the connection is forward secure
|
// ConnStateForwardSecure means that the connection is forward secure
|
||||||
|
|
|
@ -35,9 +35,9 @@ var (
|
||||||
errSessionAlreadyClosed = errors.New("Cannot close session. It was already closed before.")
|
errSessionAlreadyClosed = errors.New("Cannot close session. It was already closed before.")
|
||||||
)
|
)
|
||||||
|
|
||||||
// CryptoChangeCallback is called every time the encryption level changes
|
// cryptoChangeCallback is called every time the encryption level changes
|
||||||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||||
type CryptoChangeCallback func(isForwardSecure bool)
|
type cryptoChangeCallback func(isForwardSecure bool)
|
||||||
|
|
||||||
// closeCallback is called when a session is closed
|
// closeCallback is called when a session is closed
|
||||||
type closeCallback func(id protocol.ConnectionID)
|
type closeCallback func(id protocol.ConnectionID)
|
||||||
|
@ -49,7 +49,7 @@ type session struct {
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
|
||||||
closeCallback closeCallback
|
closeCallback closeCallback
|
||||||
cryptoChangeCallback CryptoChangeCallback
|
cryptoChangeCallback cryptoChangeCallback
|
||||||
|
|
||||||
conn connection
|
conn connection
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||||
return s, err
|
return s, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
|
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) {
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue