diff --git a/client_multiplexer.go b/client_multiplexer.go index 5b541e7e..e2b1f717 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -15,9 +15,14 @@ import ( var ( clientMuxerOnce sync.Once - clientMuxer *clientMultiplexer + clientMuxer multiplexer ) +type multiplexer interface { + AddConn(net.PacketConn) packetHandlerManager + AddHandler(net.PacketConn, protocol.ConnectionID, packetHandler) error +} + // The clientMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the session handler. type clientMultiplexer struct { @@ -29,7 +34,9 @@ type clientMultiplexer struct { logger utils.Logger } -func getClientMultiplexer() *clientMultiplexer { +var _ multiplexer = &clientMultiplexer{} + +func getClientMultiplexer() multiplexer { clientMuxerOnce.Do(func() { clientMuxer = &clientMultiplexer{ conns: make(map[net.PacketConn]packetHandlerManager), diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index 77ada895..87109572 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -96,11 +96,11 @@ var _ = Describe("Client Multiplexer", func() { It("ignores packets arriving late for closed sessions", func() { manager := NewMockPacketHandlerManager(mockCtrl) - origNewPacketHandlerManager := getClientMultiplexer().newPacketHandlerManager + origNewPacketHandlerManager := getClientMultiplexer().(*clientMultiplexer).newPacketHandlerManager defer func() { - getClientMultiplexer().newPacketHandlerManager = origNewPacketHandlerManager + getClientMultiplexer().(*clientMultiplexer).newPacketHandlerManager = origNewPacketHandlerManager }() - getClientMultiplexer().newPacketHandlerManager = func() packetHandlerManager { return manager } + getClientMultiplexer().(*clientMultiplexer).newPacketHandlerManager = func() packetHandlerManager { return manager } conn := newMockPacketConn() connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} diff --git a/client_test.go b/client_test.go index e910521c..39284af2 100644 --- a/client_test.go +++ b/client_test.go @@ -23,10 +23,12 @@ import ( var _ = Describe("Client", func() { var ( - cl *client - packetConn *mockPacketConn - addr net.Addr - connID protocol.ConnectionID + cl *client + packetConn *mockPacketConn + addr net.Addr + connID protocol.ConnectionID + mockMultiplexer *MockMultiplexer + origMultiplexer multiplexer originalClientSessConstructor func(connection, sessionRunner, string, protocol.VersionNumber, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (quicSession, error) ) @@ -59,9 +61,15 @@ var _ = Describe("Client", func() { conn: &conn{pconn: packetConn, currentAddr: addr}, logger: utils.DefaultLogger, } + getClientMultiplexer() // make the sync.Once execute + // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer + mockMultiplexer = NewMockMultiplexer(mockCtrl) + origMultiplexer = clientMuxer + clientMuxer = mockMultiplexer }) AfterEach(func() { + clientMuxer = origMultiplexer newClientSession = originalClientSessConstructor }) @@ -137,14 +145,11 @@ var _ = Describe("Client", func() { Eventually(hostnameChan).Should(Receive(Equal("foobar"))) }) - It("errors when receiving an error from the connection", func() { - testErr := errors.New("connection error") - packetConn.readErr = testErr - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) - Expect(err).To(MatchError(testErr)) - }) - It("returns after the handshake is complete", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + run := make(chan struct{}) newClientSession = func( _ connection, @@ -160,11 +165,9 @@ var _ = Describe("Client", func() { ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() { close(run) }) - sess.EXPECT().handlePacket(gomock.Any()) runner.onHandshakeComplete(sess) return sess, nil } - packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) s, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -172,8 +175,11 @@ var _ = Describe("Client", func() { }) It("returns an error that occurs while waiting for the connection to become secure", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + testErr := errors.New("early handshake error") - handledPacket := make(chan struct{}) newClientSession = func( conn connection, _ sessionRunner, @@ -187,17 +193,19 @@ var _ = Describe("Client", func() { _ utils.Logger, ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { close(handledPacket) }) sess.EXPECT().run().Return(testErr) return sess, nil } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, nil) Expect(err).To(MatchError(testErr)) - Eventually(handledPacket).Should(BeClosed()) }) - It("closes the session when the context is canceledd", func() { + It("closes the session when the context is canceled", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + sessionRunning := make(chan struct{}) defer close(sessionRunning) sess := NewMockQuicSession(mockCtrl) @@ -232,6 +240,37 @@ var _ = Describe("Client", func() { Eventually(dialed).Should(BeClosed()) }) + It("removes closed sessions from the multiplexer", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Remove(connID) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + + var runner sessionRunner + sess := NewMockQuicSession(mockCtrl) + newClientSession = func( + conn connection, + runnerP sessionRunner, + _ string, + _ protocol.VersionNumber, + connID protocol.ConnectionID, + _ *tls.Config, + _ *Config, + _ protocol.VersionNumber, + _ []protocol.VersionNumber, + _ utils.Logger, + ) (quicSession, error) { + runner = runnerP + return sess, nil + } + sess.EXPECT().run().Do(func() { + runner.removeConnectionID(connID) + }) + + _, err := DialContext(context.Background(), packetConn, addr, "quic.clemnte.io:1337", nil, nil) + Expect(err).ToNot(HaveOccurred()) + }) + Context("quic.Config", func() { It("setups with the right values", func() { config := &Config{ @@ -250,6 +289,9 @@ var _ = Describe("Client", func() { }) It("errors when the Config contains an invalid version", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + version := protocol.VersionNumber(0x1234) _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) @@ -286,6 +328,10 @@ var _ = Describe("Client", func() { Context("gQUIC", func() { It("errors if it can't create a session", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + testErr := errors.New("error creating session") newClientSession = func( _ connection, @@ -308,6 +354,10 @@ var _ = Describe("Client", func() { Context("IETF QUIC", func() { It("creates new TLS sessions with the right parameters", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} c := make(chan struct{}) var cconn connection @@ -360,6 +410,10 @@ var _ = Describe("Client", func() { }) It("returns an error that occurs during version negotiation", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + testErr := errors.New("early handshake error") newClientSession = func( conn connection, @@ -627,6 +681,10 @@ var _ = Describe("Client", func() { }) It("creates new gQUIC sessions with the right parameters", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + config := &Config{Versions: protocol.SupportedVersions} c := make(chan struct{}) var cconn connection @@ -664,6 +722,10 @@ var _ = Describe("Client", func() { }) It("creates a new session when the server performs a retry", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} cl.config = config sess1 := NewMockQuicSession(mockCtrl) @@ -694,6 +756,10 @@ var _ = Describe("Client", func() { }) It("only accepts one Retry packet", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + mockMultiplexer.EXPECT().AddConn(packetConn).Return(manager) + mockMultiplexer.EXPECT().AddHandler(packetConn, gomock.Any(), gomock.Any()) + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} sess1 := NewMockQuicSession(mockCtrl) sess1.EXPECT().run().Return(handshake.ErrCloseSessionForRetry) diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go new file mode 100644 index 00000000..d1e34a78 --- /dev/null +++ b/mock_multiplexer_test.go @@ -0,0 +1,60 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: Multiplexer) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockMultiplexer is a mock of Multiplexer interface +type MockMultiplexer struct { + ctrl *gomock.Controller + recorder *MockMultiplexerMockRecorder +} + +// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer +type MockMultiplexerMockRecorder struct { + mock *MockMultiplexer +} + +// NewMockMultiplexer creates a new mock instance +func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer { + mock := &MockMultiplexer{ctrl: ctrl} + mock.recorder = &MockMultiplexerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { + return m.recorder +} + +// AddConn mocks base method +func (m *MockMultiplexer) AddConn(arg0 net.PacketConn) packetHandlerManager { + ret := m.ctrl.Call(m, "AddConn", arg0) + ret0, _ := ret[0].(packetHandlerManager) + return ret0 +} + +// AddConn indicates an expected call of AddConn +func (mr *MockMultiplexerMockRecorder) AddConn(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0) +} + +// AddHandler mocks base method +func (m *MockMultiplexer) AddHandler(arg0 net.PacketConn, arg1 protocol.ConnectionID, arg2 packetHandler) error { + ret := m.ctrl.Call(m, "AddHandler", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddHandler indicates an expected call of AddHandler +func (mr *MockMultiplexerMockRecorder) AddHandler(arg0, arg1, arg2 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHandler", reflect.TypeOf((*MockMultiplexer)(nil).AddHandler), arg0, arg1, arg2) +} diff --git a/mockgen.go b/mockgen.go index feaaf595..13da145a 100644 --- a/mockgen.go +++ b/mockgen.go @@ -14,5 +14,6 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner" //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager" +//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer Multiplexer" //go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'" //go:generate sh -c "goimports -w mock*_test.go"