diff --git a/client_test.go b/client_test.go index b22aa29f..2d1b090a 100644 --- a/client_test.go +++ b/client_test.go @@ -9,15 +9,15 @@ import ( "os" "time" - "github.com/lucas-clemente/quic-go/logging" - - "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/mocks" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/quictrace" + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -31,7 +31,7 @@ var _ = Describe("Client", func() { mockMultiplexer *MockMultiplexer origMultiplexer multiplexer tlsConf *tls.Config - tracer *mocks.MockConnectionTracer + tracer *mocklogging.MockConnectionTracer config *Config originalClientSessConstructor func( @@ -66,8 +66,8 @@ var _ = Describe("Client", func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession - tracer = mocks.NewMockConnectionTracer(mockCtrl) - tr := mocks.NewMockTracer(mockCtrl) + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + tr := mocklogging.NewMockTracer(mockCtrl) tr.EXPECT().TracerForConnection(protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr} Eventually(areSessionsRunning).Should(BeFalse()) diff --git a/config_test.go b/config_test.go index ad9b23c7..d7199646 100644 --- a/config_test.go +++ b/config_test.go @@ -6,7 +6,7 @@ import ( "reflect" "time" - "github.com/lucas-clemente/quic-go/internal/mocks" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/quictrace" @@ -73,7 +73,7 @@ var _ = Describe("Config", func() { case "QuicTracer": f.Set(reflect.ValueOf(quictrace.NewTracer())) case "Tracer": - f.Set(reflect.ValueOf(mocks.NewMockTracer(mockCtrl))) + f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) default: Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) } diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index bf45394a..7e777eb9 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -242,10 +242,11 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { + a.rollKeys(time.Now()) + a.logger.Debugf("Initiating key update to key phase %s", a.keyPhase) if a.tracer != nil { a.tracer.UpdatedKey(a.keyPhase, false) } - a.rollKeys(time.Now()) } return a.keyPhase.Bit() } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 7ec8cb61..3ed89caf 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -7,6 +7,7 @@ import ( "os" "time" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qtls" "github.com/lucas-clemente/quic-go/internal/utils" @@ -37,24 +38,31 @@ var _ = Describe("Updatable AEAD", func() { cs := cipherSuites[i] Context(fmt.Sprintf("using %s", qtls.CipherSuiteName(cs.ID)), func() { - getPeers := func(rttStats *utils.RTTStats) (client, server *updatableAEAD) { + var ( + client, server *updatableAEAD + clientTracer, serverTracer *mocklogging.MockConnectionTracer + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + clientTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) rand.Read(trafficSecret2) - client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) - server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) + rttStats = utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, clientTracer, utils.DefaultLogger) + server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) server.SetWriteKey(cs, trafficSecret2) - return - } + }) Context("header protection", func() { It("encrypts and decrypts the header", func() { - server, client := getPeers(&utils.RTTStats{}) var lastFiveBitsDifferent int for i := 0; i < 100; i++ { sample := make([]byte, 16) @@ -76,12 +84,8 @@ var _ = Describe("Updatable AEAD", func() { Context("message encryption", func() { var msg, ad []byte - var server, client *updatableAEAD - var rttStats *utils.RTTStats BeforeEach(func() { - rttStats = &utils.RTTStats{} - server, client = getPeers(rttStats) msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ad = []byte("Donec in velit neque.") }) @@ -144,6 +148,7 @@ var _ = Describe("Updatable AEAD", func() { // now received a message at key phase one client.rollKeys(now) encrypted1 := client.Seal(nil, msg, 0x43, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(decrypted).To(Equal(msg)) @@ -163,6 +168,7 @@ var _ = Describe("Updatable AEAD", func() { client.rollKeys(now) encrypted1 := client.Seal(nil, msg, 0x44, ad) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) @@ -188,6 +194,7 @@ var _ = Describe("Updatable AEAD", func() { client.rollKeys(now) encrypted1 := client.Seal(nil, msg, 0x44, ad) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) Expect(err).ToNot(HaveOccurred()) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) @@ -233,6 +240,7 @@ var _ = Describe("Updatable AEAD", func() { // no update allowed before receiving an acknowledgement for the current key phase Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.SetLargestAcked(0) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) @@ -248,6 +256,7 @@ var _ = Describe("Updatable AEAD", func() { Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) server.Seal(nil, msg, 1, ad) server.SetLargestAcked(1) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) }) }) diff --git a/internal/mocks/connection_tracer.go b/internal/mocks/logging/connection_tracer.go similarity index 99% rename from internal/mocks/connection_tracer.go rename to internal/mocks/logging/connection_tracer.go index aac749ba..a91694e9 100644 --- a/internal/mocks/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/lucas-clemente/quic-go/logging (interfaces: ConnectionTracer) -// Package mocks is a generated GoMock package. -package mocks +// Package mocklogging is a generated GoMock package. +package mocklogging import ( net "net" diff --git a/internal/mocks/tracer.go b/internal/mocks/logging/tracer.go similarity index 97% rename from internal/mocks/tracer.go rename to internal/mocks/logging/tracer.go index 4ed0d14d..2a643e35 100644 --- a/internal/mocks/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/lucas-clemente/quic-go/logging (interfaces: Tracer) -// Package mocks is a generated GoMock package. -package mocks +// Package mocklogging is a generated GoMock package. +package mocklogging import ( net "net" diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index c2128c2d..fb38a3cc 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -3,8 +3,8 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" //go:generate sh -c "mockgen -package mockquic -destination quic/early_session_tmp.go github.com/lucas-clemente/quic-go EarlySession && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_session_tmp.go > quic/early_session.go && rm quic/early_session_tmp.go && goimports -w quic/early_session.go" //go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener && goimports -w quic/early_listener.go" -//go:generate sh -c "mockgen -package mocks -destination tracer.go github.com/lucas-clemente/quic-go/logging Tracer && goimports -w tracer.go" -//go:generate sh -c "mockgen -package mocks -destination connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer && goimports -w connection_tracer.go" +//go:generate sh -c "mockgen -package mocklogging -destination logging/tracer.go github.com/lucas-clemente/quic-go/logging Tracer && goimports -w logging/tracer.go" +//go:generate sh -c "mockgen -package mocklogging -destination logging/connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer && goimports -w logging/connection_tracer.go" //go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer && goimports -w short_header_sealer.go" //go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener && goimports -w short_header_opener.go" //go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener && goimports -w long_header_opener.go" diff --git a/multiplexer_test.go b/multiplexer_test.go index 4141c46c..5faa701e 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -3,7 +3,7 @@ package quic import ( "net" - "github.com/lucas-clemente/quic-go/internal/mocks" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -25,7 +25,7 @@ var _ = Describe("Client Multiplexer", func() { pconn := newMockPacketConn() pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} conn := testConn{PacketConn: pconn} - tracer := mocks.NewMockTracer(mockCtrl) + tracer := mocklogging.NewMockTracer(mockCtrl) _, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) Expect(err).ToNot(HaveOccurred()) conn.counter++ @@ -52,9 +52,9 @@ var _ = Describe("Client Multiplexer", func() { It("errors when adding an existing conn with different tracers", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 7, nil, mocks.NewMockTracer(mockCtrl)) + _, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, nil, mocks.NewMockTracer(mockCtrl)) + _, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) Expect(err).To(MatchError("cannot use different tracers on the same packet conn")) }) }) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index b2e99355..5e7e1a2b 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -7,7 +7,7 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/internal/mocks" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -23,7 +23,7 @@ var _ = Describe("Packet Handler Map", func() { var ( handler *packetHandlerMap conn *mockPacketConn - tracer *mocks.MockTracer + tracer *mocklogging.MockTracer connIDLen int statelessResetKey []byte @@ -51,7 +51,7 @@ var _ = Describe("Packet Handler Map", func() { BeforeEach(func() { statelessResetKey = nil connIDLen = 0 - tracer = mocks.NewMockTracer(mockCtrl) + tracer = mocklogging.NewMockTracer(mockCtrl) }) JustBeforeEach(func() { diff --git a/server_test.go b/server_test.go index b866c393..a2a9278a 100644 --- a/server_test.go +++ b/server_test.go @@ -14,9 +14,8 @@ import ( "sync/atomic" "time" - "github.com/lucas-clemente/quic-go/internal/mocks" - "github.com/lucas-clemente/quic-go/internal/handshake" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -186,11 +185,11 @@ var _ = Describe("Server", func() { var ( serv *baseServer phm *MockPacketHandlerManager - tracer *mocks.MockTracer + tracer *mocklogging.MockTracer ) BeforeEach(func() { - tracer = mocks.NewMockTracer(mockCtrl) + tracer = mocklogging.NewMockTracer(mockCtrl) ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer}) Expect(err).ToNot(HaveOccurred()) serv = ln.(*baseServer) diff --git a/session_test.go b/session_test.go index e7901cce..07cb3e91 100644 --- a/session_test.go +++ b/session_test.go @@ -17,6 +17,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" mockackhandler "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" + mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/testutils" @@ -50,7 +51,7 @@ var _ = Describe("Session", func() { streamManager *MockStreamManager packer *MockPacker cryptoSetup *mocks.MockCryptoSetup - tracer *mocks.MockConnectionTracer + tracer *mocklogging.MockConnectionTracer ) remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} @@ -88,7 +89,7 @@ var _ = Describe("Session", func() { mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) Expect(err).ToNot(HaveOccurred()) - tracer = mocks.NewMockConnectionTracer(mockCtrl) + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().UpdatedCongestionState(gomock.Any()) @@ -2109,7 +2110,7 @@ var _ = Describe("Client Session", func() { packer *MockPacker mconn *MockSendConn cryptoSetup *mocks.MockCryptoSetup - tracer *mocks.MockConnectionTracer + tracer *mocklogging.MockConnectionTracer tlsConf *tls.Config quicConf *Config ) @@ -2148,7 +2149,7 @@ var _ = Describe("Client Session", func() { tlsConf = &tls.Config{} } sessionRunner = NewMockSessionRunner(mockCtrl) - tracer = mocks.NewMockConnectionTracer(mockCtrl) + tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().UpdatedCongestionState(gomock.Any())