mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
remove non forward-secure dialing
This was broken for a long time, and won't be available when using the TLS 1.3 handshake.
This commit is contained in:
parent
bd60e996dc
commit
99a2853e7d
15 changed files with 219 additions and 441 deletions
|
@ -1,8 +1,9 @@
|
|||
# Changelog
|
||||
|
||||
## v0.6.1 (unreleased)
|
||||
## v0.7 (unreleased)
|
||||
|
||||
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||
|
||||
## v0.6.0 (2017-12-12)
|
||||
|
||||
|
|
55
client.go
55
client.go
|
@ -60,33 +60,15 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrNonFWSecure(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func DialNonFWSecure(
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
) (Session, error) {
|
||||
connID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -119,26 +101,7 @@ func DialNonFWSecure(
|
|||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session.(NonFWSession), nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := sess.WaitUntilHandshakeComplete(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
|
@ -268,14 +231,8 @@ func (c *client) establishSecureConnection() error {
|
|||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case ev := <-c.session.handshakeStatus():
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
}
|
||||
if !c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure {
|
||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
||||
}
|
||||
return nil
|
||||
case err := <-c.session.handshakeStatus():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -101,57 +101,7 @@ var _ = Describe("Client", func() {
|
|||
generateConnectionID = origGenerateConnectionID
|
||||
})
|
||||
|
||||
It("dials non-forward-secure", func() {
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("dials a non-forward-secure address", func() {
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server, err := net.ListenUDP("udp", serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
for {
|
||||
_, clientAddr, err := server.ReadFromUDP(make([]byte, 200))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = server.WriteToUDP(acceptClientVersionPacket(cl.connectionID), clientAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
}()
|
||||
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := DialAddrNonFWSecure(server.LocalAddr().String(), nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
server.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("Dial only returns after the handshake is complete", func() {
|
||||
It("returns after the handshake is complete", func() {
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -161,9 +111,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(s).ToNot(BeNil())
|
||||
close(dialed)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
close(sess.handshakeComplete)
|
||||
close(sess.handshakeChan)
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -249,22 +197,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{err: testErr}
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||
testErr := errors.New("late handshake error")
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
sess.handshakeComplete <- testErr
|
||||
sess.handshakeChan <- testErr
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
@ -309,7 +242,7 @@ var _ = Describe("Client", func() {
|
|||
) (packetHandler, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -335,7 +268,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
sessionChan := make(chan *mockSession)
|
||||
handshakeChan := make(chan handshakeEvent)
|
||||
handshakeChan := make(chan error)
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ string,
|
||||
|
@ -386,7 +319,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(negotiatedVersions).To(ContainElement(newVersion))
|
||||
Expect(initialVersion).To(Equal(actualInitialVersion))
|
||||
|
||||
handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
close(handshakeChan)
|
||||
Eventually(established).Should(BeClosed())
|
||||
})
|
||||
|
||||
|
|
|
@ -2,14 +2,12 @@ package self_test
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
@ -111,18 +109,6 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
expectDurationInRTTs(4)
|
||||
})
|
||||
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT to become secure
|
||||
// TODO (marten-seemann): enable this test (see #625)
|
||||
PIt("is secure after 2 RTTs", func() {
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
fmt.Println("#### is non fw secure ###")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return true
|
||||
|
|
|
@ -130,13 +130,6 @@ type Session interface {
|
|||
Context() context.Context
|
||||
}
|
||||
|
||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
||||
// The communication is encrypted, but not yet forward secure.
|
||||
type NonFWSession interface {
|
||||
Session
|
||||
WaitUntilHandshakeComplete() error
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// The QUIC versions that can be negotiated.
|
||||
|
|
|
@ -51,8 +51,8 @@ type cryptoSetupClient struct {
|
|||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
|
||||
paramsChan chan<- TransportParameters
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ func NewCryptoSetupClient(
|
|||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
|
@ -93,7 +93,7 @@ func NewCryptoSetupClient(
|
|||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
divNonceChan: make(chan []byte),
|
||||
|
@ -159,8 +159,8 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||
}
|
||||
// blocks until the session has received the parameters
|
||||
h.paramsChan <- *params
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
default:
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
|
@ -496,10 +496,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
stream *mockStream
|
||||
keyDerivationCalledWith *keyDerivationValues
|
||||
shloMap map[Tag][]byte
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
handshakeEvent chan struct{}
|
||||
paramsChan chan TransportParameters
|
||||
)
|
||||
|
||||
|
@ -108,7 +108,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
version := protocol.Version39
|
||||
// use a buffered channel here, so that we can parse a SHLO without having to receive the TransportParameters to avoid blocking
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
csInt, err := NewCryptoSetupClient(
|
||||
stream,
|
||||
"hostname",
|
||||
|
@ -117,7 +117,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
nil,
|
||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
protocol.Version39,
|
||||
nil,
|
||||
)
|
||||
|
@ -385,22 +385,22 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
cs.receivedSecurePacket = false
|
||||
_, err := cs.handleSHLOMessage(shloMap)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")))
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects SHLOs without a PUBS", func() {
|
||||
delete(shloMap, TagPUBS)
|
||||
_, err := cs.handleSHLOMessage(shloMap)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects SHLOs without a version list", func() {
|
||||
delete(shloMap, TagVER)
|
||||
_, err := cs.handleSHLOMessage(shloMap)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts a SHLO after a version negotiation", func() {
|
||||
|
@ -435,7 +435,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(params.IdleTimeout).To(Equal(13 * time.Second))
|
||||
})
|
||||
|
||||
It("closes the aeadChanged when receiving an SHLO", func() {
|
||||
It("closes the handshakeEvent chan when receiving an SHLO", func() {
|
||||
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -444,8 +444,8 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Eventually(aeadChanged).Should(BeClosed())
|
||||
Eventually(handshakeEvent).Should(Receive())
|
||||
Eventually(handshakeEvent).Should(BeClosed())
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -652,9 +652,9 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(keyDerivationCalledWith.cert).To(Equal(certManager.leafCert))
|
||||
Expect(keyDerivationCalledWith.divNonce).To(Equal(cs.diversificationNonce))
|
||||
Expect(keyDerivationCalledWith.pers).To(Equal(protocol.PerspectiveClient))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("uses the server nonce, if the server sent one", func() {
|
||||
|
@ -664,24 +664,24 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(keyDerivationCalledWith.nonces).To(Equal(append(cs.nonc, cs.sno...)))
|
||||
Expect(aeadChanged).To(Receive())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't create a secureAEAD if the certificate is not yet verified, even if it has all necessary values", func() {
|
||||
err := cs.maybeUpgradeCrypto()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.secureAEAD).To(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
cs.serverVerified = true
|
||||
// make sure we really had all necessary values before, and only serverVerified was missing
|
||||
err = cs.maybeUpgradeCrypto()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("tries to escalate before reading a handshake message", func() {
|
||||
|
@ -694,10 +694,10 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error())))
|
||||
close(done)
|
||||
}()
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Eventually(handshakeEvent).Should(Receive())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -715,10 +715,10 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
cs.serverVerified = true
|
||||
Expect(cs.secureAEAD).To(BeNil())
|
||||
cs.SetDiversificationNonce([]byte("div"))
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Eventually(handshakeEvent).Should(Receive())
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).ToNot(Receive())
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
// make the go routine return
|
||||
stream.close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
|
|
@ -42,7 +42,7 @@ type cryptoSetupServer struct {
|
|||
|
||||
receivedParams bool
|
||||
paramsChan chan<- TransportParameters
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
@ -76,7 +76,7 @@ func NewCryptoSetup(
|
|||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *Cookie) bool,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
|
@ -96,7 +96,7 @@ func NewCryptoSetup(
|
|||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
aeadChanged: aeadChanged,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -182,7 +182,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
|
|||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
||||
return false, err
|
||||
}
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.sentSHLO)
|
||||
return true, nil
|
||||
}
|
||||
|
@ -206,9 +206,9 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
|
|||
if err == nil {
|
||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||
h.receivedForwardSecurePacket = true
|
||||
// wait until protocol.EncryptionForwardSecure was sent on the aeadChan
|
||||
// wait for the send on the handshakeEvent chan
|
||||
<-h.sentSHLO
|
||||
close(h.aeadChanged)
|
||||
close(h.handshakeEvent)
|
||||
}
|
||||
return res, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
|
@ -396,8 +396,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
|
||||
// Generate a new curve instance to derive the forward secure key
|
||||
var fsNonce bytes.Buffer
|
||||
|
|
|
@ -124,7 +124,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
cs *cryptoSetupServer
|
||||
stream *mockStream
|
||||
paramsChan chan TransportParameters
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
handshakeEvent chan struct{}
|
||||
nonce32 []byte
|
||||
versionTag []byte
|
||||
validSTK []byte
|
||||
|
@ -146,7 +146,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
|
||||
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
stream = newMockStream()
|
||||
kex = &mockKEX{}
|
||||
signer = &mockSigner{}
|
||||
|
@ -170,7 +170,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
supportedVersions,
|
||||
nil,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupServer)
|
||||
|
@ -343,10 +343,10 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
err := cs.HandleCryptoStream()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects client nonces that have the wrong length", func() {
|
||||
|
@ -377,9 +377,9 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
|
||||
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).ToNot(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||
Expect(handshakeEvent).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs missing SCID", func() {
|
||||
|
@ -535,7 +535,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagKEXS: kexs,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
close(cs.sentSHLO)
|
||||
}
|
||||
|
||||
|
@ -657,7 +657,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
|
||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(BeClosed())
|
||||
Expect(handshakeEvent).To(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ type cryptoSetupTLS struct {
|
|||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
|
||||
tls MintTLS
|
||||
cryptoStream *CryptoStreamConn
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
tls MintTLS
|
||||
cryptoStream *CryptoStreamConn
|
||||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
|
@ -36,16 +36,16 @@ func NewCryptoSetupTLSServer(
|
|||
tls MintTLS,
|
||||
cryptoStream *CryptoStreamConn,
|
||||
nullAEAD crypto.AEAD,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
return &cryptoSetupTLS{
|
||||
tls: tls,
|
||||
cryptoStream: cryptoStream,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
tls: tls,
|
||||
cryptoStream: cryptoStream,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,7 @@ func NewCryptoSetupTLSClient(
|
|||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
hostname string,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
handshakeEvent chan<- struct{},
|
||||
tls MintTLS,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
|
@ -64,11 +64,11 @@ func NewCryptoSetupTLSClient(
|
|||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: tls,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: tls,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -102,9 +102,8 @@ handshakeLoop:
|
|||
h.aead = aead
|
||||
h.mutex.Unlock()
|
||||
|
||||
// signal to the outside world that the handshake completed
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -20,17 +20,17 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
|
|||
|
||||
var _ = Describe("TLS Crypto Setup", func() {
|
||||
var (
|
||||
cs *cryptoSetupTLS
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
cs *cryptoSetupTLS
|
||||
handshakeEvent chan struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
cs = NewCryptoSetupTLSServer(
|
||||
nil,
|
||||
NewCryptoStreamConn(nil),
|
||||
nil, // AEAD
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
protocol.VersionTLS,
|
||||
).(*cryptoSetupTLS)
|
||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
||||
|
@ -51,8 +51,8 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).To(BeClosed())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
Expect(handshakeEvent).To(BeClosed())
|
||||
})
|
||||
|
||||
It("handshakes until it is connected", func() {
|
||||
|
@ -63,7 +63,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive())
|
||||
Expect(handshakeEvent).To(Receive())
|
||||
})
|
||||
|
||||
Context("escalating crypto", func() {
|
||||
|
@ -180,17 +180,17 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
|
||||
var _ = Describe("TLS Crypto Setup, for the client", func() {
|
||||
var (
|
||||
cs *cryptoSetupTLS
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
cs *cryptoSetupTLS
|
||||
handshakeEvent chan struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent = make(chan struct{})
|
||||
csInt, err := NewCryptoSetupTLSClient(
|
||||
nil,
|
||||
0,
|
||||
"quic.clemente.io",
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
nil, // mintTLS
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
|
|
12
server.go
12
server.go
|
@ -20,7 +20,7 @@ import (
|
|||
type packetHandler interface {
|
||||
Session
|
||||
getCryptoStream() cryptoStreamI
|
||||
handshakeStatus() <-chan handshakeEvent
|
||||
handshakeStatus() <-chan error
|
||||
handlePacket(*receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
run() error
|
||||
|
@ -391,14 +391,8 @@ func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.C
|
|||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
ev := <-session.handshakeStatus()
|
||||
if ev.err != nil {
|
||||
return
|
||||
}
|
||||
if ev.encLevel == protocol.EncryptionForwardSecure {
|
||||
break
|
||||
}
|
||||
if err := <-session.handshakeStatus(); err != nil {
|
||||
return
|
||||
}
|
||||
s.sessionQueue <- session
|
||||
}()
|
||||
|
|
|
@ -22,14 +22,13 @@ import (
|
|||
)
|
||||
|
||||
type mockSession struct {
|
||||
connectionID protocol.ConnectionID
|
||||
packetCount int
|
||||
closed bool
|
||||
closeReason error
|
||||
closedRemote bool
|
||||
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
|
||||
handshakeChan chan handshakeEvent
|
||||
handshakeComplete chan error // for WaitUntilHandshakeComplete
|
||||
connectionID protocol.ConnectionID
|
||||
packetCount int
|
||||
closed bool
|
||||
closeReason error
|
||||
closedRemote bool
|
||||
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
|
||||
handshakeChan chan error
|
||||
}
|
||||
|
||||
func (s *mockSession) handlePacket(*receivedPacket) {
|
||||
|
@ -40,9 +39,6 @@ func (s *mockSession) run() error {
|
|||
<-s.stopRunLoop
|
||||
return s.closeReason
|
||||
}
|
||||
func (s *mockSession) WaitUntilHandshakeComplete() error {
|
||||
return <-s.handshakeComplete
|
||||
}
|
||||
func (s *mockSession) Close(e error) error {
|
||||
if s.closed {
|
||||
return nil
|
||||
|
@ -61,17 +57,16 @@ func (s *mockSession) closeRemote(e error) {
|
|||
func (s *mockSession) OpenStream() (Stream, error) {
|
||||
return &stream{}, nil
|
||||
}
|
||||
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
|
||||
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
|
||||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
|
||||
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
|
||||
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
|
||||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||
func (s *mockSession) handshakeStatus() <-chan error { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") }
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
var _ NonFWSession = &mockSession{}
|
||||
|
||||
func newMockSession(
|
||||
_ connection,
|
||||
|
@ -82,10 +77,9 @@ func newMockSession(
|
|||
_ *Config,
|
||||
) (packetHandler, error) {
|
||||
s := mockSession{
|
||||
connectionID: connectionID,
|
||||
handshakeChan: make(chan handshakeEvent),
|
||||
handshakeComplete: make(chan error),
|
||||
stopRunLoop: make(chan struct{}),
|
||||
connectionID: connectionID,
|
||||
handshakeChan: make(chan error),
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
@ -155,9 +149,8 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[connID].(*mockSession)
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
|
||||
Consistently(func() Session { return acceptedSess }).Should(BeNil())
|
||||
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure}
|
||||
close(sess.handshakeChan)
|
||||
Eventually(func() Session { return acceptedSess }).Should(Equal(sess))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
@ -173,7 +166,7 @@ var _ = Describe("Server", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
sess := serv.sessions[connID].(*mockSession)
|
||||
sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")}
|
||||
sess.handshakeChan <- errors.New("handshake failed")
|
||||
Consistently(func() bool { return accepted }).Should(BeFalse())
|
||||
close(done)
|
||||
})
|
||||
|
|
112
session.go
112
session.go
|
@ -39,11 +39,6 @@ var (
|
|||
newCryptoSetupClient = handshake.NewCryptoSetupClient
|
||||
)
|
||||
|
||||
type handshakeEvent struct {
|
||||
encLevel protocol.EncryptionLevel
|
||||
err error
|
||||
}
|
||||
|
||||
type closeError struct {
|
||||
err error
|
||||
remote bool
|
||||
|
@ -90,17 +85,14 @@ type session struct {
|
|||
|
||||
// this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them
|
||||
paramsChan <-chan handshake.TransportParameters
|
||||
// this channel is passed to the CryptoSetup and receives the current encryption level
|
||||
// it is closed as soon as the handshake is complete
|
||||
aeadChanged <-chan protocol.EncryptionLevel
|
||||
// the handshakeEvent channel is passed to the CryptoSetup.
|
||||
// It receives when it makes sense to try decrypting undecryptable packets.
|
||||
handshakeEvent <-chan struct{}
|
||||
// handshakeChan is returned by handshakeStatus.
|
||||
// It receives any error that might occur during the handshake.
|
||||
// It is closed when the handshake is complete.
|
||||
handshakeChan chan error
|
||||
handshakeComplete bool
|
||||
// will be closed as soon as the handshake completes, and receive any error that might occur until then
|
||||
// it is used to block WaitUntilHandshakeComplete()
|
||||
handshakeCompleteChan chan error
|
||||
// handshakeChan receives handshake events and is closed as soon the handshake completes
|
||||
// the receiving end of this channel is passed to the creator of the session
|
||||
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error
|
||||
handshakeChan chan handshakeEvent
|
||||
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
// Used to calculate the next packet number from the truncated wire
|
||||
|
@ -131,15 +123,15 @@ func newSession(
|
|||
config *Config,
|
||||
) (packetHandler, error) {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
config: config,
|
||||
aeadChanged: aeadChanged,
|
||||
paramsChan: paramsChan,
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
config: config,
|
||||
handshakeEvent: handshakeEvent,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
transportParams := &handshake.TransportParameters{
|
||||
|
@ -158,7 +150,7 @@ func newSession(
|
|||
s.config.Versions,
|
||||
s.config.AcceptCookie,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -179,15 +171,15 @@ var newClientSession = func(
|
|||
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
|
||||
) (packetHandler, error) {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
config: config,
|
||||
aeadChanged: aeadChanged,
|
||||
paramsChan: paramsChan,
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
config: config,
|
||||
handshakeEvent: handshakeEvent,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
transportParams := &handshake.TransportParameters{
|
||||
|
@ -205,7 +197,7 @@ var newClientSession = func(
|
|||
tlsConf,
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
initialVersion,
|
||||
negotiatedVersions,
|
||||
)
|
||||
|
@ -227,21 +219,21 @@ func newTLSServerSession(
|
|||
peerParams *handshake.TransportParameters,
|
||||
v protocol.VersionNumber,
|
||||
) (packetHandler, error) {
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
aeadChanged: aeadChanged,
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
s.preSetup()
|
||||
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
|
||||
tls,
|
||||
cryptoStreamConn,
|
||||
nullAEAD,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
v,
|
||||
)
|
||||
if err := s.postSetup(initialPacketNumber); err != nil {
|
||||
|
@ -264,15 +256,15 @@ var newTLSClientSession = func(
|
|||
paramsChan <-chan handshake.TransportParameters,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
) (packetHandler, error) {
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
handshakeEvent := make(chan struct{}, 1)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
aeadChanged: aeadChanged,
|
||||
paramsChan: paramsChan,
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
handshakeEvent: handshakeEvent,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
tls.SetCryptoStream(s.cryptoStream)
|
||||
|
@ -280,7 +272,7 @@ var newTLSClientSession = func(
|
|||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
hostname,
|
||||
aeadChanged,
|
||||
handshakeEvent,
|
||||
tls,
|
||||
v,
|
||||
)
|
||||
|
@ -302,8 +294,7 @@ func (s *session) preSetup() {
|
|||
}
|
||||
|
||||
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
||||
s.handshakeChan = make(chan handshakeEvent, 3)
|
||||
s.handshakeCompleteChan = make(chan error, 1)
|
||||
s.handshakeChan = make(chan error, 1)
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
|
@ -343,7 +334,7 @@ func (s *session) run() error {
|
|||
}()
|
||||
|
||||
var closeErr closeError
|
||||
aeadChanged := s.aeadChanged
|
||||
handshakeEvent := s.handshakeEvent
|
||||
|
||||
runLoop:
|
||||
for {
|
||||
|
@ -381,16 +372,14 @@ runLoop:
|
|||
putPacketBuffer(p.header.Raw)
|
||||
case p := <-s.paramsChan:
|
||||
s.processTransportParameters(&p)
|
||||
case l, ok := <-aeadChanged:
|
||||
case _, ok := <-handshakeEvent:
|
||||
if !ok { // the aeadChanged chan was closed. This means that the handshake is completed.
|
||||
s.handshakeComplete = true
|
||||
aeadChanged = nil // prevent this case from ever being selected again
|
||||
handshakeEvent = nil // prevent this case from ever being selected again
|
||||
s.sentPacketHandler.SetHandshakeComplete()
|
||||
close(s.handshakeChan)
|
||||
close(s.handshakeCompleteChan)
|
||||
} else {
|
||||
s.tryDecryptingQueuedPackets()
|
||||
s.handshakeChan <- handshakeEvent{encLevel: l}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -428,8 +417,7 @@ runLoop:
|
|||
// only send the error the handshakeChan when the handshake is not completed yet
|
||||
// otherwise this chan will already be closed
|
||||
if !s.handshakeComplete {
|
||||
s.handshakeCompleteChan <- closeErr.err
|
||||
s.handshakeChan <- handshakeEvent{err: closeErr.err}
|
||||
s.handshakeChan <- closeErr.err
|
||||
}
|
||||
s.handleCloseError(closeErr)
|
||||
return closeErr.err
|
||||
|
@ -878,10 +866,6 @@ func (s *session) OpenStreamSync() (Stream, error) {
|
|||
return s.streamsMap.OpenStreamSync()
|
||||
}
|
||||
|
||||
func (s *session) WaitUntilHandshakeComplete() error {
|
||||
return <-s.handshakeCompleteChan
|
||||
}
|
||||
|
||||
func (s *session) newStream(id protocol.StreamID) streamI {
|
||||
var initialSendWindow protocol.ByteCount
|
||||
if s.peerParams != nil {
|
||||
|
@ -970,7 +954,7 @@ func (s *session) RemoteAddr() net.Addr {
|
|||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *session) handshakeStatus() <-chan handshakeEvent {
|
||||
func (s *session) handshakeStatus() <-chan error {
|
||||
return s.handshakeChan
|
||||
}
|
||||
|
||||
|
|
167
session_test.go
167
session_test.go
|
@ -78,11 +78,11 @@ func areSessionsRunning() bool {
|
|||
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
sess *session
|
||||
scfg *handshake.ServerConfig
|
||||
mconn *mockConnection
|
||||
cryptoSetup *mockCryptoSetup
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
sess *session
|
||||
scfg *handshake.ServerConfig
|
||||
mconn *mockConnection
|
||||
cryptoSetup *mockCryptoSetup
|
||||
handshakeChan chan<- struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -99,9 +99,9 @@ var _ = Describe("Session", func() {
|
|||
_ []protocol.VersionNumber,
|
||||
_ func(net.Addr, *Cookie) bool,
|
||||
_ chan<- handshake.TransportParameters,
|
||||
aeadChangedP chan<- protocol.EncryptionLevel,
|
||||
handshakeChanP chan<- struct{},
|
||||
) (handshake.CryptoSetup, error) {
|
||||
aeadChanged = aeadChangedP
|
||||
handshakeChan = handshakeChanP
|
||||
return cryptoSetup, nil
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,7 @@ var _ = Describe("Session", func() {
|
|||
_ []protocol.VersionNumber,
|
||||
cookieFunc func(net.Addr, *Cookie) bool,
|
||||
_ chan<- handshake.TransportParameters,
|
||||
_ chan<- protocol.EncryptionLevel,
|
||||
_ chan<- struct{},
|
||||
) (handshake.CryptoSetup, error) {
|
||||
cookieVerify = cookieFunc
|
||||
return cryptoSetup, nil
|
||||
|
@ -516,61 +516,6 @@ var _ = Describe("Session", func() {
|
|||
Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242)))
|
||||
})
|
||||
|
||||
Context("waiting until the handshake completes", func() {
|
||||
It("waits until the handshake is complete", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess.run()
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.WaitUntilHandshakeComplete()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
aeadChanged <- protocol.EncryptionForwardSecure
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
close(aeadChanged)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
})
|
||||
|
||||
It("errors if the handshake fails", func(done Done) {
|
||||
testErr := errors.New("crypto error")
|
||||
sess.cryptoSetup = &mockCryptoSetup{handleErr: testErr}
|
||||
go sess.run()
|
||||
err := sess.WaitUntilHandshakeComplete()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
||||
It("returns when Close is called", func(done Done) {
|
||||
testErr := errors.New("close error")
|
||||
go sess.run()
|
||||
var waitReturned bool
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.WaitUntilHandshakeComplete()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
waitReturned = true
|
||||
}()
|
||||
sess.Close(testErr)
|
||||
Eventually(func() bool { return waitReturned }).Should(BeTrue())
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("doesn't wait if the handshake is already completed", func(done Done) {
|
||||
go sess.run()
|
||||
close(aeadChanged)
|
||||
err := sess.WaitUntilHandshakeComplete()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
|
||||
Context("accepting streams", func() {
|
||||
BeforeEach(func() {
|
||||
// don't use the mock here
|
||||
|
@ -1362,46 +1307,48 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
})
|
||||
|
||||
It("send a handshake event on the handshakeChan when the AEAD changes to secure", func(done Done) {
|
||||
go sess.run()
|
||||
aeadChanged <- protocol.EncryptionSecure
|
||||
Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure}))
|
||||
It("doesn't do anything when the crypto setup says to decrypt undecryptable packets", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
handshakeChan <- struct{}{}
|
||||
Consistently(sess.handshakeStatus()).ShouldNot(Receive())
|
||||
// make sure the go routine returns
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
close(done)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("send a handshake event on the handshakeChan when the AEAD changes to forward-secure", func(done Done) {
|
||||
go sess.run()
|
||||
aeadChanged <- protocol.EncryptionForwardSecure
|
||||
Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure}))
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("closes the handshakeChan when the handshake completes", func(done Done) {
|
||||
go sess.run()
|
||||
close(aeadChanged)
|
||||
It("closes the handshakeChan when the handshake completes", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
close(handshakeChan)
|
||||
Eventually(sess.handshakeStatus()).Should(BeClosed())
|
||||
// make sure the go routine returns
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
close(done)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("passes errors to the handshakeChan", func(done Done) {
|
||||
It("passes errors to the handshakeChan", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
go sess.run()
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Expect(sess.handshakeStatus()).To(Receive(&handshakeEvent{err: testErr}))
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("does not block if an error occurs", func(done Done) {
|
||||
// this test basically tests that the handshakeChan has a capacity of 3
|
||||
// The session needs to run (and close) properly, even if no one is receiving from the handshakeChan
|
||||
go sess.run()
|
||||
aeadChanged <- protocol.EncryptionSecure
|
||||
aeadChanged <- protocol.EncryptionForwardSecure
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
close(done)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := sess.run()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
sess.Close(testErr)
|
||||
Expect(sess.handshakeStatus()).To(Receive(Equal(testErr)))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("process transport parameters received from the peer", func() {
|
||||
|
@ -1503,7 +1450,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
It("closes the session due to the idle timeout after handshake", func() {
|
||||
sess.config.IdleTimeout = 0
|
||||
close(aeadChanged)
|
||||
close(handshakeChan)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -1619,9 +1566,9 @@ var _ = Describe("Session", func() {
|
|||
|
||||
var _ = Describe("Client Session", func() {
|
||||
var (
|
||||
sess *session
|
||||
mconn *mockConnection
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
sess *session
|
||||
mconn *mockConnection
|
||||
handshakeChan chan<- struct{}
|
||||
|
||||
cryptoSetup *mockCryptoSetup
|
||||
)
|
||||
|
@ -1638,11 +1585,11 @@ var _ = Describe("Client Session", func() {
|
|||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ chan<- handshake.TransportParameters,
|
||||
aeadChangedP chan<- protocol.EncryptionLevel,
|
||||
handshakeChanP chan<- struct{},
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
) (handshake.CryptoSetup, error) {
|
||||
aeadChanged = aeadChangedP
|
||||
handshakeChan = handshakeChanP
|
||||
return cryptoSetup, nil
|
||||
}
|
||||
|
||||
|
@ -1674,10 +1621,13 @@ var _ = Describe("Client Session", func() {
|
|||
sess.unpacker = &mockUnpacker{}
|
||||
})
|
||||
|
||||
It("passes the diversification nonce to the cryptoSetup", func() {
|
||||
It("passes the diversification nonce to the crypto setup", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess.run()
|
||||
err := sess.run()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
hdr.PacketNumber = 5
|
||||
hdr.DiversificationNonce = []byte("foobar")
|
||||
|
@ -1685,16 +1635,7 @@ var _ = Describe("Client Session", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce))
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
It("does not block if an error occurs", func(done Done) {
|
||||
// this test basically tests that the handshakeChan has a capacity of 3
|
||||
// The session needs to run (and close) properly, even if no one is receiving from the handshakeChan
|
||||
go sess.run()
|
||||
aeadChanged <- protocol.EncryptionSecure
|
||||
aeadChanged <- protocol.EncryptionForwardSecure
|
||||
Expect(sess.Close(nil)).To(Succeed())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue