diff --git a/client.go b/client.go index 02deb680..1906abdf 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,8 @@ type client struct { version protocol.VersionNumber session packetHandler + + logger utils.Logger } var ( @@ -102,9 +104,10 @@ func Dial( config: clientConfig, version: clientConfig.Versions[0], versionNegotiationChan: make(chan struct{}), + logger: utils.DefaultLogger, } - utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) if err := c.dial(); err != nil { return nil, err @@ -197,7 +200,7 @@ func (c *client) dialTLS() error { MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), } csc := handshake.NewCryptoStreamConn(nil) - extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) + extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger) mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) if err != nil { return err @@ -214,7 +217,7 @@ func (c *client) dialTLS() error { if err != handshake.ErrCloseSessionForRetry { return err } - utils.Infof("Received a Retry packet. Recreating session.") + c.logger.Infof("Received a Retry packet. Recreating session.") if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { return err } @@ -237,7 +240,7 @@ func (c *client) establishSecureConnection() error { go func() { runErr = c.session.run() // returns as soon as the session is closed close(errorChan) - utils.Infof("Connection %x closed.", c.connectionID) + c.logger.Infof("Connection %x closed.", c.connectionID) if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { c.conn.Close() } @@ -291,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { r := bytes.NewReader(packet) hdr, err := wire.ParseHeaderSentByServer(r, c.version) if err != nil { - utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) + c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) // drop this packet if we can't parse the header return } @@ -314,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { // check if the remote address and the connection ID match // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { - utils.Infof("Received a spoofed Public Reset. Ignoring.") + c.logger.Infof("Received a spoofed Public Reset. Ignoring.") return } pr, err := wire.ParsePublicReset(r) if err != nil { - utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) + c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) return } - utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) + c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) return } @@ -368,7 +371,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { } } - utils.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) + c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { @@ -385,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { if err != nil { return err } - utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) + c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) c.session.Close(errCloseSessionForNewVersion) return nil } @@ -402,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) { c.config, c.initialVersion, c.negotiatedVersions, + c.logger, ) return err } @@ -421,6 +425,7 @@ func (c *client) createNewTLSSession( c.tls, paramsChan, 1, + c.logger, ) return err } diff --git a/client_test.go b/client_test.go index 787bb0e3..06665305 100644 --- a/client_test.go +++ b/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" @@ -25,7 +26,7 @@ var _ = Describe("Client", func() { packetConn *mockPacketConn addr net.Addr - originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) (packetHandler, error) + originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, logger utils.Logger) (packetHandler, error) ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -43,7 +44,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _ := newMockSession(nil, 0, 0, nil, nil, nil) + msess, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) sess = msess.(*mockSession) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = newMockPacketConn() @@ -55,6 +56,7 @@ var _ = Describe("Client", func() { version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, versionNegotiationChan: make(chan struct{}), + logger: utils.DefaultLogger, } }) @@ -82,6 +84,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) return sess, nil @@ -125,6 +128,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { remoteAddrChan <- conn.RemoteAddr().String() return sess, nil @@ -153,6 +157,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { hostnameChan <- h return sess, nil @@ -264,6 +269,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { return nil, testErr } @@ -314,6 +320,7 @@ var _ = Describe("Client", func() { _ *Config, initialVersionP protocol.VersionNumber, negotiatedVersionsP []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { initialVersion = initialVersionP negotiatedVersions = negotiatedVersionsP @@ -370,6 +377,7 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { atomic.AddUint32(&sessionCounter, 1) return &mockSession{ @@ -474,6 +482,7 @@ var _ = Describe("Client", func() { configP *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (packetHandler, error) { cconn = connP hostname = hostnameP @@ -514,6 +523,7 @@ var _ = Describe("Client", func() { tls handshake.MintTLS, paramsChan <-chan handshake.TransportParameters, _ protocol.PacketNumber, + _ utils.Logger, ) (packetHandler, error) { cconn = connP hostname = hostnameP @@ -550,6 +560,7 @@ var _ = Describe("Client", func() { tls handshake.MintTLS, paramsChan <-chan handshake.TransportParameters, _ protocol.PacketNumber, + _ utils.Logger, ) (packetHandler, error) { sess := &mockSession{ stopRunLoop: make(chan struct{}), diff --git a/example/client/main.go b/example/client/main.go index 2a28c161..23f045c8 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -19,12 +19,14 @@ func main() { flag.Parse() urls := flag.Args() + logger := utils.DefaultLogger + if *verbose { - utils.SetLogLevel(utils.LogLevelDebug) + logger.SetLogLevel(utils.LogLevelDebug) } else { - utils.SetLogLevel(utils.LogLevelInfo) + logger.SetLogLevel(utils.LogLevelInfo) } - utils.SetLogTimeFormat("") + logger.SetLogTimeFormat("") versions := protocol.SupportedVersions if *tls { @@ -42,21 +44,21 @@ func main() { var wg sync.WaitGroup wg.Add(len(urls)) for _, addr := range urls { - utils.Infof("GET %s", addr) + logger.Infof("GET %s", addr) go func(addr string) { rsp, err := hclient.Get(addr) if err != nil { panic(err) } - utils.Infof("Got response for %s: %#v", addr, rsp) + logger.Infof("Got response for %s: %#v", addr, rsp) body := &bytes.Buffer{} _, err = io.Copy(body, rsp.Body) if err != nil { panic(err) } - utils.Infof("Request Body:") - utils.Infof("%s", body.Bytes()) + logger.Infof("Request Body:") + logger.Infof("%s", body.Bytes()) wg.Done() }(addr) } diff --git a/example/main.go b/example/main.go index 35aaa85c..e83fb870 100644 --- a/example/main.go +++ b/example/main.go @@ -91,7 +91,7 @@ func init() { } } if err != nil { - utils.Infof("Error receiving upload: %#v", err) + utils.DefaultLogger.Infof("Error receiving upload: %#v", err) } } io.WriteString(w, `
@@ -126,12 +126,14 @@ func main() { tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)") flag.Parse() + logger := utils.DefaultLogger + if *verbose { - utils.SetLogLevel(utils.LogLevelDebug) + logger.SetLogLevel(utils.LogLevelDebug) } else { - utils.SetLogLevel(utils.LogLevelInfo) + logger.SetLogLevel(utils.LogLevelInfo) } - utils.SetLogTimeFormat("") + logger.SetLogTimeFormat("") versions := protocol.SupportedVersions if *tls { diff --git a/h2quic/client.go b/h2quic/client.go index 53e667cd..40980882 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -46,6 +46,8 @@ type client struct { requestWriter *requestWriter responses map[protocol.StreamID]chan *http.Response + + logger utils.Logger } var _ http.RoundTripper = &client{} @@ -75,6 +77,7 @@ func newClient( opts: opts, headerErrored: make(chan struct{}), dialer: dialer, + logger: utils.DefaultLogger, } } @@ -95,7 +98,7 @@ func (c *client) dial() error { if err != nil { return err } - c.requestWriter = newRequestWriter(c.headerStream) + c.requestWriter = newRequestWriter(c.headerStream, c.logger) go c.handleHeaderStream() return nil } @@ -109,7 +112,7 @@ func (c *client) handleHeaderStream() { err = c.readResponse(h2framer, decoder) } if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway { - utils.Debugf("Error handling header stream: %s", err) + c.logger.Debugf("Error handling header stream: %s", err) } c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) // stop all running request diff --git a/h2quic/client_test.go b/h2quic/client_test.go index dc23d13a..d4732c63 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -14,6 +14,7 @@ import ( quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" "time" @@ -54,7 +55,7 @@ var _ = Describe("Client", func() { headerStream = newMockStream(3) client.headerStream = headerStream - client.requestWriter = newRequestWriter(headerStream) + client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger) var err error req, err = http.NewRequest("GET", "https://localhost:1337", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/h2quic/request_writer.go b/h2quic/request_writer.go index fdd7ad20..ddaaa741 100644 --- a/h2quic/request_writer.go +++ b/h2quic/request_writer.go @@ -23,13 +23,16 @@ type requestWriter struct { henc *hpack.Encoder hbuf bytes.Buffer // HPACK encoder writes into this + + logger utils.Logger } const defaultUserAgent = "quic-go" -func newRequestWriter(headerStream quic.Stream) *requestWriter { +func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter { rw := &requestWriter{ headerStream: headerStream, + logger: logger, } rw.henc = hpack.NewEncoder(&rw.hbuf) return rw @@ -156,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } func (w *requestWriter) writeHeader(name, value string) { - utils.Debugf("http2: Transport encoding header %q = %q", name, value) + w.logger.Debugf("http2: Transport encoding header %q = %q", name, value) w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) } diff --git a/h2quic/request_writer_test.go b/h2quic/request_writer_test.go index 576f755c..2b33fe7d 100644 --- a/h2quic/request_writer_test.go +++ b/h2quic/request_writer_test.go @@ -10,6 +10,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -23,7 +24,7 @@ var _ = Describe("Request", func() { BeforeEach(func() { headerStream = &mockStream{} - rw = newRequestWriter(headerStream) + rw = newRequestWriter(headerStream, utils.DefaultLogger) decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) }) diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index 1dd4e928..25b77a54 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -24,15 +24,24 @@ type responseWriter struct { header http.Header status int // status code passed to WriteHeader headerWritten bool + + logger utils.Logger } -func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter { +func newResponseWriter( + headerStream quic.Stream, + headerStreamMutex *sync.Mutex, + dataStream quic.Stream, + dataStreamID protocol.StreamID, + logger utils.Logger, +) *responseWriter { return &responseWriter{ header: http.Header{}, headerStream: headerStream, headerStreamMutex: headerStreamMutex, dataStream: dataStream, dataStreamID: dataStreamID, + logger: logger, } } @@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) { } } - utils.Infof("Responding with %d", status) + w.logger.Infof("Responding with %d", status) w.headerStreamMutex.Lock() defer w.headerStreamMutex.Unlock() h2framer := http2.NewFramer(w.headerStream, nil) @@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) { BlockFragment: headers.Bytes(), }) if err != nil { - utils.Errorf("could not write h2 header: %s", err.Error()) + w.logger.Errorf("could not write h2 header: %s", err.Error()) } } diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 847cf645..77c67a93 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -13,6 +13,7 @@ import ( quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -72,7 +73,7 @@ var _ = Describe("Response Writer", func() { BeforeEach(func() { headerStream = &mockStream{} dataStream = &mockStream{} - w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5) + w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger) }) decodeHeaderFields := func() map[string][]string { diff --git a/h2quic/server.go b/h2quic/server.go index 329edfd0..a2412bd1 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -53,6 +53,8 @@ type Server struct { closed bool supportedVersionsAsString string + + logger utils.Logger // will be set by Server.serveImpl() } // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. @@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { if s.Server == nil { return errors.New("use of h2quic.Server without http.Server") } + s.logger = utils.DefaultLogger s.listenerMutex.Lock() if s.closed { s.listenerMutex.Unlock() @@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) { // In this case, the session has already logged the error, so we don't // need to log it again. if _, ok := err.(*qerr.QuicError); !ok { - utils.Errorf("error handling h2 request: %s", err.Error()) + s.logger.Errorf("error handling h2 request: %s", err.Error()) } session.Close(err) return @@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, } headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment()) if err != nil { - utils.Errorf("invalid http2 headers encoding: %s", err.Error()) + s.logger.Errorf("invalid http2 headers encoding: %s", err.Error()) return err } @@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, return err } - if utils.Debug() { - utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) + if s.logger.Debug() { + s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) } else { - utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) + s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) } dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID)) @@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, req.RemoteAddr = session.RemoteAddr().String() - responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) + responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger) handler := s.Handler if handler == nil { @@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - utils.Errorf("http: panic serving: %v\n%s", p, buf) + s.logger.Errorf("http: panic serving: %v\n%s", p, buf) panicked = true } }() diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 9d986559..5217a725 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -19,6 +19,7 @@ import ( quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -96,6 +97,7 @@ var _ = Describe("H2 server", func() { Server: &http.Server{ TLSConfig: testdata.GetTLSConfig(), }, + logger: utils.DefaultLogger, } dataStream = newMockStream(0) close(dataStream.unblockRead) @@ -287,7 +289,6 @@ var _ = Describe("H2 server", func() { Expect(dataStream.remoteClosed).To(BeTrue()) Expect(dataStream.reset).To(BeFalse()) }) - }) It("handles the header stream", func() { diff --git a/integrationtests/chrome/chrome_suite_test.go b/integrationtests/chrome/chrome_suite_test.go index 3117460b..526bb6b4 100644 --- a/integrationtests/chrome/chrome_suite_test.go +++ b/integrationtests/chrome/chrome_suite_test.go @@ -126,7 +126,7 @@ func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func( fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()), url, } - utils.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '")) + utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '")) command := exec.Command(path, args...) session, err := gexec.Start(command, nil, nil) Expect(err).NotTo(HaveOccurred()) diff --git a/integrationtests/tools/proxy/proxy.go b/integrationtests/tools/proxy/proxy.go index 423aec01..d12a3bae 100644 --- a/integrationtests/tools/proxy/proxy.go +++ b/integrationtests/tools/proxy/proxy.go @@ -95,6 +95,8 @@ type QuicProxy struct { // Mapping from client addresses (as host:port) to connection clientDict map[string]*connection + + logger utils.Logger } // NewQuicProxy creates a new UDP proxy @@ -132,9 +134,10 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu dropPacket: packetDropper, delayPacket: packetDelayer, version: version, + logger: utils.DefaultLogger, } - utils.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr) + p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr) go p.runProxy() return &p, nil } @@ -200,8 +203,8 @@ func (p *QuicProxy) runProxy() error { packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) if p.dropPacket(DirectionIncoming, packetCount) { - if utils.Debug() { - utils.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n) + if p.logger.Debug() { + p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n) } continue } @@ -209,16 +212,16 @@ func (p *QuicProxy) runProxy() error { // Send the packet to the server delay := p.delayPacket(DirectionIncoming, packetCount) if delay != 0 { - if utils.Debug() { - utils.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay) + if p.logger.Debug() { + p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay) } time.AfterFunc(delay, func() { // TODO: handle error _, _ = conn.ServerConn.Write(raw) }) } else { - if utils.Debug() { - utils.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr()) + if p.logger.Debug() { + p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr()) } if _, err := conn.ServerConn.Write(raw); err != nil { return err @@ -240,24 +243,24 @@ func (p *QuicProxy) runConnection(conn *connection) error { packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) if p.dropPacket(DirectionOutgoing, packetCount) { - if utils.Debug() { - utils.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n) + if p.logger.Debug() { + p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n) } continue } delay := p.delayPacket(DirectionOutgoing, packetCount) if delay != 0 { - if utils.Debug() { - utils.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay) + if p.logger.Debug() { + p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay) } time.AfterFunc(delay, func() { // TODO: handle error _, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) }) } else { - if utils.Debug() { - utils.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr) + if p.logger.Debug() { + p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr) } if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { return err diff --git a/integrationtests/tools/testlog/testlog.go b/integrationtests/tools/testlog/testlog.go index 78353162..c987ddb7 100644 --- a/integrationtests/tools/testlog/testlog.go +++ b/integrationtests/tools/testlog/testlog.go @@ -30,7 +30,7 @@ var _ = BeforeEach(func() { logFile, err = os.Create(logFileName) Expect(err).ToNot(HaveOccurred()) log.SetOutput(logFile) - utils.SetLogLevel(utils.LogLevelDebug) + utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug) } }) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 40ab9abf..11bd6a1c 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -65,10 +65,12 @@ type sentPacketHandler struct { // The alarm timeout alarm time.Time + + logger utils.Logger } // NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { +func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, @@ -82,6 +84,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler { stopWaitingManager: stopWaitingManager{}, rttStats: rttStats, congestion: congestion, + logger: logger, } } @@ -170,7 +173,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe // duplicate or out of order ACK if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck { - utils.Debugf("Ignoring ACK frame (duplicate or out of order).") + h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).") return nil } h.largestReceivedPacketWithAck = withPacketNumber @@ -435,7 +438,7 @@ func (h *sentPacketHandler) SendMode() SendMode { // we will stop sending out new data when reaching MaxOutstandingSentPackets, // but still allow sending of retransmissions and ACKs. if numTrackedPackets >= protocol.MaxTrackedSentPackets { - utils.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) + h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) return SendNone } // Send retransmissions first, if there are any. @@ -444,11 +447,11 @@ func (h *sentPacketHandler) SendMode() SendMode { } // Only send ACKs if we're congestion limited. if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd { - utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) + h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd) return SendAck } if numTrackedPackets >= protocol.MaxOutstandingSentPackets { - utils.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) + h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) return SendAck } return SendAny @@ -470,7 +473,7 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { func (h *sentPacketHandler) queueRTOs() error { for i := 0; i < 2; i++ { if p := h.packetHistory.FirstOutstanding(); p != nil { - utils.Debugf("\tQueueing packet %#x for retransmission (RTO), %d outstanding", p.PacketNumber, h.packetHistory.Len()) + h.logger.Debugf("\tQueueing packet %#x for retransmission (RTO), %d outstanding", p.PacketNumber, h.packetHistory.Len()) if err := h.queuePacketForRetransmission(p); err != nil { return err } diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 58642f9f..5c67f97d 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -8,6 +8,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -61,7 +62,7 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { rttStats := &congestion.RTTStats{} - handler = NewSentPacketHandler(rttStats).(*sentPacketHandler) + handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler) handler.SetHandshakeComplete() streamFrame = wire.StreamFrame{ StreamID: 5, diff --git a/internal/congestion/rtt_stats.go b/internal/congestion/rtt_stats.go index 9e5e4541..599e350f 100644 --- a/internal/congestion/rtt_stats.go +++ b/internal/congestion/rtt_stats.go @@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) { // UpdateRTT updates the RTT based on a new sample. func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { if sendDelta == utils.InfDuration || sendDelta <= 0 { - utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond) return } diff --git a/internal/flowcontrol/base_flow_controller.go b/internal/flowcontrol/base_flow_controller.go index 3a5c0161..fb92f084 100644 --- a/internal/flowcontrol/base_flow_controller.go +++ b/internal/flowcontrol/base_flow_controller.go @@ -25,6 +25,8 @@ type baseFlowController struct { epochStartTime time.Time epochStartOffset protocol.ByteCount rttStats *congestion.RTTStats + + logger utils.Logger } func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 975cc583..c4f6e125 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -22,6 +22,7 @@ func NewConnectionFlowController( receiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount, rttStats *congestion.RTTStats, + logger utils.Logger, ) ConnectionFlowController { return &connectionFlowController{ baseFlowController: baseFlowController{ @@ -29,6 +30,7 @@ func NewConnectionFlowController( receiveWindow: receiveWindow, receiveWindowSize: receiveWindow, maxReceiveWindowSize: maxReceiveWindow, + logger: logger, }, } } @@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() if oldWindowSize < c.receiveWindowSize { - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) } c.mutex.Unlock() return offset diff --git a/internal/flowcontrol/connection_flow_controller_test.go b/internal/flowcontrol/connection_flow_controller_test.go index 25e867a1..cba41eb2 100644 --- a/internal/flowcontrol/connection_flow_controller_test.go +++ b/internal/flowcontrol/connection_flow_controller_test.go @@ -5,6 +5,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -21,6 +22,7 @@ var _ = Describe("Connection Flow controller", func() { BeforeEach(func() { controller = &connectionFlowController{} controller.rttStats = &congestion.RTTStats{} + controller.logger = utils.DefaultLogger }) Context("Constructor", func() { @@ -30,7 +32,7 @@ var _ = Describe("Connection Flow controller", func() { receiveWindow := protocol.ByteCount(2000) maxReceiveWindow := protocol.ByteCount(3000) - fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats).(*connectionFlowController) + fc := NewConnectionFlowController(receiveWindow, maxReceiveWindow, rttStats, utils.DefaultLogger).(*connectionFlowController) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) }) diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 51ecfe7f..16bef261 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -31,6 +31,7 @@ func NewStreamFlowController( maxReceiveWindow protocol.ByteCount, initialSendWindow protocol.ByteCount, rttStats *congestion.RTTStats, + logger utils.Logger, ) StreamFlowController { return &streamFlowController{ streamID: streamID, @@ -42,6 +43,7 @@ func NewStreamFlowController( receiveWindowSize: receiveWindow, maxReceiveWindowSize: maxReceiveWindow, sendWindow: initialSendWindow, + logger: logger, }, } } @@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { oldWindowSize := c.receiveWindowSize offset := c.baseFlowController.getWindowUpdate() if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size - utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) if c.contributesToConnection { c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) } diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index ba2e7dee..dfca659a 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -5,6 +5,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -17,10 +18,11 @@ var _ = Describe("Stream Flow controller", func() { rttStats := &congestion.RTTStats{} controller = &streamFlowController{ streamID: 10, - connection: NewConnectionFlowController(1000, 1000, rttStats).(*connectionFlowController), + connection: NewConnectionFlowController(1000, 1000, rttStats, utils.DefaultLogger).(*connectionFlowController), } controller.maxReceiveWindowSize = 10000 controller.rttStats = rttStats + controller.logger = utils.DefaultLogger }) Context("Constructor", func() { @@ -31,8 +33,8 @@ var _ = Describe("Stream Flow controller", func() { maxReceiveWindow := protocol.ByteCount(3000) sendWindow := protocol.ByteCount(4000) - cc := NewConnectionFlowController(0, 0, nil) - fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats).(*streamFlowController) + cc := NewConnectionFlowController(0, 0, nil, utils.DefaultLogger) + fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, rttStats, utils.DefaultLogger).(*streamFlowController) Expect(fc.streamID).To(Equal(protocol.StreamID(5))) Expect(fc.receiveWindow).To(Equal(receiveWindow)) Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow)) diff --git a/internal/handshake/cookie_handler.go b/internal/handshake/cookie_handler.go index 4257745c..bc2bd8e5 100644 --- a/internal/handshake/cookie_handler.go +++ b/internal/handshake/cookie_handler.go @@ -11,15 +11,16 @@ import ( // The cookie is sent in the TLS Retry. // By including the cookie in its ClientHello, a client can proof ownership of its source address. type CookieHandler struct { - callback func(net.Addr, *Cookie) bool - + callback func(net.Addr, *Cookie) bool cookieGenerator *CookieGenerator + + logger utils.Logger } var _ mint.CookieHandler = &CookieHandler{} // NewCookieHandler creates a new CookieHandler. -func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) { +func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) { cookieGenerator, err := NewCookieGenerator() if err != nil { return nil, err @@ -27,6 +28,7 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er return &CookieHandler{ callback: callback, cookieGenerator: cookieGenerator, + logger: logger, }, nil } @@ -42,7 +44,7 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { data, err := h.cookieGenerator.DecodeToken(token) if err != nil { - utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) + h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) return false } return h.callback(conn.RemoteAddr(), data) diff --git a/internal/handshake/cookie_handler_test.go b/internal/handshake/cookie_handler_test.go index 16b9207e..8fc8cc14 100644 --- a/internal/handshake/cookie_handler_test.go +++ b/internal/handshake/cookie_handler_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -37,7 +38,7 @@ var _ = Describe("Cookie Handler", func() { BeforeEach(func() { callbackReturn = false var err error - ch, err = NewCookieHandler(mockCallback) + ch, err = NewCookieHandler(mockCallback, utils.DefaultLogger) Expect(err).ToNot(HaveOccurred()) addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46} conn = mint.NewConn(&mockConn{remoteAddr: addr}, &mint.Config{}, false) diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index ee553639..0700399a 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -54,6 +54,8 @@ type cryptoSetupClient struct { handshakeEvent chan<- struct{} params *TransportParameters + + logger utils.Logger } var _ CryptoSetup = &cryptoSetupClient{} @@ -76,6 +78,7 @@ func NewCryptoSetupClient( handshakeEvent chan<- struct{}, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, + logger utils.Logger, ) (CryptoSetup, chan<- []byte, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) if err != nil { @@ -96,6 +99,7 @@ func NewCryptoSetupClient( initialVersion: initialVersion, negotiatedVersions: negotiatedVersions, divNonceChan: divNonceChan, + logger: logger, } return cs, divNonceChan, nil } @@ -146,7 +150,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { return err } - utils.Debugf("Got %s", message) + h.logger.Debugf("Got %s", message) switch message.Tag { case TagREJ: if err := h.handleREJMessage(message.Data); err != nil { @@ -211,7 +215,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { err = h.certManager.Verify(h.hostname) if err != nil { - utils.Infof("Certificate validation failed: %s", err.Error()) + h.logger.Infof("Certificate validation failed: %s", err.Error()) return qerr.ProofInvalid } } @@ -219,7 +223,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) if !validProof { - utils.Infof("Server proof verification failed") + h.logger.Infof("Server proof verification failed") return qerr.ProofInvalid } @@ -400,7 +404,7 @@ func (h *cryptoSetupClient) sendCHLO() error { Data: tags, } - utils.Debugf("Sending %s", message) + h.logger.Debugf("Sending %s", message) message.Write(b) _, err = h.cryptoStream.Write(b.Bytes()) diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index 56e566b9..0ab8beb8 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -131,6 +131,7 @@ var _ = Describe("Client Crypto Setup", func() { handshakeEvent, protocol.Version39, nil, + utils.DefaultLogger, ) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupClient) diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index ec4fbbab..d977f655 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -54,6 +54,8 @@ type cryptoSetupServer struct { params *TransportParameters sni string // need to fill out the ConnectionState + + logger utils.Logger } var _ CryptoSetup = &cryptoSetupServer{} @@ -80,6 +82,7 @@ func NewCryptoSetup( acceptSTK func(net.Addr, *Cookie) bool, paramsChan chan<- TransportParameters, handshakeEvent chan<- struct{}, + logger utils.Logger, ) (CryptoSetup, error) { nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) if err != nil { @@ -101,6 +104,7 @@ func NewCryptoSetup( sentSHLO: make(chan struct{}), paramsChan: paramsChan, handshakeEvent: handshakeEvent, + logger: logger, }, nil } @@ -116,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error { return qerr.InvalidCryptoMessageType } - utils.Debugf("Got %s", message) + h.logger.Debugf("Got %s", message) done, err := h.handleMessage(chloData.Bytes(), message.Data) if err != nil { return err @@ -299,7 +303,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt func (h *cryptoSetupServer) acceptSTK(token []byte) bool { stk, err := h.scfg.cookieGenerator.DecodeToken(token) if err != nil { - utils.Debugf("STK invalid: %s", err.Error()) + h.logger.Debugf("STK invalid: %s", err.Error()) return false } return h.acceptSTKCallback(h.remoteAddr, stk) @@ -342,7 +346,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa var serverReply bytes.Buffer message.Write(&serverReply) - utils.Debugf("Sending %s", message) + h.logger.Debugf("Sending %s", message) return serverReply.Bytes(), nil } @@ -443,7 +447,7 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T } var reply bytes.Buffer message.Write(&reply) - utils.Debugf("Sending %s", message) + h.logger.Debugf("Sending %s", message) return reply.Bytes(), nil } diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index bc4d90c9..2fdf053b 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -171,6 +171,7 @@ var _ = Describe("Server Crypto Setup", func() { nil, paramsChan, handshakeEvent, + utils.DefaultLogger, ) Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index 765cf79f..8e711be5 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -19,6 +19,8 @@ type extensionHandlerClient struct { initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber version protocol.VersionNumber + + logger utils.Logger } var _ mint.AppExtensionHandler = &extensionHandlerClient{} @@ -30,6 +32,7 @@ func NewExtensionHandlerClient( initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, + logger utils.Logger, ) TLSExtensionHandler { // The client reads the transport parameters from the Encrypted Extensions message. // The paramsChan is used in the session's run loop's select statement. @@ -41,6 +44,7 @@ func NewExtensionHandlerClient( initialVersion: initialVersion, supportedVersions: supportedVersions, version: version, + logger: logger, } } @@ -49,7 +53,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi return nil } - utils.Debugf("Sending Transport Parameters: %s", h.ourParams) + h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) data, err := syntax.Marshal(clientHelloTransportParameters{ InitialVersion: uint32(h.initialVersion), Parameters: h.ourParams.getTransportParameters(), @@ -122,7 +126,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte if err != nil { return err } - utils.Debugf("Received Transport Parameters: %s", params) + h.logger.Debugf("Received Transport Parameters: %s", params) h.paramsChan <- *params return nil } diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index ddbd2eb4..e68034eb 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -7,6 +7,7 @@ import ( "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -18,7 +19,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { ) BeforeEach(func() { - handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever).(*extensionHandlerClient) + handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerClient) el = make(mint.ExtensionList, 0) }) diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index ec73585c..138fc21b 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -19,6 +19,8 @@ type extensionHandlerServer struct { version protocol.VersionNumber supportedVersions []protocol.VersionNumber + + logger utils.Logger } var _ mint.AppExtensionHandler = &extensionHandlerServer{} @@ -29,6 +31,7 @@ func NewExtensionHandlerServer( params *TransportParameters, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, + logger utils.Logger, ) TLSExtensionHandler { // Processing the ClientHello is performed statelessly (and from a single go-routine). // Therefore, we have to use a buffered chan to pass the transport parameters to that go routine. @@ -38,6 +41,7 @@ func NewExtensionHandlerServer( paramsChan: paramsChan, supportedVersions: supportedVersions, version: version, + logger: logger, } } @@ -56,7 +60,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi for i, v := range supportedVersions { versions[i] = uint32(v) } - utils.Debugf("Sending Transport Parameters: %s", h.ourParams) + h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ NegotiatedVersion: uint32(h.version), SupportedVersions: versions, @@ -108,7 +112,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte if err != nil { return err } - utils.Debugf("Received Transport Parameters: %s", params) + h.logger.Debugf("Received Transport Parameters: %s", params) h.paramsChan <- *params return nil } diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index 41169489..43c49d81 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -6,6 +6,7 @@ import ( "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -25,7 +26,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { ) BeforeEach(func() { - handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever).(*extensionHandlerServer) + handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerServer) el = make(mint.ExtensionList, 0) }) diff --git a/internal/utils/log.go b/internal/utils/log.go index 342d8ddc..62a3d075 100644 --- a/internal/utils/log.go +++ b/internal/utils/log.go @@ -11,8 +11,6 @@ import ( // LogLevel of quic-go type LogLevel uint8 -const logEnv = "QUIC_GO_LOG_LEVEL" - const ( // LogLevelNothing disables LogLevelNothing LogLevel = iota @@ -24,72 +22,92 @@ const ( LogLevelDebug ) -var ( - logLevel = LogLevelNothing - timeFormat = "" -) +const logEnv = "QUIC_GO_LOG_LEVEL" + +// A Logger logs. +type Logger interface { + SetLogLevel(LogLevel) + SetLogTimeFormat(format string) + Debug() bool + + Errorf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Debugf(format string, args ...interface{}) +} + +// DefaultLogger is used by quic-go for logging. +var DefaultLogger Logger + +type defaultLogger struct { + logLevel LogLevel + timeFormat string +} + +var _ Logger = &defaultLogger{} // SetLogLevel sets the log level -func SetLogLevel(level LogLevel) { - logLevel = level +func (l *defaultLogger) SetLogLevel(level LogLevel) { + l.logLevel = level } // SetLogTimeFormat sets the format of the timestamp // an empty string disables the logging of timestamps -func SetLogTimeFormat(format string) { +func (l *defaultLogger) SetLogTimeFormat(format string) { log.SetFlags(0) // disable timestamp logging done by the log package - timeFormat = format + l.timeFormat = format } // Debugf logs something -func Debugf(format string, args ...interface{}) { - if logLevel == LogLevelDebug { - logMessage(format, args...) +func (l *defaultLogger) Debugf(format string, args ...interface{}) { + if l.logLevel == LogLevelDebug { + l.logMessage(format, args...) } } // Infof logs something -func Infof(format string, args ...interface{}) { - if logLevel >= LogLevelInfo { - logMessage(format, args...) +func (l *defaultLogger) Infof(format string, args ...interface{}) { + if l.logLevel >= LogLevelInfo { + l.logMessage(format, args...) } } // Errorf logs something -func Errorf(format string, args ...interface{}) { - if logLevel >= LogLevelError { - logMessage(format, args...) +func (l *defaultLogger) Errorf(format string, args ...interface{}) { + if l.logLevel >= LogLevelError { + l.logMessage(format, args...) } } -func logMessage(format string, args ...interface{}) { - if len(timeFormat) > 0 { - log.Printf(time.Now().Format(timeFormat)+" "+format, args...) +func (l *defaultLogger) logMessage(format string, args ...interface{}) { + if len(l.timeFormat) > 0 { + log.Printf(time.Now().Format(l.timeFormat)+" "+format, args...) } else { log.Printf(format, args...) } } // Debug returns true if the log level is LogLevelDebug -func Debug() bool { - return logLevel == LogLevelDebug +func (l *defaultLogger) Debug() bool { + return l.logLevel == LogLevelDebug } func init() { - readLoggingEnv() + DefaultLogger = &defaultLogger{} + DefaultLogger.SetLogLevel(readLoggingEnv()) } -func readLoggingEnv() { +func readLoggingEnv() LogLevel { switch strings.ToLower(os.Getenv(logEnv)) { case "": - return + return LogLevelNothing case "debug": - logLevel = LogLevelDebug + return LogLevelDebug case "info": - logLevel = LogLevelInfo + return LogLevelInfo case "error": - logLevel = LogLevelError + return LogLevelError default: fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") + return LogLevelNothing } } diff --git a/internal/utils/log_test.go b/internal/utils/log_test.go index f13c3150..dcd904af 100644 --- a/internal/utils/log_test.go +++ b/internal/utils/log_test.go @@ -11,22 +11,16 @@ import ( ) var _ = Describe("Log", func() { - var ( - b *bytes.Buffer - - initialTimeFormat string - ) + var b *bytes.Buffer BeforeEach(func() { - b = bytes.NewBuffer([]byte{}) + b = &bytes.Buffer{} log.SetOutput(b) - initialTimeFormat = timeFormat }) AfterEach(func() { log.SetOutput(os.Stdout) - SetLogLevel(LogLevelNothing) - timeFormat = initialTimeFormat + DefaultLogger.SetLogLevel(LogLevelNothing) }) It("the log level has the correct numeric value", func() { @@ -37,103 +31,97 @@ var _ = Describe("Log", func() { }) It("log level nothing", func() { - SetLogLevel(LogLevelNothing) - Debugf("debug") - Infof("info") - Errorf("err") + DefaultLogger.SetLogLevel(LogLevelNothing) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") Expect(b.Bytes()).To(Equal([]byte(""))) }) It("log level err", func() { - SetLogLevel(LogLevelError) - Debugf("debug") - Infof("info") - Errorf("err") + DefaultLogger.SetLogLevel(LogLevelError) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") Expect(b.Bytes()).To(ContainSubstring("err\n")) Expect(b.Bytes()).ToNot(ContainSubstring("info")) Expect(b.Bytes()).ToNot(ContainSubstring("debug")) }) It("log level info", func() { - SetLogLevel(LogLevelInfo) - Debugf("debug") - Infof("info") - Errorf("err") + DefaultLogger.SetLogLevel(LogLevelInfo) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") Expect(b.Bytes()).To(ContainSubstring("err\n")) Expect(b.Bytes()).To(ContainSubstring("info\n")) Expect(b.Bytes()).ToNot(ContainSubstring("debug")) }) It("log level debug", func() { - SetLogLevel(LogLevelDebug) - Debugf("debug") - Infof("info") - Errorf("err") + DefaultLogger.SetLogLevel(LogLevelDebug) + DefaultLogger.Debugf("debug") + DefaultLogger.Infof("info") + DefaultLogger.Errorf("err") Expect(b.Bytes()).To(ContainSubstring("err\n")) Expect(b.Bytes()).To(ContainSubstring("info\n")) Expect(b.Bytes()).To(ContainSubstring("debug\n")) }) It("doesn't add a timestamp if the time format is empty", func() { - SetLogLevel(LogLevelDebug) - SetLogTimeFormat("") - Debugf("debug") + DefaultLogger.SetLogLevel(LogLevelDebug) + DefaultLogger.SetLogTimeFormat("") + DefaultLogger.Debugf("debug") Expect(b.Bytes()).To(Equal([]byte("debug\n"))) }) It("adds a timestamp", func() { format := "Jan 2, 2006" - SetLogTimeFormat(format) - SetLogLevel(LogLevelInfo) - Infof("info") + DefaultLogger.SetLogTimeFormat(format) + DefaultLogger.SetLogLevel(LogLevelInfo) + DefaultLogger.Infof("info") t, err := time.Parse(format, string(b.Bytes()[:b.Len()-6])) Expect(err).ToNot(HaveOccurred()) Expect(t).To(BeTemporally("~", time.Now(), 25*time.Hour)) }) It("says whether debug is enabled", func() { - Expect(Debug()).To(BeFalse()) - SetLogLevel(LogLevelDebug) - Expect(Debug()).To(BeTrue()) + Expect(DefaultLogger.Debug()).To(BeFalse()) + DefaultLogger.SetLogLevel(LogLevelDebug) + Expect(DefaultLogger.Debug()).To(BeTrue()) }) Context("reading from env", func() { BeforeEach(func() { - Expect(logLevel).To(Equal(LogLevelNothing)) + Expect(DefaultLogger.(*defaultLogger).logLevel).To(Equal(LogLevelNothing)) }) It("reads DEBUG", func() { os.Setenv(logEnv, "DEBUG") - readLoggingEnv() - Expect(logLevel).To(Equal(LogLevelDebug)) + Expect(readLoggingEnv()).To(Equal(LogLevelDebug)) }) It("reads debug", func() { os.Setenv(logEnv, "debug") - readLoggingEnv() - Expect(logLevel).To(Equal(LogLevelDebug)) + Expect(readLoggingEnv()).To(Equal(LogLevelDebug)) }) It("reads INFO", func() { os.Setenv(logEnv, "INFO") readLoggingEnv() - Expect(logLevel).To(Equal(LogLevelInfo)) + Expect(readLoggingEnv()).To(Equal(LogLevelInfo)) }) It("reads ERROR", func() { os.Setenv(logEnv, "ERROR") - readLoggingEnv() - Expect(logLevel).To(Equal(LogLevelError)) + Expect(readLoggingEnv()).To(Equal(LogLevelError)) }) It("does not error reading invalid log levels from env", func() { - Expect(logLevel).To(Equal(LogLevelNothing)) os.Setenv(logEnv, "") - readLoggingEnv() - Expect(logLevel).To(Equal(LogLevelNothing)) + Expect(readLoggingEnv()).To(Equal(LogLevelNothing)) os.Setenv(logEnv, "asdf") - readLoggingEnv() - Expect(logLevel).To(Equal(LogLevelNothing)) + Expect(readLoggingEnv()).To(Equal(LogLevelNothing)) }) }) }) diff --git a/internal/wire/header.go b/internal/wire/header.go index 207312ff..fc346f3f 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // Header is the header of a QUIC packet. @@ -103,10 +104,10 @@ func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNu } // Log logs the Header -func (h *Header) Log() { +func (h *Header) Log(logger utils.Logger) { if h.isPublicHeader { - h.logPublicHeader() + h.logPublicHeader(logger) } else { - h.logHeader() + h.logHeader(logger) } } diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index a6128c7a..1533915f 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -225,30 +225,33 @@ var _ = Describe("Header", func() { }) Context("logging", func() { - var buf bytes.Buffer + var ( + buf *bytes.Buffer + logger utils.Logger + ) BeforeEach(func() { - buf.Reset() - utils.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(&buf) + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) }) AfterEach(func() { - utils.SetLogLevel(utils.LogLevelNothing) log.SetOutput(os.Stdout) }) It("logs an IETF draft header", func() { (&Header{ IsLongHeader: true, - }).Log() + }).Log(logger) Expect(buf.String()).To(ContainSubstring("Long Header")) }) It("logs a Public Header", func() { (&Header{ isPublicHeader: true, - }).Log() + }).Log(logger) Expect(buf.String()).To(ContainSubstring("Public Header")) }) }) diff --git a/internal/wire/ietf_header.go b/internal/wire/ietf_header.go index 6d18ef39..01bf0a26 100644 --- a/internal/wire/ietf_header.go +++ b/internal/wire/ietf_header.go @@ -174,14 +174,14 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) { return length, nil } -func (h *Header) logHeader() { +func (h *Header) logHeader(logger utils.Logger) { if h.IsLongHeader { - utils.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) + logger.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) } else { connID := "(omitted)" if !h.OmitConnectionID { connID = fmt.Sprintf("%#x", h.ConnectionID) } - utils.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + logger.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) } } diff --git a/internal/wire/ietf_header_test.go b/internal/wire/ietf_header_test.go index 9eecde82..4097bc6a 100644 --- a/internal/wire/ietf_header_test.go +++ b/internal/wire/ietf_header_test.go @@ -372,16 +372,19 @@ var _ = Describe("IETF draft Header", func() { }) Context("logging", func() { - var buf bytes.Buffer + var ( + buf *bytes.Buffer + logger utils.Logger + ) BeforeEach(func() { - buf.Reset() - utils.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(&buf) + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) }) AfterEach(func() { - utils.SetLogLevel(utils.LogLevelNothing) log.SetOutput(os.Stdout) }) @@ -392,7 +395,7 @@ var _ = Describe("IETF draft Header", func() { PacketNumber: 0x1337, ConnectionID: 0xdeadbeef, Version: 0xfeed, - }).logHeader() + }).logHeader(logger) Expect(buf.String()).To(ContainSubstring("Long Header{Type: Handshake, ConnectionID: 0xdeadbeef, PacketNumber: 0x1337, Version: 0xfeed}")) }) @@ -402,7 +405,7 @@ var _ = Describe("IETF draft Header", func() { PacketNumber: 0x1337, PacketNumberLen: 4, ConnectionID: 0xdeadbeef, - }).logHeader() + }).logHeader(logger) Expect(buf.String()).To(ContainSubstring("Short Header{ConnectionID: 0xdeadbeef, PacketNumber: 0x1337, PacketNumberLen: 4, KeyPhase: 1}")) }) @@ -411,7 +414,7 @@ var _ = Describe("IETF draft Header", func() { PacketNumber: 0x12, PacketNumberLen: 1, OmitConnectionID: true, - }).logHeader() + }).logHeader(logger) Expect(buf.String()).To(ContainSubstring("Short Header{ConnectionID: (omitted), PacketNumber: 0x12, PacketNumberLen: 1, KeyPhase: 0}")) }) }) diff --git a/internal/wire/log.go b/internal/wire/log.go index 0e72ea98..eaf5b1ea 100644 --- a/internal/wire/log.go +++ b/internal/wire/log.go @@ -3,8 +3,8 @@ package wire import "github.com/lucas-clemente/quic-go/internal/utils" // LogFrame logs a frame, either sent or received -func LogFrame(frame Frame, sent bool) { - if !utils.Debug() { +func LogFrame(logger utils.Logger, frame Frame, sent bool) { + if !logger.Debug() { return } dir := "<-" @@ -13,16 +13,16 @@ func LogFrame(frame Frame, sent bool) { } switch f := frame.(type) { case *StreamFrame: - utils.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) case *StopWaitingFrame: if sent { - utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) + logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) } else { - utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) + logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) } case *AckFrame: - utils.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) default: - utils.Debugf("\t%s %#v", dir, frame) + logger.Debugf("\t%s %#v", dir, frame) } } diff --git a/internal/wire/log_test.go b/internal/wire/log_test.go index 33547efa..9bed2c53 100644 --- a/internal/wire/log_test.go +++ b/internal/wire/log_test.go @@ -15,33 +15,34 @@ import ( var _ = Describe("Frame logging", func() { var ( - buf bytes.Buffer + buf *bytes.Buffer + logger utils.Logger ) BeforeEach(func() { - buf.Reset() - utils.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(&buf) + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) }) - AfterSuite(func() { - utils.SetLogLevel(utils.LogLevelNothing) + AfterEach(func() { log.SetOutput(os.Stdout) }) It("doesn't log when debug is disabled", func() { - utils.SetLogLevel(utils.LogLevelInfo) - LogFrame(&RstStreamFrame{}, true) + logger.SetLogLevel(utils.LogLevelInfo) + LogFrame(logger, &RstStreamFrame{}, true) Expect(buf.Len()).To(BeZero()) }) It("logs sent frames", func() { - LogFrame(&RstStreamFrame{}, true) + LogFrame(logger, &RstStreamFrame{}, true) Expect(buf.Bytes()).To(ContainSubstring("\t-> &wire.RstStreamFrame{StreamID:0x0, ErrorCode:0x0, ByteOffset:0x0}\n")) }) It("logs received frames", func() { - LogFrame(&RstStreamFrame{}, false) + LogFrame(logger, &RstStreamFrame{}, false) Expect(buf.Bytes()).To(ContainSubstring("\t<- &wire.RstStreamFrame{StreamID:0x0, ErrorCode:0x0, ByteOffset:0x0}\n")) }) @@ -51,7 +52,7 @@ var _ = Describe("Frame logging", func() { Offset: 0x1337, Data: bytes.Repeat([]byte{'f'}, 0x100), } - LogFrame(frame, false) + LogFrame(logger, frame, false) Expect(buf.Bytes()).To(ContainSubstring("\t<- &wire.StreamFrame{StreamID: 42, FinBit: false, Offset: 0x1337, Data length: 0x100, Offset + Data length: 0x1437}\n")) }) @@ -61,7 +62,7 @@ var _ = Describe("Frame logging", func() { LowestAcked: 0x42, DelayTime: 1 * time.Millisecond, } - LogFrame(frame, false) + LogFrame(logger, frame, false) Expect(buf.Bytes()).To(ContainSubstring("\t<- &wire.AckFrame{LargestAcked: 0x1337, LowestAcked: 0x42, AckRanges: []wire.AckRange(nil), DelayTime: 1ms}\n")) }) @@ -69,7 +70,7 @@ var _ = Describe("Frame logging", func() { frame := &StopWaitingFrame{ LeastUnacked: 0x1337, } - LogFrame(frame, false) + LogFrame(logger, frame, false) Expect(buf.Bytes()).To(ContainSubstring("\t<- &wire.StopWaitingFrame{LeastUnacked: 0x1337}\n")) }) @@ -78,7 +79,7 @@ var _ = Describe("Frame logging", func() { LeastUnacked: 0x1337, PacketNumberLen: protocol.PacketNumberLen4, } - LogFrame(frame, true) + LogFrame(logger, frame, true) Expect(buf.Bytes()).To(ContainSubstring("\t-> &wire.StopWaitingFrame{LeastUnacked: 0x1337, PacketNumberLen: 0x4}\n")) }) }) diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index 77bd8f52..af996b29 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -231,7 +231,7 @@ func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { return true } -func (h *Header) logPublicHeader() { +func (h *Header) logPublicHeader(logger utils.Logger) { connID := "(omitted)" if !h.OmitConnectionID { connID = fmt.Sprintf("%#x", h.ConnectionID) @@ -240,5 +240,5 @@ func (h *Header) logPublicHeader() { if h.Version != 0 { ver = h.Version.String() } - utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) + logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) } diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index 762b3b61..f0de8bf5 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -468,16 +468,19 @@ var _ = Describe("Public Header", func() { }) Context("logging", func() { - var buf bytes.Buffer + var ( + buf *bytes.Buffer + logger utils.Logger + ) BeforeEach(func() { - buf.Reset() - utils.SetLogLevel(utils.LogLevelDebug) - log.SetOutput(&buf) + buf = &bytes.Buffer{} + logger = utils.DefaultLogger + logger.SetLogLevel(utils.LogLevelDebug) + log.SetOutput(buf) }) AfterEach(func() { - utils.SetLogLevel(utils.LogLevelNothing) log.SetOutput(os.Stdout) }) @@ -487,7 +490,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0x1337, PacketNumberLen: 6, Version: protocol.Version39, - }).logPublicHeader() + }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: 0xdecafbad, PacketNumber: 0x1337, PacketNumberLen: 6, Version: gQUIC 39")) }) @@ -497,7 +500,7 @@ var _ = Describe("Public Header", func() { PacketNumber: 0x1337, PacketNumberLen: 6, Version: protocol.Version39, - }).logPublicHeader() + }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (omitted)")) }) @@ -506,7 +509,7 @@ var _ = Describe("Public Header", func() { OmitConnectionID: true, PacketNumber: 0x1337, PacketNumberLen: 6, - }).logPublicHeader() + }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("Version: (unset)")) }) @@ -514,7 +517,7 @@ var _ = Describe("Public Header", func() { (&Header{ ConnectionID: 0xdecafbad, DiversificationNonce: []byte{0xba, 0xdf, 0x00, 0x0d}, - }).logPublicHeader() + }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("DiversificationNonce: []byte{0xba, 0xdf, 0x0, 0xd}")) }) diff --git a/mint_utils.go b/mint_utils.go index 8b20275a..0f5a48e4 100644 --- a/mint_utils.go +++ b/mint_utils.go @@ -107,7 +107,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf // unpackInitialOrRetryPacket unpacks packets Initial and Retry packets // These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0. -func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) { +func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger utils.Logger, version protocol.VersionNumber) (*wire.StreamFrame, error) { buf := *getPacketBuffer() buf = buf[:0] defer putPacketBuffer(&buf) @@ -139,17 +139,17 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio if frame.Offset != 0 { return nil, errors.New("received stream data with non-zero offset") } - if utils.Debug() { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) - hdr.Log() - wire.LogFrame(frame, false) + if logger.Debug() { + logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + hdr.Log(logger) + wire.LogFrame(logger, frame, false) } return frame, nil } // packUnencryptedPacket provides a low-overhead way to pack a packet. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. -func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) { +func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective, logger utils.Logger) ([]byte, error) { raw := *getPacketBuffer() buffer := bytes.NewBuffer(raw[:0]) if err := hdr.Write(buffer, pers, hdr.Version); err != nil { @@ -162,10 +162,10 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, per raw = raw[0:buffer.Len()] _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) raw = raw[0 : buffer.Len()+aead.Overhead()] - if utils.Debug() { - utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) - hdr.Log() - wire.LogFrame(f, true) + if logger.Debug() { + logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) + hdr.Log(logger) + wire.LogFrame(logger, f, true) } return raw, nil } diff --git a/mint_utils_test.go b/mint_utils_test.go index 52b1b37b..42f4c0b5 100644 --- a/mint_utils_test.go +++ b/mint_utils_test.go @@ -9,6 +9,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -107,14 +108,14 @@ var _ = Describe("Packing and unpacking Initial packets", func() { Data: []byte("foobar"), } p := packPacket([]wire.Frame{f}) - frame, err := unpackInitialPacket(aead, hdr, p, ver) + frame, err := unpackInitialPacket(aead, hdr, p, utils.DefaultLogger, ver) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("rejects a packet that doesn't contain a STREAM_FRAME", func() { p := packPacket([]wire.Frame{&wire.PingFrame{}}) - _, err := unpackInitialPacket(aead, hdr, p, ver) + _, err := unpackInitialPacket(aead, hdr, p, utils.DefaultLogger, ver) Expect(err).To(MatchError("Packet doesn't contain a STREAM_FRAME")) }) @@ -124,7 +125,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() { Data: []byte("foobar"), } p := packPacket([]wire.Frame{f}) - _, err := unpackInitialPacket(aead, hdr, p, ver) + _, err := unpackInitialPacket(aead, hdr, p, utils.DefaultLogger, ver) Expect(err).To(MatchError("Received STREAM_FRAME for wrong stream (Stream ID 42)")) }) @@ -135,7 +136,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() { Data: []byte("foobar"), } p := packPacket([]wire.Frame{f}) - _, err := unpackInitialPacket(aead, hdr, p, ver) + _, err := unpackInitialPacket(aead, hdr, p, utils.DefaultLogger, ver) Expect(err).To(MatchError("received stream data with non-zero offset")) }) }) @@ -146,7 +147,7 @@ var _ = Describe("Packing and unpacking Initial packets", func() { Data: []byte("foobar"), FinBit: true, } - data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer) + data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer, utils.DefaultLogger) Expect(err).ToNot(HaveOccurred()) aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver) Expect(err).ToNot(HaveOccurred()) diff --git a/qerr/quic_error.go b/qerr/quic_error.go index a620bd19..42d08c4c 100644 --- a/qerr/quic_error.go +++ b/qerr/quic_error.go @@ -2,8 +2,6 @@ package qerr import ( "fmt" - - "github.com/lucas-clemente/quic-go/internal/utils" ) // ErrorCode can be used as a normal error without reason. @@ -51,6 +49,5 @@ func ToQuicError(err error) *QuicError { case ErrorCode: return Error(e, "") } - utils.Errorf("Internal error: %v", err) return Error(InternalError, err.Error()) } diff --git a/server.go b/server.go index 7a19afcb..1e56f0b9 100644 --- a/server.go +++ b/server.go @@ -50,8 +50,10 @@ type server struct { errorChan chan struct{} // set as members, so they can be set in the tests - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error) deleteClosedSessionsAfter time.Duration + + logger utils.Logger } var _ Listener = &server{} @@ -110,6 +112,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), supportsTLS: supportsTLS, + logger: utils.DefaultLogger, } if supportsTLS { if err := s.setupTLS(); err != nil { @@ -117,16 +120,16 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, } } go s.serve() - utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } func (s *server) setupTLS() error { - cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie) + cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger) if err != nil { return err } - serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf) + serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger) if err != nil { return err } @@ -245,7 +248,7 @@ func (s *server) serve() { } data = data[:n] if err := s.handlePacket(s.conn, remoteAddr, data); err != nil { - utils.Errorf("error handling packet: %s", err.Error()) + s.logger.Errorf("error handling packet: %s", err.Error()) } } } @@ -328,12 +331,12 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet var pr *wire.PublicReset pr, err = wire.ParsePublicReset(r) if err != nil { - utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID) + s.logger.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID) } else { - utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) + s.logger.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) } } else { - utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) + s.logger.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) } return nil } @@ -360,7 +363,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet if len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) { return errors.New("dropping small packet with unknown version") } - utils.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) + s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) return err } @@ -377,7 +380,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return errors.New("Server BUG: negotiated version not supported") } - utils.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) + s.logger.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) session, err = s.newSession( &conn{pconn: pconn, currentAddr: remoteAddr}, version, @@ -385,6 +388,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet s.scfg, s.tlsConf, s.config, + s.logger, ) if err != nil { return err diff --git a/server_test.go b/server_test.go index f93af333..ac2b8671 100644 --- a/server_test.go +++ b/server_test.go @@ -79,6 +79,7 @@ func newMockSession( _ *handshake.ServerConfig, _ *tls.Config, _ *Config, + _ utils.Logger, ) (packetHandler, error) { s := mockSession{ connectionID: connectionID, @@ -116,6 +117,7 @@ var _ = Describe("Server", func() { config: config, sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), + logger: utils.DefaultLogger, } b := &bytes.Buffer{} utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0])) @@ -179,7 +181,7 @@ var _ = Describe("Server", func() { It("accepts new TLS sessions", func() { connID := protocol.ConnectionID(0x12345) - sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil) + sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) err = serv.setupTLS() Expect(err).ToNot(HaveOccurred()) @@ -196,9 +198,9 @@ var _ = Describe("Server", func() { It("only accepts one new TLS sessions for one connection ID", func() { connID := protocol.ConnectionID(0x12345) - sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil) + sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) - sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil) + sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) err = serv.setupTLS() Expect(err).ToNot(HaveOccurred()) @@ -301,7 +303,7 @@ var _ = Describe("Server", func() { It("closes sessions and the connection when Close is called", func() { go serv.serve() - session, _ := newMockSession(nil, 0, 0, nil, nil, nil) + session, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) serv.sessions[1] = session err := serv.Close() Expect(err).NotTo(HaveOccurred()) @@ -351,7 +353,7 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - session, _ := newMockSession(nil, 0, 0, nil, nil, nil) + session, _ := newMockSession(nil, 0, 0, nil, nil, nil, nil) serv.sessions[0x12345] = session Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") diff --git a/server_tls.go b/server_tls.go index ba8b593c..9f387409 100644 --- a/server_tls.go +++ b/server_tls.go @@ -43,6 +43,8 @@ type serverTLS struct { newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) sessionChan chan<- tlsSession + + logger utils.Logger } func newServerTLS( @@ -50,6 +52,7 @@ func newServerTLS( config *Config, cookieHandler *handshake.CookieHandler, tlsConf *tls.Config, + logger utils.Logger, ) (*serverTLS, <-chan tlsSession, error) { mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) if err != nil { @@ -77,16 +80,17 @@ func newServerTLS( MaxBidiStreams: uint16(config.MaxIncomingStreams), MaxUniStreams: uint16(config.MaxIncomingUniStreams), }, + logger: logger, } s.newMintConn = s.newMintConnImpl return s, sessionChan, nil } func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { - utils.Debugf("Received a Packet. Handling it statelessly.") + s.logger.Debugf("Received a Packet. Handling it statelessly.") sess, err := s.handleInitialImpl(remoteAddr, hdr, data) if err != nil { - utils.Errorf("Error occurred handling initial packet: %s", err) + s.logger.Errorf("Error occurred handling initial packet: %s", err) return } if sess == nil { // a stateless reset was done @@ -100,7 +104,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] // will be set to s.newMintConn by the constructor func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { - extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v) + extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v, s.logger) conf := s.mintConf.Clone() conf.ExtensionHandler = extHandler return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil @@ -118,7 +122,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea PacketNumber: 1, // random packet number Version: clientHdr.Version, } - data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer) + data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger) if err != nil { return err } @@ -132,7 +136,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat } // check version, if not matching send VNP if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { - utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) + s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.supportedVersions), remoteAddr) return nil, err } @@ -142,15 +146,15 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat if err != nil { return nil, err } - frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version) + frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version) if err != nil { - utils.Debugf("Error unpacking initial packet: %s", err) + s.logger.Debugf("Error unpacking initial packet: %s", err) return nil, nil } sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) if err != nil { if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { - utils.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr) + s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr) } return nil, err } @@ -180,7 +184,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, StreamID: version.CryptoStreamID(), Data: bc.GetDataForWriting(), } - data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer) + data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger) if err != nil { return nil, err } @@ -210,6 +214,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, aead, ¶ms, version, + s.logger, ) if err != nil { return nil, err diff --git a/server_tls_test.go b/server_tls_test.go index 3994294c..633dcf9a 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -11,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/mocks/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" @@ -36,7 +37,7 @@ var _ = Describe("Stateless TLS handling", func() { Versions: []protocol.VersionNumber{protocol.VersionTLS}, } var err error - server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig()) + server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig(), utils.DefaultLogger) Expect(err).ToNot(HaveOccurred()) server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { mintReply = bc diff --git a/session.go b/session.go index 8f1e56be..697fb53a 100644 --- a/session.go +++ b/session.go @@ -136,6 +136,8 @@ type session struct { // keepAlivePingSent stores whether a Ping frame was sent to the peer or not // it is reset as soon as we receive a packet from the peer keepAlivePingSent bool + + logger utils.Logger } var _ Session = &session{} @@ -149,6 +151,7 @@ func newSession( scfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, + logger utils.Logger, ) (packetHandler, error) { paramsChan := make(chan handshake.TransportParameters) handshakeEvent := make(chan struct{}, 1) @@ -160,6 +163,7 @@ func newSession( config: config, handshakeEvent: handshakeEvent, paramsChan: paramsChan, + logger: logger, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -184,6 +188,7 @@ func newSession( s.config.AcceptCookie, paramsChan, handshakeEvent, + s.logger, ) if err != nil { return nil, err @@ -215,6 +220,7 @@ var newClientSession = func( config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiation + logger utils.Logger, ) (packetHandler, error) { paramsChan := make(chan handshake.TransportParameters) handshakeEvent := make(chan struct{}, 1) @@ -226,6 +232,7 @@ var newClientSession = func( config: config, handshakeEvent: handshakeEvent, paramsChan: paramsChan, + logger: logger, } s.preSetup() transportParams := &handshake.TransportParameters{ @@ -246,6 +253,7 @@ var newClientSession = func( handshakeEvent, initialVersion, negotiatedVersions, + s.logger, ) if err != nil { return nil, err @@ -278,6 +286,7 @@ func newTLSServerSession( nullAEAD crypto.AEAD, peerParams *handshake.TransportParameters, v protocol.VersionNumber, + logger utils.Logger, ) (packetHandler, error) { handshakeEvent := make(chan struct{}, 1) s := &session{ @@ -287,6 +296,7 @@ func newTLSServerSession( perspective: protocol.PerspectiveServer, version: v, handshakeEvent: handshakeEvent, + logger: logger, } s.preSetup() cs := handshake.NewCryptoSetupTLSServer( @@ -328,6 +338,7 @@ var newTLSClientSession = func( tls handshake.MintTLS, paramsChan <-chan handshake.TransportParameters, initialPacketNumber protocol.PacketNumber, + logger utils.Logger, ) (packetHandler, error) { handshakeEvent := make(chan struct{}, 1) s := &session{ @@ -338,6 +349,7 @@ var newTLSClientSession = func( version: v, handshakeEvent: handshakeEvent, paramsChan: paramsChan, + logger: logger, } s.preSetup() tls.SetCryptoStream(s.cryptoStream) @@ -371,11 +383,12 @@ var newTLSClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, + s.logger, ) s.cryptoStream = s.newCryptoStream() } @@ -575,13 +588,13 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { ) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) - if utils.Debug() { + if s.logger.Debug() { if err != nil { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) } else { - utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID, packet.encryptionLevel) } - hdr.Log() + hdr.Log(s.logger) } // if the decryption failed, this might be a packet sent by an attacker if err != nil { @@ -616,7 +629,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { for _, ff := range fs { var err error - wire.LogFrame(ff, false) + wire.LogFrame(s.logger, ff, false) switch frame := ff.(type) { case *wire.StreamFrame: err = s.handleStreamFrame(frame, encLevel) @@ -779,9 +792,9 @@ func (s *session) handleCloseError(closeErr closeError) error { } // Don't log 'normal' reasons if quicErr.ErrorCode == qerr.PeerGoingAway || quicErr.ErrorCode == qerr.NetworkIdleTimeout { - utils.Infof("Closing connection %x", s.connectionID) + s.logger.Infof("Closing connection %x", s.connectionID) } else { - utils.Errorf("Closing session with error: %s", closeErr.err.Error()) + s.logger.Errorf("Closing session with error: %s", closeErr.err.Error()) } s.cryptoStream.closeForShutdown(quicErr) @@ -907,16 +920,16 @@ func (s *session) maybeSendRetransmission() (bool, error) { // An Initial might have been retransmitted multiple times before we receive a response. // As soon as we receive one response, we don't need to send any more Initials. if s.receivedFirstPacket && retransmitPacket.PacketType == protocol.PacketTypeInitial { - utils.Debugf("Skipping retransmission of packet %d. Already received a response to an Initial.", retransmitPacket.PacketNumber) + s.logger.Debugf("Skipping retransmission of packet %d. Already received a response to an Initial.", retransmitPacket.PacketNumber) continue } break } if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { - utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) + s.logger.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) } else { - utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) + s.logger.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) } if s.version.UsesStopWaitingFrames() { @@ -987,14 +1000,14 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { } func (s *session) logPacket(packet *packedPacket) { - if !utils.Debug() { + if !s.logger.Debug() { // We don't need to allocate the slices for calling the format functions return } - utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) - packet.header.Log() + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.header.PacketNumber, len(packet.raw), s.connectionID, packet.encryptionLevel) + packet.header.Log(s.logger) for _, frame := range packet.frames { - wire.LogFrame(frame, true) + wire.LogFrame(s.logger, frame, true) } } @@ -1057,6 +1070,7 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), initialSendWindow, s.rttStats, + s.logger, ) } @@ -1070,12 +1084,13 @@ func (s *session) newCryptoStream() cryptoStreamI { protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), 0, s.rttStats, + s.logger, ) return newCryptoStream(s, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { - utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) + s.logger.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) } @@ -1089,7 +1104,7 @@ func (s *session) scheduleSending() { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.handshakeComplete { - utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) + s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { @@ -1098,10 +1113,10 @@ func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { s.receivedTooManyUndecrytablePacketsTime = time.Now() s.maybeResetTimer() } - utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) + s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) return } - utils.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) + s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) s.undecryptablePackets = append(s.undecryptablePackets, p) } diff --git a/session_test.go b/session_test.go index fd8a7e8a..b2fac4ea 100644 --- a/session_test.go +++ b/session_test.go @@ -22,6 +22,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/mocks/ackhandler" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/qerr" ) @@ -90,6 +91,7 @@ var _ = Describe("Session", func() { _ func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, handshakeChanP chan<- struct{}, + _ utils.Logger, ) (handshake.CryptoSetup, error) { handshakeChan = handshakeChanP return cryptoSetup, nil @@ -109,6 +111,7 @@ var _ = Describe("Session", func() { scfg, nil, populateServerConfig(&Config{}), + utils.DefaultLogger, ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) @@ -142,6 +145,7 @@ var _ = Describe("Session", func() { cookieFunc func(net.Addr, *Cookie) bool, _ chan<- handshake.TransportParameters, _ chan<- struct{}, + _ utils.Logger, ) (handshake.CryptoSetup, error) { cookieVerify = cookieFunc return cryptoSetup, nil @@ -160,6 +164,7 @@ var _ = Describe("Session", func() { scfg, nil, conf, + utils.DefaultLogger, ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) @@ -1687,6 +1692,7 @@ var _ = Describe("Client Session", func() { handshakeChanP chan<- struct{}, _ protocol.VersionNumber, _ []protocol.VersionNumber, + _ utils.Logger, ) (handshake.CryptoSetup, chan<- []byte, error) { handshakeChan = handshakeChanP return cryptoSetup, divNonceChan, nil @@ -1702,6 +1708,7 @@ var _ = Describe("Client Session", func() { populateClientConfig(&Config{}), protocol.VersionWhatever, nil, + utils.DefaultLogger, ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred())