diff --git a/mock_connection_test.go b/mock_connection_test.go new file mode 100644 index 00000000..4eb2b3ea --- /dev/null +++ b/mock_connection_test.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: Connection) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockConnection is a mock of Connection interface +type MockConnection struct { + ctrl *gomock.Controller + recorder *MockConnectionMockRecorder +} + +// MockConnectionMockRecorder is the mock recorder for MockConnection +type MockConnectionMockRecorder struct { + mock *MockConnection +} + +// NewMockConnection creates a new mock instance +func NewMockConnection(ctrl *gomock.Controller) *MockConnection { + mock := &MockConnection{ctrl: ctrl} + mock.recorder = &MockConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockConnection) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockConnectionMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close)) +} + +// LocalAddr mocks base method +func (m *MockConnection) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr +func (mr *MockConnectionMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockConnection)(nil).LocalAddr)) +} + +// Read mocks base method +func (m *MockConnection) Read(arg0 []byte) (int, net.Addr, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(net.Addr) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Read indicates an expected call of Read +func (mr *MockConnectionMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConnection)(nil).Read), arg0) +} + +// RemoteAddr mocks base method +func (m *MockConnection) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr +func (mr *MockConnectionMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockConnection)(nil).RemoteAddr)) +} + +// SetCurrentRemoteAddr mocks base method +func (m *MockConnection) SetCurrentRemoteAddr(arg0 net.Addr) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCurrentRemoteAddr", arg0) +} + +// SetCurrentRemoteAddr indicates an expected call of SetCurrentRemoteAddr +func (mr *MockConnectionMockRecorder) SetCurrentRemoteAddr(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCurrentRemoteAddr", reflect.TypeOf((*MockConnection)(nil).SetCurrentRemoteAddr), arg0) +} + +// Write mocks base method +func (m *MockConnection) Write(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write +func (mr *MockConnectionMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConnection)(nil).Write), arg0) +} diff --git a/mockgen.go b/mockgen.go index bced2fc5..4ea15f20 100644 --- a/mockgen.go +++ b/mockgen.go @@ -1,5 +1,6 @@ package quic +//go:generate sh -c "./mockgen_private.sh quic mock_connection_test.go github.com/lucas-clemente/quic-go connection" //go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI" diff --git a/session_test.go b/session_test.go index f09a1008..253d3307 100644 --- a/session_test.go +++ b/session_test.go @@ -75,7 +75,7 @@ var _ = Describe("Session", func() { var ( sess *session sessionRunner *MockSessionRunner - mconn *mockConnection + mconn *MockConnection streamManager *MockStreamManager packer *MockPacker cryptoSetup *mocks.MockCryptoSetup @@ -108,7 +108,8 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) sessionRunner = NewMockSessionRunner(mockCtrl) - mconn = newMockConnection() + mconn = NewMockConnection(mockCtrl) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2) tokenGenerator, err := handshake.NewTokenGenerator() Expect(err).ToNot(HaveOccurred()) sess = newSession( @@ -453,10 +454,9 @@ var _ = Describe("Session", func() { Expect(f.ReasonPhrase).To(BeEmpty()) return &packedPacket{raw: []byte("connection close")}, nil }) + mconn.EXPECT().Write([]byte("connection close")) Expect(sess.Close()).To(Succeed()) Eventually(areSessionsRunning).Should(BeFalse()) - Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written).To(Receive(ContainSubstring("connection close"))) Expect(sess.Context().Done()).To(BeClosed()) }) @@ -465,10 +465,10 @@ var _ = Describe("Session", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Expect(sess.Close()).To(Succeed()) Eventually(areSessionsRunning).Should(BeFalse()) - Expect(mconn.written).To(HaveLen(1)) Expect(sess.Context().Done()).To(BeClosed()) }) @@ -482,6 +482,7 @@ var _ = Describe("Session", func() { Expect(f.ReasonPhrase).To(Equal("test error")) return &packedPacket{}, nil }) + mconn.EXPECT().Write(gomock.Any()) sess.CloseWithError(0x1337, "test error") Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) @@ -499,6 +500,7 @@ var _ = Describe("Session", func() { Expect(f.ReasonPhrase).To(Equal("test error")) return &packedPacket{}, nil }) + mconn.EXPECT().Write(gomock.Any()) sess.closeLocal(testErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) @@ -515,6 +517,7 @@ var _ = Describe("Session", func() { Expect(f.ReasonPhrase).To(BeEmpty()) return &packedPacket{}, nil }) + mconn.EXPECT().Write(gomock.Any()) sess.CloseWithError(0x1337, "test error") Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) @@ -524,8 +527,8 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() + // don't EXPECT any calls to mconn.Write() sess.closeForRecreating() - Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent Eventually(areSessionsRunning).Should(BeFalse()) expectedRunErr = errCloseForRecreating }) @@ -535,9 +538,9 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() + // don't EXPECT any calls to mconn.Write() sess.destroy(testErr) Eventually(areSessionsRunning).Should(BeFalse()) - Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent expectedRunErr = testErr }) @@ -555,6 +558,7 @@ var _ = Describe("Session", func() { close(returned) }() Consistently(returned).ShouldNot(BeClosed()) + mconn.EXPECT().Write(gomock.Any()) sess.Close() Eventually(returned).Should(BeClosed()) }) @@ -646,6 +650,7 @@ var _ = Describe("Session", func() { }, nil)) Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return + mconn.EXPECT().Write(gomock.Any()) sess.closeLocal(errors.New("close")) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -665,6 +670,7 @@ var _ = Describe("Session", func() { close(done) }() expectReplaceWithClosed() + mconn.EXPECT().Write(gomock.Any()) sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -691,6 +697,7 @@ var _ = Describe("Session", func() { }, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return + mconn.EXPECT().Write(gomock.Any()) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -712,6 +719,7 @@ var _ = Describe("Session", func() { close(done) }() expectReplaceWithClosed() + mconn.EXPECT().Write(gomock.Any()) sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -784,16 +792,12 @@ var _ = Describe("Session", func() { hdr: &wire.ExtendedHeader{}, data: []byte{0}, // one PADDING frame }, nil) - origAddr := sess.conn.(*mockConnection).remoteAddr - remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - Expect(origAddr).ToNot(Equal(remoteIP)) packet := getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, }, nil) - packet.remoteAddr = remoteIP + packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(sess.handlePacketImpl(packet)).To(BeTrue()) - Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr)) }) }) @@ -904,6 +908,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -911,10 +916,10 @@ var _ = Describe("Session", func() { It("sends packets", func() { packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.Encryption1RTT, time.Now(), true) + mconn.EXPECT().Write(gomock.Any()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) - Eventually(mconn.written).Should(Receive()) }) It("doesn't send packets if there's nothing to send", func() { @@ -940,6 +945,7 @@ var _ = Describe("Session", func() { fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.connFlowController = fc + mconn.EXPECT().Write(gomock.Any()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) @@ -952,8 +958,7 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendNone) sess.sentPacketHandler = sph - err := sess.sendPackets() - Expect(err).ToNot(HaveOccurred()) + Expect(sess.sendPackets()).To(Succeed()) }) for _, enc := range []protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption1RTT} { @@ -989,6 +994,7 @@ var _ = Describe("Session", func() { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph + mconn.EXPECT().Write(gomock.Any()) Expect(sess.sendPackets()).To(Succeed()) }) @@ -1004,6 +1010,7 @@ var _ = Describe("Session", func() { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph + mconn.EXPECT().Write(gomock.Any()) Expect(sess.sendPackets()).To(Succeed()) // We're using a mock packet packer in this test. // We therefore need to test separately that the PING was actually queued. @@ -1028,6 +1035,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1040,14 +1048,14 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) // allow 2 packets... packer.EXPECT().PackPacket().Return(getPacket(10), nil) packer.EXPECT().PackPacket().Return(getPacket(11), nil) + mconn.EXPECT().Write(gomock.Any()).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(2)) - Consistently(mconn.written).Should(HaveLen(2)) + time.Sleep(50 * time.Millisecond) // make sure that only 2 packes are sent }) // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck @@ -1059,14 +1067,14 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) packer.EXPECT().PackPacket().Return(getPacket(100), nil) + mconn.EXPECT().Write(gomock.Any()) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - Consistently(mconn.written).Should(HaveLen(1)) + time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent }) It("paces packets", func() { @@ -1079,15 +1087,20 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() packer.EXPECT().PackPacket().Return(getPacket(100), nil) packer.EXPECT().PackPacket().Return(getPacket(101), nil) + written := make(chan struct{}, 2) + mconn.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + written <- struct{}{} + return len(p), nil + }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) - Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) + Eventually(written).Should(HaveLen(1)) + Consistently(written, pacingDelay/2).Should(HaveLen(1)) + Eventually(written, 2*pacingDelay).Should(HaveLen(2)) }) It("sends multiple packets at once", func() { @@ -1099,13 +1112,18 @@ var _ = Describe("Session", func() { packer.EXPECT().PackPacket().Return(getPacket(1000), nil) packer.EXPECT().PackPacket().Return(getPacket(1001), nil) packer.EXPECT().PackPacket().Return(getPacket(1002), nil) + written := make(chan struct{}, 3) + mconn.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + written <- struct{}{} + return len(p), nil + }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(3)) + Eventually(written).Should(HaveLen(3)) }) It("doesn't set a pacing timer when there is no data to send", func() { @@ -1113,17 +1131,29 @@ var _ = Describe("Session", func() { sph.EXPECT().ShouldSendNumPackets().Return(1) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() packer.EXPECT().PackPacket() + // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() sess.scheduleSending() // no packet will get sent - Consistently(mconn.written).ShouldNot(Receive()) + time.Sleep(50 * time.Millisecond) }) }) Context("scheduling sending", func() { + AfterEach(func() { + // make the go routine return + expectReplaceWithClosed() + streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + sess.Close() + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("sends when scheduleSending is called", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() @@ -1139,16 +1169,13 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() - Consistently(mconn.written).ShouldNot(Receive()) + // don't EXPECT any calls to mconn.Write() + time.Sleep(50 * time.Millisecond) + // only EXPECT calls after scheduleSending is called + written := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(written) }) sess.scheduleSending() - Eventually(mconn.written).Should(Receive()) - // make the go routine return - expectReplaceWithClosed() - streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - cryptoSetup.EXPECT().Close() - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(written).Should(BeClosed()) }) It("sets the timer to the ack timer", func() { @@ -1169,19 +1196,14 @@ var _ = Describe("Session", func() { rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) sess.receivedPacketHandler = rph + written := make(chan struct{}) + mconn.EXPECT().Write(gomock.Any()).Do(func([]byte) { close(written) }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) sess.run() }() - Eventually(mconn.written).Should(Receive()) - // make sure the go routine returns - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - expectReplaceWithClosed() - streamManager.EXPECT().CloseWithError(gomock.Any()) - cryptoSetup.EXPECT().Close() - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(written).Should(BeClosed()) }) }) @@ -1206,6 +1228,7 @@ var _ = Describe("Session", func() { }() handshakeCtx := sess.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token close(finishHandshake) Eventually(handshakeCtx.Done()).Should(BeClosed()) Eventually(sphNotified).Should(BeClosed()) @@ -1214,6 +1237,7 @@ var _ = Describe("Session", func() { expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1231,6 +1255,7 @@ var _ = Describe("Session", func() { }() handshakeCtx := sess.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + mconn.EXPECT().Write(gomock.Any()) sess.closeLocal(errors.New("handshake error")) Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1254,6 +1279,8 @@ var _ = Describe("Session", func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().DropHandshakeKeys() + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token + mconn.EXPECT().Write(gomock.Any()) close(sess.handshakeCompleteChan) sess.run() }() @@ -1263,6 +1290,7 @@ var _ = Describe("Session", func() { expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1279,6 +1307,7 @@ var _ = Describe("Session", func() { expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1297,6 +1326,7 @@ var _ = Describe("Session", func() { expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.CloseWithError(0x1337, testErr.Error())).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1326,14 +1356,14 @@ var _ = Describe("Session", func() { // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) Expect(s.Close()).To(Succeed()) }).Times(4) // initial connection ID + initial client dest conn ID + 2 newly issued conn IDs packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() - sess.Close() + mconn.EXPECT().Write(gomock.Any()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) }) @@ -1364,6 +1394,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1400,7 +1431,8 @@ var _ = Describe("Session", func() { sess.config.KeepAlive = false sess.lastPacketReceivedTime = time.Now().Add(-time.Second * 5 / 2) runSession() - Consistently(mconn.written).ShouldNot(Receive()) + // don't EXPECT() any calls to mconn.Write() + time.Sleep(50 * time.Millisecond) }) It("doesn't send a PING if the handshake isn't completed yet", func() { @@ -1409,7 +1441,8 @@ var _ = Describe("Session", func() { // Otherwise we'll try to send a CONNECTION_CLOSE. sess.lastPacketReceivedTime = time.Now().Add(-20 * time.Second) runSession() - Consistently(mconn.written).ShouldNot(Receive()) + // don't EXPECT() any calls to mconn.Write() + time.Sleep(50 * time.Millisecond) }) }) @@ -1475,7 +1508,8 @@ var _ = Describe("Session", func() { sess.handshakeComplete = true expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - sess.Close() + mconn.EXPECT().Write(gomock.Any()) + Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1492,6 +1526,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) close(sess.handshakeCompleteChan) err := sess.run() nerr, ok := err.(net.Error) @@ -1517,6 +1552,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1586,13 +1622,13 @@ var _ = Describe("Session", func() { It("returns the local address", func() { addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - mconn.localAddr = addr + mconn.EXPECT().LocalAddr().Return(addr) Expect(sess.LocalAddr()).To(Equal(addr)) }) It("returns the remote address", func() { addr := &net.UDPAddr{IP: net.IPv4(1, 2, 7, 1), Port: 7331} - mconn.remoteAddr = addr + mconn.EXPECT().RemoteAddr().Return(addr) Expect(sess.RemoteAddr()).To(Equal(addr)) }) }) @@ -1602,7 +1638,7 @@ var _ = Describe("Client Session", func() { sess *session sessionRunner *MockSessionRunner packer *MockPacker - mconn *mockConnection + mconn *MockConnection cryptoSetup *mocks.MockCryptoSetup tlsConf *tls.Config quicConf *Config @@ -1628,15 +1664,18 @@ var _ = Describe("Client Session", func() { BeforeEach(func() { quicConf = populateClientConfig(&Config{}, true) + tlsConf = nil }) JustBeforeEach(func() { Eventually(areSessionsRunning).Should(BeFalse()) + mconn = NewMockConnection(mockCtrl) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2) if tlsConf == nil { + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) tlsConf = &tls.Config{} } - mconn = newMockConnection() sessionRunner = NewMockSessionRunner(mockCtrl) sess = newClientSession( mconn, @@ -1687,6 +1726,7 @@ var _ = Describe("Client Session", func() { packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1812,6 +1852,7 @@ var _ = Describe("Client Session", func() { }) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil).MaxTimes(1) cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) } closed = true }