From 3a7075b7245143265d6497c0fa6391d9eacec543 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:48:36 +0300 Subject: [PATCH 01/36] feat: add headers group support --- src/sigdb/core/__init__.py | 4 +- src/sigdb/core/reader.py | 151 ++++++++++++++++++++++++++++++++++++- src/sigdb/format/trie.py | 69 ++++++++++++----- src/sigdb/groups.py | 55 ++++++++++++++ 4 files changed, 257 insertions(+), 22 deletions(-) create mode 100644 src/sigdb/groups.py diff --git a/src/sigdb/core/__init__.py b/src/sigdb/core/__init__.py index 3ad41eb..0d4d52d 100644 --- a/src/sigdb/core/__init__.py +++ b/src/sigdb/core/__init__.py @@ -7,7 +7,7 @@ read_sigdb_metadata, validate_sigdb, ) -from sigdb.core.reader import SigDBMatcher, SigDBReader, match +from sigdb.core.reader import SigDBMatcher, SigDBReader, match, match_group, match_search __all__ = [ "SigDBMatcher", @@ -16,6 +16,8 @@ "compile_sigdb_json", "load_sigdb", "match", + "match_group", + "match_search", "read_sigdb_metadata", "validate_sigdb", ] diff --git a/src/sigdb/core/reader.py b/src/sigdb/core/reader.py index 28aea24..835e11c 100644 --- a/src/sigdb/core/reader.py +++ b/src/sigdb/core/reader.py @@ -1,10 +1,24 @@ from __future__ import annotations +from collections.abc import Mapping from pathlib import Path from typing import Any, overload from sigdb.format.trie import load_sigdb, read_sigdb_metadata, validate_sigdb -from sigdb.types import SigDBDatabase, SigDBMatchResult, SigDBValidationResult +from sigdb.groups import ( + SIGDB_GROUPS, + SIGDB_GROUPS_MAP, + format_list_pattern, + format_map_pattern, + parse_string_list, + parse_string_map, +) +from sigdb.types import ( + SigDBDatabase, + SigDBFormatError, + SigDBMatchResult, + SigDBValidationResult, +) def _normalize_head(head: str) -> str: @@ -17,6 +31,26 @@ def _normalize_head(head: str) -> str: return f"{name}:{value}" +def _iter_search_heads(search: Mapping[str, Any]) -> list[str]: + heads: list[str] = [] + headers = parse_string_map(search.get("headers"), "headers") + for header_name, header_value in headers.items(): + heads.append(format_map_pattern("headers", header_name, header_value)) + + for group in SIGDB_GROUPS: + if group == "headers": + continue + if group in SIGDB_GROUPS_MAP: + group_map = parse_string_map(search.get(group), group) + for name, value in group_map.items(): + heads.append(format_map_pattern(group, name, value)) + else: + group_values = parse_string_list(search.get(group), group) + for value in group_values: + heads.append(format_list_pattern(group, value)) + return heads + + class SigDBMatcher: __slots__ = ("_automaton", "_items") @@ -76,6 +110,32 @@ def match(self, head: str) -> SigDBMatchResult: return SigDBMatchResult(result=False, item_id=None, item=None, head=normalized) + def match_group( + self, + group: str, + value: str, + *, + name: str | None = None, + ) -> SigDBMatchResult: + if group not in SIGDB_GROUPS: + raise SigDBFormatError(f"unknown group: {group}") + if name is None: + if group in SIGDB_GROUPS_MAP: + raise SigDBFormatError(f"group {group} requires a name") + head = format_list_pattern(group, value) + else: + if group not in SIGDB_GROUPS_MAP: + raise SigDBFormatError(f"group {group} does not accept a name") + head = format_map_pattern(group, name, value) + return self.match(head) + + def match_search(self, search: Mapping[str, Any]) -> SigDBMatchResult: + for head in _iter_search_heads(search): + result = self.match(head) + if result.result: + return result + return SigDBMatchResult(result=False, item_id=None, item=None, head="") + @overload def match(head: str, src: SigDBMatcher) -> SigDBMatchResult: ... @@ -99,6 +159,83 @@ def match(head: str, src: object) -> SigDBMatchResult: raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") +@overload +def match_group( + group: str, + value: str, + src: SigDBMatcher, + *, + name: str | None = None, +) -> SigDBMatchResult: ... + + +@overload +def match_group( + group: str, + value: str, + src: SigDBDatabase, + *, + name: str | None = None, +) -> SigDBMatchResult: ... + + +@overload +def match_group( + group: str, + value: str, + src: SigDBReader, + *, + name: str | None = None, +) -> SigDBMatchResult: ... + + +def match_group( + group: str, + value: str, + src: object, + *, + name: str | None = None, +) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match_group(group, value, name=name) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match_group(group, value, name=name) + if isinstance(src, SigDBReader): + return src.matcher().match_group(group, value, name=name) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + +@overload +def match_search( + search: Mapping[str, Any], + src: SigDBMatcher, +) -> SigDBMatchResult: ... + + +@overload +def match_search( + search: Mapping[str, Any], + src: SigDBDatabase, +) -> SigDBMatchResult: ... + + +@overload +def match_search( + search: Mapping[str, Any], + src: SigDBReader, +) -> SigDBMatchResult: ... + + +def match_search(search: Mapping[str, Any], src: object) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match_search(search) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match_search(search) + if isinstance(src, SigDBReader): + return src.matcher().match_search(search) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + class SigDBReader: def __init__(self, path: str | Path) -> None: self._path = Path(path) @@ -175,3 +312,15 @@ def matcher( def match(self, head: str) -> SigDBMatchResult: return self.matcher().match(head) + + def match_group( + self, + group: str, + value: str, + *, + name: str | None = None, + ) -> SigDBMatchResult: + return self.matcher().match_group(group, value, name=name) + + def match_search(self, search: Mapping[str, Any]) -> SigDBMatchResult: + return self.matcher().match_search(search) diff --git a/src/sigdb/format/trie.py b/src/sigdb/format/trie.py index 22347ca..bf0bb76 100644 --- a/src/sigdb/format/trie.py +++ b/src/sigdb/format/trie.py @@ -16,6 +16,14 @@ sign_hash, verify_hash_signature, ) +from sigdb.groups import ( + SIGDB_GROUPS, + SIGDB_GROUPS_MAP, + format_list_pattern, + format_map_pattern, + parse_string_list, + parse_string_map, +) from sigdb.storage import read_exact from sigdb.types import ( Automaton, @@ -290,6 +298,7 @@ def _compile_rules(rules: object) -> tuple[list[SigDBItem], dict[bytes, list[int # { "nginx": {"headers": {"Server": "nginx"}}, ... } items: list[SigDBItem] = [] rules_map = cast(Mapping[object, object], rules) + patterns: dict[bytes, list[int]] = {} for key_any, value_any in rules_map.items(): if not isinstance(key_any, str) or not key_any: raise SigDBFormatError("rule keys must be non-empty strings") @@ -297,15 +306,20 @@ def _compile_rules(rules: object) -> tuple[list[SigDBItem], dict[bytes, list[int raise SigDBFormatError("rule value must be an object") key = key_any value = cast(Mapping[str, Any], value_any) - headers = _parse_headers(value.get("headers", {})) + headers = parse_string_map(value.get("headers", {}), "headers") + item_id = len(items) items.append(SigDBItem(key=key, headers=headers)) - patterns: dict[bytes, list[int]] = {} - for item_id, item in enumerate(items): - for header_name, needle in item.headers.items(): - # Pattern is "header:needle" and matches as a substring in "header:value". - pattern = f"{header_name}:{needle}".lower().encode("utf-8") - patterns.setdefault(pattern, []).append(item_id) + _add_map_patterns(patterns, "headers", headers, item_id) + for group in SIGDB_GROUPS: + if group == "headers": + continue + if group in SIGDB_GROUPS_MAP: + group_map = parse_string_map(value.get(group), group) + _add_map_patterns(patterns, group, group_map, item_id) + else: + group_values = parse_string_list(value.get(group), group) + _add_list_patterns(patterns, group, group_values, item_id) for pattern, ids in patterns.items(): if len(ids) > 1: @@ -314,18 +328,33 @@ def _compile_rules(rules: object) -> tuple[list[SigDBItem], dict[bytes, list[int return items, patterns -def _parse_headers(value: object) -> dict[str, str]: - if value is None: - return {} - if not isinstance(value, Mapping): - raise SigDBFormatError("headers must be an object") - m = cast(Mapping[object, object], value) - out: dict[str, str] = {} - for k, v in m.items(): - if not isinstance(k, str) or not isinstance(v, str): - raise SigDBFormatError("headers keys/values must be strings") - out[k] = v - return out +def _add_pattern( + patterns: dict[bytes, list[int]], + pattern: str, + item_id: int, +) -> None: + pattern_bytes = pattern.strip().lower().encode("utf-8") + patterns.setdefault(pattern_bytes, []).append(item_id) + + +def _add_map_patterns( + patterns: dict[bytes, list[int]], + group: str, + values: Mapping[str, str], + item_id: int, +) -> None: + for key, needle in values.items(): + _add_pattern(patterns, format_map_pattern(group, key, needle), item_id) + + +def _add_list_patterns( + patterns: dict[bytes, list[int]], + group: str, + values: list[str], + item_id: int, +) -> None: + for needle in values: + _add_pattern(patterns, format_list_pattern(group, needle), item_id) def _parse_items(data: bytes) -> list[SigDBItem]: @@ -347,7 +376,7 @@ def _parse_items(data: bytes) -> list[SigDBItem]: key_any, headers_any = entry_list[0], entry_list[1] if not isinstance(key_any, str) or not key_any: raise SigDBFormatError("item key must be a non-empty string") - headers = _parse_headers(headers_any) + headers = parse_string_map(headers_any, "headers") items.append(SigDBItem(key=key_any, headers=headers)) return items diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py new file mode 100644 index 0000000..90219f9 --- /dev/null +++ b/src/sigdb/groups.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence + +from sigdb.types import SigDBFormatError + +SIGDB_GROUPS: tuple[str, ...] = ( + "headers", +) + +SIGDB_GROUPS_MAP: frozenset[str] = frozenset( + { + "headers", + } +) + + +def parse_string_map(value: object, group: str) -> dict[str, str]: + if value is None: + return {} + if not isinstance(value, Mapping): + raise SigDBFormatError(f"{group} must be an object") + out: dict[str, str] = {} + for k, v in value.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise SigDBFormatError(f"{group} keys/values must be strings") + out[k] = v + return out + + +def parse_string_list(value: object, group: str) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value] + if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray)): + out: list[str] = [] + for item in value: + if not isinstance(item, str): + raise SigDBFormatError(f"{group} items must be strings") + out.append(item) + return out + raise SigDBFormatError(f"{group} must be a string or list of strings") + + +def format_map_pattern(group: str, key: str, value: str) -> str: + key_s = key.strip() + value_s = value.strip() + if group == "headers": + return f"{key_s}:{value_s}" + return f"{group}:{key_s}:{value_s}" + + +def format_list_pattern(group: str, value: str) -> str: + return f"{group}:{value.strip()}" From ac872dcf6bc767700e25ee894ecc1a5903f4ab8c Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:49:00 +0300 Subject: [PATCH 02/36] feat: add js group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 90219f9..8a16165 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -6,6 +6,7 @@ SIGDB_GROUPS: tuple[str, ...] = ( "headers", + "js", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 2e35af51df84f8411058f4aaec8b85d88eaa2e54 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:49:18 +0300 Subject: [PATCH 03/36] feat: add meta group --- src/sigdb/groups.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 8a16165..6dd31f3 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -7,11 +7,13 @@ SIGDB_GROUPS: tuple[str, ...] = ( "headers", "js", + "meta", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( { "headers", + "meta", } ) From 4fe18e84fe28301624fbb838342c3f12883604ce Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:49:39 +0300 Subject: [PATCH 04/36] feat: add html group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 6dd31f3..3208f22 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -8,6 +8,7 @@ "headers", "js", "meta", + "html", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 6b1e3bf95443e3806a984a1ecdf9dbfcdd5986ac Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:49:56 +0300 Subject: [PATCH 05/36] feat: add script_src group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 3208f22..e6512ae 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -9,6 +9,7 @@ "js", "meta", "html", + "script_src", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 0cb1be729f62e6ad88cec9d39c6a9e3296e88dc1 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:50:12 +0300 Subject: [PATCH 06/36] feat: add css group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index e6512ae..17d64ce 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -10,6 +10,7 @@ "meta", "html", "script_src", + "css", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From c2dc587ddf6632202111d5a98cc6c184ddbfb99c Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:50:31 +0300 Subject: [PATCH 07/36] feat: add url group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 17d64ce..c689477 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -11,6 +11,7 @@ "html", "script_src", "css", + "url", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From daf88c3b8fdc081021c7a06f42056cf69445c74b Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:50:57 +0300 Subject: [PATCH 08/36] feat: add path group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index c689477..824603d 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -12,6 +12,7 @@ "script_src", "css", "url", + "path", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 1f30b25b85f405ea59ec5c232cd54255bba940de Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:51:16 +0300 Subject: [PATCH 09/36] feat: add file group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 824603d..c5b75c4 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -13,6 +13,7 @@ "css", "url", "path", + "file", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From bd201c7797050f3012e403c582cb355d04be189b Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:51:35 +0300 Subject: [PATCH 10/36] feat: add dns group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index c5b75c4..c17186a 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -14,6 +14,7 @@ "url", "path", "file", + "dns", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 1cbaf9590d6471ffc2d0f56e84a64c7f70b6b715 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:51:52 +0300 Subject: [PATCH 11/36] feat: add subdomain group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index c17186a..41ab05e 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -15,6 +15,7 @@ "path", "file", "dns", + "subdomain", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 21a52347cfca2fc3d6f46f9a03366fbfc6299df2 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:52:08 +0300 Subject: [PATCH 12/36] feat: add link group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 41ab05e..0edaaf9 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -16,6 +16,7 @@ "file", "dns", "subdomain", + "link", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 3b20a05121bb4247ae73ab371abcfa7b872ba65d Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:52:28 +0300 Subject: [PATCH 13/36] feat: add json group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 0edaaf9..b849e21 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -17,6 +17,7 @@ "dns", "subdomain", "link", + "json", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 7e7b95ff6f1d40045d0aa45a3a788a2b66ef2beb Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:52:45 +0300 Subject: [PATCH 14/36] feat: add api group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index b849e21..adc7108 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -18,6 +18,7 @@ "subdomain", "link", "json", + "api", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From f0157cd0b987c9cf6777aeee9ae1ff46fdfaaab8 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:53:02 +0300 Subject: [PATCH 15/36] feat: add tls group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index adc7108..1b4f712 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -19,6 +19,7 @@ "link", "json", "api", + "tls", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From eb88ee913327979a6e1fb4d9c0c57f8acb4ae008 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:53:20 +0300 Subject: [PATCH 16/36] feat: add server group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 1b4f712..0cfcb02 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -20,6 +20,7 @@ "json", "api", "tls", + "server", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 16dcde4c47d901735fa93cf7808d3ffe381de8cd Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:53:37 +0300 Subject: [PATCH 17/36] feat: add framework group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 0cfcb02..dfc1544 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -21,6 +21,7 @@ "api", "tls", "server", + "framework", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From f6ed02c95f1001ddff197aa8961891f49d60daf5 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:54:04 +0300 Subject: [PATCH 18/36] feat: add cms group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index dfc1544..01a8d86 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -22,6 +22,7 @@ "tls", "server", "framework", + "cms", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 8ab2db0cdefbdf3b11a8b92e43d9d1a8f7fa7084 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:54:22 +0300 Subject: [PATCH 19/36] feat: add cdn group --- src/sigdb/groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index 01a8d86..cf5dc61 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -23,6 +23,7 @@ "server", "framework", "cms", + "cdn", ) SIGDB_GROUPS_MAP: frozenset[str] = frozenset( From 7cd312785be0de1879cdb5be83e3061992a404b5 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 16:57:16 +0300 Subject: [PATCH 20/36] feat: add rule and search types --- src/sigdb/core/api.py | 9 ++++- src/sigdb/core/compiler.py | 4 +- src/sigdb/core/reader.py | 29 +++++++------- src/sigdb/format/trie.py | 3 +- src/sigdb/types/__init__.py | 14 +++++++ src/sigdb/types/rules.py | 77 +++++++++++++++++++++++++++++++++++++ 6 files changed, 117 insertions(+), 19 deletions(-) create mode 100644 src/sigdb/types/rules.py diff --git a/src/sigdb/core/api.py b/src/sigdb/core/api.py index b40006a..e89107c 100644 --- a/src/sigdb/core/api.py +++ b/src/sigdb/core/api.py @@ -4,12 +4,17 @@ from pathlib import Path from typing import Any -from sigdb.types import SigDBBuildResult, SigDBDatabase, SigDBValidationResult +from sigdb.types import ( + SigDBBuildResult, + SigDBDatabase, + SigDBRules, + SigDBValidationResult, +) def build_sigdb( *, - rules: object, + rules: SigDBRules, output_path: str | Path, metadata: Mapping[str, Any] | None = None, signing_key_hex: str | None = None, diff --git a/src/sigdb/core/compiler.py b/src/sigdb/core/compiler.py index b07f3ea..02b15db 100644 --- a/src/sigdb/core/compiler.py +++ b/src/sigdb/core/compiler.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, cast -from sigdb.types import SigDBBuildResult, SigDBFormatError +from sigdb.types import SigDBBuildResult, SigDBFormatError, SigDBRules def compile_sigdb_json( @@ -25,7 +25,7 @@ def compile_sigdb_json( from sigdb.format.trie import build_sigdb return build_sigdb( - rules=cast(object, rules_any), + rules=cast(SigDBRules, rules_any), output_path=output_path, metadata=metadata, signing_key_hex=signing_key_hex, diff --git a/src/sigdb/core/reader.py b/src/sigdb/core/reader.py index 835e11c..9fd7f22 100644 --- a/src/sigdb/core/reader.py +++ b/src/sigdb/core/reader.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Mapping from pathlib import Path from typing import Any, overload @@ -16,7 +15,9 @@ from sigdb.types import ( SigDBDatabase, SigDBFormatError, + SigDBGroupName, SigDBMatchResult, + SigDBSearchDefinition, SigDBValidationResult, ) @@ -31,7 +32,7 @@ def _normalize_head(head: str) -> str: return f"{name}:{value}" -def _iter_search_heads(search: Mapping[str, Any]) -> list[str]: +def _iter_search_heads(search: SigDBSearchDefinition) -> list[str]: heads: list[str] = [] headers = parse_string_map(search.get("headers"), "headers") for header_name, header_value in headers.items(): @@ -112,7 +113,7 @@ def match(self, head: str) -> SigDBMatchResult: def match_group( self, - group: str, + group: SigDBGroupName, value: str, *, name: str | None = None, @@ -129,7 +130,7 @@ def match_group( head = format_map_pattern(group, name, value) return self.match(head) - def match_search(self, search: Mapping[str, Any]) -> SigDBMatchResult: + def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: for head in _iter_search_heads(search): result = self.match(head) if result.result: @@ -161,7 +162,7 @@ def match(head: str, src: object) -> SigDBMatchResult: @overload def match_group( - group: str, + group: SigDBGroupName, value: str, src: SigDBMatcher, *, @@ -171,7 +172,7 @@ def match_group( @overload def match_group( - group: str, + group: SigDBGroupName, value: str, src: SigDBDatabase, *, @@ -181,7 +182,7 @@ def match_group( @overload def match_group( - group: str, + group: SigDBGroupName, value: str, src: SigDBReader, *, @@ -190,7 +191,7 @@ def match_group( def match_group( - group: str, + group: SigDBGroupName, value: str, src: object, *, @@ -207,26 +208,26 @@ def match_group( @overload def match_search( - search: Mapping[str, Any], + search: SigDBSearchDefinition, src: SigDBMatcher, ) -> SigDBMatchResult: ... @overload def match_search( - search: Mapping[str, Any], + search: SigDBSearchDefinition, src: SigDBDatabase, ) -> SigDBMatchResult: ... @overload def match_search( - search: Mapping[str, Any], + search: SigDBSearchDefinition, src: SigDBReader, ) -> SigDBMatchResult: ... -def match_search(search: Mapping[str, Any], src: object) -> SigDBMatchResult: +def match_search(search: SigDBSearchDefinition, src: object) -> SigDBMatchResult: if isinstance(src, SigDBMatcher): return src.match_search(search) if isinstance(src, SigDBDatabase): @@ -315,12 +316,12 @@ def match(self, head: str) -> SigDBMatchResult: def match_group( self, - group: str, + group: SigDBGroupName, value: str, *, name: str | None = None, ) -> SigDBMatchResult: return self.matcher().match_group(group, value, name=name) - def match_search(self, search: Mapping[str, Any]) -> SigDBMatchResult: + def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: return self.matcher().match_search(search) diff --git a/src/sigdb/format/trie.py b/src/sigdb/format/trie.py index bf0bb76..024c0ea 100644 --- a/src/sigdb/format/trie.py +++ b/src/sigdb/format/trie.py @@ -32,6 +32,7 @@ SigDBFormatError, SigDBIntegrityError, SigDBItem, + SigDBRules, SigDBSignatureError, SigDBValidationResult, ) @@ -73,7 +74,7 @@ def read_sigdb_metadata(path: str | Path) -> dict[str, Any]: def build_sigdb( *, - rules: object, + rules: SigDBRules, output_path: str | Path, metadata: Mapping[str, Any] | None = None, signing_key_hex: str | None = None, diff --git a/src/sigdb/types/__init__.py b/src/sigdb/types/__init__.py index 3ea6ba8..2e63b04 100644 --- a/src/sigdb/types/__init__.py +++ b/src/sigdb/types/__init__.py @@ -15,6 +15,14 @@ SigDBMatchResult, SigDBValidationResult, ) +from sigdb.types.rules import ( + SigDBGroupListName, + SigDBGroupMapName, + SigDBGroupName, + SigDBRuleDefinition, + SigDBRules, + SigDBSearchDefinition, +) __all__ = [ "Automaton", @@ -26,6 +34,12 @@ "SigDBIntegrityError", "SigDBItem", "SigDBMatchResult", + "SigDBGroupListName", + "SigDBGroupMapName", + "SigDBGroupName", + "SigDBRuleDefinition", + "SigDBRules", + "SigDBSearchDefinition", "SigDBSignatureError", "SigDBValidationResult", ] diff --git a/src/sigdb/types/rules.py b/src/sigdb/types/rules.py new file mode 100644 index 0000000..737fcd4 --- /dev/null +++ b/src/sigdb/types/rules.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Literal, TypeAlias, TypedDict + +SigDBGroupName: TypeAlias = Literal[ + "headers", + "js", + "meta", + "html", + "script_src", + "css", + "url", + "path", + "file", + "dns", + "subdomain", + "link", + "json", + "api", + "tls", + "server", + "framework", + "cms", + "cdn", +] + +SigDBGroupMapName: TypeAlias = Literal["headers", "meta"] + +SigDBGroupListName: TypeAlias = Literal[ + "js", + "html", + "script_src", + "css", + "url", + "path", + "file", + "dns", + "subdomain", + "link", + "json", + "api", + "tls", + "server", + "framework", + "cms", + "cdn", +] + +SigDBStringList: TypeAlias = Sequence[str] | str +SigDBStringMap: TypeAlias = Mapping[str, str] + + +class SigDBRuleDefinition(TypedDict, total=False): + headers: SigDBStringMap + js: SigDBStringList + meta: SigDBStringMap + html: SigDBStringList + script_src: SigDBStringList + css: SigDBStringList + url: SigDBStringList + path: SigDBStringList + file: SigDBStringList + dns: SigDBStringList + subdomain: SigDBStringList + link: SigDBStringList + json: SigDBStringList + api: SigDBStringList + tls: SigDBStringList + server: SigDBStringList + framework: SigDBStringList + cms: SigDBStringList + cdn: SigDBStringList + + +SigDBRules: TypeAlias = Mapping[str, SigDBRuleDefinition] +SigDBSearchDefinition: TypeAlias = SigDBRuleDefinition From f96dd5867743e26e7d1ab578572de3ce2cfc974e Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 17:03:19 +0300 Subject: [PATCH 21/36] feat: add html tag matching --- src/sigdb/core/__init__.py | 10 +++- src/sigdb/core/reader.py | 37 +++++++++++++- src/sigdb/format/trie.py | 4 +- src/sigdb/groups.py | 102 ++++++++++++++++++++++++++++++++++++- 4 files changed, 147 insertions(+), 6 deletions(-) diff --git a/src/sigdb/core/__init__.py b/src/sigdb/core/__init__.py index 0d4d52d..e0c63a5 100644 --- a/src/sigdb/core/__init__.py +++ b/src/sigdb/core/__init__.py @@ -7,7 +7,14 @@ read_sigdb_metadata, validate_sigdb, ) -from sigdb.core.reader import SigDBMatcher, SigDBReader, match, match_group, match_search +from sigdb.core.reader import ( + SigDBMatcher, + SigDBReader, + match, + match_group, + match_html, + match_search, +) __all__ = [ "SigDBMatcher", @@ -17,6 +24,7 @@ "load_sigdb", "match", "match_group", + "match_html", "match_search", "read_sigdb_metadata", "validate_sigdb", diff --git a/src/sigdb/core/reader.py b/src/sigdb/core/reader.py index 9fd7f22..df868af 100644 --- a/src/sigdb/core/reader.py +++ b/src/sigdb/core/reader.py @@ -9,7 +9,8 @@ SIGDB_GROUPS_MAP, format_list_pattern, format_map_pattern, - parse_string_list, + html_heads, + parse_group_list, parse_string_map, ) from sigdb.types import ( @@ -46,7 +47,7 @@ def _iter_search_heads(search: SigDBSearchDefinition) -> list[str]: for name, value in group_map.items(): heads.append(format_map_pattern(group, name, value)) else: - group_values = parse_string_list(search.get(group), group) + group_values = parse_group_list(search.get(group), group) for value in group_values: heads.append(format_list_pattern(group, value)) return heads @@ -137,6 +138,13 @@ def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: return result return SigDBMatchResult(result=False, item_id=None, item=None, head="") + def match_html(self, html: str) -> SigDBMatchResult: + for head in html_heads(html): + result = self.match(head) + if result.result: + return result + return SigDBMatchResult(result=False, item_id=None, item=None, head="") + @overload def match(head: str, src: SigDBMatcher) -> SigDBMatchResult: ... @@ -237,6 +245,28 @@ def match_search(search: SigDBSearchDefinition, src: object) -> SigDBMatchResult raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") +@overload +def match_html(html: str, src: SigDBMatcher) -> SigDBMatchResult: ... + + +@overload +def match_html(html: str, src: SigDBDatabase) -> SigDBMatchResult: ... + + +@overload +def match_html(html: str, src: SigDBReader) -> SigDBMatchResult: ... + + +def match_html(html: str, src: object) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match_html(html) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match_html(html) + if isinstance(src, SigDBReader): + return src.matcher().match_html(html) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + class SigDBReader: def __init__(self, path: str | Path) -> None: self._path = Path(path) @@ -325,3 +355,6 @@ def match_group( def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: return self.matcher().match_search(search) + + def match_html(self, html: str) -> SigDBMatchResult: + return self.matcher().match_html(html) diff --git a/src/sigdb/format/trie.py b/src/sigdb/format/trie.py index 024c0ea..5ef5d13 100644 --- a/src/sigdb/format/trie.py +++ b/src/sigdb/format/trie.py @@ -21,7 +21,7 @@ SIGDB_GROUPS_MAP, format_list_pattern, format_map_pattern, - parse_string_list, + parse_group_list, parse_string_map, ) from sigdb.storage import read_exact @@ -319,7 +319,7 @@ def _compile_rules(rules: object) -> tuple[list[SigDBItem], dict[bytes, list[int group_map = parse_string_map(value.get(group), group) _add_map_patterns(patterns, group, group_map, item_id) else: - group_values = parse_string_list(value.get(group), group) + group_values = parse_group_list(value.get(group), group) _add_list_patterns(patterns, group, group_values, item_id) for pattern, ids in patterns.items(): diff --git a/src/sigdb/groups.py b/src/sigdb/groups.py index cf5dc61..d3083ff 100644 --- a/src/sigdb/groups.py +++ b/src/sigdb/groups.py @@ -1,6 +1,7 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +import re +from collections.abc import Iterable, Mapping, Sequence from sigdb.types import SigDBFormatError @@ -33,6 +34,11 @@ } ) +_HTML_TAG_RE = re.compile(r"<\s*([a-zA-Z][\w:-]*)\b([^<>]*)>", re.IGNORECASE) +_HTML_ATTR_RE = re.compile( + r'([a-zA-Z_:][\w:.-]*)(?:\s*=\s*(?:"([^"]*)"|\'([^\']*)\'|([^\s"\'=<>`]+)))?' +) + def parse_string_map(value: object, group: str) -> dict[str, str]: if value is None: @@ -62,6 +68,100 @@ def parse_string_list(value: object, group: str) -> list[str]: raise SigDBFormatError(f"{group} must be a string or list of strings") +def parse_group_list(value: object, group: str) -> list[str]: + if group == "html": + return parse_html_list(value) + return parse_string_list(value, group) + + +def parse_html_list(value: object) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value] + if isinstance(value, Mapping): + return [_html_spec_to_value(value)] + if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray)): + out: list[str] = [] + for item in value: + if isinstance(item, str): + out.append(item) + continue + if isinstance(item, Mapping): + out.append(_html_spec_to_value(item)) + continue + raise SigDBFormatError("html items must be strings or objects") + return out + raise SigDBFormatError("html must be a string, object, or list") + + +def _html_spec_to_value(spec: Mapping[object, object]) -> str: + allowed = {"tag", "attr", "value"} + for key in spec.keys(): + if not isinstance(key, str) or key not in allowed: + raise SigDBFormatError("html spec has invalid keys") + + tag = spec.get("tag") + attr = spec.get("attr") + value = spec.get("value") + + if tag is not None and (not isinstance(tag, str) or not tag): + raise SigDBFormatError("html tag must be a non-empty string") + if attr is not None and (not isinstance(attr, str) or not attr): + raise SigDBFormatError("html attr must be a non-empty string") + if value is not None and not isinstance(value, str): + raise SigDBFormatError("html value must be a string") + if value is not None and attr is None: + raise SigDBFormatError("html value requires attr") + if tag is None and attr is None and value is None: + raise SigDBFormatError("html spec must include tag or attr") + + parts: list[str] = [] + if tag is not None: + parts.extend(("tag", tag)) + if attr is not None: + parts.extend(("attr", attr)) + if value is not None: + parts.extend(("value", value)) + return ":".join(parts) + + +def html_heads(html: str) -> list[str]: + heads: list[str] = [] + seen: set[str] = set() + + def add(value: str) -> None: + if value in seen: + return + seen.add(value) + heads.append(value) + + for tag, attrs in _iter_html_tags(html): + add(format_list_pattern("html", f"tag:{tag}")) + for name, value in attrs: + add(format_list_pattern("html", f"tag:{tag}:attr:{name}")) + if value is not None: + add(format_list_pattern("html", f"tag:{tag}:attr:{name}:value:{value}")) + add(format_list_pattern("html", f"attr:{name}")) + if value is not None: + add(format_list_pattern("html", f"attr:{name}:value:{value}")) + return heads + + +def _iter_html_tags(html: str) -> Iterable[tuple[str, list[tuple[str, str | None]]]]: + for match in _HTML_TAG_RE.finditer(html): + tag = match.group(1) + attrs_raw = match.group(2) + attrs: list[tuple[str, str | None]] = [] + for attr in _HTML_ATTR_RE.finditer(attrs_raw): + name = attr.group(1) + value = attr.group(2) or attr.group(3) or attr.group(4) + if value is not None: + value = value.strip() + attrs.append((name, value)) + yield tag, attrs + + def format_map_pattern(group: str, key: str, value: str) -> str: key_s = key.strip() value_s = value.strip() From 1ea0e03ae75b82572a8310189e26a5e67ddeea08 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 17:03:51 +0300 Subject: [PATCH 22/36] chore: update html rule types --- src/sigdb/types/__init__.py | 6 ++++++ src/sigdb/types/rules.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/sigdb/types/__init__.py b/src/sigdb/types/__init__.py index 2e63b04..7878865 100644 --- a/src/sigdb/types/__init__.py +++ b/src/sigdb/types/__init__.py @@ -19,6 +19,9 @@ SigDBGroupListName, SigDBGroupMapName, SigDBGroupName, + SigDBHtmlList, + SigDBHtmlPattern, + SigDBHtmlSpec, SigDBRuleDefinition, SigDBRules, SigDBSearchDefinition, @@ -37,6 +40,9 @@ "SigDBGroupListName", "SigDBGroupMapName", "SigDBGroupName", + "SigDBHtmlList", + "SigDBHtmlPattern", + "SigDBHtmlSpec", "SigDBRuleDefinition", "SigDBRules", "SigDBSearchDefinition", diff --git a/src/sigdb/types/rules.py b/src/sigdb/types/rules.py index 737fcd4..cd3971b 100644 --- a/src/sigdb/types/rules.py +++ b/src/sigdb/types/rules.py @@ -51,11 +51,21 @@ SigDBStringMap: TypeAlias = Mapping[str, str] +class SigDBHtmlSpec(TypedDict, total=False): + tag: str + attr: str + value: str + + +SigDBHtmlPattern: TypeAlias = SigDBHtmlSpec | str +SigDBHtmlList: TypeAlias = Sequence[SigDBHtmlPattern] | SigDBHtmlPattern + + class SigDBRuleDefinition(TypedDict, total=False): headers: SigDBStringMap js: SigDBStringList meta: SigDBStringMap - html: SigDBStringList + html: SigDBHtmlList script_src: SigDBStringList css: SigDBStringList url: SigDBStringList From 418dcdb123958b21dbb30a3131b9a564d7280bc0 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 17:05:15 +0300 Subject: [PATCH 23/36] refactor: move groups into internal --- src/sigdb/core/reader.py | 2 +- src/sigdb/format/trie.py | 2 +- src/sigdb/internal/__init__.py | 1 + src/sigdb/{ => internal}/groups.py | 0 4 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 src/sigdb/internal/__init__.py rename src/sigdb/{ => internal}/groups.py (100%) diff --git a/src/sigdb/core/reader.py b/src/sigdb/core/reader.py index df868af..3afbb4d 100644 --- a/src/sigdb/core/reader.py +++ b/src/sigdb/core/reader.py @@ -4,7 +4,7 @@ from typing import Any, overload from sigdb.format.trie import load_sigdb, read_sigdb_metadata, validate_sigdb -from sigdb.groups import ( +from sigdb.internal.groups import ( SIGDB_GROUPS, SIGDB_GROUPS_MAP, format_list_pattern, diff --git a/src/sigdb/format/trie.py b/src/sigdb/format/trie.py index 5ef5d13..f1c1228 100644 --- a/src/sigdb/format/trie.py +++ b/src/sigdb/format/trie.py @@ -16,7 +16,7 @@ sign_hash, verify_hash_signature, ) -from sigdb.groups import ( +from sigdb.internal.groups import ( SIGDB_GROUPS, SIGDB_GROUPS_MAP, format_list_pattern, diff --git a/src/sigdb/internal/__init__.py b/src/sigdb/internal/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/sigdb/internal/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/sigdb/groups.py b/src/sigdb/internal/groups.py similarity index 100% rename from src/sigdb/groups.py rename to src/sigdb/internal/groups.py From 21cdab019a3f2d924b0145babf9ecbfcce72f635 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 17:05:19 +0300 Subject: [PATCH 24/36] chore: bump version to 1.0.2 --- src/sigdb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sigdb/__init__.py b/src/sigdb/__init__.py index 6114b4b..4dd6c99 100644 --- a/src/sigdb/__init__.py +++ b/src/sigdb/__init__.py @@ -12,4 +12,4 @@ "utils", ] -__version__ = "1.0.1" \ No newline at end of file +__version__ = "1.0.2" From fc4680de83d7a26615f29cae222af8be32719cd8 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Sun, 15 Mar 2026 17:11:48 +0300 Subject: [PATCH 25/36] chore: change `ruff.line-length` from 88 to 100 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a58f122..6946a20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ typecheck = "pyright" [tool.ruff] target-version = "py311" -line-length = 88 +line-length = 100 [tool.ruff.format] quote-style = "double" From 86e42ab86a3532d4b95d0a84442c763687ed8be7 Mon Sep 17 00:00:00 2001 From: "reekeer[bot]" Date: Mon, 16 Mar 2026 17:29:45 +0300 Subject: [PATCH 26/36] chore: add reekeerBot --- .gitignore | 4 + scripts/reekeerBot.py | 369 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 373 insertions(+) create mode 100644 scripts/reekeerBot.py diff --git a/.gitignore b/.gitignore index d594933..aabcd3a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ __pycache__/ venv/ .env/ +# Enviroments +*.env +*.pem + # Packaging / build outputs build/ dist/ diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py new file mode 100644 index 0000000..5345287 --- /dev/null +++ b/scripts/reekeerBot.py @@ -0,0 +1,369 @@ +import datetime +import json +import os +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +import jwt +import requests +from dotenv import load_dotenv + +load_dotenv() + +API = "https://api.github.com" +REPO_ROOT = Path(__file__).resolve().parent.parent + + +@dataclass(frozen=True) +class Config: + repo: str + base_branch: str + reviewer: str | None + app_id: int + installation_id: int + private_key_path: Path + + +def _require_env(name: str) -> str: + value = os.getenv(name) + if not value: + raise RuntimeError(f"Missing required env var: {name}") + return value + + +def run_cmd( + args: list[str], + *, + capture: bool = False, + check: bool = True, +) -> str | None: + result = subprocess.run(args, capture_output=capture, text=True) + if check and result.returncode != 0: + stderr = (result.stderr or "").strip() + raise RuntimeError(f"Command failed ({result.returncode}): {args}\n{stderr}") + return result.stdout.strip() if capture else None + + +def run_shell( + cmd: str, + *, + capture: bool = False, + check: bool = True, +) -> str | None: + result = subprocess.run(cmd, shell=True, capture_output=capture, text=True) + if check and result.returncode != 0: + stderr = (result.stderr or "").strip() + raise RuntimeError(f"Command failed ({result.returncode}): {cmd}\n{stderr}") + return result.stdout.strip() if capture else None + + +def git(args: list[str]) -> None: + run_cmd(["git", *args], check=True) + + +def load_json(path: Path) -> Any | None: + if not path.exists(): + return None + try: + with path.open(encoding="utf-8") as f: + return json.load(f) + except (OSError, ValueError): + return None + + +def create_jwt(*, app_id: int, private_key_path: Path) -> str: + private_key = private_key_path.read_text(encoding="utf-8") + payload = { + "iat": int(time.time()) - 60, + "exp": int(time.time()) + 600, + # PyJWT expects "iss" to be a string. + "iss": str(app_id), + } + return jwt.encode(payload, private_key, algorithm="RS256") # pyright: ignore[reportUnknownMemberType] + + +def installation_token(config: Config) -> str: + jwt_token = create_jwt(app_id=config.app_id, private_key_path=config.private_key_path) + headers = { + "Authorization": f"Bearer {jwt_token}", + "Accept": "application/vnd.github+json", + } + url = f"{API}/app/installations/{config.installation_id}/access_tokens" + r = requests.post(url, headers=headers, timeout=30) + r.raise_for_status() + token = r.json().get("token") + if not isinstance(token, str) or not token: + raise RuntimeError("GitHub installation token response is missing a token") + return token + + +def gh_headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + +def gh_request(token: str, method: str, url: str, *, data: dict[str, Any] | None = None) -> Any: + r = requests.request(method, url, headers=gh_headers(token), json=data, timeout=30) + try: + r.raise_for_status() + except requests.HTTPError as exc: + body = (r.text or "").strip() + raise RuntimeError(f"GitHub API error {r.status_code} for {method} {url}: {body}") from exc + if r.status_code == 204: + return None + return r.json() + + +def git_setup() -> None: + git(["config", "user.name", "reekeer[bot]"]) + git(["config", "user.email", "reekeer[bot]@users.noreply.github.com"]) + + +def create_branch() -> str: + ts = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d%H%M%S") + branch = f"bot/code-quality-{ts}" + git(["checkout", "-b", branch]) + return branch + + +def commit_if_changes(message: str) -> bool: + git(["add", "-A"]) + changed = subprocess.run(["git", "diff", "--cached", "--quiet"]) + if changed.returncode != 0: + git(["commit", "-m", message]) + return True + return False + + +def push(*, token: str, repo: str, branch: str) -> None: + git(["push", f"https://x-access-token:{token}@github.com/{repo}.git", branch]) + + +def ruff_fix() -> None: + run_shell("ruff check . --fix", check=False) + commit_if_changes("style(ruff): auto-fix lint issues") + run_shell("ruff check . --output-format=json > ruff.json || true", check=False) + + +def black_fix() -> None: + run_shell("black .", check=False) + commit_if_changes("style(black): format code") + + +def pyright_scan() -> None: + run_shell("pyright --outputjson > pyright.json || true", check=False) + + +def _as_dict_list(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + out: list[dict[str, Any]] = [] + for item in cast(list[Any], value): + if isinstance(item, dict): + out.append(cast(dict[str, Any], item)) + return out + + +def ruff_errors() -> list[dict[str, Any]]: + data = load_json(Path("ruff.json")) + return _as_dict_list(data) + + +def pyright_errors() -> list[dict[str, Any]]: + data = load_json(Path("pyright.json")) + if not isinstance(data, dict): + return [] + data_d = cast(dict[str, Any], data) + diags = data_d.get("generalDiagnostics") + return _as_dict_list(diags) + + +def create_pr(*, token: str, repo: str, branch: str, base_branch: str) -> dict[str, Any]: + url = f"{API}/repos/{repo}/pulls" + pr = gh_request( + token, + "POST", + url, + data={ + "title": "chore: automated code quality fixes", + "head": branch, + "base": base_branch, + "body": "Automated fixes from reekeerBot", + }, + ) + if not isinstance(pr, dict): + raise RuntimeError("Unexpected response from create PR") + return cast(dict[str, Any], pr) + + +def add_reviewer(*, token: str, repo: str, pr_number: int, reviewer: str | None) -> None: + if not reviewer: + return + url = f"{API}/repos/{repo}/pulls/{pr_number}/requested_reviewers" + gh_request(token, "POST", url, data={"reviewers": [reviewer]}) + + +def comment_pr(*, token: str, repo: str, pr_number: int, text: str) -> None: + url = f"{API}/repos/{repo}/issues/{pr_number}/comments" + gh_request(token, "POST", url, data={"body": text}) + + +def pr_head_sha(*, token: str, repo: str, pr_number: int) -> str: + url = f"{API}/repos/{repo}/pulls/{pr_number}" + pr = gh_request(token, "GET", url) + if not isinstance(pr, dict): + raise RuntimeError("Unexpected response from PR details") + pr_d = cast(dict[str, Any], pr) + head = pr_d.get("head") + if not isinstance(head, dict): + raise RuntimeError("Unexpected response from PR details (missing head)") + head_d = cast(dict[str, Any], head) + sha = head_d.get("sha") + if not isinstance(sha, str) or not sha: + raise RuntimeError("Could not read PR head SHA") + return sha + + +def _diagnostic_repo_path(file_value: Any) -> str | None: + if not isinstance(file_value, str) or not file_value: + return None + p = Path(file_value) + if not p.is_absolute(): + p = (REPO_ROOT / p).resolve() + try: + rel = p.relative_to(REPO_ROOT) + except ValueError: + return None + return rel.as_posix() + + +def review_comments(*, token: str, repo: str, pr_number: int, errors: list[dict[str, Any]]) -> None: + url = f"{API}/repos/{repo}/pulls/{pr_number}/comments" + commit_sha = pr_head_sha(token=token, repo=repo, pr_number=pr_number) + + posted = 0 + for e in errors: + if posted >= 30: + break + message = e.get("message") + repo_path = _diagnostic_repo_path(e.get("file")) + start_line = e.get("range", {}).get("start", {}).get("line") + if not isinstance(message, str) or not message: + continue + if repo_path is None: + continue + if not isinstance(start_line, int) or start_line < 0: + continue + + gh_request( + token, + "POST", + url, + data={ + "body": f"Pyright error:\n{message}", + "commit_id": commit_sha, + "path": repo_path, + "line": start_line + 1, # Pyright is 0-based; GitHub is 1-based. + "side": "RIGHT", + }, + ) + posted += 1 + + +def auto_merge(*, token: str, repo: str, pr_number: int) -> None: + url = f"{API}/repos/{repo}/pulls/{pr_number}/merge" + gh_request(token, "PUT", url) + + +def summarize(ruff: list[dict[str, Any]], pyright: list[dict[str, Any]]) -> str: + msg = "# reekeer[bot] Report\n\n" + + if ruff: + msg += "## Ruff issues\n" + for r in ruff[:15]: + filename = r.get("filename") + loc = r.get("location") + row = cast(dict[str, Any], loc).get("row") if isinstance(loc, dict) else None + message = r.get("message") + if isinstance(filename, str) and isinstance(row, int) and isinstance(message, str): + msg += f"- {filename}:{row} {message}\n" + + if pyright: + msg += "\n## Pyright issues\n" + for p in pyright[:15]: + file_value = p.get("file") + message = p.get("message") + if isinstance(file_value, str) and isinstance(message, str): + msg += f"- {file_value} {message}\n" + + if not ruff and not pyright: + msg += "✅ No issues detected" + + return msg + + +def load_config() -> Config: + repo = _require_env("GITHUB_REPOSITORY") + base_branch = os.getenv("BASE_BRANCH", "dev") + reviewer = os.getenv("REVIEWER") or None + app_id = int(_require_env("APP_ID")) + installation_id = int(_require_env("INSTALLATION_ID")) + + private_key_env = os.getenv("PRIVATE_KEY_PATH") + private_key_path = ( + Path(private_key_env).expanduser() if private_key_env else (REPO_ROOT / "private-key.pem") + ) + if not private_key_path.is_absolute(): + private_key_path = (REPO_ROOT / private_key_path).resolve() + if not private_key_path.is_file(): + raise RuntimeError(f"Private key not found: {private_key_path}") + + return Config( + repo=repo, + base_branch=base_branch, + reviewer=reviewer, + app_id=app_id, + installation_id=installation_id, + private_key_path=private_key_path, + ) + + +def main() -> None: + config = load_config() + token = installation_token(config) + + git_setup() + branch = create_branch() + + ruff_fix() + black_fix() + pyright_scan() + + push(token=token, repo=config.repo, branch=branch) + + pr = create_pr(token=token, repo=config.repo, branch=branch, base_branch=config.base_branch) + pr_number = pr.get("number") + if not isinstance(pr_number, int): + raise RuntimeError("Unexpected PR response (missing number)") + + add_reviewer(token=token, repo=config.repo, pr_number=pr_number, reviewer=config.reviewer) + + ruff = ruff_errors() + pyright = pyright_errors() + + comment_pr(token=token, repo=config.repo, pr_number=pr_number, text=summarize(ruff, pyright)) + + if pyright: + review_comments(token=token, repo=config.repo, pr_number=pr_number, errors=pyright) + else: + auto_merge(token=token, repo=config.repo, pr_number=pr_number) + + +if __name__ == "__main__": + main() From a2fe9f542e9fd51bc08612262943b3de8ee41bdc Mon Sep 17 00:00:00 2001 From: "reekeer[bot]" Date: Mon, 16 Mar 2026 17:38:21 +0300 Subject: [PATCH 27/36] fix: some bug and add more checks --- scripts/reekeerBot.py | 49 +++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py index 5345287..8449c3c 100644 --- a/scripts/reekeerBot.py +++ b/scripts/reekeerBot.py @@ -230,6 +230,22 @@ def pr_head_sha(*, token: str, repo: str, pr_number: int) -> str: return sha +def pr_changed_files(*, token: str, repo: str, pr_number: int) -> set[str]: + url = f"{API}/repos/{repo}/pulls/{pr_number}/files?per_page=100" + files: set[str] = set() + + page = gh_request(token, "GET", url) + if not isinstance(page, list): + return files + for item in cast(list[Any], page): + if not isinstance(item, dict): + continue + filename = cast(dict[str, Any], item).get("filename") + if isinstance(filename, str) and filename: + files.add(filename) + return files + + def _diagnostic_repo_path(file_value: Any) -> str | None: if not isinstance(file_value, str) or not file_value: return None @@ -246,6 +262,7 @@ def _diagnostic_repo_path(file_value: Any) -> str | None: def review_comments(*, token: str, repo: str, pr_number: int, errors: list[dict[str, Any]]) -> None: url = f"{API}/repos/{repo}/pulls/{pr_number}/comments" commit_sha = pr_head_sha(token=token, repo=repo, pr_number=pr_number) + changed_files = pr_changed_files(token=token, repo=repo, pr_number=pr_number) posted = 0 for e in errors: @@ -258,22 +275,28 @@ def review_comments(*, token: str, repo: str, pr_number: int, errors: list[dict[ continue if repo_path is None: continue + if changed_files and repo_path not in changed_files: + continue if not isinstance(start_line, int) or start_line < 0: continue - gh_request( - token, - "POST", - url, - data={ - "body": f"Pyright error:\n{message}", - "commit_id": commit_sha, - "path": repo_path, - "line": start_line + 1, # Pyright is 0-based; GitHub is 1-based. - "side": "RIGHT", - }, - ) - posted += 1 + try: + gh_request( + token, + "POST", + url, + data={ + "body": f"Pyright error:\n{message}", + "commit_id": commit_sha, + "path": repo_path, + "line": start_line + 1, + "side": "RIGHT", + }, + ) + except RuntimeError: + continue + else: + posted += 1 def auto_merge(*, token: str, repo: str, pr_number: int) -> None: From 6224daa43ae8b3fea3a5a9947c441795351b7654 Mon Sep 17 00:00:00 2001 From: "reekeer[bot]" Date: Mon, 16 Mar 2026 17:40:34 +0300 Subject: [PATCH 28/36] style(ruff): auto-fix lint issues --- .gitignore | 90 +-- scripts/reekeerBot.py | 784 +++++++++++------------ src/sigdb/core/reader.py | 720 ++++++++++----------- src/sigdb/crypto/ed25519.py | 102 +-- src/sigdb/format/trie.py | 1068 ++++++++++++++++---------------- src/sigdb/types/__init__.py | 102 +-- tests/test_container_errors.py | 182 +++--- tests/test_crypto.py | 168 ++--- tests/test_varint.py | 172 ++--- 9 files changed, 1694 insertions(+), 1694 deletions(-) diff --git a/.gitignore b/.gitignore index aabcd3a..24b47bd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,46 +1,46 @@ -# Python bytecode / caches -__pycache__/ -*.py[cod] -*.pyd - -# Virtual environments -.venv/ -venv/ -.env/ - -# Enviroments -*.env -*.pem - -# Packaging / build outputs -build/ -dist/ -*.egg-info/ -pip-wheel-metadata/ - -# Hatch -.hatch/ - -# Test / coverage -.pytest_cache/ -.coverage -.coverage.* -htmlcov/ -coverage.xml - -# Type checking / linting caches -.mypy_cache/ -.pyright/ -.ruff_cache/ - -# IDEs / editors -.idea/ -.vscode/ -*.iml - -# OS files -.DS_Store -Thumbs.db - -# SignatureDB files +# Python bytecode / caches +__pycache__/ +*.py[cod] +*.pyd + +# Virtual environments +.venv/ +venv/ +.env/ + +# Enviroments +*.env +*.pem + +# Packaging / build outputs +build/ +dist/ +*.egg-info/ +pip-wheel-metadata/ + +# Hatch +.hatch/ + +# Test / coverage +.pytest_cache/ +.coverage +.coverage.* +htmlcov/ +coverage.xml + +# Type checking / linting caches +.mypy_cache/ +.pyright/ +.ruff_cache/ + +# IDEs / editors +.idea/ +.vscode/ +*.iml + +# OS files +.DS_Store +Thumbs.db + +# SignatureDB files *.sigdb \ No newline at end of file diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py index 8449c3c..ce4f2e9 100644 --- a/scripts/reekeerBot.py +++ b/scripts/reekeerBot.py @@ -1,392 +1,392 @@ -import datetime -import json -import os -import subprocess -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Any, cast - -import jwt -import requests -from dotenv import load_dotenv - -load_dotenv() - -API = "https://api.github.com" -REPO_ROOT = Path(__file__).resolve().parent.parent - - -@dataclass(frozen=True) -class Config: - repo: str - base_branch: str - reviewer: str | None - app_id: int - installation_id: int - private_key_path: Path - - -def _require_env(name: str) -> str: - value = os.getenv(name) - if not value: - raise RuntimeError(f"Missing required env var: {name}") - return value - - -def run_cmd( - args: list[str], - *, - capture: bool = False, - check: bool = True, -) -> str | None: - result = subprocess.run(args, capture_output=capture, text=True) - if check and result.returncode != 0: - stderr = (result.stderr or "").strip() - raise RuntimeError(f"Command failed ({result.returncode}): {args}\n{stderr}") - return result.stdout.strip() if capture else None - - -def run_shell( - cmd: str, - *, - capture: bool = False, - check: bool = True, -) -> str | None: - result = subprocess.run(cmd, shell=True, capture_output=capture, text=True) - if check and result.returncode != 0: - stderr = (result.stderr or "").strip() - raise RuntimeError(f"Command failed ({result.returncode}): {cmd}\n{stderr}") - return result.stdout.strip() if capture else None - - -def git(args: list[str]) -> None: - run_cmd(["git", *args], check=True) - - -def load_json(path: Path) -> Any | None: - if not path.exists(): - return None - try: - with path.open(encoding="utf-8") as f: - return json.load(f) - except (OSError, ValueError): - return None - - -def create_jwt(*, app_id: int, private_key_path: Path) -> str: - private_key = private_key_path.read_text(encoding="utf-8") - payload = { - "iat": int(time.time()) - 60, - "exp": int(time.time()) + 600, - # PyJWT expects "iss" to be a string. - "iss": str(app_id), - } - return jwt.encode(payload, private_key, algorithm="RS256") # pyright: ignore[reportUnknownMemberType] - - -def installation_token(config: Config) -> str: - jwt_token = create_jwt(app_id=config.app_id, private_key_path=config.private_key_path) - headers = { - "Authorization": f"Bearer {jwt_token}", - "Accept": "application/vnd.github+json", - } - url = f"{API}/app/installations/{config.installation_id}/access_tokens" - r = requests.post(url, headers=headers, timeout=30) - r.raise_for_status() - token = r.json().get("token") - if not isinstance(token, str) or not token: - raise RuntimeError("GitHub installation token response is missing a token") - return token - - -def gh_headers(token: str) -> dict[str, str]: - return { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - } - - -def gh_request(token: str, method: str, url: str, *, data: dict[str, Any] | None = None) -> Any: - r = requests.request(method, url, headers=gh_headers(token), json=data, timeout=30) - try: - r.raise_for_status() - except requests.HTTPError as exc: - body = (r.text or "").strip() - raise RuntimeError(f"GitHub API error {r.status_code} for {method} {url}: {body}") from exc - if r.status_code == 204: - return None - return r.json() - - -def git_setup() -> None: - git(["config", "user.name", "reekeer[bot]"]) - git(["config", "user.email", "reekeer[bot]@users.noreply.github.com"]) - - -def create_branch() -> str: - ts = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d%H%M%S") - branch = f"bot/code-quality-{ts}" - git(["checkout", "-b", branch]) - return branch - - -def commit_if_changes(message: str) -> bool: - git(["add", "-A"]) - changed = subprocess.run(["git", "diff", "--cached", "--quiet"]) - if changed.returncode != 0: - git(["commit", "-m", message]) - return True - return False - - -def push(*, token: str, repo: str, branch: str) -> None: - git(["push", f"https://x-access-token:{token}@github.com/{repo}.git", branch]) - - -def ruff_fix() -> None: - run_shell("ruff check . --fix", check=False) - commit_if_changes("style(ruff): auto-fix lint issues") - run_shell("ruff check . --output-format=json > ruff.json || true", check=False) - - -def black_fix() -> None: - run_shell("black .", check=False) - commit_if_changes("style(black): format code") - - -def pyright_scan() -> None: - run_shell("pyright --outputjson > pyright.json || true", check=False) - - -def _as_dict_list(value: Any) -> list[dict[str, Any]]: - if not isinstance(value, list): - return [] - out: list[dict[str, Any]] = [] - for item in cast(list[Any], value): - if isinstance(item, dict): - out.append(cast(dict[str, Any], item)) - return out - - -def ruff_errors() -> list[dict[str, Any]]: - data = load_json(Path("ruff.json")) - return _as_dict_list(data) - - -def pyright_errors() -> list[dict[str, Any]]: - data = load_json(Path("pyright.json")) - if not isinstance(data, dict): - return [] - data_d = cast(dict[str, Any], data) - diags = data_d.get("generalDiagnostics") - return _as_dict_list(diags) - - -def create_pr(*, token: str, repo: str, branch: str, base_branch: str) -> dict[str, Any]: - url = f"{API}/repos/{repo}/pulls" - pr = gh_request( - token, - "POST", - url, - data={ - "title": "chore: automated code quality fixes", - "head": branch, - "base": base_branch, - "body": "Automated fixes from reekeerBot", - }, - ) - if not isinstance(pr, dict): - raise RuntimeError("Unexpected response from create PR") - return cast(dict[str, Any], pr) - - -def add_reviewer(*, token: str, repo: str, pr_number: int, reviewer: str | None) -> None: - if not reviewer: - return - url = f"{API}/repos/{repo}/pulls/{pr_number}/requested_reviewers" - gh_request(token, "POST", url, data={"reviewers": [reviewer]}) - - -def comment_pr(*, token: str, repo: str, pr_number: int, text: str) -> None: - url = f"{API}/repos/{repo}/issues/{pr_number}/comments" - gh_request(token, "POST", url, data={"body": text}) - - -def pr_head_sha(*, token: str, repo: str, pr_number: int) -> str: - url = f"{API}/repos/{repo}/pulls/{pr_number}" - pr = gh_request(token, "GET", url) - if not isinstance(pr, dict): - raise RuntimeError("Unexpected response from PR details") - pr_d = cast(dict[str, Any], pr) - head = pr_d.get("head") - if not isinstance(head, dict): - raise RuntimeError("Unexpected response from PR details (missing head)") - head_d = cast(dict[str, Any], head) - sha = head_d.get("sha") - if not isinstance(sha, str) or not sha: - raise RuntimeError("Could not read PR head SHA") - return sha - - -def pr_changed_files(*, token: str, repo: str, pr_number: int) -> set[str]: - url = f"{API}/repos/{repo}/pulls/{pr_number}/files?per_page=100" - files: set[str] = set() - - page = gh_request(token, "GET", url) - if not isinstance(page, list): - return files - for item in cast(list[Any], page): - if not isinstance(item, dict): - continue - filename = cast(dict[str, Any], item).get("filename") - if isinstance(filename, str) and filename: - files.add(filename) - return files - - -def _diagnostic_repo_path(file_value: Any) -> str | None: - if not isinstance(file_value, str) or not file_value: - return None - p = Path(file_value) - if not p.is_absolute(): - p = (REPO_ROOT / p).resolve() - try: - rel = p.relative_to(REPO_ROOT) - except ValueError: - return None - return rel.as_posix() - - -def review_comments(*, token: str, repo: str, pr_number: int, errors: list[dict[str, Any]]) -> None: - url = f"{API}/repos/{repo}/pulls/{pr_number}/comments" - commit_sha = pr_head_sha(token=token, repo=repo, pr_number=pr_number) - changed_files = pr_changed_files(token=token, repo=repo, pr_number=pr_number) - - posted = 0 - for e in errors: - if posted >= 30: - break - message = e.get("message") - repo_path = _diagnostic_repo_path(e.get("file")) - start_line = e.get("range", {}).get("start", {}).get("line") - if not isinstance(message, str) or not message: - continue - if repo_path is None: - continue - if changed_files and repo_path not in changed_files: - continue - if not isinstance(start_line, int) or start_line < 0: - continue - - try: - gh_request( - token, - "POST", - url, - data={ - "body": f"Pyright error:\n{message}", - "commit_id": commit_sha, - "path": repo_path, - "line": start_line + 1, - "side": "RIGHT", - }, - ) - except RuntimeError: - continue - else: - posted += 1 - - -def auto_merge(*, token: str, repo: str, pr_number: int) -> None: - url = f"{API}/repos/{repo}/pulls/{pr_number}/merge" - gh_request(token, "PUT", url) - - -def summarize(ruff: list[dict[str, Any]], pyright: list[dict[str, Any]]) -> str: - msg = "# reekeer[bot] Report\n\n" - - if ruff: - msg += "## Ruff issues\n" - for r in ruff[:15]: - filename = r.get("filename") - loc = r.get("location") - row = cast(dict[str, Any], loc).get("row") if isinstance(loc, dict) else None - message = r.get("message") - if isinstance(filename, str) and isinstance(row, int) and isinstance(message, str): - msg += f"- {filename}:{row} {message}\n" - - if pyright: - msg += "\n## Pyright issues\n" - for p in pyright[:15]: - file_value = p.get("file") - message = p.get("message") - if isinstance(file_value, str) and isinstance(message, str): - msg += f"- {file_value} {message}\n" - - if not ruff and not pyright: - msg += "✅ No issues detected" - - return msg - - -def load_config() -> Config: - repo = _require_env("GITHUB_REPOSITORY") - base_branch = os.getenv("BASE_BRANCH", "dev") - reviewer = os.getenv("REVIEWER") or None - app_id = int(_require_env("APP_ID")) - installation_id = int(_require_env("INSTALLATION_ID")) - - private_key_env = os.getenv("PRIVATE_KEY_PATH") - private_key_path = ( - Path(private_key_env).expanduser() if private_key_env else (REPO_ROOT / "private-key.pem") - ) - if not private_key_path.is_absolute(): - private_key_path = (REPO_ROOT / private_key_path).resolve() - if not private_key_path.is_file(): - raise RuntimeError(f"Private key not found: {private_key_path}") - - return Config( - repo=repo, - base_branch=base_branch, - reviewer=reviewer, - app_id=app_id, - installation_id=installation_id, - private_key_path=private_key_path, - ) - - -def main() -> None: - config = load_config() - token = installation_token(config) - - git_setup() - branch = create_branch() - - ruff_fix() - black_fix() - pyright_scan() - - push(token=token, repo=config.repo, branch=branch) - - pr = create_pr(token=token, repo=config.repo, branch=branch, base_branch=config.base_branch) - pr_number = pr.get("number") - if not isinstance(pr_number, int): - raise RuntimeError("Unexpected PR response (missing number)") - - add_reviewer(token=token, repo=config.repo, pr_number=pr_number, reviewer=config.reviewer) - - ruff = ruff_errors() - pyright = pyright_errors() - - comment_pr(token=token, repo=config.repo, pr_number=pr_number, text=summarize(ruff, pyright)) - - if pyright: - review_comments(token=token, repo=config.repo, pr_number=pr_number, errors=pyright) - else: - auto_merge(token=token, repo=config.repo, pr_number=pr_number) - - -if __name__ == "__main__": - main() +import datetime +import json +import os +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, cast + +import jwt +import requests +from dotenv import load_dotenv + +load_dotenv() + +API = "https://api.github.com" +REPO_ROOT = Path(__file__).resolve().parent.parent + + +@dataclass(frozen=True) +class Config: + repo: str + base_branch: str + reviewer: str | None + app_id: int + installation_id: int + private_key_path: Path + + +def _require_env(name: str) -> str: + value = os.getenv(name) + if not value: + raise RuntimeError(f"Missing required env var: {name}") + return value + + +def run_cmd( + args: list[str], + *, + capture: bool = False, + check: bool = True, +) -> str | None: + result = subprocess.run(args, capture_output=capture, text=True) + if check and result.returncode != 0: + stderr = (result.stderr or "").strip() + raise RuntimeError(f"Command failed ({result.returncode}): {args}\n{stderr}") + return result.stdout.strip() if capture else None + + +def run_shell( + cmd: str, + *, + capture: bool = False, + check: bool = True, +) -> str | None: + result = subprocess.run(cmd, shell=True, capture_output=capture, text=True) + if check and result.returncode != 0: + stderr = (result.stderr or "").strip() + raise RuntimeError(f"Command failed ({result.returncode}): {cmd}\n{stderr}") + return result.stdout.strip() if capture else None + + +def git(args: list[str]) -> None: + run_cmd(["git", *args], check=True) + + +def load_json(path: Path) -> Any | None: + if not path.exists(): + return None + try: + with path.open(encoding="utf-8") as f: + return json.load(f) + except (OSError, ValueError): + return None + + +def create_jwt(*, app_id: int, private_key_path: Path) -> str: + private_key = private_key_path.read_text(encoding="utf-8") + payload = { + "iat": int(time.time()) - 60, + "exp": int(time.time()) + 600, + # PyJWT expects "iss" to be a string. + "iss": str(app_id), + } + return jwt.encode(payload, private_key, algorithm="RS256") # pyright: ignore[reportUnknownMemberType] + + +def installation_token(config: Config) -> str: + jwt_token = create_jwt(app_id=config.app_id, private_key_path=config.private_key_path) + headers = { + "Authorization": f"Bearer {jwt_token}", + "Accept": "application/vnd.github+json", + } + url = f"{API}/app/installations/{config.installation_id}/access_tokens" + r = requests.post(url, headers=headers, timeout=30) + r.raise_for_status() + token = r.json().get("token") + if not isinstance(token, str) or not token: + raise RuntimeError("GitHub installation token response is missing a token") + return token + + +def gh_headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + +def gh_request(token: str, method: str, url: str, *, data: dict[str, Any] | None = None) -> Any: + r = requests.request(method, url, headers=gh_headers(token), json=data, timeout=30) + try: + r.raise_for_status() + except requests.HTTPError as exc: + body = (r.text or "").strip() + raise RuntimeError(f"GitHub API error {r.status_code} for {method} {url}: {body}") from exc + if r.status_code == 204: + return None + return r.json() + + +def git_setup() -> None: + git(["config", "user.name", "reekeer[bot]"]) + git(["config", "user.email", "reekeer[bot]@users.noreply.github.com"]) + + +def create_branch() -> str: + ts = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d%H%M%S") + branch = f"bot/code-quality-{ts}" + git(["checkout", "-b", branch]) + return branch + + +def commit_if_changes(message: str) -> bool: + git(["add", "-A"]) + changed = subprocess.run(["git", "diff", "--cached", "--quiet"]) + if changed.returncode != 0: + git(["commit", "-m", message]) + return True + return False + + +def push(*, token: str, repo: str, branch: str) -> None: + git(["push", f"https://x-access-token:{token}@github.com/{repo}.git", branch]) + + +def ruff_fix() -> None: + run_shell("ruff check . --fix", check=False) + commit_if_changes("style(ruff): auto-fix lint issues") + run_shell("ruff check . --output-format=json > ruff.json || true", check=False) + + +def black_fix() -> None: + run_shell("black .", check=False) + commit_if_changes("style(black): format code") + + +def pyright_scan() -> None: + run_shell("pyright --outputjson > pyright.json || true", check=False) + + +def _as_dict_list(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + out: list[dict[str, Any]] = [] + for item in cast(list[Any], value): + if isinstance(item, dict): + out.append(cast(dict[str, Any], item)) + return out + + +def ruff_errors() -> list[dict[str, Any]]: + data = load_json(Path("ruff.json")) + return _as_dict_list(data) + + +def pyright_errors() -> list[dict[str, Any]]: + data = load_json(Path("pyright.json")) + if not isinstance(data, dict): + return [] + data_d = cast(dict[str, Any], data) + diags = data_d.get("generalDiagnostics") + return _as_dict_list(diags) + + +def create_pr(*, token: str, repo: str, branch: str, base_branch: str) -> dict[str, Any]: + url = f"{API}/repos/{repo}/pulls" + pr = gh_request( + token, + "POST", + url, + data={ + "title": "chore: automated code quality fixes", + "head": branch, + "base": base_branch, + "body": "Automated fixes from reekeerBot", + }, + ) + if not isinstance(pr, dict): + raise RuntimeError("Unexpected response from create PR") + return cast(dict[str, Any], pr) + + +def add_reviewer(*, token: str, repo: str, pr_number: int, reviewer: str | None) -> None: + if not reviewer: + return + url = f"{API}/repos/{repo}/pulls/{pr_number}/requested_reviewers" + gh_request(token, "POST", url, data={"reviewers": [reviewer]}) + + +def comment_pr(*, token: str, repo: str, pr_number: int, text: str) -> None: + url = f"{API}/repos/{repo}/issues/{pr_number}/comments" + gh_request(token, "POST", url, data={"body": text}) + + +def pr_head_sha(*, token: str, repo: str, pr_number: int) -> str: + url = f"{API}/repos/{repo}/pulls/{pr_number}" + pr = gh_request(token, "GET", url) + if not isinstance(pr, dict): + raise RuntimeError("Unexpected response from PR details") + pr_d = cast(dict[str, Any], pr) + head = pr_d.get("head") + if not isinstance(head, dict): + raise RuntimeError("Unexpected response from PR details (missing head)") + head_d = cast(dict[str, Any], head) + sha = head_d.get("sha") + if not isinstance(sha, str) or not sha: + raise RuntimeError("Could not read PR head SHA") + return sha + + +def pr_changed_files(*, token: str, repo: str, pr_number: int) -> set[str]: + url = f"{API}/repos/{repo}/pulls/{pr_number}/files?per_page=100" + files: set[str] = set() + + page = gh_request(token, "GET", url) + if not isinstance(page, list): + return files + for item in cast(list[Any], page): + if not isinstance(item, dict): + continue + filename = cast(dict[str, Any], item).get("filename") + if isinstance(filename, str) and filename: + files.add(filename) + return files + + +def _diagnostic_repo_path(file_value: Any) -> str | None: + if not isinstance(file_value, str) or not file_value: + return None + p = Path(file_value) + if not p.is_absolute(): + p = (REPO_ROOT / p).resolve() + try: + rel = p.relative_to(REPO_ROOT) + except ValueError: + return None + return rel.as_posix() + + +def review_comments(*, token: str, repo: str, pr_number: int, errors: list[dict[str, Any]]) -> None: + url = f"{API}/repos/{repo}/pulls/{pr_number}/comments" + commit_sha = pr_head_sha(token=token, repo=repo, pr_number=pr_number) + changed_files = pr_changed_files(token=token, repo=repo, pr_number=pr_number) + + posted = 0 + for e in errors: + if posted >= 30: + break + message = e.get("message") + repo_path = _diagnostic_repo_path(e.get("file")) + start_line = e.get("range", {}).get("start", {}).get("line") + if not isinstance(message, str) or not message: + continue + if repo_path is None: + continue + if changed_files and repo_path not in changed_files: + continue + if not isinstance(start_line, int) or start_line < 0: + continue + + try: + gh_request( + token, + "POST", + url, + data={ + "body": f"Pyright error:\n{message}", + "commit_id": commit_sha, + "path": repo_path, + "line": start_line + 1, + "side": "RIGHT", + }, + ) + except RuntimeError: + continue + else: + posted += 1 + + +def auto_merge(*, token: str, repo: str, pr_number: int) -> None: + url = f"{API}/repos/{repo}/pulls/{pr_number}/merge" + gh_request(token, "PUT", url) + + +def summarize(ruff: list[dict[str, Any]], pyright: list[dict[str, Any]]) -> str: + msg = "# reekeer[bot] Report\n\n" + + if ruff: + msg += "## Ruff issues\n" + for r in ruff[:15]: + filename = r.get("filename") + loc = r.get("location") + row = cast(dict[str, Any], loc).get("row") if isinstance(loc, dict) else None + message = r.get("message") + if isinstance(filename, str) and isinstance(row, int) and isinstance(message, str): + msg += f"- {filename}:{row} {message}\n" + + if pyright: + msg += "\n## Pyright issues\n" + for p in pyright[:15]: + file_value = p.get("file") + message = p.get("message") + if isinstance(file_value, str) and isinstance(message, str): + msg += f"- {file_value} {message}\n" + + if not ruff and not pyright: + msg += "✅ No issues detected" + + return msg + + +def load_config() -> Config: + repo = _require_env("GITHUB_REPOSITORY") + base_branch = os.getenv("BASE_BRANCH", "dev") + reviewer = os.getenv("REVIEWER") or None + app_id = int(_require_env("APP_ID")) + installation_id = int(_require_env("INSTALLATION_ID")) + + private_key_env = os.getenv("PRIVATE_KEY_PATH") + private_key_path = ( + Path(private_key_env).expanduser() if private_key_env else (REPO_ROOT / "private-key.pem") + ) + if not private_key_path.is_absolute(): + private_key_path = (REPO_ROOT / private_key_path).resolve() + if not private_key_path.is_file(): + raise RuntimeError(f"Private key not found: {private_key_path}") + + return Config( + repo=repo, + base_branch=base_branch, + reviewer=reviewer, + app_id=app_id, + installation_id=installation_id, + private_key_path=private_key_path, + ) + + +def main() -> None: + config = load_config() + token = installation_token(config) + + git_setup() + branch = create_branch() + + ruff_fix() + black_fix() + pyright_scan() + + push(token=token, repo=config.repo, branch=branch) + + pr = create_pr(token=token, repo=config.repo, branch=branch, base_branch=config.base_branch) + pr_number = pr.get("number") + if not isinstance(pr_number, int): + raise RuntimeError("Unexpected PR response (missing number)") + + add_reviewer(token=token, repo=config.repo, pr_number=pr_number, reviewer=config.reviewer) + + ruff = ruff_errors() + pyright = pyright_errors() + + comment_pr(token=token, repo=config.repo, pr_number=pr_number, text=summarize(ruff, pyright)) + + if pyright: + review_comments(token=token, repo=config.repo, pr_number=pr_number, errors=pyright) + else: + auto_merge(token=token, repo=config.repo, pr_number=pr_number) + + +if __name__ == "__main__": + main() diff --git a/src/sigdb/core/reader.py b/src/sigdb/core/reader.py index 3afbb4d..f86464e 100644 --- a/src/sigdb/core/reader.py +++ b/src/sigdb/core/reader.py @@ -1,360 +1,360 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Any, overload - -from sigdb.format.trie import load_sigdb, read_sigdb_metadata, validate_sigdb -from sigdb.internal.groups import ( - SIGDB_GROUPS, - SIGDB_GROUPS_MAP, - format_list_pattern, - format_map_pattern, - html_heads, - parse_group_list, - parse_string_map, -) -from sigdb.types import ( - SigDBDatabase, - SigDBFormatError, - SigDBGroupName, - SigDBMatchResult, - SigDBSearchDefinition, - SigDBValidationResult, -) - - -def _normalize_head(head: str) -> str: - s = head.strip() - i = s.find(":") - if i == -1: - return s.lower() - name = s[:i].strip().lower() - value = s[i + 1 :].strip().lower() - return f"{name}:{value}" - - -def _iter_search_heads(search: SigDBSearchDefinition) -> list[str]: - heads: list[str] = [] - headers = parse_string_map(search.get("headers"), "headers") - for header_name, header_value in headers.items(): - heads.append(format_map_pattern("headers", header_name, header_value)) - - for group in SIGDB_GROUPS: - if group == "headers": - continue - if group in SIGDB_GROUPS_MAP: - group_map = parse_string_map(search.get(group), group) - for name, value in group_map.items(): - heads.append(format_map_pattern(group, name, value)) - else: - group_values = parse_group_list(search.get(group), group) - for value in group_values: - heads.append(format_list_pattern(group, value)) - return heads - - -class SigDBMatcher: - __slots__ = ("_automaton", "_items") - - def __init__(self, db: SigDBDatabase) -> None: - self._automaton = db.automaton - self._items = db.items - - def match(self, head: str) -> SigDBMatchResult: - normalized = _normalize_head(head) - data = normalized.encode("utf-8") - - a = self._automaton - labels = a.labels - children_start = a.children_start - children_count = a.children_count - next_state = a.next_state - fail = a.fail - out_start = a.out_start - out_count = a.out_count - outputs = a.outputs - - state = 0 - for b in data: - while True: - start = children_start[state] - count = children_count[state] - nxt = -1 - if count: - lo = 0 - hi = count - while lo < hi: - mid = (lo + hi) >> 1 - lb = labels[start + mid] - if lb < b: - lo = mid + 1 - elif lb > b: - hi = mid - else: - nxt = next_state[start + mid] - break - - if nxt != -1: - state = nxt - break - if state == 0: - break - state = fail[state] - - oc = out_count[state] - if oc: - ostart = out_start[state] - item_id = outputs[ostart] - item = self._items[item_id] - return SigDBMatchResult( - result=True, item_id=item_id, item=item, head=normalized - ) - - return SigDBMatchResult(result=False, item_id=None, item=None, head=normalized) - - def match_group( - self, - group: SigDBGroupName, - value: str, - *, - name: str | None = None, - ) -> SigDBMatchResult: - if group not in SIGDB_GROUPS: - raise SigDBFormatError(f"unknown group: {group}") - if name is None: - if group in SIGDB_GROUPS_MAP: - raise SigDBFormatError(f"group {group} requires a name") - head = format_list_pattern(group, value) - else: - if group not in SIGDB_GROUPS_MAP: - raise SigDBFormatError(f"group {group} does not accept a name") - head = format_map_pattern(group, name, value) - return self.match(head) - - def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: - for head in _iter_search_heads(search): - result = self.match(head) - if result.result: - return result - return SigDBMatchResult(result=False, item_id=None, item=None, head="") - - def match_html(self, html: str) -> SigDBMatchResult: - for head in html_heads(html): - result = self.match(head) - if result.result: - return result - return SigDBMatchResult(result=False, item_id=None, item=None, head="") - - -@overload -def match(head: str, src: SigDBMatcher) -> SigDBMatchResult: ... - - -@overload -def match(head: str, src: SigDBDatabase) -> SigDBMatchResult: ... - - -@overload -def match(head: str, src: SigDBReader) -> SigDBMatchResult: ... - - -def match(head: str, src: object) -> SigDBMatchResult: - if isinstance(src, SigDBMatcher): - return src.match(head) - if isinstance(src, SigDBDatabase): - return SigDBMatcher(src).match(head) - if isinstance(src, SigDBReader): - return src.matcher().match(head) - raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") - - -@overload -def match_group( - group: SigDBGroupName, - value: str, - src: SigDBMatcher, - *, - name: str | None = None, -) -> SigDBMatchResult: ... - - -@overload -def match_group( - group: SigDBGroupName, - value: str, - src: SigDBDatabase, - *, - name: str | None = None, -) -> SigDBMatchResult: ... - - -@overload -def match_group( - group: SigDBGroupName, - value: str, - src: SigDBReader, - *, - name: str | None = None, -) -> SigDBMatchResult: ... - - -def match_group( - group: SigDBGroupName, - value: str, - src: object, - *, - name: str | None = None, -) -> SigDBMatchResult: - if isinstance(src, SigDBMatcher): - return src.match_group(group, value, name=name) - if isinstance(src, SigDBDatabase): - return SigDBMatcher(src).match_group(group, value, name=name) - if isinstance(src, SigDBReader): - return src.matcher().match_group(group, value, name=name) - raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") - - -@overload -def match_search( - search: SigDBSearchDefinition, - src: SigDBMatcher, -) -> SigDBMatchResult: ... - - -@overload -def match_search( - search: SigDBSearchDefinition, - src: SigDBDatabase, -) -> SigDBMatchResult: ... - - -@overload -def match_search( - search: SigDBSearchDefinition, - src: SigDBReader, -) -> SigDBMatchResult: ... - - -def match_search(search: SigDBSearchDefinition, src: object) -> SigDBMatchResult: - if isinstance(src, SigDBMatcher): - return src.match_search(search) - if isinstance(src, SigDBDatabase): - return SigDBMatcher(src).match_search(search) - if isinstance(src, SigDBReader): - return src.matcher().match_search(search) - raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") - - -@overload -def match_html(html: str, src: SigDBMatcher) -> SigDBMatchResult: ... - - -@overload -def match_html(html: str, src: SigDBDatabase) -> SigDBMatchResult: ... - - -@overload -def match_html(html: str, src: SigDBReader) -> SigDBMatchResult: ... - - -def match_html(html: str, src: object) -> SigDBMatchResult: - if isinstance(src, SigDBMatcher): - return src.match_html(html) - if isinstance(src, SigDBDatabase): - return SigDBMatcher(src).match_html(html) - if isinstance(src, SigDBReader): - return src.matcher().match_html(html) - raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") - - -class SigDBReader: - def __init__(self, path: str | Path) -> None: - self._path = Path(path) - self._cache_db: SigDBDatabase | None = None - self._cache_params: tuple[str | None, bool, bool] | None = None - - @property - def path(self) -> Path: - return self._path - - def metadata(self) -> dict[str, Any]: - return read_sigdb_metadata(self._path) - - def validate( - self, - *, - public_key_hex: str | None = None, - verify_hash: bool = True, - verify_signature: bool = True, - ) -> SigDBValidationResult: - return validate_sigdb( - self._path, - public_key_hex=public_key_hex, - verify_hash=verify_hash, - verify_signature=verify_signature, - ) - - def load( - self, - *, - public_key_hex: str | None = None, - verify_hash: bool = True, - verify_signature: bool = True, - ) -> SigDBDatabase: - return load_sigdb( - self._path, - public_key_hex=public_key_hex, - verify_hash=verify_hash, - verify_signature=verify_signature, - ) - - def load_cached( - self, - *, - public_key_hex: str | None = None, - verify_hash: bool = True, - verify_signature: bool = True, - ) -> SigDBDatabase: - params = (public_key_hex, verify_hash, verify_signature) - if self._cache_db is not None and self._cache_params == params: - return self._cache_db - self._cache_db = self.load( - public_key_hex=public_key_hex, - verify_hash=verify_hash, - verify_signature=verify_signature, - ) - self._cache_params = params - return self._cache_db - - def matcher( - self, - *, - public_key_hex: str | None = None, - verify_hash: bool = True, - verify_signature: bool = True, - ) -> SigDBMatcher: - return SigDBMatcher( - self.load_cached( - public_key_hex=public_key_hex, - verify_hash=verify_hash, - verify_signature=verify_signature, - ) - ) - - def match(self, head: str) -> SigDBMatchResult: - return self.matcher().match(head) - - def match_group( - self, - group: SigDBGroupName, - value: str, - *, - name: str | None = None, - ) -> SigDBMatchResult: - return self.matcher().match_group(group, value, name=name) - - def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: - return self.matcher().match_search(search) - - def match_html(self, html: str) -> SigDBMatchResult: - return self.matcher().match_html(html) +from __future__ import annotations + +from pathlib import Path +from typing import Any, overload + +from sigdb.format.trie import load_sigdb, read_sigdb_metadata, validate_sigdb +from sigdb.internal.groups import ( + SIGDB_GROUPS, + SIGDB_GROUPS_MAP, + format_list_pattern, + format_map_pattern, + html_heads, + parse_group_list, + parse_string_map, +) +from sigdb.types import ( + SigDBDatabase, + SigDBFormatError, + SigDBGroupName, + SigDBMatchResult, + SigDBSearchDefinition, + SigDBValidationResult, +) + + +def _normalize_head(head: str) -> str: + s = head.strip() + i = s.find(":") + if i == -1: + return s.lower() + name = s[:i].strip().lower() + value = s[i + 1 :].strip().lower() + return f"{name}:{value}" + + +def _iter_search_heads(search: SigDBSearchDefinition) -> list[str]: + heads: list[str] = [] + headers = parse_string_map(search.get("headers"), "headers") + for header_name, header_value in headers.items(): + heads.append(format_map_pattern("headers", header_name, header_value)) + + for group in SIGDB_GROUPS: + if group == "headers": + continue + if group in SIGDB_GROUPS_MAP: + group_map = parse_string_map(search.get(group), group) + for name, value in group_map.items(): + heads.append(format_map_pattern(group, name, value)) + else: + group_values = parse_group_list(search.get(group), group) + for value in group_values: + heads.append(format_list_pattern(group, value)) + return heads + + +class SigDBMatcher: + __slots__ = ("_automaton", "_items") + + def __init__(self, db: SigDBDatabase) -> None: + self._automaton = db.automaton + self._items = db.items + + def match(self, head: str) -> SigDBMatchResult: + normalized = _normalize_head(head) + data = normalized.encode("utf-8") + + a = self._automaton + labels = a.labels + children_start = a.children_start + children_count = a.children_count + next_state = a.next_state + fail = a.fail + out_start = a.out_start + out_count = a.out_count + outputs = a.outputs + + state = 0 + for b in data: + while True: + start = children_start[state] + count = children_count[state] + nxt = -1 + if count: + lo = 0 + hi = count + while lo < hi: + mid = (lo + hi) >> 1 + lb = labels[start + mid] + if lb < b: + lo = mid + 1 + elif lb > b: + hi = mid + else: + nxt = next_state[start + mid] + break + + if nxt != -1: + state = nxt + break + if state == 0: + break + state = fail[state] + + oc = out_count[state] + if oc: + ostart = out_start[state] + item_id = outputs[ostart] + item = self._items[item_id] + return SigDBMatchResult( + result=True, item_id=item_id, item=item, head=normalized + ) + + return SigDBMatchResult(result=False, item_id=None, item=None, head=normalized) + + def match_group( + self, + group: SigDBGroupName, + value: str, + *, + name: str | None = None, + ) -> SigDBMatchResult: + if group not in SIGDB_GROUPS: + raise SigDBFormatError(f"unknown group: {group}") + if name is None: + if group in SIGDB_GROUPS_MAP: + raise SigDBFormatError(f"group {group} requires a name") + head = format_list_pattern(group, value) + else: + if group not in SIGDB_GROUPS_MAP: + raise SigDBFormatError(f"group {group} does not accept a name") + head = format_map_pattern(group, name, value) + return self.match(head) + + def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: + for head in _iter_search_heads(search): + result = self.match(head) + if result.result: + return result + return SigDBMatchResult(result=False, item_id=None, item=None, head="") + + def match_html(self, html: str) -> SigDBMatchResult: + for head in html_heads(html): + result = self.match(head) + if result.result: + return result + return SigDBMatchResult(result=False, item_id=None, item=None, head="") + + +@overload +def match(head: str, src: SigDBMatcher) -> SigDBMatchResult: ... + + +@overload +def match(head: str, src: SigDBDatabase) -> SigDBMatchResult: ... + + +@overload +def match(head: str, src: SigDBReader) -> SigDBMatchResult: ... + + +def match(head: str, src: object) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match(head) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match(head) + if isinstance(src, SigDBReader): + return src.matcher().match(head) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + +@overload +def match_group( + group: SigDBGroupName, + value: str, + src: SigDBMatcher, + *, + name: str | None = None, +) -> SigDBMatchResult: ... + + +@overload +def match_group( + group: SigDBGroupName, + value: str, + src: SigDBDatabase, + *, + name: str | None = None, +) -> SigDBMatchResult: ... + + +@overload +def match_group( + group: SigDBGroupName, + value: str, + src: SigDBReader, + *, + name: str | None = None, +) -> SigDBMatchResult: ... + + +def match_group( + group: SigDBGroupName, + value: str, + src: object, + *, + name: str | None = None, +) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match_group(group, value, name=name) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match_group(group, value, name=name) + if isinstance(src, SigDBReader): + return src.matcher().match_group(group, value, name=name) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + +@overload +def match_search( + search: SigDBSearchDefinition, + src: SigDBMatcher, +) -> SigDBMatchResult: ... + + +@overload +def match_search( + search: SigDBSearchDefinition, + src: SigDBDatabase, +) -> SigDBMatchResult: ... + + +@overload +def match_search( + search: SigDBSearchDefinition, + src: SigDBReader, +) -> SigDBMatchResult: ... + + +def match_search(search: SigDBSearchDefinition, src: object) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match_search(search) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match_search(search) + if isinstance(src, SigDBReader): + return src.matcher().match_search(search) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + +@overload +def match_html(html: str, src: SigDBMatcher) -> SigDBMatchResult: ... + + +@overload +def match_html(html: str, src: SigDBDatabase) -> SigDBMatchResult: ... + + +@overload +def match_html(html: str, src: SigDBReader) -> SigDBMatchResult: ... + + +def match_html(html: str, src: object) -> SigDBMatchResult: + if isinstance(src, SigDBMatcher): + return src.match_html(html) + if isinstance(src, SigDBDatabase): + return SigDBMatcher(src).match_html(html) + if isinstance(src, SigDBReader): + return src.matcher().match_html(html) + raise TypeError("src must be SigDBReader, SigDBDatabase, or SigDBMatcher") + + +class SigDBReader: + def __init__(self, path: str | Path) -> None: + self._path = Path(path) + self._cache_db: SigDBDatabase | None = None + self._cache_params: tuple[str | None, bool, bool] | None = None + + @property + def path(self) -> Path: + return self._path + + def metadata(self) -> dict[str, Any]: + return read_sigdb_metadata(self._path) + + def validate( + self, + *, + public_key_hex: str | None = None, + verify_hash: bool = True, + verify_signature: bool = True, + ) -> SigDBValidationResult: + return validate_sigdb( + self._path, + public_key_hex=public_key_hex, + verify_hash=verify_hash, + verify_signature=verify_signature, + ) + + def load( + self, + *, + public_key_hex: str | None = None, + verify_hash: bool = True, + verify_signature: bool = True, + ) -> SigDBDatabase: + return load_sigdb( + self._path, + public_key_hex=public_key_hex, + verify_hash=verify_hash, + verify_signature=verify_signature, + ) + + def load_cached( + self, + *, + public_key_hex: str | None = None, + verify_hash: bool = True, + verify_signature: bool = True, + ) -> SigDBDatabase: + params = (public_key_hex, verify_hash, verify_signature) + if self._cache_db is not None and self._cache_params == params: + return self._cache_db + self._cache_db = self.load( + public_key_hex=public_key_hex, + verify_hash=verify_hash, + verify_signature=verify_signature, + ) + self._cache_params = params + return self._cache_db + + def matcher( + self, + *, + public_key_hex: str | None = None, + verify_hash: bool = True, + verify_signature: bool = True, + ) -> SigDBMatcher: + return SigDBMatcher( + self.load_cached( + public_key_hex=public_key_hex, + verify_hash=verify_hash, + verify_signature=verify_signature, + ) + ) + + def match(self, head: str) -> SigDBMatchResult: + return self.matcher().match(head) + + def match_group( + self, + group: SigDBGroupName, + value: str, + *, + name: str | None = None, + ) -> SigDBMatchResult: + return self.matcher().match_group(group, value, name=name) + + def match_search(self, search: SigDBSearchDefinition) -> SigDBMatchResult: + return self.matcher().match_search(search) + + def match_html(self, html: str) -> SigDBMatchResult: + return self.matcher().match_html(html) diff --git a/src/sigdb/crypto/ed25519.py b/src/sigdb/crypto/ed25519.py index 52d3ddf..3bdae08 100644 --- a/src/sigdb/crypto/ed25519.py +++ b/src/sigdb/crypto/ed25519.py @@ -1,51 +1,51 @@ -from __future__ import annotations - -from sigdb.types import SigDBError, SigDBSignatureError - - -def _import_nacl(): - try: - from nacl.exceptions import BadSignatureError # type: ignore[import-not-found] - from nacl.signing import SigningKey, VerifyKey # type: ignore[import-not-found] - except ModuleNotFoundError as e: # pragma: no cover - raise SigDBError("missing dependency: pynacl") from e - return SigningKey, VerifyKey, BadSignatureError - - -def generate_signing_key_hex() -> str: - SigningKey, _VerifyKey, _BadSignatureError = _import_nacl() - return SigningKey.generate().encode().hex() - - -def _parse_signing_key_hex(signing_key_hex: str): - SigningKey, _VerifyKey, _BadSignatureError = _import_nacl() - raw = bytes.fromhex(signing_key_hex) - if len(raw) == 32: - return SigningKey(raw) - if len(raw) == 64: - return SigningKey(raw[:32]) - raise ValueError("signing key must be 32-byte seed or 64-byte private key") - - -def derive_public_key_hex(signing_key_hex: str) -> str: - return _parse_signing_key_hex(signing_key_hex).verify_key.encode().hex() - - -def sign_hash(data_hash: bytes, *, signing_key_hex: str) -> bytes: - key = _parse_signing_key_hex(signing_key_hex) - return key.sign(data_hash).signature - - -def verify_hash_signature( - data_hash: bytes, signature: bytes, *, public_key_hex: str -) -> None: - _SigningKey, VerifyKey, BadSignatureError = _import_nacl() - try: - verify = VerifyKey(bytes.fromhex(public_key_hex)) - except ValueError as e: - raise SigDBSignatureError("invalid public key hex") from e - - try: - verify.verify(data_hash, signature) - except BadSignatureError as e: - raise SigDBSignatureError("invalid signature") from e +from __future__ import annotations + +from sigdb.types import SigDBError, SigDBSignatureError + + +def _import_nacl(): + try: + from nacl.exceptions import BadSignatureError # type: ignore[import-not-found] + from nacl.signing import SigningKey, VerifyKey # type: ignore[import-not-found] + except ModuleNotFoundError as e: # pragma: no cover + raise SigDBError("missing dependency: pynacl") from e + return SigningKey, VerifyKey, BadSignatureError + + +def generate_signing_key_hex() -> str: + SigningKey, _VerifyKey, _BadSignatureError = _import_nacl() + return SigningKey.generate().encode().hex() + + +def _parse_signing_key_hex(signing_key_hex: str): + SigningKey, _VerifyKey, _BadSignatureError = _import_nacl() + raw = bytes.fromhex(signing_key_hex) + if len(raw) == 32: + return SigningKey(raw) + if len(raw) == 64: + return SigningKey(raw[:32]) + raise ValueError("signing key must be 32-byte seed or 64-byte private key") + + +def derive_public_key_hex(signing_key_hex: str) -> str: + return _parse_signing_key_hex(signing_key_hex).verify_key.encode().hex() + + +def sign_hash(data_hash: bytes, *, signing_key_hex: str) -> bytes: + key = _parse_signing_key_hex(signing_key_hex) + return key.sign(data_hash).signature + + +def verify_hash_signature( + data_hash: bytes, signature: bytes, *, public_key_hex: str +) -> None: + _SigningKey, VerifyKey, BadSignatureError = _import_nacl() + try: + verify = VerifyKey(bytes.fromhex(public_key_hex)) + except ValueError as e: + raise SigDBSignatureError("invalid public key hex") from e + + try: + verify.verify(data_hash, signature) + except BadSignatureError as e: + raise SigDBSignatureError("invalid signature") from e diff --git a/src/sigdb/format/trie.py b/src/sigdb/format/trie.py index f1c1228..5be4001 100644 --- a/src/sigdb/format/trie.py +++ b/src/sigdb/format/trie.py @@ -1,534 +1,534 @@ -from __future__ import annotations - -import json -import struct -import time -from collections import deque -from collections.abc import Mapping -from datetime import date -from pathlib import Path -from typing import Any, cast - -from sigdb.compression import compress_zstd, decompress_zstd -from sigdb.crypto import ( - derive_public_key_hex, - generate_signing_key_hex, - sign_hash, - verify_hash_signature, -) -from sigdb.internal.groups import ( - SIGDB_GROUPS, - SIGDB_GROUPS_MAP, - format_list_pattern, - format_map_pattern, - parse_group_list, - parse_string_map, -) -from sigdb.storage import read_exact -from sigdb.types import ( - Automaton, - SigDBBuildResult, - SigDBDatabase, - SigDBFormatError, - SigDBIntegrityError, - SigDBItem, - SigDBRules, - SigDBSignatureError, - SigDBValidationResult, -) -from sigdb.utils.hashing import sha256 -from sigdb.utils.varint import decode_varint, encode_varint - -MAGIC: bytes = b"SIGT" -VERSION: int = 1 - -MAX_HEADER_BYTES: int = 65_536 -SHA256_SIZE: int = 32 -ED25519_SIGNATURE_SIZE: int = 64 - - -def read_sigdb_metadata(path: str | Path) -> dict[str, Any]: - p = Path(path) - with p.open("rb") as f: - magic = read_exact(f, 4) - if magic != MAGIC: - raise SigDBFormatError("invalid magic") - - version = read_exact(f, 1)[0] - if version != VERSION: - raise SigDBFormatError(f"unsupported sigdb version: {version}") - - header_len = struct.unpack(">I", read_exact(f, 4))[0] - if header_len > MAX_HEADER_BYTES: - raise SigDBFormatError("HEADER_DATA too large") - - header_raw = read_exact(f, header_len) - try: - header_any = json.loads(header_raw) - except json.JSONDecodeError as e: - raise SigDBFormatError("invalid HEADER_DATA json") from e - if not isinstance(header_any, dict): - raise SigDBFormatError("HEADER_DATA must be an object") - return cast(dict[str, Any], header_any) - - -def build_sigdb( - *, - rules: SigDBRules, - output_path: str | Path, - metadata: Mapping[str, Any] | None = None, - signing_key_hex: str | None = None, - zstd_level: int = 19, -) -> SigDBBuildResult: - output = Path(output_path) - - items, patterns = _compile_rules(rules) - automaton = _build_automaton(patterns) - - header_meta: dict[str, Any] = dict(metadata or {}) - header_meta.setdefault("format", "SIGDB-TRIE") - header_meta.setdefault("version", VERSION) - header_meta.setdefault("created", int(time.time())) - header_meta.setdefault("build", date.today().isoformat()) - header_meta.setdefault("items", len(items)) - header_meta.setdefault("patterns", len(patterns)) - header_meta.setdefault("certificate", "ed25519") - header_meta.setdefault("signature_algorithm", "ed25519") - - generated_signing_key_hex: str | None = None - if signing_key_hex is None: - generated_signing_key_hex = generate_signing_key_hex() - signing_key_hex = generated_signing_key_hex - - public_key_hex = derive_public_key_hex(signing_key_hex) - if "public_key" in header_meta and header_meta["public_key"] != public_key_hex: - raise SigDBFormatError("metadata.public_key does not match signing key") - header_meta.setdefault("public_key", public_key_hex) - - header_data = json.dumps( - header_meta, ensure_ascii=False, separators=(",", ":") - ).encode("utf-8") - if len(header_data) > MAX_HEADER_BYTES: - raise SigDBFormatError("HEADER_DATA too large") - - items_raw = json.dumps( - [item.to_compact() for item in items], - ensure_ascii=False, - separators=(",", ":"), - ).encode("utf-8") - - automaton_raw = _serialize_automaton(automaton) - data_hash = sha256(items_raw + automaton_raw) - signature = sign_hash(data_hash, signing_key_hex=signing_key_hex) - - items_data = compress_zstd(items_raw, level=zstd_level) - automaton_data = compress_zstd(automaton_raw, level=zstd_level) - - if len(items_data) > 0xFFFFFFFF or len(automaton_data) > 0xFFFFFFFF: - raise SigDBFormatError("compressed block too large for 32-bit length") - - output.parent.mkdir(parents=True, exist_ok=True) - with output.open("wb") as f: - f.write(MAGIC) - f.write(bytes([VERSION])) - - f.write(struct.pack(">I", len(header_data))) - f.write(header_data) - - f.write(struct.pack(">I", len(items_data))) - f.write(items_data) - - f.write(struct.pack(">I", len(automaton_data))) - f.write(automaton_data) - - f.write(data_hash) - f.write(signature) - - return SigDBBuildResult( - output_path=output, - public_key_hex=public_key_hex, - signing_key_hex=generated_signing_key_hex, - data_hash_hex=data_hash.hex(), - signature_hex=signature.hex(), - metadata=header_meta, - ) - - -def load_sigdb( - path: str | Path, - *, - public_key_hex: str | None = None, - verify_hash: bool = True, - verify_signature: bool = True, - max_items_json_size: int = 256 * 1024 * 1024, - max_automaton_size: int = 512 * 1024 * 1024, -) -> SigDBDatabase: - header, items_compressed, auto_compressed, stored_hash, signature = _read_container( - Path(path) - ) - - items_raw = decompress_zstd(items_compressed, max_output_size=max_items_json_size) - auto_raw = decompress_zstd(auto_compressed, max_output_size=max_automaton_size) - - if verify_hash: - computed = sha256(items_raw + auto_raw) - if computed != stored_hash: - raise SigDBIntegrityError("corrupted database (hash mismatch)") - - if verify_signature: - pk = public_key_hex or _metadata_public_key(header) - if pk is None: - raise SigDBSignatureError("public key not provided and not in metadata") - verify_hash_signature(stored_hash, signature, public_key_hex=pk) - - items = _parse_items(items_raw) - automaton = _deserialize_automaton(auto_raw) - return SigDBDatabase(metadata=header, items=items, automaton=automaton) - - -def validate_sigdb( - path: str | Path, - *, - public_key_hex: str | None = None, - verify_hash: bool = True, - verify_signature: bool = True, - max_items_json_size: int = 256 * 1024 * 1024, - max_automaton_size: int = 512 * 1024 * 1024, -) -> SigDBValidationResult: - errors: list[str] = [] - header: dict[str, Any] | None = None - stored_hash: bytes | None = None - computed_hash: bytes | None = None - signature_ok: bool | None = None - pk: str | None = None - - try: - ( - header, - items_compressed, - auto_compressed, - stored_hash, - signature, - ) = _read_container(Path(path)) - pk = public_key_hex or _metadata_public_key(header) - - items_raw = decompress_zstd( - items_compressed, max_output_size=max_items_json_size - ) - auto_raw = decompress_zstd(auto_compressed, max_output_size=max_automaton_size) - - if verify_hash: - computed_hash = sha256(items_raw + auto_raw) - if computed_hash != stored_hash: - errors.append("hash mismatch") - - if verify_signature: - if pk is None: - errors.append("missing public key") - else: - try: - verify_hash_signature(stored_hash, signature, public_key_hex=pk) - signature_ok = True - except Exception: - signature_ok = False - errors.append("bad signature") - except Exception as e: - errors.append(str(e)) - - return SigDBValidationResult( - ok=(len(errors) == 0), - errors=errors, - metadata=header, - public_key_hex=pk, - stored_hash_hex=(stored_hash.hex() if stored_hash else None), - computed_hash_hex=(computed_hash.hex() if computed_hash else None), - signature_ok=signature_ok, - ) - - -def _metadata_public_key(metadata: Mapping[str, Any]) -> str | None: - pk = metadata.get("public_key") - if isinstance(pk, str) and pk: - return pk - return None - - -def _read_container(path: Path) -> tuple[dict[str, Any], bytes, bytes, bytes, bytes]: - with path.open("rb") as f: - magic = read_exact(f, 4) - if magic != MAGIC: - raise SigDBFormatError("invalid magic") - - version = read_exact(f, 1)[0] - if version != VERSION: - raise SigDBFormatError(f"unsupported sigdb version: {version}") - - header_len = struct.unpack(">I", read_exact(f, 4))[0] - if header_len > MAX_HEADER_BYTES: - raise SigDBFormatError("HEADER_DATA too large") - - header_raw = read_exact(f, header_len) - try: - header_any = json.loads(header_raw) - except json.JSONDecodeError as e: - raise SigDBFormatError("invalid HEADER_DATA json") from e - if not isinstance(header_any, dict): - raise SigDBFormatError("HEADER_DATA must be an object") - header = cast(dict[str, Any], header_any) - - items_len = struct.unpack(">I", read_exact(f, 4))[0] - items_compressed = read_exact(f, items_len) - - auto_len = struct.unpack(">I", read_exact(f, 4))[0] - auto_compressed = read_exact(f, auto_len) - - stored_hash = read_exact(f, SHA256_SIZE) - signature = read_exact(f, ED25519_SIGNATURE_SIZE) - - if f.read(1): - raise SigDBFormatError("trailing data after signature") - - return header, items_compressed, auto_compressed, stored_hash, signature - - -def _compile_rules(rules: object) -> tuple[list[SigDBItem], dict[bytes, list[int]]]: - if not isinstance(rules, Mapping): - raise SigDBFormatError("rules must be a JSON object") - - # Expected shorthand: - # { "nginx": {"headers": {"Server": "nginx"}}, ... } - items: list[SigDBItem] = [] - rules_map = cast(Mapping[object, object], rules) - patterns: dict[bytes, list[int]] = {} - for key_any, value_any in rules_map.items(): - if not isinstance(key_any, str) or not key_any: - raise SigDBFormatError("rule keys must be non-empty strings") - if not isinstance(value_any, Mapping): - raise SigDBFormatError("rule value must be an object") - key = key_any - value = cast(Mapping[str, Any], value_any) - headers = parse_string_map(value.get("headers", {}), "headers") - item_id = len(items) - items.append(SigDBItem(key=key, headers=headers)) - - _add_map_patterns(patterns, "headers", headers, item_id) - for group in SIGDB_GROUPS: - if group == "headers": - continue - if group in SIGDB_GROUPS_MAP: - group_map = parse_string_map(value.get(group), group) - _add_map_patterns(patterns, group, group_map, item_id) - else: - group_values = parse_group_list(value.get(group), group) - _add_list_patterns(patterns, group, group_values, item_id) - - for pattern, ids in patterns.items(): - if len(ids) > 1: - patterns[pattern] = sorted(set(ids)) - - return items, patterns - - -def _add_pattern( - patterns: dict[bytes, list[int]], - pattern: str, - item_id: int, -) -> None: - pattern_bytes = pattern.strip().lower().encode("utf-8") - patterns.setdefault(pattern_bytes, []).append(item_id) - - -def _add_map_patterns( - patterns: dict[bytes, list[int]], - group: str, - values: Mapping[str, str], - item_id: int, -) -> None: - for key, needle in values.items(): - _add_pattern(patterns, format_map_pattern(group, key, needle), item_id) - - -def _add_list_patterns( - patterns: dict[bytes, list[int]], - group: str, - values: list[str], - item_id: int, -) -> None: - for needle in values: - _add_pattern(patterns, format_list_pattern(group, needle), item_id) - - -def _parse_items(data: bytes) -> list[SigDBItem]: - try: - decoded_any = json.loads(data) - except json.JSONDecodeError as e: - raise SigDBFormatError("invalid items json") from e - if not isinstance(decoded_any, list): - raise SigDBFormatError("items block must be a JSON array") - decoded = cast(list[Any], decoded_any) - - items: list[SigDBItem] = [] - for entry in decoded: - if not isinstance(entry, list): - raise SigDBFormatError("item must be [key, headers]") - entry_list = cast(list[object], entry) - if len(entry_list) != 2: - raise SigDBFormatError("item must be [key, headers]") - key_any, headers_any = entry_list[0], entry_list[1] - if not isinstance(key_any, str) or not key_any: - raise SigDBFormatError("item key must be a non-empty string") - headers = parse_string_map(headers_any, "headers") - items.append(SigDBItem(key=key_any, headers=headers)) - return items - - -def _serialize_automaton(a: Automaton) -> bytes: - out = bytearray() - - node_count = len(a.children_start) - edge_count = len(a.labels) - output_total = len(a.outputs) - - out.extend(encode_varint(node_count)) - out.extend(encode_varint(edge_count)) - out.extend(encode_varint(output_total)) - - for i in range(node_count): - out.extend(encode_varint(a.children_start[i])) - out.extend(encode_varint(a.children_count[i])) - out.extend(encode_varint(a.fail[i])) - out.extend(encode_varint(a.out_start[i])) - out.extend(encode_varint(a.out_count[i])) - - out.extend(a.labels) - for nxt in a.next_state: - out.extend(encode_varint(nxt)) - for out_id in a.outputs: - out.extend(encode_varint(out_id)) - - return bytes(out) - - -def _read_varint(data: bytes, pos: int) -> tuple[int, int]: - r = decode_varint(data, pos) - return r.value, r.offset - - -def _deserialize_automaton(data: bytes) -> Automaton: - pos = 0 - node_count, pos = _read_varint(data, pos) - edge_count, pos = _read_varint(data, pos) - output_total, pos = _read_varint(data, pos) - - children_start: list[int] = [0] * node_count - children_count: list[int] = [0] * node_count - fail: list[int] = [0] * node_count - out_start: list[int] = [0] * node_count - out_count: list[int] = [0] * node_count - - for i in range(node_count): - children_start[i], pos = _read_varint(data, pos) - children_count[i], pos = _read_varint(data, pos) - fail[i], pos = _read_varint(data, pos) - out_start[i], pos = _read_varint(data, pos) - out_count[i], pos = _read_varint(data, pos) - - labels = data[pos : pos + edge_count] - pos += edge_count - - next_state: list[int] = [0] * edge_count - for i in range(edge_count): - next_state[i], pos = _read_varint(data, pos) - - outputs: list[int] = [0] * output_total - for i in range(output_total): - outputs[i], pos = _read_varint(data, pos) - - if pos != len(data): - raise SigDBFormatError("automaton block has trailing bytes") - - return Automaton( - children_start=children_start, - children_count=children_count, - fail=fail, - out_start=out_start, - out_count=out_count, - labels=labels, - next_state=next_state, - outputs=outputs, - ) - - -def _build_automaton(patterns: Mapping[bytes, list[int]]) -> Automaton: - trans: list[dict[int, int]] = [{}] - out: list[list[int]] = [[]] - - for pattern_bytes, item_ids in patterns.items(): - state = 0 - for b in pattern_bytes: - nxt = trans[state].get(b) - if nxt is None: - nxt = len(trans) - trans[state][b] = nxt - trans.append({}) - out.append([]) - state = nxt - out[state].extend(item_ids) - - for i in range(len(out)): - if len(out[i]) > 1: - out[i] = sorted(set(out[i])) - - fail: list[int] = [0] * len(trans) - q: deque[int] = deque() - for nxt in trans[0].values(): - q.append(nxt) - - while q: - v = q.popleft() - for b, u in trans[v].items(): - q.append(u) - f = fail[v] - while f != 0 and b not in trans[f]: - f = fail[f] - fail[u] = trans[f].get(b, 0) - if out[fail[u]]: - out[u].extend(out[fail[u]]) - out[u] = sorted(set(out[u])) - - node_count = len(trans) - children_start: list[int] = [0] * node_count - children_count: list[int] = [0] * node_count - out_start: list[int] = [0] * node_count - out_count: list[int] = [0] * node_count - - labels = bytearray() - next_state: list[int] = [] - outputs: list[int] = [] - - edge_cursor = 0 - out_cursor = 0 - for i in range(node_count): - items = sorted(trans[i].items(), key=lambda kv: kv[0]) - children_start[i] = edge_cursor - children_count[i] = len(items) - for b, nxt in items: - labels.append(b) - next_state.append(nxt) - edge_cursor += 1 - - out_start[i] = out_cursor - out_count[i] = len(out[i]) - outputs.extend(out[i]) - out_cursor += len(out[i]) - - return Automaton( - children_start=children_start, - children_count=children_count, - fail=fail, - out_start=out_start, - out_count=out_count, - labels=bytes(labels), - next_state=next_state, - outputs=outputs, - ) +from __future__ import annotations + +import json +import struct +import time +from collections import deque +from collections.abc import Mapping +from datetime import date +from pathlib import Path +from typing import Any, cast + +from sigdb.compression import compress_zstd, decompress_zstd +from sigdb.crypto import ( + derive_public_key_hex, + generate_signing_key_hex, + sign_hash, + verify_hash_signature, +) +from sigdb.internal.groups import ( + SIGDB_GROUPS, + SIGDB_GROUPS_MAP, + format_list_pattern, + format_map_pattern, + parse_group_list, + parse_string_map, +) +from sigdb.storage import read_exact +from sigdb.types import ( + Automaton, + SigDBBuildResult, + SigDBDatabase, + SigDBFormatError, + SigDBIntegrityError, + SigDBItem, + SigDBRules, + SigDBSignatureError, + SigDBValidationResult, +) +from sigdb.utils.hashing import sha256 +from sigdb.utils.varint import decode_varint, encode_varint + +MAGIC: bytes = b"SIGT" +VERSION: int = 1 + +MAX_HEADER_BYTES: int = 65_536 +SHA256_SIZE: int = 32 +ED25519_SIGNATURE_SIZE: int = 64 + + +def read_sigdb_metadata(path: str | Path) -> dict[str, Any]: + p = Path(path) + with p.open("rb") as f: + magic = read_exact(f, 4) + if magic != MAGIC: + raise SigDBFormatError("invalid magic") + + version = read_exact(f, 1)[0] + if version != VERSION: + raise SigDBFormatError(f"unsupported sigdb version: {version}") + + header_len = struct.unpack(">I", read_exact(f, 4))[0] + if header_len > MAX_HEADER_BYTES: + raise SigDBFormatError("HEADER_DATA too large") + + header_raw = read_exact(f, header_len) + try: + header_any = json.loads(header_raw) + except json.JSONDecodeError as e: + raise SigDBFormatError("invalid HEADER_DATA json") from e + if not isinstance(header_any, dict): + raise SigDBFormatError("HEADER_DATA must be an object") + return cast(dict[str, Any], header_any) + + +def build_sigdb( + *, + rules: SigDBRules, + output_path: str | Path, + metadata: Mapping[str, Any] | None = None, + signing_key_hex: str | None = None, + zstd_level: int = 19, +) -> SigDBBuildResult: + output = Path(output_path) + + items, patterns = _compile_rules(rules) + automaton = _build_automaton(patterns) + + header_meta: dict[str, Any] = dict(metadata or {}) + header_meta.setdefault("format", "SIGDB-TRIE") + header_meta.setdefault("version", VERSION) + header_meta.setdefault("created", int(time.time())) + header_meta.setdefault("build", date.today().isoformat()) + header_meta.setdefault("items", len(items)) + header_meta.setdefault("patterns", len(patterns)) + header_meta.setdefault("certificate", "ed25519") + header_meta.setdefault("signature_algorithm", "ed25519") + + generated_signing_key_hex: str | None = None + if signing_key_hex is None: + generated_signing_key_hex = generate_signing_key_hex() + signing_key_hex = generated_signing_key_hex + + public_key_hex = derive_public_key_hex(signing_key_hex) + if "public_key" in header_meta and header_meta["public_key"] != public_key_hex: + raise SigDBFormatError("metadata.public_key does not match signing key") + header_meta.setdefault("public_key", public_key_hex) + + header_data = json.dumps( + header_meta, ensure_ascii=False, separators=(",", ":") + ).encode("utf-8") + if len(header_data) > MAX_HEADER_BYTES: + raise SigDBFormatError("HEADER_DATA too large") + + items_raw = json.dumps( + [item.to_compact() for item in items], + ensure_ascii=False, + separators=(",", ":"), + ).encode("utf-8") + + automaton_raw = _serialize_automaton(automaton) + data_hash = sha256(items_raw + automaton_raw) + signature = sign_hash(data_hash, signing_key_hex=signing_key_hex) + + items_data = compress_zstd(items_raw, level=zstd_level) + automaton_data = compress_zstd(automaton_raw, level=zstd_level) + + if len(items_data) > 0xFFFFFFFF or len(automaton_data) > 0xFFFFFFFF: + raise SigDBFormatError("compressed block too large for 32-bit length") + + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("wb") as f: + f.write(MAGIC) + f.write(bytes([VERSION])) + + f.write(struct.pack(">I", len(header_data))) + f.write(header_data) + + f.write(struct.pack(">I", len(items_data))) + f.write(items_data) + + f.write(struct.pack(">I", len(automaton_data))) + f.write(automaton_data) + + f.write(data_hash) + f.write(signature) + + return SigDBBuildResult( + output_path=output, + public_key_hex=public_key_hex, + signing_key_hex=generated_signing_key_hex, + data_hash_hex=data_hash.hex(), + signature_hex=signature.hex(), + metadata=header_meta, + ) + + +def load_sigdb( + path: str | Path, + *, + public_key_hex: str | None = None, + verify_hash: bool = True, + verify_signature: bool = True, + max_items_json_size: int = 256 * 1024 * 1024, + max_automaton_size: int = 512 * 1024 * 1024, +) -> SigDBDatabase: + header, items_compressed, auto_compressed, stored_hash, signature = _read_container( + Path(path) + ) + + items_raw = decompress_zstd(items_compressed, max_output_size=max_items_json_size) + auto_raw = decompress_zstd(auto_compressed, max_output_size=max_automaton_size) + + if verify_hash: + computed = sha256(items_raw + auto_raw) + if computed != stored_hash: + raise SigDBIntegrityError("corrupted database (hash mismatch)") + + if verify_signature: + pk = public_key_hex or _metadata_public_key(header) + if pk is None: + raise SigDBSignatureError("public key not provided and not in metadata") + verify_hash_signature(stored_hash, signature, public_key_hex=pk) + + items = _parse_items(items_raw) + automaton = _deserialize_automaton(auto_raw) + return SigDBDatabase(metadata=header, items=items, automaton=automaton) + + +def validate_sigdb( + path: str | Path, + *, + public_key_hex: str | None = None, + verify_hash: bool = True, + verify_signature: bool = True, + max_items_json_size: int = 256 * 1024 * 1024, + max_automaton_size: int = 512 * 1024 * 1024, +) -> SigDBValidationResult: + errors: list[str] = [] + header: dict[str, Any] | None = None + stored_hash: bytes | None = None + computed_hash: bytes | None = None + signature_ok: bool | None = None + pk: str | None = None + + try: + ( + header, + items_compressed, + auto_compressed, + stored_hash, + signature, + ) = _read_container(Path(path)) + pk = public_key_hex or _metadata_public_key(header) + + items_raw = decompress_zstd( + items_compressed, max_output_size=max_items_json_size + ) + auto_raw = decompress_zstd(auto_compressed, max_output_size=max_automaton_size) + + if verify_hash: + computed_hash = sha256(items_raw + auto_raw) + if computed_hash != stored_hash: + errors.append("hash mismatch") + + if verify_signature: + if pk is None: + errors.append("missing public key") + else: + try: + verify_hash_signature(stored_hash, signature, public_key_hex=pk) + signature_ok = True + except Exception: + signature_ok = False + errors.append("bad signature") + except Exception as e: + errors.append(str(e)) + + return SigDBValidationResult( + ok=(len(errors) == 0), + errors=errors, + metadata=header, + public_key_hex=pk, + stored_hash_hex=(stored_hash.hex() if stored_hash else None), + computed_hash_hex=(computed_hash.hex() if computed_hash else None), + signature_ok=signature_ok, + ) + + +def _metadata_public_key(metadata: Mapping[str, Any]) -> str | None: + pk = metadata.get("public_key") + if isinstance(pk, str) and pk: + return pk + return None + + +def _read_container(path: Path) -> tuple[dict[str, Any], bytes, bytes, bytes, bytes]: + with path.open("rb") as f: + magic = read_exact(f, 4) + if magic != MAGIC: + raise SigDBFormatError("invalid magic") + + version = read_exact(f, 1)[0] + if version != VERSION: + raise SigDBFormatError(f"unsupported sigdb version: {version}") + + header_len = struct.unpack(">I", read_exact(f, 4))[0] + if header_len > MAX_HEADER_BYTES: + raise SigDBFormatError("HEADER_DATA too large") + + header_raw = read_exact(f, header_len) + try: + header_any = json.loads(header_raw) + except json.JSONDecodeError as e: + raise SigDBFormatError("invalid HEADER_DATA json") from e + if not isinstance(header_any, dict): + raise SigDBFormatError("HEADER_DATA must be an object") + header = cast(dict[str, Any], header_any) + + items_len = struct.unpack(">I", read_exact(f, 4))[0] + items_compressed = read_exact(f, items_len) + + auto_len = struct.unpack(">I", read_exact(f, 4))[0] + auto_compressed = read_exact(f, auto_len) + + stored_hash = read_exact(f, SHA256_SIZE) + signature = read_exact(f, ED25519_SIGNATURE_SIZE) + + if f.read(1): + raise SigDBFormatError("trailing data after signature") + + return header, items_compressed, auto_compressed, stored_hash, signature + + +def _compile_rules(rules: object) -> tuple[list[SigDBItem], dict[bytes, list[int]]]: + if not isinstance(rules, Mapping): + raise SigDBFormatError("rules must be a JSON object") + + # Expected shorthand: + # { "nginx": {"headers": {"Server": "nginx"}}, ... } + items: list[SigDBItem] = [] + rules_map = cast(Mapping[object, object], rules) + patterns: dict[bytes, list[int]] = {} + for key_any, value_any in rules_map.items(): + if not isinstance(key_any, str) or not key_any: + raise SigDBFormatError("rule keys must be non-empty strings") + if not isinstance(value_any, Mapping): + raise SigDBFormatError("rule value must be an object") + key = key_any + value = cast(Mapping[str, Any], value_any) + headers = parse_string_map(value.get("headers", {}), "headers") + item_id = len(items) + items.append(SigDBItem(key=key, headers=headers)) + + _add_map_patterns(patterns, "headers", headers, item_id) + for group in SIGDB_GROUPS: + if group == "headers": + continue + if group in SIGDB_GROUPS_MAP: + group_map = parse_string_map(value.get(group), group) + _add_map_patterns(patterns, group, group_map, item_id) + else: + group_values = parse_group_list(value.get(group), group) + _add_list_patterns(patterns, group, group_values, item_id) + + for pattern, ids in patterns.items(): + if len(ids) > 1: + patterns[pattern] = sorted(set(ids)) + + return items, patterns + + +def _add_pattern( + patterns: dict[bytes, list[int]], + pattern: str, + item_id: int, +) -> None: + pattern_bytes = pattern.strip().lower().encode("utf-8") + patterns.setdefault(pattern_bytes, []).append(item_id) + + +def _add_map_patterns( + patterns: dict[bytes, list[int]], + group: str, + values: Mapping[str, str], + item_id: int, +) -> None: + for key, needle in values.items(): + _add_pattern(patterns, format_map_pattern(group, key, needle), item_id) + + +def _add_list_patterns( + patterns: dict[bytes, list[int]], + group: str, + values: list[str], + item_id: int, +) -> None: + for needle in values: + _add_pattern(patterns, format_list_pattern(group, needle), item_id) + + +def _parse_items(data: bytes) -> list[SigDBItem]: + try: + decoded_any = json.loads(data) + except json.JSONDecodeError as e: + raise SigDBFormatError("invalid items json") from e + if not isinstance(decoded_any, list): + raise SigDBFormatError("items block must be a JSON array") + decoded = cast(list[Any], decoded_any) + + items: list[SigDBItem] = [] + for entry in decoded: + if not isinstance(entry, list): + raise SigDBFormatError("item must be [key, headers]") + entry_list = cast(list[object], entry) + if len(entry_list) != 2: + raise SigDBFormatError("item must be [key, headers]") + key_any, headers_any = entry_list[0], entry_list[1] + if not isinstance(key_any, str) or not key_any: + raise SigDBFormatError("item key must be a non-empty string") + headers = parse_string_map(headers_any, "headers") + items.append(SigDBItem(key=key_any, headers=headers)) + return items + + +def _serialize_automaton(a: Automaton) -> bytes: + out = bytearray() + + node_count = len(a.children_start) + edge_count = len(a.labels) + output_total = len(a.outputs) + + out.extend(encode_varint(node_count)) + out.extend(encode_varint(edge_count)) + out.extend(encode_varint(output_total)) + + for i in range(node_count): + out.extend(encode_varint(a.children_start[i])) + out.extend(encode_varint(a.children_count[i])) + out.extend(encode_varint(a.fail[i])) + out.extend(encode_varint(a.out_start[i])) + out.extend(encode_varint(a.out_count[i])) + + out.extend(a.labels) + for nxt in a.next_state: + out.extend(encode_varint(nxt)) + for out_id in a.outputs: + out.extend(encode_varint(out_id)) + + return bytes(out) + + +def _read_varint(data: bytes, pos: int) -> tuple[int, int]: + r = decode_varint(data, pos) + return r.value, r.offset + + +def _deserialize_automaton(data: bytes) -> Automaton: + pos = 0 + node_count, pos = _read_varint(data, pos) + edge_count, pos = _read_varint(data, pos) + output_total, pos = _read_varint(data, pos) + + children_start: list[int] = [0] * node_count + children_count: list[int] = [0] * node_count + fail: list[int] = [0] * node_count + out_start: list[int] = [0] * node_count + out_count: list[int] = [0] * node_count + + for i in range(node_count): + children_start[i], pos = _read_varint(data, pos) + children_count[i], pos = _read_varint(data, pos) + fail[i], pos = _read_varint(data, pos) + out_start[i], pos = _read_varint(data, pos) + out_count[i], pos = _read_varint(data, pos) + + labels = data[pos : pos + edge_count] + pos += edge_count + + next_state: list[int] = [0] * edge_count + for i in range(edge_count): + next_state[i], pos = _read_varint(data, pos) + + outputs: list[int] = [0] * output_total + for i in range(output_total): + outputs[i], pos = _read_varint(data, pos) + + if pos != len(data): + raise SigDBFormatError("automaton block has trailing bytes") + + return Automaton( + children_start=children_start, + children_count=children_count, + fail=fail, + out_start=out_start, + out_count=out_count, + labels=labels, + next_state=next_state, + outputs=outputs, + ) + + +def _build_automaton(patterns: Mapping[bytes, list[int]]) -> Automaton: + trans: list[dict[int, int]] = [{}] + out: list[list[int]] = [[]] + + for pattern_bytes, item_ids in patterns.items(): + state = 0 + for b in pattern_bytes: + nxt = trans[state].get(b) + if nxt is None: + nxt = len(trans) + trans[state][b] = nxt + trans.append({}) + out.append([]) + state = nxt + out[state].extend(item_ids) + + for i in range(len(out)): + if len(out[i]) > 1: + out[i] = sorted(set(out[i])) + + fail: list[int] = [0] * len(trans) + q: deque[int] = deque() + for nxt in trans[0].values(): + q.append(nxt) + + while q: + v = q.popleft() + for b, u in trans[v].items(): + q.append(u) + f = fail[v] + while f != 0 and b not in trans[f]: + f = fail[f] + fail[u] = trans[f].get(b, 0) + if out[fail[u]]: + out[u].extend(out[fail[u]]) + out[u] = sorted(set(out[u])) + + node_count = len(trans) + children_start: list[int] = [0] * node_count + children_count: list[int] = [0] * node_count + out_start: list[int] = [0] * node_count + out_count: list[int] = [0] * node_count + + labels = bytearray() + next_state: list[int] = [] + outputs: list[int] = [] + + edge_cursor = 0 + out_cursor = 0 + for i in range(node_count): + items = sorted(trans[i].items(), key=lambda kv: kv[0]) + children_start[i] = edge_cursor + children_count[i] = len(items) + for b, nxt in items: + labels.append(b) + next_state.append(nxt) + edge_cursor += 1 + + out_start[i] = out_cursor + out_count[i] = len(out[i]) + outputs.extend(out[i]) + out_cursor += len(out[i]) + + return Automaton( + children_start=children_start, + children_count=children_count, + fail=fail, + out_start=out_start, + out_count=out_count, + labels=bytes(labels), + next_state=next_state, + outputs=outputs, + ) diff --git a/src/sigdb/types/__init__.py b/src/sigdb/types/__init__.py index 7878865..b2dd40e 100644 --- a/src/sigdb/types/__init__.py +++ b/src/sigdb/types/__init__.py @@ -1,51 +1,51 @@ -from __future__ import annotations - -from sigdb.types.exceptions import ( - SigDBError, - SigDBFormatError, - SigDBIntegrityError, - SigDBSignatureError, -) -from sigdb.types.models import ( - Automaton, - DecodeResult, - SigDBBuildResult, - SigDBDatabase, - SigDBItem, - SigDBMatchResult, - SigDBValidationResult, -) -from sigdb.types.rules import ( - SigDBGroupListName, - SigDBGroupMapName, - SigDBGroupName, - SigDBHtmlList, - SigDBHtmlPattern, - SigDBHtmlSpec, - SigDBRuleDefinition, - SigDBRules, - SigDBSearchDefinition, -) - -__all__ = [ - "Automaton", - "DecodeResult", - "SigDBBuildResult", - "SigDBDatabase", - "SigDBError", - "SigDBFormatError", - "SigDBIntegrityError", - "SigDBItem", - "SigDBMatchResult", - "SigDBGroupListName", - "SigDBGroupMapName", - "SigDBGroupName", - "SigDBHtmlList", - "SigDBHtmlPattern", - "SigDBHtmlSpec", - "SigDBRuleDefinition", - "SigDBRules", - "SigDBSearchDefinition", - "SigDBSignatureError", - "SigDBValidationResult", -] +from __future__ import annotations + +from sigdb.types.exceptions import ( + SigDBError, + SigDBFormatError, + SigDBIntegrityError, + SigDBSignatureError, +) +from sigdb.types.models import ( + Automaton, + DecodeResult, + SigDBBuildResult, + SigDBDatabase, + SigDBItem, + SigDBMatchResult, + SigDBValidationResult, +) +from sigdb.types.rules import ( + SigDBGroupListName, + SigDBGroupMapName, + SigDBGroupName, + SigDBHtmlList, + SigDBHtmlPattern, + SigDBHtmlSpec, + SigDBRuleDefinition, + SigDBRules, + SigDBSearchDefinition, +) + +__all__ = [ + "Automaton", + "DecodeResult", + "SigDBBuildResult", + "SigDBDatabase", + "SigDBError", + "SigDBFormatError", + "SigDBGroupListName", + "SigDBGroupMapName", + "SigDBGroupName", + "SigDBHtmlList", + "SigDBHtmlPattern", + "SigDBHtmlSpec", + "SigDBIntegrityError", + "SigDBItem", + "SigDBMatchResult", + "SigDBRuleDefinition", + "SigDBRules", + "SigDBSearchDefinition", + "SigDBSignatureError", + "SigDBValidationResult", +] diff --git a/tests/test_container_errors.py b/tests/test_container_errors.py index f34138a..f607484 100644 --- a/tests/test_container_errors.py +++ b/tests/test_container_errors.py @@ -1,91 +1,91 @@ -from __future__ import annotations - -from collections.abc import Callable -from pathlib import Path -from typing import Any, TypeVar - -from sigdb.core import build_sigdb, load_sigdb, validate_sigdb -from sigdb.crypto import derive_public_key_hex, generate_signing_key_hex -from sigdb.types import SigDBFormatError - -TExc = TypeVar("TExc", bound=BaseException) - - -def assert_true(value: bool, msg: str) -> None: - if not value: - raise AssertionError(msg) - - -def assert_in(needle: str, haystack: str, msg: str) -> None: - if needle not in haystack: - raise AssertionError(f"{msg}: {needle!r} not in {haystack!r}") - - -def assert_raises( - exc_type: type[TExc], - fn: Callable[[], object], - *, - msg_contains: str | None = None, -) -> TExc: - try: - fn() - except exc_type as e: - if msg_contains is not None: - assert_in(msg_contains, str(e), "exception message mismatch") - return e - except Exception as e: - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e - raise AssertionError(f"expected {exc_type.__name__}, got no exception") - - -def main() -> None: - out = Path(__file__).with_name("test_trailing_data.sigdb") - out_bad = Path(__file__).with_name("test_trailing_data_extra.sigdb") - - rules: dict[str, Any] = { - "nginx": {"headers": {"Server": "nginx"}}, - } - - metadata: dict[str, Any] = { - "dataset": "Example", - "version": "1.0.0", - "author": "Container Checker", - "contact": "container@reekeer.hidden", - "license": "MIT", - "repository": "https://github.com/reekeer/sigdb", - "homepage": "https://reekeer.com", - "description": "Container layout tests", - } - - signing_key_hex = generate_signing_key_hex() - public_key_hex = derive_public_key_hex(signing_key_hex) - metadata["public_key"] = public_key_hex - - build_sigdb( - rules=rules, - output_path=out, - metadata=metadata, - signing_key_hex=signing_key_hex, - ) - - out_bad.write_bytes(out.read_bytes() + b"\x00") - - assert_raises( - SigDBFormatError, - lambda: load_sigdb(out_bad), - msg_contains="trailing data after signature", - ) - - v = validate_sigdb(out_bad) - assert_true(not v.ok, "validate_sigdb must fail for trailing data") - assert_true(len(v.errors) > 0, "validate_sigdb must report errors") - assert_true( - any("trailing data after signature" in e for e in v.errors), - "missing expected error", - ) - - -if __name__ == "__main__": - main() +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import Any, TypeVar + +from sigdb.core import build_sigdb, load_sigdb, validate_sigdb +from sigdb.crypto import derive_public_key_hex, generate_signing_key_hex +from sigdb.types import SigDBFormatError + +TExc = TypeVar("TExc", bound=BaseException) + + +def assert_true(value: bool, msg: str) -> None: + if not value: + raise AssertionError(msg) + + +def assert_in(needle: str, haystack: str, msg: str) -> None: + if needle not in haystack: + raise AssertionError(f"{msg}: {needle!r} not in {haystack!r}") + + +def assert_raises( + exc_type: type[TExc], + fn: Callable[[], object], + *, + msg_contains: str | None = None, +) -> TExc: + try: + fn() + except exc_type as e: + if msg_contains is not None: + assert_in(msg_contains, str(e), "exception message mismatch") + return e + except Exception as e: + raise AssertionError( + f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" + ) from e + raise AssertionError(f"expected {exc_type.__name__}, got no exception") + + +def main() -> None: + out = Path(__file__).with_name("test_trailing_data.sigdb") + out_bad = Path(__file__).with_name("test_trailing_data_extra.sigdb") + + rules: dict[str, Any] = { + "nginx": {"headers": {"Server": "nginx"}}, + } + + metadata: dict[str, Any] = { + "dataset": "Example", + "version": "1.0.0", + "author": "Container Checker", + "contact": "container@reekeer.hidden", + "license": "MIT", + "repository": "https://github.com/reekeer/sigdb", + "homepage": "https://reekeer.com", + "description": "Container layout tests", + } + + signing_key_hex = generate_signing_key_hex() + public_key_hex = derive_public_key_hex(signing_key_hex) + metadata["public_key"] = public_key_hex + + build_sigdb( + rules=rules, + output_path=out, + metadata=metadata, + signing_key_hex=signing_key_hex, + ) + + out_bad.write_bytes(out.read_bytes() + b"\x00") + + assert_raises( + SigDBFormatError, + lambda: load_sigdb(out_bad), + msg_contains="trailing data after signature", + ) + + v = validate_sigdb(out_bad) + assert_true(not v.ok, "validate_sigdb must fail for trailing data") + assert_true(len(v.errors) > 0, "validate_sigdb must report errors") + assert_true( + any("trailing data after signature" in e for e in v.errors), + "missing expected error", + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 5d4e025..db2070e 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -1,84 +1,84 @@ -from __future__ import annotations - -import hashlib -from collections.abc import Callable -from typing import TypeVar - -from sigdb.crypto import ( - derive_public_key_hex, - generate_signing_key_hex, - sign_hash, - verify_hash_signature, -) -from sigdb.types import SigDBSignatureError - -TExc = TypeVar("TExc", bound=BaseException) - - -def assert_true(value: bool, msg: str) -> None: - if not value: - raise AssertionError(msg) - - -def assert_eq(left: object, right: object, msg: str) -> None: - if left != right: - raise AssertionError(f"{msg}: {left!r} != {right!r}") - - -def assert_in(needle: str, haystack: str, msg: str) -> None: - if needle not in haystack: - raise AssertionError(f"{msg}: {needle!r} not in {haystack!r}") - - -def assert_raises( - exc_type: type[TExc], - fn: Callable[[], object], - *, - msg_contains: str | None = None, -) -> TExc: - try: - fn() - except exc_type as e: - if msg_contains is not None: - assert_in(msg_contains, str(e), "exception message mismatch") - return e - except Exception as e: - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e - raise AssertionError(f"expected {exc_type.__name__}, got no exception") - - -def main() -> None: - signing_key_hex = generate_signing_key_hex() - assert_eq(len(signing_key_hex), 64, "signing_key_hex length mismatch") - assert_eq(len(bytes.fromhex(signing_key_hex)), 32, "signing_key_hex bytes mismatch") - - public_key_hex = derive_public_key_hex(signing_key_hex) - assert_eq(len(public_key_hex), 64, "public_key_hex length mismatch") - assert_eq(len(bytes.fromhex(public_key_hex)), 32, "public_key_hex bytes mismatch") - - data_hash = hashlib.sha256(b"sigdb").digest() - signature = sign_hash(data_hash, signing_key_hex=signing_key_hex) - assert_eq(len(signature), 64, "signature length mismatch") - - verify_hash_signature(data_hash, signature, public_key_hex=public_key_hex) - - assert_raises( - SigDBSignatureError, - lambda: verify_hash_signature(data_hash, signature, public_key_hex="0"), - msg_contains="invalid public key hex", - ) - - other_pk = derive_public_key_hex(generate_signing_key_hex()) - assert_true(other_pk != public_key_hex, "generated public keys must differ") - assert_raises( - SigDBSignatureError, - lambda: verify_hash_signature(data_hash, signature, public_key_hex=other_pk), - msg_contains="invalid signature", - ) - - -if __name__ == "__main__": - main() - +from __future__ import annotations + +import hashlib +from collections.abc import Callable +from typing import TypeVar + +from sigdb.crypto import ( + derive_public_key_hex, + generate_signing_key_hex, + sign_hash, + verify_hash_signature, +) +from sigdb.types import SigDBSignatureError + +TExc = TypeVar("TExc", bound=BaseException) + + +def assert_true(value: bool, msg: str) -> None: + if not value: + raise AssertionError(msg) + + +def assert_eq(left: object, right: object, msg: str) -> None: + if left != right: + raise AssertionError(f"{msg}: {left!r} != {right!r}") + + +def assert_in(needle: str, haystack: str, msg: str) -> None: + if needle not in haystack: + raise AssertionError(f"{msg}: {needle!r} not in {haystack!r}") + + +def assert_raises( + exc_type: type[TExc], + fn: Callable[[], object], + *, + msg_contains: str | None = None, +) -> TExc: + try: + fn() + except exc_type as e: + if msg_contains is not None: + assert_in(msg_contains, str(e), "exception message mismatch") + return e + except Exception as e: + raise AssertionError( + f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" + ) from e + raise AssertionError(f"expected {exc_type.__name__}, got no exception") + + +def main() -> None: + signing_key_hex = generate_signing_key_hex() + assert_eq(len(signing_key_hex), 64, "signing_key_hex length mismatch") + assert_eq(len(bytes.fromhex(signing_key_hex)), 32, "signing_key_hex bytes mismatch") + + public_key_hex = derive_public_key_hex(signing_key_hex) + assert_eq(len(public_key_hex), 64, "public_key_hex length mismatch") + assert_eq(len(bytes.fromhex(public_key_hex)), 32, "public_key_hex bytes mismatch") + + data_hash = hashlib.sha256(b"sigdb").digest() + signature = sign_hash(data_hash, signing_key_hex=signing_key_hex) + assert_eq(len(signature), 64, "signature length mismatch") + + verify_hash_signature(data_hash, signature, public_key_hex=public_key_hex) + + assert_raises( + SigDBSignatureError, + lambda: verify_hash_signature(data_hash, signature, public_key_hex="0"), + msg_contains="invalid public key hex", + ) + + other_pk = derive_public_key_hex(generate_signing_key_hex()) + assert_true(other_pk != public_key_hex, "generated public keys must differ") + assert_raises( + SigDBSignatureError, + lambda: verify_hash_signature(data_hash, signature, public_key_hex=other_pk), + msg_contains="invalid signature", + ) + + +if __name__ == "__main__": + main() + diff --git a/tests/test_varint.py b/tests/test_varint.py index e8fbc25..93eb7bb 100644 --- a/tests/test_varint.py +++ b/tests/test_varint.py @@ -1,86 +1,86 @@ -from __future__ import annotations - -from collections.abc import Callable -from typing import TypeVar - -from sigdb.types import SigDBFormatError -from sigdb.utils.varint import decode_varint, encode_varint - -TExc = TypeVar("TExc", bound=BaseException) - - -def assert_true(value: bool, msg: str) -> None: - if not value: - raise AssertionError(msg) - - -def assert_eq(left: object, right: object, msg: str) -> None: - if left != right: - raise AssertionError(f"{msg}: {left!r} != {right!r}") - - -def assert_in(needle: str, haystack: str, msg: str) -> None: - if needle not in haystack: - raise AssertionError(f"{msg}: {needle!r} not in {haystack!r}") - - -def assert_raises( - exc_type: type[TExc], - fn: Callable[[], object], - *, - msg_contains: str | None = None, -) -> TExc: - try: - fn() - except exc_type as e: - if msg_contains is not None: - assert_in(msg_contains, str(e), "exception message mismatch") - return e - except Exception as e: - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e - raise AssertionError(f"expected {exc_type.__name__}, got no exception") - - -def main() -> None: - for value in [0, 1, 2, 10, 127, 128, 255, 300, 16_384, 2**32, 2**63 - 1]: - encoded = encode_varint(value) - decoded = decode_varint(encoded, 0) - assert_eq(decoded.value, value, "varint roundtrip mismatch") - assert_eq(decoded.offset, len(encoded), "varint decode offset mismatch") - - assert_raises( - ValueError, - lambda: encode_varint(-1), - msg_contains="negative", - ) - assert_raises( - SigDBFormatError, - lambda: decode_varint(b"", 0), - msg_contains="truncated varint", - ) - assert_raises( - SigDBFormatError, - lambda: decode_varint(b"\x80", 0), - msg_contains="truncated varint", - ) - assert_raises( - SigDBFormatError, - lambda: decode_varint(b"\x80" * 10, 0), - msg_contains="varint too long", - ) - assert_raises( - SigDBFormatError, - lambda: decode_varint(b"\x00", -1), - msg_contains="negative offset", - ) - - encoded = encode_varint(300) - decoded = decode_varint(encoded, 1) - assert_true(decoded.value != 300, "decode must respect offset") - - -if __name__ == "__main__": - main() - +from __future__ import annotations + +from collections.abc import Callable +from typing import TypeVar + +from sigdb.types import SigDBFormatError +from sigdb.utils.varint import decode_varint, encode_varint + +TExc = TypeVar("TExc", bound=BaseException) + + +def assert_true(value: bool, msg: str) -> None: + if not value: + raise AssertionError(msg) + + +def assert_eq(left: object, right: object, msg: str) -> None: + if left != right: + raise AssertionError(f"{msg}: {left!r} != {right!r}") + + +def assert_in(needle: str, haystack: str, msg: str) -> None: + if needle not in haystack: + raise AssertionError(f"{msg}: {needle!r} not in {haystack!r}") + + +def assert_raises( + exc_type: type[TExc], + fn: Callable[[], object], + *, + msg_contains: str | None = None, +) -> TExc: + try: + fn() + except exc_type as e: + if msg_contains is not None: + assert_in(msg_contains, str(e), "exception message mismatch") + return e + except Exception as e: + raise AssertionError( + f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" + ) from e + raise AssertionError(f"expected {exc_type.__name__}, got no exception") + + +def main() -> None: + for value in [0, 1, 2, 10, 127, 128, 255, 300, 16_384, 2**32, 2**63 - 1]: + encoded = encode_varint(value) + decoded = decode_varint(encoded, 0) + assert_eq(decoded.value, value, "varint roundtrip mismatch") + assert_eq(decoded.offset, len(encoded), "varint decode offset mismatch") + + assert_raises( + ValueError, + lambda: encode_varint(-1), + msg_contains="negative", + ) + assert_raises( + SigDBFormatError, + lambda: decode_varint(b"", 0), + msg_contains="truncated varint", + ) + assert_raises( + SigDBFormatError, + lambda: decode_varint(b"\x80", 0), + msg_contains="truncated varint", + ) + assert_raises( + SigDBFormatError, + lambda: decode_varint(b"\x80" * 10, 0), + msg_contains="varint too long", + ) + assert_raises( + SigDBFormatError, + lambda: decode_varint(b"\x00", -1), + msg_contains="negative offset", + ) + + encoded = encode_varint(300) + decoded = decode_varint(encoded, 1) + assert_true(decoded.value != 300, "decode must respect offset") + + +if __name__ == "__main__": + main() + From 6fdaaa4f76fcfc2d5ce71d8635a24fdd710bac29 Mon Sep 17 00:00:00 2001 From: "reekeer[bot]" Date: Mon, 16 Mar 2026 17:40:38 +0300 Subject: [PATCH 29/36] style(black): format code --- ruff.json | 35 ++++++++++++++++++++++++++++++++++ scripts/reekeerBot.py | 4 +++- src/sigdb/core/reader.py | 4 +--- src/sigdb/crypto/ed25519.py | 4 +--- src/sigdb/format/trie.py | 12 +++--------- tests/test_container_errors.py | 4 +--- tests/test_crypto.py | 5 +---- tests/test_varint.py | 5 +---- tests/validate_lib.py | 4 +--- 9 files changed, 47 insertions(+), 30 deletions(-) create mode 100644 ruff.json diff --git a/ruff.json b/ruff.json new file mode 100644 index 0000000..9bc5372 --- /dev/null +++ b/ruff.json @@ -0,0 +1,35 @@ +[ + { + "cell": null, + "code": "SIM118", + "end_location": { + "column": 27, + "row": 100 + }, + "filename": "/mnt/c/Users/Pavel/Documents/projects/sigdb/src/sigdb/internal/groups.py", + "fix": { + "applicability": "unsafe", + "edits": [ + { + "content": "", + "end_location": { + "column": 27, + "row": 100 + }, + "location": { + "column": 20, + "row": 100 + } + } + ], + "message": "Remove `.keys()`" + }, + "location": { + "column": 9, + "row": 100 + }, + "message": "Use `key in dict` instead of `key in dict.keys()`", + "noqa_row": 100, + "url": "https://docs.astral.sh/ruff/rules/in-dict-keys" + } +] \ No newline at end of file diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py index ce4f2e9..cee344f 100644 --- a/scripts/reekeerBot.py +++ b/scripts/reekeerBot.py @@ -82,7 +82,9 @@ def create_jwt(*, app_id: int, private_key_path: Path) -> str: # PyJWT expects "iss" to be a string. "iss": str(app_id), } - return jwt.encode(payload, private_key, algorithm="RS256") # pyright: ignore[reportUnknownMemberType] + return jwt.encode( + payload, private_key, algorithm="RS256" + ) # pyright: ignore[reportUnknownMemberType] def installation_token(config: Config) -> str: diff --git a/src/sigdb/core/reader.py b/src/sigdb/core/reader.py index f86464e..34e3c5f 100644 --- a/src/sigdb/core/reader.py +++ b/src/sigdb/core/reader.py @@ -106,9 +106,7 @@ def match(self, head: str) -> SigDBMatchResult: ostart = out_start[state] item_id = outputs[ostart] item = self._items[item_id] - return SigDBMatchResult( - result=True, item_id=item_id, item=item, head=normalized - ) + return SigDBMatchResult(result=True, item_id=item_id, item=item, head=normalized) return SigDBMatchResult(result=False, item_id=None, item=None, head=normalized) diff --git a/src/sigdb/crypto/ed25519.py b/src/sigdb/crypto/ed25519.py index 3bdae08..6f6a341 100644 --- a/src/sigdb/crypto/ed25519.py +++ b/src/sigdb/crypto/ed25519.py @@ -36,9 +36,7 @@ def sign_hash(data_hash: bytes, *, signing_key_hex: str) -> bytes: return key.sign(data_hash).signature -def verify_hash_signature( - data_hash: bytes, signature: bytes, *, public_key_hex: str -) -> None: +def verify_hash_signature(data_hash: bytes, signature: bytes, *, public_key_hex: str) -> None: _SigningKey, VerifyKey, BadSignatureError = _import_nacl() try: verify = VerifyKey(bytes.fromhex(public_key_hex)) diff --git a/src/sigdb/format/trie.py b/src/sigdb/format/trie.py index 5be4001..166fcea 100644 --- a/src/sigdb/format/trie.py +++ b/src/sigdb/format/trie.py @@ -105,9 +105,7 @@ def build_sigdb( raise SigDBFormatError("metadata.public_key does not match signing key") header_meta.setdefault("public_key", public_key_hex) - header_data = json.dumps( - header_meta, ensure_ascii=False, separators=(",", ":") - ).encode("utf-8") + header_data = json.dumps(header_meta, ensure_ascii=False, separators=(",", ":")).encode("utf-8") if len(header_data) > MAX_HEADER_BYTES: raise SigDBFormatError("HEADER_DATA too large") @@ -163,9 +161,7 @@ def load_sigdb( max_items_json_size: int = 256 * 1024 * 1024, max_automaton_size: int = 512 * 1024 * 1024, ) -> SigDBDatabase: - header, items_compressed, auto_compressed, stored_hash, signature = _read_container( - Path(path) - ) + header, items_compressed, auto_compressed, stored_hash, signature = _read_container(Path(path)) items_raw = decompress_zstd(items_compressed, max_output_size=max_items_json_size) auto_raw = decompress_zstd(auto_compressed, max_output_size=max_automaton_size) @@ -212,9 +208,7 @@ def validate_sigdb( ) = _read_container(Path(path)) pk = public_key_hex or _metadata_public_key(header) - items_raw = decompress_zstd( - items_compressed, max_output_size=max_items_json_size - ) + items_raw = decompress_zstd(items_compressed, max_output_size=max_items_json_size) auto_raw = decompress_zstd(auto_compressed, max_output_size=max_automaton_size) if verify_hash: diff --git a/tests/test_container_errors.py b/tests/test_container_errors.py index f607484..8d7339b 100644 --- a/tests/test_container_errors.py +++ b/tests/test_container_errors.py @@ -34,9 +34,7 @@ def assert_raises( assert_in(msg_contains, str(e), "exception message mismatch") return e except Exception as e: - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e + raise AssertionError(f"expected {exc_type.__name__}, got {type(e).__name__}: {e}") from e raise AssertionError(f"expected {exc_type.__name__}, got no exception") diff --git a/tests/test_crypto.py b/tests/test_crypto.py index db2070e..c81a983 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -43,9 +43,7 @@ def assert_raises( assert_in(msg_contains, str(e), "exception message mismatch") return e except Exception as e: - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e + raise AssertionError(f"expected {exc_type.__name__}, got {type(e).__name__}: {e}") from e raise AssertionError(f"expected {exc_type.__name__}, got no exception") @@ -81,4 +79,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/tests/test_varint.py b/tests/test_varint.py index 93eb7bb..6c8033e 100644 --- a/tests/test_varint.py +++ b/tests/test_varint.py @@ -37,9 +37,7 @@ def assert_raises( assert_in(msg_contains, str(e), "exception message mismatch") return e except Exception as e: - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e + raise AssertionError(f"expected {exc_type.__name__}, got {type(e).__name__}: {e}") from e raise AssertionError(f"expected {exc_type.__name__}, got no exception") @@ -83,4 +81,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/tests/validate_lib.py b/tests/validate_lib.py index 8f9db3a..37e78cd 100644 --- a/tests/validate_lib.py +++ b/tests/validate_lib.py @@ -51,9 +51,7 @@ def assert_raises( assert_in(msg_contains, str(e), "exception message mismatch") return e except Exception as e: # pragma: no cover - raise AssertionError( - f"expected {exc_type.__name__}, got {type(e).__name__}: {e}" - ) from e + raise AssertionError(f"expected {exc_type.__name__}, got {type(e).__name__}: {e}") from e raise AssertionError(f"expected {exc_type.__name__}, got no exception") From 9a09e3e229a80ae51a2f511fe013ae4e6309b0e8 Mon Sep 17 00:00:00 2001 From: "reekeer[bot]" Date: Mon, 16 Mar 2026 17:58:00 +0300 Subject: [PATCH 30/36] chore: remove reekeerBot cache files and add them to .gitignore --- .gitignore | 4 ++++ ruff.json | 35 ----------------------------------- 2 files changed, 4 insertions(+), 35 deletions(-) delete mode 100644 ruff.json diff --git a/.gitignore b/.gitignore index 24b47bd..f2c83e7 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,10 @@ venv/ *.env *.pem +# reekeerBot cache files +ruff.json +pyright.json + # Packaging / build outputs build/ dist/ diff --git a/ruff.json b/ruff.json deleted file mode 100644 index 9bc5372..0000000 --- a/ruff.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "cell": null, - "code": "SIM118", - "end_location": { - "column": 27, - "row": 100 - }, - "filename": "/mnt/c/Users/Pavel/Documents/projects/sigdb/src/sigdb/internal/groups.py", - "fix": { - "applicability": "unsafe", - "edits": [ - { - "content": "", - "end_location": { - "column": 27, - "row": 100 - }, - "location": { - "column": 20, - "row": 100 - } - } - ], - "message": "Remove `.keys()`" - }, - "location": { - "column": 9, - "row": 100 - }, - "message": "Use `key in dict` instead of `key in dict.keys()`", - "noqa_row": 100, - "url": "https://docs.astral.sh/ruff/rules/in-dict-keys" - } -] \ No newline at end of file From a68b220238bf9347e5319532e0f58f2ab029416f Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Mon, 16 Mar 2026 18:33:44 +0300 Subject: [PATCH 31/36] refactor: remove local file paths and use root-relative paths --- scripts/reekeerBot.py | 77 ++++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py index cee344f..fd62f53 100644 --- a/scripts/reekeerBot.py +++ b/scripts/reekeerBot.py @@ -2,6 +2,7 @@ import json import os import subprocess +import sys import time from dataclasses import dataclass from pathlib import Path @@ -38,9 +39,11 @@ def run_cmd( args: list[str], *, capture: bool = False, + quiet: bool = False, check: bool = True, ) -> str | None: - result = subprocess.run(args, capture_output=capture, text=True) + capture_output = capture or quiet + result = subprocess.run(args, capture_output=capture_output, text=True) if check and result.returncode != 0: stderr = (result.stderr or "").strip() raise RuntimeError(f"Command failed ({result.returncode}): {args}\n{stderr}") @@ -51,9 +54,11 @@ def run_shell( cmd: str, *, capture: bool = False, + quiet: bool = False, check: bool = True, ) -> str | None: - result = subprocess.run(cmd, shell=True, capture_output=capture, text=True) + capture_output = capture or quiet + result = subprocess.run(cmd, shell=True, capture_output=capture_output, text=True) if check and result.returncode != 0: stderr = (result.stderr or "").strip() raise RuntimeError(f"Command failed ({result.returncode}): {cmd}\n{stderr}") @@ -74,6 +79,11 @@ def load_json(path: Path) -> Any | None: return None +def _redact_local_paths(text: str) -> str: + repo_prefix = REPO_ROOT.as_posix().rstrip("/") + "/" + return text.replace(repo_prefix, "") + + def create_jwt(*, app_id: int, private_key_path: Path) -> str: private_key = private_key_path.read_text(encoding="utf-8") payload = { @@ -82,9 +92,7 @@ def create_jwt(*, app_id: int, private_key_path: Path) -> str: # PyJWT expects "iss" to be a string. "iss": str(app_id), } - return jwt.encode( - payload, private_key, algorithm="RS256" - ) # pyright: ignore[reportUnknownMemberType] + return jwt.encode(payload, private_key, algorithm="RS256") # pyright: ignore[reportUnknownMemberType] def installation_token(config: Config) -> str: @@ -147,19 +155,37 @@ def push(*, token: str, repo: str, branch: str) -> None: git(["push", f"https://x-access-token:{token}@github.com/{repo}.git", branch]) -def ruff_fix() -> None: - run_shell("ruff check . --fix", check=False) +def ruff_fix() -> list[dict[str, Any]]: + run_shell("ruff check . --fix", check=False, quiet=True) commit_if_changes("style(ruff): auto-fix lint issues") - run_shell("ruff check . --output-format=json > ruff.json || true", check=False) + + out = ( + run_shell("ruff check . --output-format=json", capture=True, check=False, quiet=True) or "" + ) + try: + data: Any = json.loads(out) if out else [] + except ValueError: + return [] + return _as_dict_list(data) def black_fix() -> None: - run_shell("black .", check=False) + run_shell("black .", check=False, quiet=True) commit_if_changes("style(black): format code") -def pyright_scan() -> None: - run_shell("pyright --outputjson > pyright.json || true", check=False) +def pyright_scan() -> list[dict[str, Any]]: + out = run_shell("pyright --outputjson", capture=True, check=False, quiet=True) or "" + try: + data: Any = json.loads(out) if out else {} + except ValueError: + return [] + + if not isinstance(data, dict): + return [] + data_d = cast(dict[str, Any], data) + diags = data_d.get("generalDiagnostics") + return _as_dict_list(diags) def _as_dict_list(value: Any) -> list[dict[str, Any]]: @@ -312,7 +338,12 @@ def summarize(ruff: list[dict[str, Any]], pyright: list[dict[str, Any]]) -> str: if ruff: msg += "## Ruff issues\n" for r in ruff[:15]: - filename = r.get("filename") + filename_raw = r.get("filename") + filename = _diagnostic_repo_path(filename_raw) or ( + filename_raw if isinstance(filename_raw, str) else None + ) + if isinstance(filename, str): + filename = _redact_local_paths(filename) loc = r.get("location") row = cast(dict[str, Any], loc).get("row") if isinstance(loc, dict) else None message = r.get("message") @@ -322,7 +353,12 @@ def summarize(ruff: list[dict[str, Any]], pyright: list[dict[str, Any]]) -> str: if pyright: msg += "\n## Pyright issues\n" for p in pyright[:15]: - file_value = p.get("file") + file_value_raw = p.get("file") + file_value = _diagnostic_repo_path(file_value_raw) or ( + file_value_raw if isinstance(file_value_raw, str) else None + ) + if isinstance(file_value, str): + file_value = _redact_local_paths(file_value) message = p.get("message") if isinstance(file_value, str) and isinstance(message, str): msg += f"- {file_value} {message}\n" @@ -347,7 +383,7 @@ def load_config() -> Config: if not private_key_path.is_absolute(): private_key_path = (REPO_ROOT / private_key_path).resolve() if not private_key_path.is_file(): - raise RuntimeError(f"Private key not found: {private_key_path}") + raise RuntimeError(f"Private key not found: {_redact_local_paths(str(private_key_path))}") return Config( repo=repo, @@ -366,9 +402,9 @@ def main() -> None: git_setup() branch = create_branch() - ruff_fix() + ruff = ruff_fix() black_fix() - pyright_scan() + pyright = pyright_scan() push(token=token, repo=config.repo, branch=branch) @@ -379,9 +415,6 @@ def main() -> None: add_reviewer(token=token, repo=config.repo, pr_number=pr_number, reviewer=config.reviewer) - ruff = ruff_errors() - pyright = pyright_errors() - comment_pr(token=token, repo=config.repo, pr_number=pr_number, text=summarize(ruff, pyright)) if pyright: @@ -391,4 +424,8 @@ def main() -> None: if __name__ == "__main__": - main() + try: + main() + except Exception as exc: + print(_redact_local_paths(str(exc)), file=sys.stderr) + raise SystemExit(1) from None From 3f44bbbd3f707353dd94fe8274d60f4dd2fbb70a Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Mon, 16 Mar 2026 18:39:29 +0300 Subject: [PATCH 32/36] feat: add cleanup artifacts --- scripts/reekeerBot.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py index fd62f53..e68db26 100644 --- a/scripts/reekeerBot.py +++ b/scripts/reekeerBot.py @@ -1,3 +1,4 @@ +import contextlib import datetime import json import os @@ -155,6 +156,18 @@ def push(*, token: str, repo: str, branch: str) -> None: git(["push", f"https://x-access-token:{token}@github.com/{repo}.git", branch]) +def _cleanup_artifacts() -> None: + for name in ("ruff.json", "pyright.json"): + with contextlib.suppress(FileNotFoundError): + (REPO_ROOT / name).unlink() + + run_cmd( + ["git", "rm", "-f", "--ignore-unmatch", "ruff.json", "pyright.json"], + check=False, + quiet=True, + ) + + def ruff_fix() -> list[dict[str, Any]]: run_shell("ruff check . --fix", check=False, quiet=True) commit_if_changes("style(ruff): auto-fix lint issues") @@ -402,10 +415,14 @@ def main() -> None: git_setup() branch = create_branch() + _cleanup_artifacts() + ruff = ruff_fix() black_fix() pyright = pyright_scan() + _cleanup_artifacts() + push(token=token, repo=config.repo, branch=branch) pr = create_pr(token=token, repo=config.repo, branch=branch, base_branch=config.base_branch) From 0fed5e5fcce23a6a9c971e947ccb66cc2099be13 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Mon, 16 Mar 2026 18:49:03 +0300 Subject: [PATCH 33/36] feat: add workflow to run reekeerBot --- .github/workflows/reekeer-bot.yml | 50 +++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 .github/workflows/reekeer-bot.yml diff --git a/.github/workflows/reekeer-bot.yml b/.github/workflows/reekeer-bot.yml new file mode 100644 index 0000000..d7f2098 --- /dev/null +++ b/.github/workflows/reekeer-bot.yml @@ -0,0 +1,50 @@ +name: reekeerBot + +on: + workflow_dispatch: + push: + branches: + - dev + pull_request: + +jobs: + bot: + + runs-on: ubuntu-latest + + permissions: + contents: write + pull-requests: write + issues: write + + steps: + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install requests pyjwt cryptography python-dotenv + pip install ruff black pyright + + - name: Create GitHub App private key + run: | + echo "${{ secrets.APP_PRIVATE_KEY }}" > private-key.pem + + - name: Run reekeerBot + env: + APP_ID: ${{ secrets.APP_ID }} + INSTALLATION_ID: ${{ secrets.INSTALLATION_ID }} + PRIVATE_KEY_PATH: private-key.pem + + GITHUB_REPOSITORY: ${{ github.repository }} + BASE_BRANCH: ${{ secrets.BASE_BRANCH }} + REVIEWER: ${{ secrets.REVIEWER }} + + run: | + python scripts/reekeerBot.py \ No newline at end of file From 167843aee9088c6f35f547a2a69b41087373b8b0 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Mon, 16 Mar 2026 19:02:54 +0300 Subject: [PATCH 34/36] fix: bug with incorrect environments specification --- .github/workflows/reekeer-bot.yml | 51 +++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/workflows/reekeer-bot.yml diff --git a/.github/workflows/reekeer-bot.yml b/.github/workflows/reekeer-bot.yml new file mode 100644 index 0000000..bbd62d0 --- /dev/null +++ b/.github/workflows/reekeer-bot.yml @@ -0,0 +1,51 @@ +name: reekeerBot + +on: + workflow_dispatch: + push: + branches: + - dev + pull_request: + +jobs: + bot: + + runs-on: ubuntu-latest + environment: reekeerBot + + permissions: + contents: write + pull-requests: write + issues: write + + steps: + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install requests pyjwt cryptography python-dotenv + pip install ruff black pyright + + - name: Create GitHub App private key + run: | + echo "${{ secrets.APP_PRIVATE_KEY }}" > private-key.pem + + - name: Run reekeerBot + env: + APP_ID: ${{ secrets.APP_ID }} + INSTALLATION_ID: ${{ secrets.INSTALLATION_ID }} + PRIVATE_KEY_PATH: private-key.pem + + GITHUB_REPOSITORY: ${{ github.repository }} + BASE_BRANCH: ${{ vars.BASE_BRANCH }} + REVIEWER: ${{ vars.REVIEWER }} + + run: | + python scripts/reekeerBot.py \ No newline at end of file From 0b7c5fa16fc6b3b16560dfa60e6cddb82323e964 Mon Sep 17 00:00:00 2001 From: "reekeer[bot]" Date: Mon, 16 Mar 2026 16:05:38 +0000 Subject: [PATCH 35/36] style(black): format code --- scripts/reekeerBot.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/reekeerBot.py b/scripts/reekeerBot.py index e68db26..cbb2bf4 100644 --- a/scripts/reekeerBot.py +++ b/scripts/reekeerBot.py @@ -93,7 +93,9 @@ def create_jwt(*, app_id: int, private_key_path: Path) -> str: # PyJWT expects "iss" to be a string. "iss": str(app_id), } - return jwt.encode(payload, private_key, algorithm="RS256") # pyright: ignore[reportUnknownMemberType] + return jwt.encode( + payload, private_key, algorithm="RS256" + ) # pyright: ignore[reportUnknownMemberType] def installation_token(config: Config) -> str: @@ -166,7 +168,7 @@ def _cleanup_artifacts() -> None: check=False, quiet=True, ) - + def ruff_fix() -> list[dict[str, Any]]: run_shell("ruff check . --fix", check=False, quiet=True) @@ -416,13 +418,13 @@ def main() -> None: branch = create_branch() _cleanup_artifacts() - + ruff = ruff_fix() black_fix() pyright = pyright_scan() _cleanup_artifacts() - + push(token=token, repo=config.repo, branch=branch) pr = create_pr(token=token, repo=config.repo, branch=branch, base_branch=config.base_branch) From de658752103ef15202ded7dfe0307a86d0898397 Mon Sep 17 00:00:00 2001 From: IMDelewer Date: Mon, 16 Mar 2026 19:07:46 +0300 Subject: [PATCH 36/36] fix: hotfix bug with infinity runs --- .github/workflows/reekeer-bot.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/reekeer-bot.yml b/.github/workflows/reekeer-bot.yml index bbd62d0..9fbb36e 100644 --- a/.github/workflows/reekeer-bot.yml +++ b/.github/workflows/reekeer-bot.yml @@ -5,14 +5,13 @@ on: push: branches: - dev - pull_request: jobs: bot: runs-on: ubuntu-latest environment: reekeerBot - + permissions: contents: write pull-requests: write