remove unused return value from session constructor

This commit is contained in:
Marten Seemann 2019-10-27 13:30:27 +07:00
parent 672328ca30
commit 416fe8364e
13 changed files with 99 additions and 166 deletions

View file

@ -259,10 +259,7 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
if err := c.createNewTLSSession(c.version); err != nil {
return err
}
c.createNewTLSSession(c.version)
err := c.establishSecureConnection(ctx)
if err == errCloseForRecreating {
return c.dial(ctx)
@ -357,7 +354,7 @@ func (c *client) handleVersionNegotiationPacket(p *receivedPacket) {
c.initialPacketNumber = c.session.closeForRecreating()
}
func (c *client) createNewTLSSession(_ protocol.VersionNumber) error {
func (c *client) createNewTLSSession(_ protocol.VersionNumber) {
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -372,8 +369,7 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) error {
}
c.mutex.Lock()
defer c.mutex.Unlock()
sess, err := newClientSession(
c.session = newClientSession(
c.conn,
c.packetHandlers,
c.destConnID,
@ -386,12 +382,8 @@ func (c *client) createNewTLSSession(_ protocol.VersionNumber) error {
c.logger,
c.version,
)
if err != nil {
return err
}
c.session = sess
c.mutex.Unlock()
c.packetHandlers.Add(c.srcConnID, c)
return nil
}
func (c *client) Close() error {

View file

@ -42,7 +42,7 @@ var _ = Describe("Client", func() {
initialVersion protocol.VersionNumber,
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error)
) quicSession
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -145,12 +145,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
remoteAddrChan <- conn.RemoteAddr().String()
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
return sess
}
_, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond})
Expect(err).ToNot(HaveOccurred())
@ -176,12 +176,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
hostnameChan <- tlsConf.ServerName
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
return sess
}
tlsConf.ServerName = "foobar"
_, err := DialAddr("localhost:17890", tlsConf, nil)
@ -207,12 +207,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
hostnameChan <- tlsConf.ServerName
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().HandshakeComplete().Return(context.Background())
sess.EXPECT().run()
return sess, nil
return sess
}
_, err := Dial(
packetConn,
@ -243,13 +243,13 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) })
ctx, cancel := context.WithCancel(context.Background())
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
return sess, nil
return sess
}
s, err := Dial(
packetConn,
@ -281,11 +281,11 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr)
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
return sess
}
packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID)
_, err := Dial(
@ -322,8 +322,8 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
return sess, nil
) quicSession {
return sess
}
ctx, cancel := context.WithCancel(context.Background())
dialed := make(chan struct{})
@ -366,9 +366,9 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
runner = runnerP
return sess, nil
return sess
}
sess.EXPECT().run().Do(func() {
runner.Retire(connID)
@ -411,10 +411,10 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
conn = connP
close(sessionCreated)
return sess, nil
return sess
}
sess.EXPECT().run().Do(func() {
<-run
@ -531,7 +531,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, /* initial version */
_ utils.Logger,
versionP protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
cconn = connP
version = versionP
conf = configP
@ -540,7 +540,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
return sess
}
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
@ -580,12 +580,12 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr)
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
return sess
}
_, err := Dial(
packetConn,

View file

@ -126,8 +126,8 @@ func NewCryptoSetupClient(
tlsConf *tls.Config,
rttStats *congestion.RTTStats,
logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
cs, clientHelloWritten, err := newCryptoSetup(
) (CryptoSetup, <-chan struct{} /* ClientHello written */) {
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
@ -139,11 +139,8 @@ func NewCryptoSetupClient(
logger,
protocol.PerspectiveClient,
)
if err != nil {
return nil, nil, err
}
cs.conn = qtls.Client(newConn(remoteAddr), cs.tlsConf)
return cs, clientHelloWritten, nil
return cs, clientHelloWritten
}
// NewCryptoSetupServer creates a new crypto setup for the server
@ -158,8 +155,8 @@ func NewCryptoSetupServer(
tlsConf *tls.Config,
rttStats *congestion.RTTStats,
logger utils.Logger,
) (CryptoSetup, error) {
cs, _, err := newCryptoSetup(
) CryptoSetup {
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
@ -171,11 +168,8 @@ func NewCryptoSetupServer(
logger,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
cs.conn = qtls.Server(newConn(remoteAddr), cs.tlsConf)
return cs, nil
return cs
}
func newCryptoSetup(
@ -189,11 +183,8 @@ func newCryptoSetup(
rttStats *congestion.RTTStats,
logger utils.Logger,
perspective protocol.Perspective,
) (*cryptoSetup, <-chan struct{} /* ClientHello written */, error) {
initialSealer, initialOpener, err := NewInitialAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
) (*cryptoSetup, <-chan struct{} /* ClientHello written */) {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective)
extHandler := newExtensionHandler(tp.Marshal(), perspective)
cs := &cryptoSetup{
initialStream: initialStream,
@ -219,17 +210,13 @@ func newCryptoSetup(
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler)
cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan, nil
return cs, cs.clientHelloWrittenChan
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error {
initialSealer, initialOpener, err := NewInitialAEAD(id, h.perspective)
if err != nil {
return err
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective)
h.initialSealer = initialSealer
h.initialOpener = initialOpener
return nil
}
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) {

View file

@ -86,7 +86,7 @@ var _ = Describe("Crypto Setup TLS", func() {
return &tls.Config{ServerName: ch.ServerName}, nil
},
}
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
&bytes.Buffer{},
&bytes.Buffer{},
ioutil.Discard,
@ -98,7 +98,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
qtlsConf := server.(*cryptoSetup).tlsConf
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
_, getCertificateErr := qtlsConf.GetCertificate(nil)
@ -118,7 +117,7 @@ var _ = Describe("Crypto Setup TLS", func() {
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream, _ := initStreams()
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -130,7 +129,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -156,7 +154,7 @@ var _ = Describe("Crypto Setup TLS", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -168,7 +166,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -179,6 +176,7 @@ var _ = Describe("Crypto Setup TLS", func() {
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
var err error
Expect(sErrChan).To(Receive(&err))
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
qerr := err.(*qerr.QuicError)
@ -196,7 +194,7 @@ var _ = Describe("Crypto Setup TLS", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -208,7 +206,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -230,7 +227,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams()
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -242,7 +239,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -319,7 +315,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
client, _, err := NewCryptoSetupClient(
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
cOneRTTStream,
@ -331,7 +327,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams()
@ -341,7 +336,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
var token [16]byte
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
sOneRTTStream,
@ -353,7 +348,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
handshake(client, cChunkChan, server, sChunkChan)
var cErr, sErr error
@ -394,7 +388,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
client, chChan, err := NewCryptoSetupClient(
client, chChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
@ -406,7 +400,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -435,7 +428,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b })
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
@ -447,7 +440,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
var token [16]byte
@ -458,7 +450,7 @@ var _ = Describe("Crypto Setup TLS", func() {
IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: &token,
}
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -470,7 +462,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -495,7 +486,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
@ -507,13 +498,12 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -525,7 +515,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -552,7 +541,7 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _, err := NewCryptoSetupClient(
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
@ -564,13 +553,12 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server, err := NewCryptoSetupServer(
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
@ -582,7 +570,6 @@ var _ = Describe("Crypto Setup TLS", func() {
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {

View file

@ -17,7 +17,7 @@ var initialSuite = &qtls.CipherSuiteTLS13{
}
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener, error) {
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
@ -34,8 +34,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Lo
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true)),
nil
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true))
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {

View file

@ -55,8 +55,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
})
It("encrypts the client's Initial", func() {
sealer, _, err := NewInitialAEAD(connID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient)
header := split("c3ff000017088394c8f03e5157080000449e00000002")
data := split("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001")
data = append(data, make([]byte, 1162-len(data))...) // add PADDING
@ -71,8 +70,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
})
It("encrypt the server's Initial", func() {
sealer, _, err := NewInitialAEAD(connID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer)
header := split("c1ff0000170008f067a5502a4262b50040740001")
data := split("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304")
sealed := sealer.Seal(nil, data, 1, header)
@ -87,10 +85,8 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
It("seals and opens", func() {
connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}
clientSealer, clientOpener, err := NewInitialAEAD(connectionID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverSealer, serverOpener, err := NewInitialAEAD(connectionID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient)
serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer)
clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad"))
m, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad"))
@ -105,22 +101,18 @@ var _ = Describe("Initial AEAD using AES-GCM", func() {
It("doesn't work if initialized with different connection IDs", func() {
c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1}
c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2}
clientSealer, _, err := NewInitialAEAD(c1, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
_, serverOpener, err := NewInitialAEAD(c2, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient)
_, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer)
clientMessage := clientSealer.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err = serverOpener.Open(nil, clientMessage, 42, []byte("aad"))
_, err := serverOpener.Open(nil, clientMessage, 42, []byte("aad"))
Expect(err).To(MatchError(ErrDecryptionFailed))
})
It("encrypts und decrypts the header", func() {
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
clientSealer, clientOpener, err := NewInitialAEAD(connID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverSealer, serverOpener, err := NewInitialAEAD(connID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient)
serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer)
// the first byte and the last 4 bytes should be encrypted
header := []byte{0x5e, 0, 1, 2, 3, 4, 0xde, 0xad, 0xbe, 0xef}

View file

@ -69,7 +69,7 @@ type handshakeRunner interface {
type CryptoSetup interface {
RunHandshake()
io.Closer
ChangeConnectionID(protocol.ConnectionID) error
ChangeConnectionID(protocol.ConnectionID)
HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber)

View file

@ -37,11 +37,9 @@ func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
}
// ChangeConnectionID mocks base method
func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) error {
func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ChangeConnectionID", arg0)
ret0, _ := ret[0].(error)
return ret0
m.ctrl.Call(m, "ChangeConnectionID", arg0)
}
// ChangeConnectionID indicates an expected call of ChangeConnectionID

View file

@ -65,7 +65,7 @@ func ComposeAckFrame(smallest protocol.PacketNumber, largest protocol.PacketNumb
// ComposeInitialPacket returns an Initial packet encrypted under key
// (the original destination connection ID) containing specified frames
func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.ConnectionID, version protocol.VersionNumber, key protocol.ConnectionID, frames []wire.Frame) []byte {
sealer, _, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer)
sealer, _ := handshake.NewInitialAEAD(key, protocol.PerspectiveServer)
// compose payload
var payload []byte

View file

@ -82,7 +82,7 @@ type baseServer struct {
sessionHandler packetHandlerManager
// set as a member, so they can be set in the tests
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) (quicSession, error)
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, *handshake.TokenGenerator, utils.Logger, protocol.VersionNumber) quicSession
serverError error
errorChan chan struct{}
@ -424,7 +424,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
return nil, nil, err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
sess, err := s.createNewSession(
sess := s.createNewSession(
p.remoteAddr,
origDestConnectionID,
hdr.DestConnectionID,
@ -432,9 +432,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (qui
connID,
hdr.Version,
)
if err != nil {
return nil, nil, err
}
sess.handlePacket(p)
return sess, connID, nil
}
@ -446,7 +443,7 @@ func (s *baseServer) createNewSession(
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
version protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
token := s.sessionHandler.GetStatelessResetToken(srcConnID)
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -462,7 +459,7 @@ func (s *baseServer) createNewSession(
StatelessResetToken: &token,
OriginalConnectionID: origDestConnID,
}
sess, err := s.newSession(
sess := s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionHandler,
clientDestConnID,
@ -475,12 +472,9 @@ func (s *baseServer) createNewSession(
s.logger,
version,
)
if err != nil {
return nil, err
}
go sess.run()
go s.handleNewSession(sess)
return sess, nil
return sess
}
func (s *baseServer) handleNewSession(sess quicSession) {
@ -542,10 +536,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
}
func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error {
sealer, _, err := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
if err != nil {
return err
}
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer)
packetBuffer := getPacketBuffer()
defer packetBuffer.Release()
buf := bytes.NewBuffer(packetBuffer.Slice[:0])

View file

@ -294,7 +294,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
Expect(origConnID).To(Equal(hdr.DestConnectionID))
Expect(destConnID).To(Equal(hdr.SrcConnectionID))
// make sure we're using a server-generated connection ID
@ -305,7 +305,7 @@ var _ = Describe("Server", func() {
sess.EXPECT().run().Do(func() { close(run) })
sess.EXPECT().Context().Return(context.Background())
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess, nil
return sess
}
done := make(chan struct{})
@ -346,7 +346,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
@ -354,7 +354,7 @@ var _ = Describe("Server", func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
return sess, nil
return sess
}
var wg sync.WaitGroup
@ -407,7 +407,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().Context().Return(ctx)
@ -415,7 +415,7 @@ var _ = Describe("Server", func() {
cancel()
sess.EXPECT().HandshakeComplete().Return(ctx)
close(sessionCreated)
return sess, nil
return sess
}
serv.handlePacket(p)
@ -504,14 +504,13 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess.EXPECT().HandshakeComplete().Return(ctx)
sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
return sess, nil
return sess
}
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake
Eventually(done).Should(BeClosed())
@ -553,14 +552,13 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess.EXPECT().run().Do(func() {})
sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background())
return sess, nil
return sess
}
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed())
close(ready)
Eventually(done).Should(BeClosed())
@ -591,7 +589,7 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
ready := make(chan struct{})
close(ready)
sess := NewMockQuicSession(mockCtrl)
@ -599,7 +597,7 @@ var _ = Describe("Server", func() {
sess.EXPECT().run()
sess.EXPECT().earlySessionReady().Return(ready)
sess.EXPECT().Context().Return(context.Background())
return sess, nil
return sess
}
var wg sync.WaitGroup
@ -652,13 +650,13 @@ var _ = Describe("Server", func() {
_ *handshake.TokenGenerator,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().earlySessionReady()
sess.EXPECT().Context().Return(ctx)
close(sessionCreated)
return sess, nil
return sess
}
serv.handlePacket(p)

View file

@ -49,7 +49,7 @@ type streamManager interface {
type cryptoStreamHandler interface {
RunHandshake()
ChangeConnectionID(protocol.ConnectionID) error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber)
io.Closer
ConnectionState() tls.ConnectionState
@ -188,7 +188,7 @@ var newSession = func(
tokenGenerator *handshake.TokenGenerator,
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
s := &session{
conn: conn,
sessionRunner: runner,
@ -215,7 +215,7 @@ var newSession = func(
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
cs, err := handshake.NewCryptoSetupServer(
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
oneRTTStream,
@ -232,9 +232,6 @@ var newSession = func(
s.rttStats,
logger,
)
if err != nil {
return nil, err
}
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(
s.destConnID,
@ -254,7 +251,7 @@ var newSession = func(
s.postSetup()
s.unpacker = newPacketUnpacker(cs, s.version)
return s, nil
return s
}
// declare this as a variable, such that we can it mock it in the tests
@ -270,7 +267,7 @@ var newClientSession = func(
initialVersion protocol.VersionNumber,
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error) {
) quicSession {
s := &session{
conn: conn,
sessionRunner: runner,
@ -288,7 +285,7 @@ var newClientSession = func(
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
cs, clientHelloWritten, err := handshake.NewCryptoSetupClient(
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream,
handshakeStream,
oneRTTStream,
@ -305,9 +302,6 @@ var newClientSession = func(
s.rttStats,
logger,
)
if err != nil {
return nil, err
}
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
@ -346,7 +340,7 @@ var newClientSession = func(
}
}
s.postSetup()
return s, nil
return s
}
func (s *session) preSetup() {

View file

@ -82,8 +82,7 @@ var _ = Describe("Session", func() {
mconn = newMockConnection()
tokenGenerator, err := handshake.NewTokenGenerator()
Expect(err).ToNot(HaveOccurred())
var pSess Session
pSess, err = newSession(
sess = newSession(
mconn,
sessionRunner,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
@ -95,9 +94,7 @@ var _ = Describe("Session", func() {
tokenGenerator,
utils.DefaultLogger,
protocol.VersionTLS,
)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
).(*session)
streamManager = NewMockStreamManager(mockCtrl)
sess.streamsMap = streamManager
packer = NewMockPacker(mockCtrl)
@ -1498,7 +1495,7 @@ var _ = Describe("Client Session", func() {
}
mconn = newMockConnection()
sessionRunner = NewMockSessionRunner(mockCtrl)
sessP, err := newClientSession(
sess = newClientSession(
mconn,
sessionRunner,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
@ -1510,9 +1507,7 @@ var _ = Describe("Client Session", func() {
protocol.VersionTLS,
utils.DefaultLogger,
protocol.VersionTLS,
)
sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred())
).(*session)
packer = NewMockPacker(mockCtrl)
sess.packer = packer
cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl)