git » git-arr » commit cba2f26

Update type annotations to allow static type checking

author Alberto Bertogli
2025-07-14 00:19:35 UTC
committer Alberto Bertogli
2025-07-14 00:46:14 UTC
parent bb1856427a99585b29c3eb1646413ca78d81ee96

Update type annotations to allow static type checking

This patch updates type annotations to make the mypy static type checker
happy and passing again.

No actual logic changes are needed, but we changed how optional modules
are implemented in utils.py, to allow for more natural type checking.

git.py +16 -9
pyproject.toml +4 -0
utils.py +126 -114

diff --git a/git.py b/git.py
index 0869bdd..e726051 100644
--- a/git.py
+++ b/git.py
@@ -17,7 +17,7 @@ import email.utils
 import datetime
 import urllib.request, urllib.parse, urllib.error
 from html import escape
-from typing import Any, Dict, IO, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, IO, Iterable, List, Tuple, Union
 
 
 # Path to the git binary.
@@ -25,7 +25,11 @@ GIT_BIN = "git"
 
 
 def run_git(
-    repo_path: str, params, stdin: bytes = None, silent_stderr=False, raw=False
+    repo_path: str,
+    params,
+    stdin: bytes | None = None,
+    silent_stderr=False,
+    raw=False,
 ) -> Union[IO[str], IO[bytes]]:
     """Invokes git with the given parameters.
 
@@ -48,7 +52,11 @@ def run_git(
 
 
 def _run_git(
-    repo_path: str, params, stdin: bytes = None, silent_stderr=False, raw=False
+    repo_path: str,
+    params,
+    stdin: bytes | None = None,
+    silent_stderr=False,
+    raw=False,
 ) -> Union[IO[str], IO[bytes]]:
     """Invokes git with the given parameters.
 
@@ -95,7 +103,7 @@ class GitCommand(object):
         self._cmd = cmd
         self._args: List[str] = []
         self._kwargs: Dict[str, str] = {}
-        self._stdin_buf: Optional[bytes] = None
+        self._stdin_buf: bytes | None = None
         self._raw = False
         self._override = False
 
@@ -149,8 +157,6 @@ class smstr:
     """A "smart" string, containing many representations for ease of use."""
 
     raw: str  # string, probably utf8-encoded, good enough to show.
-    url: str  # escaped for safe embedding in URLs (not human-readable).
-    html: str  # HTML-embeddable representation.
 
     def __init__(self, s: str):
         self.raw = s
@@ -174,11 +180,12 @@ class smstr:
         return smstr(self.raw + other)
 
     @functools.cached_property
-    def url(self):
+    def url(self) -> str:
+        """Escaped for safe embedding in URLs (not human-readable)."""
         return urllib.request.pathname2url(self.raw)
 
     @functools.cached_property
-    def html(self):
+    def html(self) -> str:
         """Returns an html representation of the unicode string."""
         html = ""
         for c in escape(self.raw):
@@ -571,7 +578,7 @@ class Tree:
     @functools.lru_cache
     def ls(
         self, path, recursive=False
-    ) -> Iterable[Tuple[str, smstr, str, Optional[int]]]:
+    ) -> Iterable[Tuple[str, smstr, str, int | None]]:
         """Generates (type, name, oid, size) for each file in path."""
         cmd = self.repo.cmd("ls-tree")
         cmd.long = None
diff --git a/pyproject.toml b/pyproject.toml
index 8573a6d..79e0f2e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,7 @@
 [tool.black]
 line-length = 79
 include = "(git-arr|git.py|utils.py)$"
+
+[[tool.mypy.overrides]]
+module = ["xattr.*"]
+follow_untyped_imports = true
diff --git a/utils.py b/utils.py
index 49fbb98..9b45904 100644
--- a/utils.py
+++ b/utils.py
@@ -4,29 +4,8 @@ Miscellaneous utilities.
 These are mostly used in templates, for presentation purposes.
 """
 
-try:
-    import pygments  # type: ignore
-    from pygments import highlight  # type: ignore
-    from pygments import lexers  # type: ignore
-    from pygments.formatters import HtmlFormatter  # type: ignore
-
-    _html_formatter = HtmlFormatter(
-        encoding="utf-8",
-        cssclass="source_code",
-        linenos="table",
-        anchorlinenos=True,
-        lineanchors="line",
-    )
-except ImportError:
-    pygments = None
-
-try:
-    import markdown  # type: ignore
-    import markdown.treeprocessors  # type: ignore
-except ImportError:
-    markdown = None
-
-
+from typing import Sequence
+from types import ModuleType
 import base64
 import functools
 import mimetypes
@@ -40,113 +19,98 @@ import os.path
 import git
 
 
-def shorten(s: str, width=60):
-    if len(s) < 60:
-        return s
-    return s[:57] + "..."
-
-
-@functools.lru_cache
-def can_colorize(s: str):
-    """True if we can colorize the string, False otherwise."""
-    if pygments is None:
-        return False
-
-    # Pygments can take a huge amount of time with long files, or with very
-    # long lines; these are heuristics to try to avoid those situations.
-    if len(s) > (512 * 1024):
-        return False
+try:
+    import pygments
+    from pygments import highlight
+    from pygments import lexers
+    from pygments.formatters import HtmlFormatter
 
-    # If any of the first 5 lines is over 300 characters long, don't colorize.
-    start = 0
-    for i in range(5):
-        pos = s.find("\n", start)
-        if pos == -1:
-            break
+    _html_formatter = HtmlFormatter(
+        encoding="utf-8",
+        cssclass="source_code",
+        linenos="table",
+        anchorlinenos=True,
+        lineanchors="line",
+    )
 
-        if pos - start > 300:
+    @functools.lru_cache
+    def can_colorize(s: str) -> bool:
+        """True if we can colorize the string, False otherwise."""
+        # Pygments can take a huge amount of time with long files, or with
+        # very long lines; these are heuristics to try to avoid those
+        # situations.
+        if len(s) > (512 * 1024):
             return False
-        start = pos + 1
-
-    return True
-
-
-def can_markdown(repo: git.Repo, fname: str):
-    """True if we can process file through markdown, False otherwise."""
-    if markdown is None:
-        return False
-
-    if not repo.info.embed_markdown:
-        return False
 
-    return fname.endswith(".md")
+        # If any of the first 5 lines is over 300 characters long, don't
+        # colorize.
+        start = 0
+        for i in range(5):
+            pos = s.find("\n", start)
+            if pos == -1:
+                break
 
+            if pos - start > 300:
+                return False
+            start = pos + 1
 
-def can_embed_image(repo, fname):
-    """True if we can embed image file in HTML, False otherwise."""
-    if not repo.info.embed_images:
-        return False
-
-    return ("." in fname) and (
-        fname.split(".")[-1].lower() in ["jpg", "jpeg", "png", "gif"]
-    )
-
+        return True
 
-@functools.lru_cache
-def colorize_diff(s: str) -> str:
-    lexer = lexers.DiffLexer(encoding="utf-8")
-    formatter = HtmlFormatter(encoding="utf-8", cssclass="source_code")
+    @functools.lru_cache
+    def colorize_diff(s: str) -> str:
+        lexer = lexers.DiffLexer(encoding="utf-8")
+        formatter = HtmlFormatter(encoding="utf-8", cssclass="source_code")
 
-    return highlight(s, lexer, formatter)
+        return highlight(s, lexer, formatter)
 
+    @functools.lru_cache
+    def colorize_blob(fname, s: str) -> str:
+        # Explicit import to enable type checking, otherwise mypy gets confused
+        # because pygments is defined as a generic module | None.
+        import pygments.lexer
 
-@functools.lru_cache
-def colorize_blob(fname, s: str) -> str:
-    try:
-        lexer = lexers.guess_lexer_for_filename(fname, s, encoding="utf-8")
-    except lexers.ClassNotFound:
-        # Only try to guess lexers if the file starts with a shebang,
-        # otherwise it's likely a text file and guess_lexer() is prone to
-        # make mistakes with those.
-        lexer = lexers.TextLexer(encoding="utf-8")
-        if s.startswith("#!"):
-            try:
-                lexer = lexers.guess_lexer(s[:80], encoding="utf-8")
-            except lexers.ClassNotFound:
-                pass
-
-    return highlight(s, lexer, _html_formatter)
+        lexer: pygments.lexer.Lexer | pygments.lexer.LexerMeta
+        try:
+            lexer = lexers.guess_lexer_for_filename(fname, s, encoding="utf-8")
+        except lexers.ClassNotFound:
+            # Only try to guess lexers if the file starts with a shebang,
+            # otherwise it's likely a text file and guess_lexer() is prone to
+            # make mistakes with those.
+            if s.startswith("#!"):
+                try:
+                    lexer = lexers.guess_lexer(s[:80], encoding="utf-8")
+                except lexers.ClassNotFound:
+                    pass
+            else:
+                lexer = lexers.TextLexer(encoding="utf-8")
+
+        return highlight(s, lexer, _html_formatter)
 
+except ImportError:
 
-def embed_image_blob(fname: str, image_data: bytes) -> str:
-    mimetype = mimetypes.guess_type(fname)[0]
-    b64img = base64.b64encode(image_data).decode("ascii")
-    return '<img style="max-width:100%;" src="data:{0};base64,{1}" />'.format(
-        mimetype, b64img
-    )
+    @functools.lru_cache
+    def can_colorize(s: str) -> bool:
+        """True if we can colorize the string, False otherwise."""
+        return False
 
+    @functools.lru_cache
+    def colorize_diff(s: str) -> str:
+        raise RuntimeError("colorize_diff() called without pygments support")
 
-@functools.lru_cache
-def is_binary(b: bytes):
-    # Git considers a blob binary if NUL in first ~8KB, so do the same.
-    return b"\0" in b[:8192]
+    @functools.lru_cache
+    def colorize_blob(fname, s: str) -> str:
+        raise RuntimeError("colorize_blob() called without pygments support")
 
 
-@functools.lru_cache
-def hexdump(s: bytes):
-    graph = string.ascii_letters + string.digits + string.punctuation + " "
-    b = s.decode("latin1")
-    offset = 0
-    while b:
-        t = b[:16]
-        hexvals = ["%.2x" % ord(c) for c in t]
-        text = "".join(c if c in graph else "." for c in t)
-        yield offset, " ".join(hexvals[:8]), " ".join(hexvals[8:]), text
-        offset += 16
-        b = b[16:]
+try:
+    import markdown
 
+    def can_markdown(repo: git.Repo, fname: str) -> bool:
+        """True if we can process file through markdown, False otherwise."""
+        if not repo.info.embed_markdown:
+            return False
 
-if markdown:
+        return fname.endswith(".md")
 
     class RewriteLocalLinks(markdown.treeprocessors.Treeprocessor):
         """Rewrites relative links to files, to match git-arr's links.
@@ -184,7 +148,7 @@ if markdown:
                 RewriteLocalLinks(), "RewriteLocalLinks", 1000
             )
 
-    _md_extensions = [
+    _md_extensions: Sequence[str | markdown.Extension] = [
         "markdown.extensions.fenced_code",
         "markdown.extensions.tables",
         RewriteLocalLinksExtension(),
@@ -194,13 +158,61 @@ if markdown:
     def markdown_blob(s: str) -> str:
         return markdown.markdown(s, extensions=_md_extensions)
 
-else:
+except ImportError:
+
+    def can_markdown(repo: git.Repo, fname: str) -> bool:
+        """True if we can process file through markdown, False otherwise."""
+        return False
 
     @functools.lru_cache
     def markdown_blob(s: str) -> str:
         raise RuntimeError("markdown_blob() called without markdown support")
 
 
+def shorten(s: str, width=60):
+    if len(s) < 60:
+        return s
+    return s[:57] + "..."
+
+
+def can_embed_image(repo: git.Repo, fname: str) -> bool:
+    """True if we can embed image file in HTML, False otherwise."""
+    if not repo.info.embed_images:
+        return False
+
+    return ("." in fname) and (
+        fname.split(".")[-1].lower() in ["jpg", "jpeg", "png", "gif"]
+    )
+
+
+def embed_image_blob(fname: str, image_data: bytes) -> str:
+    mimetype = mimetypes.guess_type(fname)[0]
+    b64img = base64.b64encode(image_data).decode("ascii")
+    return '<img style="max-width:100%;" src="data:{0};base64,{1}" />'.format(
+        mimetype, b64img
+    )
+
+
+@functools.lru_cache
+def is_binary(b: bytes):
+    # Git considers a blob binary if NUL in first ~8KB, so do the same.
+    return b"\0" in b[:8192]
+
+
+@functools.lru_cache
+def hexdump(s: bytes):
+    graph = string.ascii_letters + string.digits + string.punctuation + " "
+    b = s.decode("latin1")
+    offset = 0
+    while b:
+        t = b[:16]
+        hexvals = ["%.2x" % ord(c) for c in t]
+        text = "".join(c if c in graph else "." for c in t)
+        yield offset, " ".join(hexvals[:8]), " ".join(hexvals[8:]), text
+        offset += 16
+        b = b[16:]
+
+
 def log_timing(*log_args):
     "Decorator to log how long a function call took."
     if not os.environ.get("GIT_ARR_DEBUG"):