diff --git a/client_test.go b/client_test.go index 5eca0f66..7d530938 100644 --- a/client_test.go +++ b/client_test.go @@ -7,9 +7,9 @@ import ( "fmt" "net" "os" - "sync/atomic" "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -23,7 +23,6 @@ import ( var _ = Describe("Client", func() { var ( cl *client - sess *mockSession packetConn *mockPacketConn addr net.Addr connID protocol.ConnectionID @@ -43,13 +42,13 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) return b.Bytes() } + _ = acceptClientVersionPacket BeforeEach(func() { connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil) - sess = msess.(*mockSession) + // sess = NewMockPacketHandler(mockCtrl) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = newMockPacketConn() packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} @@ -57,7 +56,6 @@ var _ = Describe("Client", func() { cl = &client{ srcConnID: connID, destConnID: connID, - session: sess, version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, logger: utils.DefaultLogger, @@ -93,8 +91,7 @@ var _ = Describe("Client", func() { if os.Getenv("APPVEYOR") == "True" { Skip("This test is flaky on AppVeyor.") } - closeErr := errors.New("peer doesn't reply") - remoteAddrChan := make(chan string) + remoteAddrChan := make(chan string, 1) newClientSession = func( conn connection, _ sessionRunner, @@ -108,23 +105,17 @@ var _ = Describe("Client", func() { _ utils.Logger, ) (packetHandler, error) { remoteAddrChan <- conn.RemoteAddr().String() + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run() return sess, nil } - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := DialAddr("localhost:17890", nil, &Config{HandshakeTimeout: time.Millisecond}) - Expect(err).To(MatchError(closeErr)) - close(dialed) - }() + _, err := DialAddr("localhost:17890", nil, &Config{HandshakeTimeout: time.Millisecond}) + Expect(err).ToNot(HaveOccurred()) Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) - sess.Close(closeErr) - Eventually(dialed).Should(BeClosed()) }) It("uses the tls.Config.ServerName as the hostname, if present", func() { - closeErr := errors.New("peer doesn't reply") - hostnameChan := make(chan string) + hostnameChan := make(chan string, 1) newClientSession = func( _ connection, _ sessionRunner, @@ -138,18 +129,13 @@ var _ = Describe("Client", func() { _ utils.Logger, ) (packetHandler, error) { hostnameChan <- h + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run() return sess, nil } - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil) - Expect(err).To(MatchError(closeErr)) - close(dialed) - }() + _, err := DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil) + Expect(err).ToNot(HaveOccurred()) Eventually(hostnameChan).Should(Receive(Equal("foobar"))) - sess.Close(closeErr) - Eventually(dialed).Should(BeClosed()) }) It("errors when receiving an error from the connection", func() { @@ -160,6 +146,7 @@ var _ = Describe("Client", func() { }) It("returns after the handshake is complete", func() { + run := make(chan struct{}) newClientSession = func( _ connection, runner sessionRunner, @@ -172,25 +159,22 @@ var _ = Describe("Client", func() { _ []protocol.VersionNumber, _ utils.Logger, ) (packetHandler, error) { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run().Do(func() { close(run) }) + sess.EXPECT().handlePacket(gomock.Any()) runner.onHandshakeComplete(sess) return sess, nil } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(s).ToNot(BeNil()) - close(dialed) - }() - Eventually(dialed).Should(BeClosed()) - // make the session run loop return - close(sess.stopRunLoop) + s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Eventually(run).Should(BeClosed()) }) It("returns an error that occurs while waiting for the connection to become secure", func() { testErr := errors.New("early handshake error") + handledPacket := make(chan struct{}) newClientSession = func( conn connection, _ sessionRunner, @@ -203,20 +187,15 @@ var _ = Describe("Client", func() { _ []protocol.VersionNumber, _ utils.Logger, ) (packetHandler, error) { - Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { close(handledPacket) }) + sess.EXPECT().run().Return(testErr) return sess, nil } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) - Expect(err).To(MatchError(testErr)) - close(done) - }() - sess.closeReason = testErr - close(sess.stopRunLoop) - Eventually(done).Should(BeClosed()) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) + Expect(err).To(MatchError(testErr)) + Eventually(handledPacket).Should(BeClosed()) }) Context("quic.Config", func() { @@ -320,21 +299,17 @@ var _ = Describe("Client", func() { conf = configP close(c) // TODO: check connection IDs? + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run() return sess, nil } - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) - close(dialed) - }() + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) Expect(cconn.(*conn).pconn).To(Equal(packetConn)) Expect(hostname).To(Equal("quic.clemente.io")) Expect(version).To(Equal(config.Versions[0])) Expect(conf.Versions).To(Equal(config.Versions)) - sess.Close(errors.New("peer doesn't reply")) - Eventually(dialed).Should(BeClosed()) }) }) @@ -351,6 +326,7 @@ var _ = Describe("Client", func() { }) It("returns an error that occurs during version negotiation", func() { + testErr := errors.New("early handshake error") newClientSession = func( conn connection, _ sessionRunner, @@ -364,21 +340,18 @@ var _ = Describe("Client", func() { _ utils.Logger, ) (packetHandler, error) { Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run().Return(testErr) return sess, nil } - testErr := errors.New("early handshake error") - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) - Expect(err).To(MatchError(testErr)) - close(done) - }() - sess.Close(testErr) - Eventually(done).Should(BeClosed()) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) + Expect(err).To(MatchError(testErr)) }) It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) + cl.session = sess ph := wire.Header{ PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, @@ -394,68 +367,18 @@ var _ = Describe("Client", func() { }) It("changes the version after receiving a version negotiation packet", func() { - var initialVersion protocol.VersionNumber - var negotiatedVersions []protocol.VersionNumber - newVersion := protocol.VersionNumber(77) - Expect(newVersion).ToNot(Equal(cl.version)) - cl.config = &Config{Versions: []protocol.VersionNumber{newVersion}} - sessionChan := make(chan *mockSession) - stopRunLoop := make(chan struct{}) - newClientSession = func( - _ connection, - _ sessionRunner, - _ string, - _ protocol.VersionNumber, - connectionID protocol.ConnectionID, - _ *tls.Config, - _ *Config, - initialVersionP protocol.VersionNumber, - negotiatedVersionsP []protocol.VersionNumber, - _ utils.Logger, - ) (packetHandler, error) { - initialVersion = initialVersionP - negotiatedVersions = negotiatedVersionsP - - sess := &mockSession{ - connectionID: connectionID, - stopRunLoop: stopRunLoop, - } - sessionChan <- sess - return sess, nil - } - - established := make(chan struct{}) - go func() { - defer GinkgoRecover() - err := cl.dial() - Expect(err).ToNot(HaveOccurred()) - close(established) - }() - go cl.listen() - - actualInitialVersion := cl.version - var firstSession, secondSession *mockSession - Eventually(sessionChan).Should(Receive(&firstSession)) - packetConn.dataToRead <- wire.ComposeGQUICVersionNegotiation( - connID, - []protocol.VersionNumber{newVersion}, - ) - // it didn't pass the version negoation packet to the old session (since it has no payload) - Eventually(func() bool { return firstSession.closed }).Should(BeTrue()) - Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion)) - Expect(firstSession.handledPackets).To(BeEmpty()) - Eventually(sessionChan).Should(Receive(&secondSession)) - // make the server accept the new version - packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID) - Consistently(func() bool { return secondSession.closed }).Should(BeFalse()) - Expect(negotiatedVersions).To(ContainElement(newVersion)) - Expect(initialVersion).To(Equal(actualInitialVersion)) - - Eventually(established).Should(BeClosed()) - }) - - It("only accepts one version negotiation packet", func() { - sessionCounter := uint32(0) + version1 := protocol.Version39 + version2 := protocol.Version39 + 1 + Expect(version2.UsesTLS()).To(BeFalse()) + sess1 := NewMockPacketHandler(mockCtrl) + run1 := make(chan struct{}) + sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion) + sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) }) + sess2 := NewMockPacketHandler(mockCtrl) + sess2.EXPECT().run() + sessionChan := make(chan *MockPacketHandler, 2) + sessionChan <- sess1 + sessionChan <- sess2 newClientSession = func( _ connection, _ sessionRunner, @@ -468,42 +391,94 @@ var _ = Describe("Client", func() { _ []protocol.VersionNumber, _ utils.Logger, ) (packetHandler, error) { - atomic.AddUint32(&sessionCounter, 1) - return &mockSession{ - connectionID: connectionID, - stopRunLoop: make(chan struct{}), - }, nil + return <-sessionChan, nil } - go cl.dial() - Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) - cl.config = &Config{Versions: []protocol.VersionNumber{77, 78}} - err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{77})) + + cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}} + dialed := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cl.dial() + Expect(err).ToNot(HaveOccurred()) + close(dialed) + }() + Eventually(sessionChan).Should(HaveLen(1)) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) Expect(err).ToNot(HaveOccurred()) - Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) - err = cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{78})) + Eventually(sessionChan).Should(BeEmpty()) + }) + + It("only accepts one version negotiation packet", func() { + version1 := protocol.Version39 + version2 := protocol.Version39 + 1 + version3 := protocol.Version39 + 2 + Expect(version2.UsesTLS()).To(BeFalse()) + Expect(version3.UsesTLS()).To(BeFalse()) + sess1 := NewMockPacketHandler(mockCtrl) + run1 := make(chan struct{}) + sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion) + sess1.EXPECT().Close(errCloseSessionForNewVersion).Do(func(error) { close(run1) }) + sess2 := NewMockPacketHandler(mockCtrl) + sess2.EXPECT().run() + sessionChan := make(chan *MockPacketHandler, 2) + sessionChan <- sess1 + sessionChan <- sess2 + newClientSession = func( + _ connection, + _ sessionRunner, + _ string, + _ protocol.VersionNumber, + connectionID protocol.ConnectionID, + _ *tls.Config, + _ *Config, + _ protocol.VersionNumber, + _ []protocol.VersionNumber, + _ utils.Logger, + ) (packetHandler, error) { + return <-sessionChan, nil + } + + cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}} + dialed := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := cl.dial() + Expect(err).ToNot(HaveOccurred()) + close(dialed) + }() + Eventually(sessionChan).Should(HaveLen(1)) + err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2})) + Expect(err).ToNot(HaveOccurred()) + Eventually(sessionChan).Should(BeEmpty()) + err = cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3})) Expect(err).To(MatchError("received a delayed Version Negotiation Packet")) - Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2)) + Eventually(dialed).Should(BeClosed()) }) It("errors if no matching version is found", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().Close(gomock.Any()) + cl.session = sess cl.config = &Config{Versions: protocol.SupportedVersions} err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1})) Expect(err).ToNot(HaveOccurred()) - Expect(cl.session.(*mockSession).closed).To(BeTrue()) - Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().Close(gomock.Any()) + cl.session = sess v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v})) Expect(err).ToNot(HaveOccurred()) - Expect(cl.session.(*mockSession).closed).To(BeTrue()) - Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("changes to the version preferred by the quic.Config", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().Close(errCloseSessionForNewVersion) + cl.session = sess config := &Config{Versions: []protocol.VersionNumber{1234, 4321}} cl.config = config err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234})) @@ -521,14 +496,14 @@ var _ = Describe("Client", func() { }) It("ignores packets with an invalid public header", func() { + cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls err := cl.handlePacket(addr, []byte("invalid packet")) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error parsing packet from")) - Expect(sess.handledPackets).To(BeEmpty()) - Expect(sess.closed).To(BeFalse()) }) It("errors on packets that are smaller than the Payload Length in the packet header", func() { + cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, @@ -540,11 +515,14 @@ var _ = Describe("Client", func() { } Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) - Expect(sess.handledPackets).To(BeEmpty()) - Expect(sess.closed).To(BeFalse()) }) It("cuts packets at the payload length", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) { + Expect(packet.data).To(HaveLen(123)) + }) + cl.session = sess b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, @@ -557,8 +535,6 @@ var _ = Describe("Client", func() { Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) Expect(err).ToNot(HaveOccurred()) - Expect(sess.handledPackets).To(HaveLen(1)) - Expect(sess.handledPackets[0].data).To(HaveLen(123)) }) It("ignores packets with the wrong Long Header Type", func() { @@ -577,6 +553,7 @@ var _ = Describe("Client", func() { }) It("ignores packets without connection id, if it didn't request connection id trunctation", func() { + cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls cl.config = &Config{RequestConnectionIDOmission: false} buf := &bytes.Buffer{} err := (&wire.Header{ @@ -589,11 +566,10 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) err = cl.handlePacket(addr, buf.Bytes()) Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation")) - Expect(sess.handledPackets).To(BeEmpty()) - Expect(sess.closed).To(BeFalse()) }) It("ignores packets with the wrong destination connection ID", func() { + cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any handlePacket calls buf := &bytes.Buffer{} cl.version = versionIETFFrames cl.config = &Config{RequestConnectionIDOmission: false} @@ -609,13 +585,10 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) err = cl.handlePacket(addr, buf.Bytes()) Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID))) - Expect(sess.handledPackets).To(BeEmpty()) - Expect(sess.closed).To(BeFalse()) }) - It("creates new GQUIC sessions with the right parameters", func() { + It("creates new gQUIC sessions with the right parameters", func() { config := &Config{Versions: protocol.SupportedVersions} - closeErr := errors.New("peer doesn't reply") c := make(chan struct{}) var cconn connection var hostname string @@ -638,28 +611,27 @@ var _ = Describe("Client", func() { version = versionP conf = configP close(c) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run() return sess, nil } - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) - Expect(err).To(MatchError(closeErr)) - close(dialed) - }() + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) Expect(cconn.(*conn).pconn).To(Equal(packetConn)) Expect(hostname).To(Equal("quic.clemente.io")) Expect(version).To(Equal(config.Versions[0])) Expect(conf.Versions).To(Equal(config.Versions)) - sess.Close(closeErr) - Eventually(dialed).Should(BeClosed()) }) It("creates a new session when the server performs a retry", func() { config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} cl.config = config - sessionChan := make(chan *mockSession) + sess1 := NewMockPacketHandler(mockCtrl) + sess1.EXPECT().run().Return(handshake.ErrCloseSessionForRetry) + sess2 := NewMockPacketHandler(mockCtrl) + sess2.EXPECT().run() + sessions := []*MockPacketHandler{sess1, sess2} newTLSClientSession = func( connP connection, _ sessionRunner, @@ -673,28 +645,20 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ utils.Logger, ) (packetHandler, error) { - sess := &mockSession{ - stopRunLoop: make(chan struct{}), - } - sessionChan <- sess + sess := sessions[0] + sessions = sessions[1:] return sess, nil } - dialed := make(chan struct{}) - go func() { - defer GinkgoRecover() - Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) - close(dialed) - }() - var firstSession, secondSession *mockSession - Eventually(sessionChan).Should(Receive(&firstSession)) - firstSession.Close(handshake.ErrCloseSessionForRetry) - Eventually(sessionChan).Should(Receive(&secondSession)) - secondSession.Close(errors.New("stop test")) - Eventually(dialed).Should(BeClosed()) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + Expect(err).ToNot(HaveOccurred()) + Expect(sessions).To(BeEmpty()) }) Context("handling packets", func() { It("handles packets", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) + cl.session = sess ph := wire.Header{ PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, @@ -706,52 +670,55 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) packetConn.dataToRead <- b.Bytes() - Expect(sess.handledPackets).To(BeEmpty()) - stoppedListening := make(chan struct{}) + done := make(chan struct{}) go func() { + defer GinkgoRecover() cl.listen() // it should continue listening when receiving valid packets - close(stoppedListening) + close(done) }() - Eventually(func() []*receivedPacket { return sess.handledPackets }).Should(HaveLen(1)) - Expect(sess.closed).To(BeFalse()) - Consistently(stoppedListening).ShouldNot(BeClosed()) + Consistently(done).ShouldNot(BeClosed()) + // make the go routine return + sess.EXPECT().Close(gomock.Any()) + Expect(packetConn.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) }) It("closes the session when encountering an error while reading from the connection", func() { testErr := errors.New("test error") + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().Close(testErr) + cl.session = sess packetConn.readErr = testErr cl.listen() - Expect(sess.closed).To(BeTrue()) - Expect(sess.closeReason).To(MatchError(testErr)) }) }) Context("Public Reset handling", func() { It("closes the session when receiving a Public Reset", func() { + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().closeRemote(gomock.Any()).Do(func(err error) { + Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) + }) + cl.session = sess err := cl.handlePacket(addr, wire.WritePublicReset(cl.destConnID, 1, 0)) Expect(err).ToNot(HaveOccurred()) - Expect(cl.session.(*mockSession).closed).To(BeTrue()) - Expect(cl.session.(*mockSession).closedRemote).To(BeTrue()) - Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) }) It("ignores Public Resets from the wrong remote address", func() { + cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any calls spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} err := cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.destConnID, 1, 0)) Expect(err).To(MatchError("Received a spoofed Public Reset")) - Expect(cl.session.(*mockSession).closed).To(BeFalse()) - Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) It("ignores unparseable Public Resets", func() { + cl.session = NewMockPacketHandler(mockCtrl) // don't EXPECT any calls pr := wire.WritePublicReset(cl.destConnID, 1, 0) err := cl.handlePacket(addr, pr[:len(pr)-5]) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet")) - Expect(cl.session.(*mockSession).closed).To(BeFalse()) - Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) }) }) }) diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go new file mode 100644 index 00000000..17d1ae7d --- /dev/null +++ b/mock_packet_handler_test.go @@ -0,0 +1,232 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: PacketHandler) + +// Package quic is a generated GoMock package. +package quic + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + handshake "github.com/lucas-clemente/quic-go/internal/handshake" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockPacketHandler is a mock of PacketHandler interface +type MockPacketHandler struct { + ctrl *gomock.Controller + recorder *MockPacketHandlerMockRecorder +} + +// MockPacketHandlerMockRecorder is the mock recorder for MockPacketHandler +type MockPacketHandlerMockRecorder struct { + mock *MockPacketHandler +} + +// NewMockPacketHandler creates a new mock instance +func NewMockPacketHandler(ctrl *gomock.Controller) *MockPacketHandler { + mock := &MockPacketHandler{ctrl: ctrl} + mock.recorder = &MockPacketHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method +func (m *MockPacketHandler) AcceptStream() (Stream, error) { + ret := m.ctrl.Call(m, "AcceptStream") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream +func (mr *MockPacketHandlerMockRecorder) AcceptStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockPacketHandler)(nil).AcceptStream)) +} + +// AcceptUniStream mocks base method +func (m *MockPacketHandler) AcceptUniStream() (ReceiveStream, error) { + ret := m.ctrl.Call(m, "AcceptUniStream") + ret0, _ := ret[0].(ReceiveStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptUniStream indicates an expected call of AcceptUniStream +func (mr *MockPacketHandlerMockRecorder) AcceptUniStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockPacketHandler)(nil).AcceptUniStream)) +} + +// Close mocks base method +func (m *MockPacketHandler) Close(arg0 error) error { + ret := m.ctrl.Call(m, "Close", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockPacketHandlerMockRecorder) Close(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close), arg0) +} + +// ConnectionState mocks base method +func (m *MockPacketHandler) ConnectionState() handshake.ConnectionState { + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(handshake.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState +func (mr *MockPacketHandlerMockRecorder) ConnectionState() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockPacketHandler)(nil).ConnectionState)) +} + +// Context mocks base method +func (m *MockPacketHandler) Context() context.Context { + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context +func (mr *MockPacketHandlerMockRecorder) Context() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockPacketHandler)(nil).Context)) +} + +// GetVersion mocks base method +func (m *MockPacketHandler) GetVersion() protocol.VersionNumber { + ret := m.ctrl.Call(m, "GetVersion") + ret0, _ := ret[0].(protocol.VersionNumber) + return ret0 +} + +// GetVersion indicates an expected call of GetVersion +func (mr *MockPacketHandlerMockRecorder) GetVersion() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockPacketHandler)(nil).GetVersion)) +} + +// LocalAddr mocks base method +func (m *MockPacketHandler) LocalAddr() net.Addr { + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr +func (mr *MockPacketHandlerMockRecorder) LocalAddr() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketHandler)(nil).LocalAddr)) +} + +// OpenStream mocks base method +func (m *MockPacketHandler) OpenStream() (Stream, error) { + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream +func (mr *MockPacketHandlerMockRecorder) OpenStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockPacketHandler)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method +func (m *MockPacketHandler) OpenStreamSync() (Stream, error) { + ret := m.ctrl.Call(m, "OpenStreamSync") + ret0, _ := ret[0].(Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync +func (mr *MockPacketHandlerMockRecorder) OpenStreamSync() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockPacketHandler)(nil).OpenStreamSync)) +} + +// OpenUniStream mocks base method +func (m *MockPacketHandler) OpenUniStream() (SendStream, error) { + ret := m.ctrl.Call(m, "OpenUniStream") + ret0, _ := ret[0].(SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStream indicates an expected call of OpenUniStream +func (mr *MockPacketHandlerMockRecorder) OpenUniStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockPacketHandler)(nil).OpenUniStream)) +} + +// OpenUniStreamSync mocks base method +func (m *MockPacketHandler) OpenUniStreamSync() (SendStream, error) { + ret := m.ctrl.Call(m, "OpenUniStreamSync") + ret0, _ := ret[0].(SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStreamSync indicates an expected call of OpenUniStreamSync +func (mr *MockPacketHandlerMockRecorder) OpenUniStreamSync() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockPacketHandler)(nil).OpenUniStreamSync)) +} + +// RemoteAddr mocks base method +func (m *MockPacketHandler) RemoteAddr() net.Addr { + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr +func (mr *MockPacketHandlerMockRecorder) RemoteAddr() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockPacketHandler)(nil).RemoteAddr)) +} + +// closeRemote mocks base method +func (m *MockPacketHandler) closeRemote(arg0 error) { + m.ctrl.Call(m, "closeRemote", arg0) +} + +// closeRemote indicates an expected call of closeRemote +func (mr *MockPacketHandlerMockRecorder) closeRemote(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeRemote", reflect.TypeOf((*MockPacketHandler)(nil).closeRemote), arg0) +} + +// getCryptoStream mocks base method +func (m *MockPacketHandler) getCryptoStream() cryptoStreamI { + ret := m.ctrl.Call(m, "getCryptoStream") + ret0, _ := ret[0].(cryptoStreamI) + return ret0 +} + +// getCryptoStream indicates an expected call of getCryptoStream +func (mr *MockPacketHandlerMockRecorder) getCryptoStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getCryptoStream", reflect.TypeOf((*MockPacketHandler)(nil).getCryptoStream)) +} + +// handlePacket mocks base method +func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { + m.ctrl.Call(m, "handlePacket", arg0) +} + +// handlePacket indicates an expected call of handlePacket +func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) +} + +// run mocks base method +func (m *MockPacketHandler) run() error { + ret := m.ctrl.Call(m, "run") + ret0, _ := ret[0].(error) + return ret0 +} + +// run indicates an expected call of run +func (mr *MockPacketHandlerMockRecorder) run() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockPacketHandler)(nil).run)) +} diff --git a/mockgen.go b/mockgen.go index 833f29f8..f2d83575 100644 --- a/mockgen.go +++ b/mockgen.go @@ -12,5 +12,6 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler" //go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'" //go:generate sh -c "goimports -w mock*_test.go" diff --git a/server.go b/server.go index 160bc380..105ce9ff 100644 --- a/server.go +++ b/server.go @@ -128,10 +128,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, supportsTLS: supportsTLS, logger: utils.DefaultLogger.WithPrefix("server"), } - s.sessionRunner = &runner{ - onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess }, - removeConnectionIDImpl: s.removeConnection, - } + s.setup() if supportsTLS { if err := s.setupTLS(); err != nil { return nil, err @@ -142,6 +139,13 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, return s, nil } +func (s *server) setup() { + s.sessionRunner = &runner{ + onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess }, + removeConnectionIDImpl: s.removeConnection, + } +} + func (s *server) setupTLS() error { cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger) if err != nil { diff --git a/server_test.go b/server_test.go index 6e29e46a..7b83ffa8 100644 --- a/server_test.go +++ b/server_test.go @@ -2,13 +2,13 @@ package quic import ( "bytes" - "context" "crypto/tls" "errors" "net" "reflect" "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -21,71 +21,10 @@ import ( ) type mockSession struct { - runner sessionRunner - connectionID protocol.ConnectionID - handledPackets []*receivedPacket - closed bool - closeReason error - closedRemote bool - stopRunLoop chan struct{} // run returns as soon as this channel receives a value -} + *MockPacketHandler -func (s *mockSession) handlePacket(p *receivedPacket) { - s.handledPackets = append(s.handledPackets, p) -} - -func (s *mockSession) run() error { - <-s.stopRunLoop - return s.closeReason -} -func (s *mockSession) Close(e error) error { - if s.closed { - return nil - } - s.closeReason = e - s.closed = true - close(s.stopRunLoop) - return nil -} -func (s *mockSession) closeRemote(e error) { - s.closeReason = e - s.closed = true - s.closedRemote = true - close(s.stopRunLoop) -} -func (s *mockSession) OpenStream() (Stream, error) { - return &stream{}, nil -} -func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } -func (s *mockSession) AcceptUniStream() (ReceiveStream, error) { panic("not implemented") } -func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } -func (s *mockSession) OpenUniStream() (SendStream, error) { panic("not implemented") } -func (s *mockSession) OpenUniStreamSync() (SendStream, error) { panic("not implemented") } -func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } -func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } -func (*mockSession) Context() context.Context { panic("not implemented") } -func (*mockSession) ConnectionState() ConnectionState { panic("not implemented") } -func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } -func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } - -var _ Session = &mockSession{} - -func newMockSession( - _ connection, - runner sessionRunner, - _ protocol.VersionNumber, - connectionID protocol.ConnectionID, - _ *handshake.ServerConfig, - _ *tls.Config, - _ *Config, - _ utils.Logger, -) (packetHandler, error) { - s := mockSession{ - runner: runner, - connectionID: connectionID, - stopRunLoop: make(chan struct{}), - } - return &s, nil + connID protocol.ConnectionID + runner sessionRunner } var _ = Describe("Server", func() { @@ -101,30 +40,7 @@ var _ = Describe("Server", func() { config = &Config{Versions: protocol.SupportedVersions} }) - Context("with mock session", func() { - var ( - serv *server - firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) - connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} - ) - - BeforeEach(func() { - serv = &server{ - sessions: make(map[string]packetHandler), - newSession: newMockSession, - conn: conn, - config: config, - sessionQueue: make(chan Session, 5), - errorChan: make(chan struct{}), - logger: utils.DefaultLogger, - } - b := &bytes.Buffer{} - utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0])) - firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} - firstPacket = append(append(firstPacket, b.Bytes()...), 0x01) - firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding - }) - + Context("quic.Config", func() { It("setups with the right values", func() { config := &Config{ HandshakeTimeout: 1337 * time.Minute, @@ -160,6 +76,54 @@ var _ = Describe("Server", func() { Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(BeZero()) }) + }) + + Context("with mock session", func() { + var ( + serv *server + firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) + connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} + sessions = make([]*MockPacketHandler, 0) + ) + + BeforeEach(func() { + newMockSession := func( + _ connection, + runner sessionRunner, + _ protocol.VersionNumber, + connID protocol.ConnectionID, + _ *handshake.ServerConfig, + _ *tls.Config, + _ *Config, + _ utils.Logger, + ) (packetHandler, error) { + ExpectWithOffset(0, sessions).ToNot(BeEmpty()) + s := &mockSession{MockPacketHandler: sessions[0]} + s.connID = connID + s.runner = runner + sessions = sessions[1:] + return s, nil + } + serv = &server{ + sessions: make(map[string]packetHandler), + newSession: newMockSession, + conn: conn, + config: config, + sessionQueue: make(chan Session, 5), + errorChan: make(chan struct{}), + logger: utils.DefaultLogger, + } + serv.setup() + b := &bytes.Buffer{} + utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0])) + firstPacket = []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} + firstPacket = append(append(firstPacket, b.Bytes()...), 0x01) + firstPacket = append(firstPacket, bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)...) // add padding + }) + + AfterEach(func() { + Expect(sessions).To(BeEmpty()) + }) It("returns the address", func() { conn.addr = &net.UDPAddr{ @@ -170,19 +134,25 @@ var _ = Describe("Server", func() { }) It("creates new sessions", func() { + s := NewMockPacketHandler(mockCtrl) + s.EXPECT().handlePacket(gomock.Any()) + run := make(chan struct{}) + s.EXPECT().run().Do(func() { close(run) }) + sessions = append(sessions, s) err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[string(connID)].(*mockSession) - Expect(sess.connectionID).To(Equal(connID)) - Expect(sess.handledPackets).To(HaveLen(1)) + Expect(sess.connID).To(Equal(connID)) + Eventually(run).Should(BeClosed()) }) It("accepts new TLS sessions", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil) - Expect(err).ToNot(HaveOccurred()) - err = serv.setupTLS() + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run().Do(func() { close(run) }) + err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, @@ -193,57 +163,66 @@ var _ = Describe("Server", func() { defer serv.sessionsMutex.Unlock() return serv.sessions[string(connID)] }).Should(Equal(sess)) + Eventually(run).Should(BeClosed()) }) It("only accepts one new TLS sessions for one connection ID", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess1, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil) - Expect(err).ToNot(HaveOccurred()) - sess2, err := newMockSession(nil, nil, protocol.VersionTLS, connID, nil, nil, nil, nil) - Expect(err).ToNot(HaveOccurred()) - err = serv.setupTLS() + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run().Do(func() { close(run) }) + sess2 := NewMockPacketHandler(mockCtrl) + err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, - sess: sess1, + sess: sess, } Eventually(func() packetHandler { serv.sessionsMutex.Lock() defer serv.sessionsMutex.Unlock() return serv.sessions[string(connID)] - }).Should(Equal(sess1)) + }).Should(Equal(sess)) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, sess: sess2, } - Eventually(func() packetHandler { + Consistently(func() packetHandler { serv.sessionsMutex.Lock() defer serv.sessionsMutex.Unlock() return serv.sessions[string(connID)] - }).Should(Equal(sess1)) + }).Should(Equal(sess)) + Eventually(run).Should(BeClosed()) }) It("accepts a session once the connection it is forward secure", func() { - var acceptedSess Session + s := NewMockPacketHandler(mockCtrl) + s.EXPECT().handlePacket(gomock.Any()) + s.EXPECT().run() + sessions = append(sessions, s) done := make(chan struct{}) go func() { defer GinkgoRecover() - var err error - acceptedSess, err = serv.Accept() + sess, err := serv.Accept() Expect(err).ToNot(HaveOccurred()) + Expect(sess.(*mockSession).connID).To(Equal(connID)) close(done) }() err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) + Consistently(done).ShouldNot(BeClosed()) sess := serv.sessions[string(connID)].(*mockSession) - Consistently(func() Session { return acceptedSess }).Should(BeNil()) - serv.sessionQueue <- sess - Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) + sess.runner.onHandshakeComplete(sess) Eventually(done).Should(BeClosed()) }) It("doesn't accept sessions that error during the handshake", func() { + run := make(chan error) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) + sess.EXPECT().run().DoAndReturn(func() error { return <-run }) + sessions = append(sessions, sess) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -253,25 +232,27 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - sess := serv.sessions[string(connID)].(*mockSession) - sess.closeReason = errors.New("handshake failed") - close(sess.stopRunLoop) + run <- errors.New("handshake error") + serv.sessions[string(connID)].(*mockSession).runner.removeConnectionID(connID) Consistently(done).ShouldNot(BeClosed()) // make the go routine return - serv.removeConnection(connID) close(serv.errorChan) serv.Close() Eventually(done).Should(BeClosed()) }) It("assigns packets to existing sessions", func() { + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()).Times(2) + sess.EXPECT().run().Do(func() { close(run) }) + sessions = append(sessions, sess) + err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) err = serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).connectionID).To(Equal(connID)) - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2)) + Eventually(run).Should(BeClosed()) }) It("deletes sessions", func() { @@ -297,12 +278,21 @@ var _ = Describe("Server", func() { }) It("closes sessions and the connection when Close is called", func() { - go serv.serve() - session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil) - serv.sessions[string(connID)] = session - err := serv.Close() - Expect(err).NotTo(HaveOccurred()) - Expect(session.(*mockSession).closed).To(BeTrue()) + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().Close(nil) + sess.EXPECT().handlePacket(gomock.Any()) + sess.EXPECT().run().Do(func() { close(run) }) + sessions = append(sessions, sess) + go func() { + defer GinkgoRecover() + serv.serve() + }() + err := serv.handlePacket(nil, firstPacket) + Expect(err).ToNot(HaveOccurred()) + Eventually(run).Should(BeClosed()) + // close the server + Expect(serv.Close()).To(Succeed()) Expect(conn.closed).To(BeTrue()) }) @@ -348,20 +338,32 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - session, _ := newMockSession(nil, nil, 0, connID, nil, nil, nil, nil) - serv.sessions[string(connID)] = session - Expect(serv.sessions[string(connID)].(*mockSession).closed).To(BeFalse()) - testErr := errors.New("connection error") - conn.readErr = testErr - go serv.serve() - Eventually(func() bool { return session.(*mockSession).closed }).Should(BeTrue()) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().Close(nil) + serv.sessions[string(connID)] = sess + + conn.readErr = errors.New("connection error") + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.serve() + close(done) + }() Expect(serv.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) }) It("ignores delayed packets with mismatching versions", func() { + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) // only called once + sess.EXPECT().run().Do(func() { close(run) }) + sessions = append(sessions, sess) + err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) + Eventually(run).Should(BeClosed()) + b := &bytes.Buffer{} // add an unsupported version data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} @@ -371,8 +373,6 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) // if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn Expect(conn.dataWritten.Bytes()).To(BeEmpty()) - // make sure the packet was *not* passed to session.handlePacket() - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) }) It("errors on invalid public header", func() { @@ -397,9 +397,21 @@ var _ = Describe("Server", func() { }) It("cuts packets at the payload length", func() { + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + gomock.InOrder( + sess.EXPECT().handlePacket(gomock.Any()), // first packet + sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) { + Expect(packet.data).To(HaveLen(123)) + }), + ) + sess.EXPECT().run().Do(func() { close(run) }) + sessions = append(sessions, sess) + serv.supportsTLS = true err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) + Eventually(run).Should(BeClosed()) b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, @@ -412,8 +424,6 @@ var _ = Describe("Server", func() { Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) err = serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...)) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2)) - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets[1].data).To(HaveLen(123)) }) It("drops packets with invalid packet types", func() { @@ -433,14 +443,18 @@ var _ = Describe("Server", func() { }) It("ignores Public Resets", func() { + run := make(chan struct{}) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) // called only once + sess.EXPECT().run().Do(func() { close(run) }) + sessions = append(sessions, sess) err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) + Eventually(run).Should(BeClosed()) err = serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) }) It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {