remove unused return value from session constructor

This commit is contained in:
Marten Seemann 2019-10-27 13:30:27 +07:00
parent 672328ca30
commit 416fe8364e
13 changed files with 99 additions and 166 deletions

View file

@ -259,10 +259,7 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
func (c *client) dial(ctx context.Context) error { func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.createNewTLSSession(c.version)
if err := c.createNewTLSSession(c.version); err != nil {
return err
}
err := c.establishSecureConnection(ctx) err := c.establishSecureConnection(ctx)
if err == errCloseForRecreating { if err == errCloseForRecreating {
return c.dial(ctx) return c.dial(ctx)
@ -357,7 +354,7 @@ func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
c.initialPacketNumber = c.session.closeForRecreating() c.initialPacketNumber = c.session.closeForRecreating()
} }
func (c *client) createNewTLSSession(_ protocol.VersionNumber) error { func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
params := &handshake.TransportParameters{ params := &handshake.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -372,8 +369,7 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) error {
} }
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() c.session = newClientSession(
sess, err := newClientSession(
c.conn, c.conn,
c.packetHandlers, c.packetHandlers,
c.destConnID, c.destConnID,
@ -386,12 +382,8 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) error {
c.logger, c.logger,
c.version, c.version,
) )
if err != nil { c.mutex.Unlock()
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c) c.packetHandlers.Add(c.srcConnID, c)
return nil
} }
func (c *client) Close() error { func (c *client) Close() error {

View file

@ -42,7 +42,7 @@ var _ = Describe("Client", func() {
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) (quicSession, error) ) quicSession
) )
// generate a packet sent by the server that accepts the QUIC version suggested by the client // generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -145,12 +145,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
remoteAddrChan <- conn.RemoteAddr().String() remoteAddrChan <- conn.RemoteAddr().String()
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil return sess
} }
_, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond}) _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -176,12 +176,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
hostnameChan <- tlsConf.ServerName hostnameChan <- tlsConf.ServerName
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil return sess
} }
tlsConf.ServerName = "foobar" tlsConf.ServerName = "foobar"
_, err := DialAddr("localhost:17890", tlsConf, nil) _, err := DialAddr("localhost:17890", tlsConf, nil)
@ -207,12 +207,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
hostnameChan <- tlsConf.ServerName hostnameChan <- tlsConf.ServerName
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
sess.EXPECT().run() sess.EXPECT().run()
return sess, nil return sess
} }
_, err := Dial( _, err := Dial(
packetConn, packetConn,
@ -243,13 +243,13 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) }) sess.EXPECT().run().Do(func() { close(run) })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().HandshakeComplete().Return(ctx)
return sess, nil return sess
} }
s, err := Dial( s, err := Dial(
packetConn, packetConn,
@ -281,11 +281,11 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr) sess.EXPECT().run().Return(testErr)
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil return sess
} }
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
_, err := Dial( _, err := Dial(
@ -322,8 +322,8 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
return sess, nil return sess
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
dialed := make(chan struct{}) dialed := make(chan struct{})
@ -366,9 +366,9 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
runner = runnerP runner = runnerP
return sess, nil return sess
} }
sess.EXPECT().run().Do(func() { sess.EXPECT().run().Do(func() {
runner.Retire(connID) runner.Retire(connID)
@ -411,10 +411,10 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
conn = connP conn = connP
close(sessionCreated) close(sessionCreated)
return sess, nil return sess
} }
sess.EXPECT().run().Do(func() { sess.EXPECT().run().Do(func() {
<-run <-run
@ -531,7 +531,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, /* initial version */ _ protocol.VersionNumber, /* initial version */
_ utils.Logger, _ utils.Logger,
versionP protocol.VersionNumber, versionP protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
cconn = connP cconn = connP
version = versionP version = versionP
conf = configP conf = configP
@ -540,7 +540,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil return sess
} }
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -580,12 +580,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr) sess.EXPECT().run().Return(testErr)
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil return sess
} }
_, err := Dial( _, err := Dial(
packetConn, packetConn,

View file

@ -126,8 +126,8 @@ func NewCryptoSetupClient(
tlsConf *tls.Config, tlsConf *tls.Config,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { ) (CryptoSetup, <-chan struct{} /* ClientHello written */) {
cs, clientHelloWritten, err := newCryptoSetup( cs, clientHelloWritten := newCryptoSetup(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream, oneRTTStream,
@ -139,11 +139,8 @@ func NewCryptoSetupClient(
logger, logger,
protocol.PerspectiveClient, protocol.PerspectiveClient,
) )
if err != nil {
return nil, nil, err
}
cs.conn = qtls.Client(newConn(remoteAddr), cs.tlsConf) cs.conn = qtls.Client(newConn(remoteAddr), cs.tlsConf)
return cs, clientHelloWritten, nil return cs, clientHelloWritten
} }
// NewCryptoSetupServer creates a new crypto setup for the server // NewCryptoSetupServer creates a new crypto setup for the server
@ -158,8 +155,8 @@ func NewCryptoSetupServer(
tlsConf *tls.Config, tlsConf *tls.Config,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, error) { ) CryptoSetup {
cs, _, err := newCryptoSetup( cs, _ := newCryptoSetup(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream, oneRTTStream,
@ -171,11 +168,8 @@ func NewCryptoSetupServer(
logger, logger,
protocol.PerspectiveServer, protocol.PerspectiveServer,
) )
if err != nil {
return nil, err
}
cs.conn = qtls.Server(newConn(remoteAddr), cs.tlsConf) cs.conn = qtls.Server(newConn(remoteAddr), cs.tlsConf)
return cs, nil return cs
} }
func newCryptoSetup( func newCryptoSetup(
@ -189,11 +183,8 @@ func newCryptoSetup(
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger, logger utils.Logger,
perspective protocol.Perspective, perspective protocol.Perspective,
) (*cryptoSetup, <-chan struct{} /* ClientHello written */, error) { ) (*cryptoSetup, <-chan struct{} /* ClientHello written */) {
initialSealer, initialOpener, err := NewInitialAEAD(connID, perspective) initialSealer, initialOpener := NewInitialAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
extHandler := newExtensionHandler(tp.Marshal(), perspective) extHandler := newExtensionHandler(tp.Marshal(), perspective)
cs := &cryptoSetup{ cs := &cryptoSetup{
initialStream: initialStream, initialStream: initialStream,
@ -219,17 +210,13 @@ func newCryptoSetup(
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler) qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler)
cs.tlsConf = qtlsConf cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan, nil return cs, cs.clientHelloWrittenChan
} }
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error { func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener, err := NewInitialAEAD(id, h.perspective) initialSealer, initialOpener := NewInitialAEAD(id, h.perspective)
if err != nil {
return err
}
h.initialSealer = initialSealer h.initialSealer = initialSealer
h.initialOpener = initialOpener h.initialOpener = initialOpener
return nil
} }
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) { func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) {

View file

@ -86,7 +86,7 @@ var _ = Describe("Crypto Setup TLS", func() {
return &tls.Config{ServerName: ch.ServerName}, nil return &tls.Config{ServerName: ch.ServerName}, nil
}, },
} }
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
&bytes.Buffer{}, &bytes.Buffer{},
&bytes.Buffer{}, &bytes.Buffer{},
ioutil.Discard, ioutil.Discard,
@ -98,7 +98,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
qtlsConf := server.(*cryptoSetup).tlsConf qtlsConf := server.(*cryptoSetup).tlsConf
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName)) Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
_, getCertificateErr := qtlsConf.GetCertificate(nil) _, getCertificateErr := qtlsConf.GetCertificate(nil)
@ -118,7 +117,7 @@ var _ = Describe("Crypto Setup TLS", func() {
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream, _ := initStreams()
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -130,7 +129,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -156,7 +154,7 @@ var _ = Describe("Crypto Setup TLS", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream, _ := initStreams()
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -168,7 +166,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -179,6 +176,7 @@ var _ = Describe("Crypto Setup TLS", func() {
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
var err error
Expect(sErrChan).To(Receive(&err)) Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError) qerr := err.(*qerr.QuicError)
@ -196,7 +194,7 @@ var _ = Describe("Crypto Setup TLS", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream, _ := initStreams()
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -208,7 +206,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -230,7 +227,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when it is closed", func() { It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream, _ := initStreams()
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -242,7 +239,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -319,7 +315,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
client, _, err := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
cOneRTTStream, cOneRTTStream,
@ -331,7 +327,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
Expect(err).ToNot(HaveOccurred())
var sHandshakeComplete bool var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams() sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams()
@ -341,7 +336,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
var token [16]byte var token [16]byte
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
sOneRTTStream, sOneRTTStream,
@ -353,7 +348,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
handshake(client, cChunkChan, server, sChunkChan) handshake(client, cChunkChan, server, sChunkChan)
var cErr, sErr error var cErr, sErr error
@ -394,7 +388,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("signals when it has written the ClientHello", func() { It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
client, chChan, err := NewCryptoSetupClient( client, chChan := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -406,7 +400,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -435,7 +428,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b }) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
cRunner.EXPECT().OnHandshakeComplete() cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -447,7 +440,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
var token [16]byte var token [16]byte
@ -458,7 +450,7 @@ var _ = Describe("Crypto Setup TLS", func() {
IdleTimeout: 0x1337 * time.Second, IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: &token, StatelessResetToken: &token,
} }
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -470,7 +462,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -495,7 +486,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete() cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -507,13 +498,12 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete() sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -525,7 +515,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -552,7 +541,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete() cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -564,13 +553,12 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete() sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard, ioutil.Discard,
@ -582,7 +570,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{}, &congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {

View file

@ -17,7 +17,7 @@ var initialSuite = &qtls.CipherSuiteTLS13{
} }
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. // NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) { func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID) clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient { if pers == protocol.PerspectiveClient {
@ -34,8 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)), return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true)), newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true))
nil
} }
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {

View file

@ -55,8 +55,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
}) })
It("encrypts the client's Initial", func() { It("encrypts the client's Initial", func() {
sealer, _, err := NewInitialAEAD(connID, protocol.PerspectiveClient) sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
header := split("c3ff000017088394c8f03e5157080000449e00000002") header := split("c3ff000017088394c8f03e5157080000449e00000002")
data := split("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001") data := split("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001")
data = append(data, make([]byte, 1162-len(data))...) // add PADDING data = append(data, make([]byte, 1162-len(data))...) // add PADDING
@ -71,8 +70,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
}) })
It("encrypt the server's Initial", func() { It("encrypt the server's Initial", func() {
sealer, _, err := NewInitialAEAD(connID, protocol.PerspectiveServer) sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
header := split("c1ff0000170008f067a5502a4262b50040740001") header := split("c1ff0000170008f067a5502a4262b50040740001")
data := split("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304") data := split("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304")
sealed := sealer.Seal(nil, data, 1, header) sealed := sealer.Seal(nil, data, 1, header)
@ -87,10 +85,8 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
It("seals and opens", func() { It("seals and opens", func() {
connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}
clientSealer, clientOpener, err := NewInitialAEAD(connectionID, protocol.PerspectiveClient) clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred()) serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer)
serverSealer, serverOpener, err := NewInitialAEAD(connectionID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad"))
m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad")) m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad"))
@ -105,22 +101,18 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
It("doesn't work if initialized with different connection IDs", func() { It("doesn't work if initialized with different connection IDs", func() {
c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1}
c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2}
clientSealer, _, err := NewInitialAEAD(c1, protocol.PerspectiveClient) clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred()) _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer)
_, serverOpener, err := NewInitialAEAD(c2, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad")) clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err = serverOpener.Open(nil, clientMessage, 42, []byte("aad")) _, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad"))
Expect(err).To(MatchError(ErrDecryptionFailed)) Expect(err).To(MatchError(ErrDecryptionFailed))
}) })
It("encrypts und decrypts the header", func() { It("encrypts und decrypts the header", func() {
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
clientSealer, clientOpener, err := NewInitialAEAD(connID, protocol.PerspectiveClient) clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred()) serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer)
serverSealer, serverOpener, err := NewInitialAEAD(connID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
// the first byte and the last 4 bytes should be encrypted // the first byte and the last 4 bytes should be encrypted
header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef} header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef}

View file

@ -69,7 +69,7 @@ type handshakeRunner interface {
type CryptoSetup interface { type CryptoSetup interface {
RunHandshake() RunHandshake()
io.Closer io.Closer
ChangeConnectionID(protocol.ConnectionID) error ChangeConnectionID(protocol.ConnectionID)
HandleMessage([]byte, protocol.EncryptionLevel) bool HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber) SetLargest1RTTAcked(protocol.PacketNumber)

View file

@ -37,11 +37,9 @@ func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
} }
// ChangeConnectionID mocks base method // ChangeConnectionID mocks base method
func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) error { func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ChangeConnectionID", arg0) m.ctrl.Call(m, "ChangeConnectionID", arg0)
ret0, _ := ret[0].(error)
return ret0
} }
// ChangeConnectionID indicates an expected call of ChangeConnectionID // ChangeConnectionID indicates an expected call of ChangeConnectionID

View file

@ -65,7 +65,7 @@ func ComposeAckFrame(smallest protocol.PacketNumber, largest protocol.PacketNumb
// ComposeInitialPacket returns an Initial packet encrypted under key // ComposeInitialPacket returns an Initial packet encrypted under key
// (the original destination connection ID) containing specified frames // (the original destination connection ID) containing specified frames
func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte { func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte {
sealer, _, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer) sealer, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer)
// compose payload // compose payload
var payload []byte var payload []byte

View file

@ -82,7 +82,7 @@ type baseServer struct {
sessionHandler packetHandlerManager sessionHandler packetHandlerManager
// set as a member, so they can be set in the tests // set as a member, so they can be set in the tests
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) (quicSession, error) newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession
serverError error serverError error
errorChan chan struct{} errorChan chan struct{}
@ -424,7 +424,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
return nil, nil, err return nil, nil, err
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
sess, err := s.createNewSession( sess := s.createNewSession(
p.remoteAddr, p.remoteAddr,
origDestConnectionID, origDestConnectionID,
hdr.DestConnectionID, hdr.DestConnectionID,
@ -432,9 +432,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
connID, connID,
hdr.Version, hdr.Version,
) )
if err != nil {
return nil, nil, err
}
sess.handlePacket(p) sess.handlePacket(p)
return sess, connID, nil return sess, connID, nil
} }
@ -446,7 +443,7 @@ func (s *baseServer) createNewSession(
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
version protocol.VersionNumber, version protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
token := s.sessionHandler.GetStatelessResetToken(srcConnID) token := s.sessionHandler.GetStatelessResetToken(srcConnID)
params := &handshake.TransportParameters{ params := &handshake.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -462,7 +459,7 @@ func (s *baseServer) createNewSession(
StatelessResetToken: &token, StatelessResetToken: &token,
OriginalConnectionID: origDestConnID, OriginalConnectionID: origDestConnID,
} }
sess, err := s.newSession( sess := s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr}, &conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionHandler, s.sessionHandler,
clientDestConnID, clientDestConnID,
@ -475,12 +472,9 @@ func (s *baseServer) createNewSession(
s.logger, s.logger,
version, version,
) )
if err != nil {
return nil, err
}
go sess.run() go sess.run()
go s.handleNewSession(sess) go s.handleNewSession(sess)
return sess, nil return sess
} }
func (s *baseServer) handleNewSession(sess quicSession) { func (s *baseServer) handleNewSession(sess quicSession) {
@ -542,10 +536,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
} }
func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
sealer, _, err := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
if err != nil {
return err
}
packetBuffer := getPacketBuffer() packetBuffer := getPacketBuffer()
defer packetBuffer.Release() defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Slice[:0]) buf := bytes.NewBuffer(packetBuffer.Slice[:0])

View file

@ -294,7 +294,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
Expect(origConnID).To(Equal(hdr.DestConnectionID)) Expect(origConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID // make sure we're using a server-generated connection ID
@ -305,7 +305,7 @@ var _ = Describe("Server", func() {
sess.EXPECT().run().Do(func() { close(run) }) sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil return sess
} }
done := make(chan struct{}) done := make(chan struct{})
@ -346,7 +346,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p) sess.EXPECT().handlePacket(p)
sess.EXPECT().run() sess.EXPECT().run()
@ -354,7 +354,7 @@ var _ = Describe("Server", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().HandshakeComplete().Return(ctx)
return sess, nil return sess
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -407,7 +407,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess.EXPECT().handlePacket(p) sess.EXPECT().handlePacket(p)
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().Context().Return(ctx) sess.EXPECT().Context().Return(ctx)
@ -415,7 +415,7 @@ var _ = Describe("Server", func() {
cancel() cancel()
sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().HandshakeComplete().Return(ctx)
close(sessionCreated) close(sessionCreated)
return sess, nil return sess
} }
serv.handlePacket(p) serv.handlePacket(p)
@ -504,14 +504,13 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().HandshakeComplete().Return(ctx)
sess.EXPECT().run().Do(func() {}) sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
return sess, nil return sess
} }
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake cancel() // complete the handshake
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -553,14 +552,13 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess.EXPECT().run().Do(func() {}) sess.EXPECT().run().Do(func() {})
sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
return sess, nil return sess
} }
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
close(ready) close(ready)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -591,7 +589,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
ready := make(chan struct{}) ready := make(chan struct{})
close(ready) close(ready)
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
@ -599,7 +597,7 @@ var _ = Describe("Server", func() {
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background()) sess.EXPECT().Context().Return(context.Background())
return sess, nil return sess
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -652,13 +650,13 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator, _ *handshake.TokenGenerator,
_ utils.Logger, _ utils.Logger,
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
sess.EXPECT().handlePacket(p) sess.EXPECT().handlePacket(p)
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().earlySessionReady() sess.EXPECT().earlySessionReady()
sess.EXPECT().Context().Return(ctx) sess.EXPECT().Context().Return(ctx)
close(sessionCreated) close(sessionCreated)
return sess, nil return sess
} }
serv.handlePacket(p) serv.handlePacket(p)

View file

@ -49,7 +49,7 @@ type streamManager interface {
type cryptoStreamHandler interface { type cryptoStreamHandler interface {
RunHandshake() RunHandshake()
ChangeConnectionID(protocol.ConnectionID) error ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) SetLargest1RTTAcked(protocol.PacketNumber)
io.Closer io.Closer
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
@ -188,7 +188,7 @@ var newSession = func(
tokenGenerator *handshake.TokenGenerator, tokenGenerator *handshake.TokenGenerator,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: runner, sessionRunner: runner,
@ -215,7 +215,7 @@ var newSession = func(
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer) oneRTTStream := newPostHandshakeCryptoStream(s.framer)
cs, err := handshake.NewCryptoSetupServer( cs := handshake.NewCryptoSetupServer(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream, oneRTTStream,
@ -232,9 +232,6 @@ var newSession = func(
s.rttStats, s.rttStats,
logger, logger,
) )
if err != nil {
return nil, err
}
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.packer = newPacketPacker( s.packer = newPacketPacker(
s.destConnID, s.destConnID,
@ -254,7 +251,7 @@ var newSession = func(
s.postSetup() s.postSetup()
s.unpacker = newPacketUnpacker(cs, s.version) s.unpacker = newPacketUnpacker(cs, s.version)
return s, nil return s
} }
// declare this as a variable, such that we can it mock it in the tests // declare this as a variable, such that we can it mock it in the tests
@ -270,7 +267,7 @@ var newClientSession = func(
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
) (quicSession, error) { ) quicSession {
s := &session{ s := &session{
conn: conn, conn: conn,
sessionRunner: runner, sessionRunner: runner,
@ -288,7 +285,7 @@ var newClientSession = func(
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer) oneRTTStream := newPostHandshakeCryptoStream(s.framer)
cs, clientHelloWritten, err := handshake.NewCryptoSetupClient( cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream, oneRTTStream,
@ -305,9 +302,6 @@ var newClientSession = func(
s.rttStats, s.rttStats,
logger, logger,
) )
if err != nil {
return nil, err
}
s.clientHelloWritten = clientHelloWritten s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
@ -346,7 +340,7 @@ var newClientSession = func(
} }
} }
s.postSetup() s.postSetup()
return s, nil return s
} }
func (s *session) preSetup() { func (s *session) preSetup() {

View file

@ -82,8 +82,7 @@ var _ = Describe("Session", func() {
mconn = newMockConnection() mconn = newMockConnection()
tokenGenerator, err := handshake.NewTokenGenerator() tokenGenerator, err := handshake.NewTokenGenerator()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var pSess Session sess = newSession(
pSess, err = newSession(
mconn, mconn,
sessionRunner, sessionRunner,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
@ -95,9 +94,7 @@ var _ = Describe("Session", func() {
tokenGenerator, tokenGenerator,
utils.DefaultLogger, utils.DefaultLogger,
protocol.VersionTLS, protocol.VersionTLS,
) ).(*session)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
streamManager = NewMockStreamManager(mockCtrl) streamManager = NewMockStreamManager(mockCtrl)
sess.streamsMap = streamManager sess.streamsMap = streamManager
packer = NewMockPacker(mockCtrl) packer = NewMockPacker(mockCtrl)
@ -1498,7 +1495,7 @@ var _ = Describe("Client Session", func() {
} }
mconn = newMockConnection() mconn = newMockConnection()
sessionRunner = NewMockSessionRunner(mockCtrl) sessionRunner = NewMockSessionRunner(mockCtrl)
sessP, err := newClientSession( sess = newClientSession(
mconn, mconn,
sessionRunner, sessionRunner,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
@ -1510,9 +1507,7 @@ var _ = Describe("Client Session", func() {
protocol.VersionTLS, protocol.VersionTLS,
utils.DefaultLogger, utils.DefaultLogger,
protocol.VersionTLS, protocol.VersionTLS,
) ).(*session)
sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred())
packer = NewMockPacker(mockCtrl) packer = NewMockPacker(mockCtrl)
sess.packer = packer sess.packer = packer
cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl)