git » git-arr » commit 5181882

Cache some (possibly) expensive function calls

author Alberto Bertogli
2022-08-31 20:48:45 UTC
committer Alberto Bertogli
2022-08-31 22:15:16 UTC
parent 15547b279664602d524f20630fdb35f5424b0eba

Cache some (possibly) expensive function calls

This patch memoizes some of the functions to help speed up execution.
The speedup is quite variable, but ~30% is normal when generating a
medium size repository, and the output is byte-for-byte identical.

git-arr +1 -1
git.py +35 -22
utils.py +31 -18

diff --git a/git-arr b/git-arr
index b109f23..337912a 100755
--- a/git-arr
+++ b/git-arr
@@ -186,7 +186,7 @@ bottle.app.push(app)
 
 def with_utils(f):
     """Decorator to add the utilities to the return value.
-    
+
     Used to wrap functions that return dictionaries which are then passed to
     templates.
     """
diff --git a/git.py b/git.py
index 3fa160c..bf30e83 100644
--- a/git.py
+++ b/git.py
@@ -6,6 +6,7 @@ command line tool directly, so please be careful with using untrusted
 parameters.
 """
 
+import functools
 import sys
 import io
 import subprocess
@@ -199,7 +200,8 @@ class Repo:
         """Returns a GitCommand() on our path."""
         return GitCommand(self.path, cmd)
 
-    def for_each_ref(self, pattern=None, sort=None, count=None):
+    @functools.lru_cache
+    def _for_each_ref(self, pattern=None, sort=None, count=None):
         """Returns a list of references."""
         cmd = self.cmd("for-each-ref")
         if sort:
@@ -209,26 +211,25 @@ class Repo:
         if pattern:
             cmd.arg(pattern)
 
+        refs = []
         for l in cmd.run():
             obj_id, obj_type, ref = l.split()
-            yield obj_id, obj_type, ref
-
-    def branches(self, sort="-authordate"):
-        """Get the (name, obj_id) of the branches."""
-        refs = self.for_each_ref(pattern="refs/heads/", sort=sort)
-        for obj_id, _, ref in refs:
-            yield ref[len("refs/heads/") :], obj_id
+            refs.append((obj_id, obj_type, ref))
+        return refs
 
+    @functools.cache
     def branch_names(self):
         """Get the names of the branches."""
-        return (name for name, _ in self.branches())
+        refs = self._for_each_ref(pattern="refs/heads/", sort="-authordate")
+        return [ref[len("refs/heads/") :] for _, _, ref in refs]
 
+    @functools.cache
     def tags(self, sort="-taggerdate"):
         """Get the (name, obj_id) of the tags."""
-        refs = self.for_each_ref(pattern="refs/tags/", sort=sort)
-        for obj_id, _, ref in refs:
-            yield ref[len("refs/tags/") :], obj_id
+        refs = self._for_each_ref(pattern="refs/tags/", sort=sort)
+        return [(ref[len("refs/tags/") :], obj_id) for obj_id, _, ref in refs]
 
+    @functools.lru_cache
     def commit_ids(self, ref, limit=None):
         """Generate commit ids."""
         cmd = self.cmd("rev-list")
@@ -238,9 +239,9 @@ class Repo:
         cmd.arg(ref)
         cmd.arg("--")
 
-        for l in cmd.run():
-            yield l.rstrip("\n")
+        return [l.rstrip("\n") for l in cmd.run()]
 
+    @functools.lru_cache
     def commit(self, commit_id):
         """Return a single commit."""
         cs = list(self.commits(commit_id, limit=1))
@@ -248,11 +249,11 @@ class Repo:
             return None
         return cs[0]
 
-    def commits(self, ref, limit=None, offset=0):
+    @functools.lru_cache
+    def commits(self, ref, limit, offset=0):
         """Generate commit objects for the ref."""
         cmd = self.cmd("rev-list")
-        if limit:
-            cmd.max_count = limit + offset
+        cmd.max_count = limit + offset
 
         cmd.header = None
 
@@ -261,6 +262,7 @@ class Repo:
 
         info_buffer = ""
         count = 0
+        commits = []
         for l in cmd.run():
             if "\0" in l:
                 pre, post = l.split("\0", 1)
@@ -268,7 +270,7 @@ class Repo:
 
                 count += 1
                 if count > offset:
-                    yield Commit.from_str(self, info_buffer)
+                    commits.append(Commit.from_str(self, info_buffer))
 
                 # Start over.
                 info_buffer = post
@@ -278,8 +280,11 @@ class Repo:
         if info_buffer:
             count += 1
             if count > offset:
-                yield Commit.from_str(self, info_buffer)
+                commits.append(Commit.from_str(self, info_buffer))
 
+        return commits
+
+    @functools.lru_cache
     def diff(self, ref):
         """Return a Diff object for the ref."""
         cmd = self.cmd("diff-tree")
@@ -295,6 +300,7 @@ class Repo:
 
         return Diff.from_str(cmd.run())
 
+    @functools.lru_cache
     def refs(self):
         """Return a dict of obj_id -> ref."""
         cmd = self.cmd("show-ref")
@@ -308,10 +314,12 @@ class Repo:
 
         return r
 
+    @functools.lru_cache
     def tree(self, ref):
         """Returns a Tree instance for the given ref."""
         return Tree(self, ref)
 
+    @functools.lru_cache
     def blob(self, path, ref):
         """Returns a Blob instance for the given path."""
         cmd = self.cmd("cat-file")
@@ -329,9 +337,10 @@ class Repo:
 
         return Blob(out.read()[: int(head)])
 
+    @functools.cache
     def last_commit_timestamp(self):
         """Return the timestamp of the last commit."""
-        refs = self.for_each_ref(
+        refs = self._for_each_ref(
             pattern="refs/heads/", sort="-committerdate", count=1
         )
         for obj_id, _, _ in refs:
@@ -515,12 +524,13 @@ class Diff:
 
 
 class Tree:
-    """ A git tree."""
+    """A git tree."""
 
     def __init__(self, repo: Repo, ref: str):
         self.repo = repo
         self.ref = ref
 
+    @functools.lru_cache
     def ls(
         self, path, recursive=False
     ) -> Iterable[Tuple[str, smstr, Optional[int]]]:
@@ -537,6 +547,7 @@ class Tree:
         else:
             cmd.arg(path)
 
+        files = []
         for l in cmd.run():
             _mode, otype, _oid, size, name = l.split(None, 4)
             if size == "-":
@@ -553,7 +564,9 @@ class Tree:
 
             # We use a smart string for the name, as it's often tricky to
             # manipulate otherwise.
-            yield otype, smstr(name), size
+            files.append((otype, smstr(name), size))
+
+        return files
 
 
 class Blob:
diff --git a/utils.py b/utils.py
index b19c9d3..9fb6544 100644
--- a/utils.py
+++ b/utils.py
@@ -9,6 +9,14 @@ try:
     from pygments import highlight  # type: ignore
     from pygments import lexers  # type: ignore
     from pygments.formatters import HtmlFormatter  # type: ignore
+
+    _html_formatter = HtmlFormatter(
+        encoding="utf-8",
+        cssclass="source_code",
+        linenos="table",
+        anchorlinenos=True,
+        lineanchors="line",
+    )
 except ImportError:
     pygments = None
 
@@ -19,6 +27,7 @@ except ImportError:
     markdown = None
 
 import base64
+import functools
 import mimetypes
 import string
 import os.path
@@ -32,6 +41,7 @@ def shorten(s: str, width=60):
     return s[:57] + "..."
 
 
+@functools.lru_cache
 def can_colorize(s: str):
     """True if we can colorize the string, False otherwise."""
     if pygments is None:
@@ -77,6 +87,7 @@ def can_embed_image(repo, fname):
     )
 
 
+@functools.lru_cache
 def colorize_diff(s: str) -> str:
     lexer = lexers.DiffLexer(encoding="utf-8")
     formatter = HtmlFormatter(encoding="utf-8", cssclass="source_code")
@@ -84,6 +95,7 @@ def colorize_diff(s: str) -> str:
     return highlight(s, lexer, formatter)
 
 
+@functools.lru_cache
 def colorize_blob(fname, s: str) -> str:
     try:
         lexer = lexers.guess_lexer_for_filename(fname, s, encoding="utf-8")
@@ -98,24 +110,7 @@ def colorize_blob(fname, s: str) -> str:
             except lexers.ClassNotFound:
                 pass
 
-    formatter = HtmlFormatter(
-        encoding="utf-8",
-        cssclass="source_code",
-        linenos="table",
-        anchorlinenos=True,
-        lineanchors="line",
-    )
-
-    return highlight(s, lexer, formatter)
-
-
-def markdown_blob(s: str) -> str:
-    extensions = [
-        "markdown.extensions.fenced_code",
-        "markdown.extensions.tables",
-        RewriteLocalLinksExtension(),
-    ]
-    return markdown.markdown(s, extensions=extensions)
+    return highlight(s, lexer, _html_formatter)
 
 
 def embed_image_blob(fname: str, image_data: bytes) -> str:
@@ -126,11 +121,13 @@ def embed_image_blob(fname: str, image_data: bytes) -> str:
     )
 
 
+@functools.lru_cache
 def is_binary(b: bytes):
     # Git considers a blob binary if NUL in first ~8KB, so do the same.
     return b"\0" in b[:8192]
 
 
+@functools.lru_cache
 def hexdump(s: bytes):
     graph = string.ascii_letters + string.digits + string.punctuation + " "
     b = s.decode("latin1")
@@ -181,3 +178,19 @@ if markdown:
             md.treeprocessors.register(
                 RewriteLocalLinks(), "RewriteLocalLinks", 1000
             )
+
+    _md_extensions = [
+        "markdown.extensions.fenced_code",
+        "markdown.extensions.tables",
+        RewriteLocalLinksExtension(),
+    ]
+
+    @functools.lru_cache
+    def markdown_blob(s: str) -> str:
+        return markdown.markdown(s, extensions=_md_extensions)
+
+else:
+
+    @functools.lru_cache
+    def markdown_blob(s: str) -> str:
+        raise RuntimeError("markdown_blob() called without markdown support")