• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Support for documenting audit events."""
2
3from __future__ import annotations
4
5import re
6from typing import TYPE_CHECKING
7
8from docutils import nodes
9from sphinx.errors import NoUri
10from sphinx.locale import _ as sphinx_gettext
11from sphinx.transforms.post_transforms import SphinxPostTransform
12from sphinx.util import logging
13from sphinx.util.docutils import SphinxDirective
14
15if TYPE_CHECKING:
16    from collections.abc import Iterator
17
18    from sphinx.application import Sphinx
19    from sphinx.builders import Builder
20    from sphinx.environment import BuildEnvironment
21
22logger = logging.getLogger(__name__)
23
24# This list of sets are allowable synonyms for event argument names.
25# If two names are in the same set, they are treated as equal for the
26# purposes of warning. This won't help if the number of arguments is
27# different!
28_SYNONYMS = [
29    frozenset({"file", "path", "fd"}),
30]
31
32
33class AuditEvents:
34    def __init__(self) -> None:
35        self.events: dict[str, list[str]] = {}
36        self.sources: dict[str, list[tuple[str, str]]] = {}
37
38    def __iter__(self) -> Iterator[tuple[str, list[str], tuple[str, str]]]:
39        for name, args in self.events.items():
40            for source in self.sources[name]:
41                yield name, args, source
42
43    def add_event(
44        self, name, args: list[str], source: tuple[str, str]
45    ) -> None:
46        if name in self.events:
47            self._check_args_match(name, args)
48        else:
49            self.events[name] = args
50        self.sources.setdefault(name, []).append(source)
51
52    def _check_args_match(self, name: str, args: list[str]) -> None:
53        current_args = self.events[name]
54        msg = (
55            f"Mismatched arguments for audit-event {name}: "
56            f"{current_args!r} != {args!r}"
57        )
58        if current_args == args:
59            return
60        if len(current_args) != len(args):
61            logger.warning(msg)
62            return
63        for a1, a2 in zip(current_args, args, strict=False):
64            if a1 == a2:
65                continue
66            if any(a1 in s and a2 in s for s in _SYNONYMS):
67                continue
68            logger.warning(msg)
69            return
70
71    def id_for(self, name) -> str:
72        source_count = len(self.sources.get(name, ()))
73        name_clean = re.sub(r"\W", "_", name)
74        return f"audit_event_{name_clean}_{source_count}"
75
76    def rows(self) -> Iterator[tuple[str, list[str], list[tuple[str, str]]]]:
77        for name in sorted(self.events.keys()):
78            yield name, self.events[name], self.sources[name]
79
80
81def initialise_audit_events(app: Sphinx) -> None:
82    """Initialise the audit_events attribute on the environment."""
83    if not hasattr(app.env, "audit_events"):
84        app.env.audit_events = AuditEvents()
85
86
87def audit_events_purge(
88    app: Sphinx, env: BuildEnvironment, docname: str
89) -> None:
90    """This is to remove traces of removed documents from env.audit_events."""
91    fresh_audit_events = AuditEvents()
92    for name, args, (doc, target) in env.audit_events:
93        if doc != docname:
94            fresh_audit_events.add_event(name, args, (doc, target))
95
96
97def audit_events_merge(
98    app: Sphinx,
99    env: BuildEnvironment,
100    docnames: list[str],
101    other: BuildEnvironment,
102) -> None:
103    """In Sphinx parallel builds, this merges audit_events from subprocesses."""
104    for name, args, source in other.audit_events:
105        env.audit_events.add_event(name, args, source)
106
107
108class AuditEvent(SphinxDirective):
109    has_content = True
110    required_arguments = 1
111    optional_arguments = 2
112    final_argument_whitespace = True
113
114    _label = [
115        sphinx_gettext(
116            "Raises an :ref:`auditing event <auditing>` "
117            "{name} with no arguments."
118        ),
119        sphinx_gettext(
120            "Raises an :ref:`auditing event <auditing>` "
121            "{name} with argument {args}."
122        ),
123        sphinx_gettext(
124            "Raises an :ref:`auditing event <auditing>` "
125            "{name} with arguments {args}."
126        ),
127    ]
128
129    def run(self) -> list[nodes.paragraph]:
130        name = self.arguments[0]
131        if len(self.arguments) >= 2 and self.arguments[1]:
132            args = [
133                arg
134                for argument in self.arguments[1].strip("'\"").split(",")
135                if (arg := argument.strip())
136            ]
137        else:
138            args = []
139        ids = []
140        try:
141            target = self.arguments[2].strip("\"'")
142        except (IndexError, TypeError):
143            target = None
144        if not target:
145            target = self.env.audit_events.id_for(name)
146            ids.append(target)
147        self.env.audit_events.add_event(name, args, (self.env.docname, target))
148
149        node = nodes.paragraph("", classes=["audit-hook"], ids=ids)
150        self.set_source_info(node)
151        if self.content:
152            node.rawsource = '\n'.join(self.content)  # for gettext
153            self.state.nested_parse(self.content, self.content_offset, node)
154        else:
155            num_args = min(2, len(args))
156            text = self._label[num_args].format(
157                name=f"``{name}``",
158                args=", ".join(f"``{a}``" for a in args),
159            )
160            node.rawsource = text  # for gettext
161            parsed, messages = self.state.inline_text(text, self.lineno)
162            node += parsed
163            node += messages
164        return [node]
165
166
167class audit_event_list(nodes.General, nodes.Element):  # noqa: N801
168    pass
169
170
171class AuditEventListDirective(SphinxDirective):
172    def run(self) -> list[audit_event_list]:
173        return [audit_event_list()]
174
175
176class AuditEventListTransform(SphinxPostTransform):
177    default_priority = 500
178
179    def run(self) -> None:
180        if self.document.next_node(audit_event_list) is None:
181            return
182
183        table = self._make_table(self.app.builder, self.env.docname)
184        for node in self.document.findall(audit_event_list):
185            node.replace_self(table)
186
187    def _make_table(self, builder: Builder, docname: str) -> nodes.table:
188        table = nodes.table(cols=3)
189        group = nodes.tgroup(
190            "",
191            nodes.colspec(colwidth=30),
192            nodes.colspec(colwidth=55),
193            nodes.colspec(colwidth=15),
194            cols=3,
195        )
196        head = nodes.thead()
197        body = nodes.tbody()
198
199        table += group
200        group += head
201        group += body
202
203        head += nodes.row(
204            "",
205            nodes.entry("", nodes.paragraph("", "Audit event")),
206            nodes.entry("", nodes.paragraph("", "Arguments")),
207            nodes.entry("", nodes.paragraph("", "References")),
208        )
209
210        for name, args, sources in builder.env.audit_events.rows():
211            body += self._make_row(builder, docname, name, args, sources)
212
213        return table
214
215    @staticmethod
216    def _make_row(
217        builder: Builder,
218        docname: str,
219        name: str,
220        args: list[str],
221        sources: list[tuple[str, str]],
222    ) -> nodes.row:
223        row = nodes.row()
224        name_node = nodes.paragraph("", nodes.Text(name))
225        row += nodes.entry("", name_node)
226
227        args_node = nodes.paragraph()
228        for arg in args:
229            args_node += nodes.literal(arg, arg)
230            args_node += nodes.Text(", ")
231        if len(args_node.children) > 0:
232            args_node.children.pop()  # remove trailing comma
233        row += nodes.entry("", args_node)
234
235        backlinks_node = nodes.paragraph()
236        backlinks = enumerate(sorted(set(sources)), start=1)
237        for i, (doc, label) in backlinks:
238            if isinstance(label, str):
239                ref = nodes.reference("", f"[{i}]", internal=True)
240                try:
241                    target = (
242                        f"{builder.get_relative_uri(docname, doc)}#{label}"
243                    )
244                except NoUri:
245                    continue
246                else:
247                    ref["refuri"] = target
248                    backlinks_node += ref
249        row += nodes.entry("", backlinks_node)
250        return row
251
252
253def setup(app: Sphinx):
254    app.add_directive("audit-event", AuditEvent)
255    app.add_directive("audit-event-table", AuditEventListDirective)
256    app.add_post_transform(AuditEventListTransform)
257    app.connect("builder-inited", initialise_audit_events)
258    app.connect("env-purge-doc", audit_events_purge)
259    app.connect("env-merge-info", audit_events_merge)
260    return {
261        "version": "1.0",
262        "parallel_read_safe": True,
263        "parallel_write_safe": True,
264    }
265