diff --git a/client.go b/client.go index 02deb680..1906abdf 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,8 @@ type client struct { version protocol.VersionNumber session packetHandler + + logger utils.Logger } var ( @@ -102,9 +104,10 @@ func Dial( config: clientConfig, version: clientConfig.Versions[0], versionNegotiationChan: make(chan struct{}), + logger: utils.DefaultLogger, } - utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) if err := c.dial(); err != nil { return nil, err @@ -197,7 +200,7 @@ func (c *client) dialTLS() error { MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), } csc := handshake.NewCryptoStreamConn(nil) - extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) + extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger) mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) if err != nil { return err @@ -214,7 +217,7 @@ func (c *client) dialTLS() error { if err != handshake.ErrCloseSessionForRetry { return err } - utils.Infof("Received a Retry packet. Recreating session.") + c.logger.Infof("Received a Retry packet. Recreating session.") if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { return err } @@ -237,7 +240,7 @@ func (c *client) establishSecureConnection() error { go func() { runErr = c.session.run() // returns as soon as the session is closed close(errorChan) - utils.Infof("Connection %x closed.", c.connectionID) + c.logger.Infof("Connection %x closed.", c.connectionID) if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { c.conn.Close() } @@ -291,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { r := bytes.NewReader(packet) hdr, err := wire.ParseHeaderSentByServer(r, c.version) if err != nil { - utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) + c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) // drop this packet if we can't parse the header return } @@ -314,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { // check if the remote address and the connection ID match // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { - utils.Infof("Received a spoofed Public Reset. Ignoring.") + c.logger.Infof("Received a spoofed Public Reset. Ignoring.") return } pr, err := wire.ParsePublicReset(r) if err != nil { - utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) + c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) return } - utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) + c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) return } @@ -368,7 +371,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { } } - utils.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) + c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { @@ -385,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { if err != nil { return err } - utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) + c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) c.session.Close(errCloseSessionForNewVersion) return nil } @@ -402,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) { c.config, c.initialVersion, c.negotiatedVersions, + c.logger, ) return err } @@ -421,6 +425,7 @@ func (c *client) createNewTLSSession( c.tls, paramsChan, 1, + c.logger, ) return err } diff --git a/client_test.go b/client_test.go index 787bb0e3..06665305 100644 --- a/client_test.go +++ b/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" @@ -25,7 +26,7 @@ var _ = Describe("Client", func() { packetConn *mockPacketConn addr net.Addr - originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) (packetHandler, error) + originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error) ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -43,7 +44,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _ := newMockSession(nil, 0, 0, nil, nil, nil) + msess, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) sess = msess.(*mockSession) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = newMockPacketConn() @@ -55,6 +56,7 @@ var _ = Describe("Client", func() { version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, versionNegotiationChan: make(chan struct{}), + logger: utils.DefaultLogger, } }) @@ -82,6 +84,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) return sess, nil @@ -125,6 +128,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { remoteAddrChan <- conn.RemoteAddr().String() return sess, nil @@ -153,6 +157,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { hostnameChan <- h return sess, nil @@ -264,6 +269,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { return nil, testErr } @@ -314,6 +320,7 @@ var _ = Describe("Client", func() { _ *Config, initialVersionP protocol.VersionNumber, negotiatedVersionsP []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { initialVersion = initialVersionP negotiatedVersions = negotiatedVersionsP @@ -370,6 +377,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { atomic.AddUint32(&sessionCounter, 1) return &mockSession{ @@ -474,6 +482,7 @@ var _ = Describe("Client", func() { configP *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { cconn = connP hostname = hostnameP @@ -514,6 +523,7 @@ var _ = Describe("Client", func() { tls handshake.MintTLS, paramsChan <-chan handshake.TransportParameters, _ protocol.PacketNumber, + _ utils.Logger, ) (packetHandler, error) { cconn = connP hostname = hostnameP @@ -550,6 +560,7 @@ var _ = Describe("Client", func() { tls handshake.MintTLS, paramsChan <-chan handshake.TransportParameters, _ protocol.PacketNumber, + _ utils.Logger, ) (packetHandler, error) { sess := &mockSession{ stopRunLoop: make(chan struct{}), diff --git a/example/client/main.go b/example/client/main.go index 2a28c161..23f045c8 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -19,12 +19,14 @@ func main() { flag.Parse() urls := flag.Args() + logger := utils.DefaultLogger + if *verbose { - utils.SetLogLevel(utils.LogLevelDebug) + logger.SetLogLevel(utils.LogLevelDebug) } else { - utils.SetLogLevel(utils.LogLevelInfo) + logger.SetLogLevel(utils.LogLevelInfo) } - utils.SetLogTimeFormat("") + logger.SetLogTimeFormat("") versions := protocol.SupportedVersions if *tls { @@ -42,21 +44,21 @@ func main() { var wg sync.WaitGroup wg.Add(len(urls)) for _, addr := range urls { - utils.Infof("GET %s", addr) + logger.Infof("GET %s", addr) go func(addr string) { rsp, err := hclient.Get(addr) if err != nil { panic(err) } - utils.Infof("Got response for %s: %#v", addr, rsp) + logger.Infof("Got response for %s: %#v", addr, rsp) body := &bytes.Buffer{} _, err = io.Copy(body, rsp.Body) if err != nil { panic(err) } - utils.Infof("Request Body:") - utils.Infof("%s", body.Bytes()) + logger.Infof("Request Body:") + logger.Infof("%s", body.Bytes()) wg.Done() }(addr) } diff --git a/example/main.go b/example/main.go index 35aaa85c..e83fb870 100644 --- a/example/main.go +++ b/example/main.go @@ -91,7 +91,7 @@ func init() { } } if err != nil { - utils.Infof("Error receiving upload: %#v", err) + utils.DefaultLogger.Infof("Error receiving upload: %#v", err) } } io.WriteString(w, `