• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Manages diagnostics messages such as semantic errors."""
17
18from abc import ABC, abstractmethod
19from collections.abc import Callable, Iterable
20from contextlib import contextmanager
21from dataclasses import dataclass, field
22from enum import IntEnum
23from sys import stderr
24from typing import (
25    ClassVar,
26    TextIO,
27    TypeVar,
28)
29
30from typing_extensions import override
31
32from taihe.utils.logging import AnsiStyle, should_use_color
33from taihe.utils.sources import SourceLocation
34
35T = TypeVar("T")
36
37
38def _passthrough(x: str) -> str:
39    return x
40
41
42def _discard(x: str) -> str:
43    del x
44    return ""
45
46
47FilterT = Callable[[str], str]
48
49
50###################
51# The Basic Types #
52###################
53
54
55class Severity(IntEnum):
56    NOTE = 0
57    WARN = 1
58    ERROR = 2
59    FATAL = 3
60
61
62@dataclass
63class DiagBase(ABC):
64    """The base class for diagnostic messages."""
65
66    SEVERITY: ClassVar[Severity]
67    SEVERITY_DESC: ClassVar[str]
68    STYLE: ClassVar[str]
69
70    loc: SourceLocation | None = field(kw_only=True)
71    """The source location where the diagnostic refers to."""
72
73    def __str__(self) -> str:
74        return self.format_message(_discard)
75
76    @abstractmethod
77    def describe(self) -> str:
78        """Concise, human-readable description of the diagnostic message.
79
80        Subclasses must implement this method to explain the specific issue.
81
82        Example: "redefinition of ..."
83        """
84
85    def notes(self) -> Iterable["DiagNote"]:
86        """Returns an iterable of associated diagnostic notes.
87
88        Notes provide additional context or suggestions related to the main diagnostic.
89        By default, a diagnostic has no associated notes.
90        """
91        return ()
92
93    def format_message(self, f: FilterT) -> str:
94        """Formats the diagnostic message, optionally applying ANSI styling.
95
96        Args:
97            f: A filter for ANSI codes applied to parts of the string for styling.
98
99        Returns:
100            A string representing the formatted diagnostic message,
101            including location, severity, and description.
102
103        Example:
104            "example.taihe:7:20: error: redefinition of ..."
105        """
106        return (
107            f"{f(AnsiStyle.BRIGHT)}{self.loc or '???'}: "  # "example.taihe:7:20: "
108            f"{f(self.STYLE)}{self.SEVERITY_DESC}{f(AnsiStyle.RESET)}: "  # "error: "
109            f"{self.describe()}{f(AnsiStyle.RESET_ALL)}"  # "redefinition of ..."
110        )
111
112
113##########################################
114# Base classes with different severities #
115##########################################
116
117
118@dataclass
119class DiagNote(DiagBase):
120    SEVERITY = Severity.NOTE
121    SEVERITY_DESC = "note"
122    STYLE = AnsiStyle.CYAN
123
124
125@dataclass
126class DiagWarn(DiagBase):
127    SEVERITY = Severity.WARN
128    SEVERITY_DESC = "warning"
129    STYLE = AnsiStyle.MAGENTA
130
131
132@dataclass
133class DiagError(DiagBase, Exception):
134    SEVERITY = Severity.ERROR
135    SEVERITY_DESC = "error"
136    STYLE = AnsiStyle.RED
137
138
139@dataclass
140class DiagFatalError(DiagError):
141    SEVERITY = Severity.FATAL
142    SEVERITY_DESC = "fatal"
143
144
145########################
146
147
148class DiagnosticsManager(ABC):
149    _max_severity_seen: Severity = Severity.NOTE
150
151    def has_reached_severity(self, severity: Severity):
152        return self._max_severity_seen >= severity
153
154    @property
155    def has_error(self):
156        return self.has_reached_severity(Severity.ERROR)
157
158    @property
159    def has_fatal_error(self):
160        return self.has_reached_severity(Severity.FATAL)
161
162    def reset_severity(self):
163        """Resets the current maximum diagnostic severity."""
164        self._max_severity_seen = Severity.NOTE
165
166    @abstractmethod
167    def emit(self, diag: DiagBase) -> None:
168        """Emits a new diagnostic message, don't forget to call it in subclasses."""
169        self._max_severity_seen = max(self._max_severity_seen, diag.SEVERITY)
170
171    @contextmanager
172    def capture_error(self):
173        """Captures "error" and "fatal" diagnostics using context manager.
174
175        Example:
176        ```
177        # Emit the error and prevent its propogation
178        with diag_mgr.capture_error():
179            foo();
180            raise DiagError(...)
181            bar();
182
183        # Equivalent to:
184        try:
185            foo();
186            raise DiagError(...)
187            bar();
188        except DiagError as e:
189            diag_mgr.emit(e)
190        ```
191        """
192        try:
193            yield None
194        except DiagError as e:
195            self.emit(e)
196
197    def for_each(self, xs: Iterable[T], cb: Callable[[T], bool | None]) -> bool:
198        """Calls `cb` for each element. Records and recovers from `DiagError`s.
199
200        Returns `True` if no errors are encountered.
201        """
202        no_error = True
203        for x in xs:
204            try:
205                if cb(x):
206                    return True
207            except DiagError as e:
208                self.emit(e)
209                no_error = False
210        return no_error
211
212
213class ConsoleDiagnosticsManager(DiagnosticsManager):
214    """Manages diagnostic messages."""
215
216    def __init__(self, out: TextIO = stderr):
217        self._out = out
218        if should_use_color(self._out):
219            self._color_filter_fn = _passthrough
220        else:
221            self._color_filter_fn = _discard
222
223    @override
224    def emit(self, diag: DiagBase) -> None:
225        """Emits a new diagnostic message."""
226        super().emit(diag)
227        self._render(diag)
228        for n in diag.notes():
229            self._render(n)
230        stderr.flush()
231
232    def _write(self, s: str):
233        self._out.write(s)
234
235    def _flush(self):
236        self._out.flush()
237
238    def _render_source_location(self, loc: SourceLocation):
239        MAX_LINE_NO_SPACE = 5
240        if not loc.span:
241            return
242
243        line_contents = loc.file.read().splitlines()
244
245        if loc.span.start.row < 1 or loc.span.stop.row > len(line_contents):
246            return
247
248        for line, line_content in enumerate(line_contents, 1):
249            if line < loc.span.start.row or line > loc.span.stop.row:
250                continue
251
252            markers = "".join(
253                (
254                    " "
255                    if (line == loc.span.start.row and col < loc.span.start.col)
256                    or (line == loc.span.stop.row and col > loc.span.stop.col)
257                    else "^"
258                )
259                for col in range(1, len(line_content) + 1)
260            )
261
262            f = self._color_filter_fn
263
264            self._write(
265                f"{line:{MAX_LINE_NO_SPACE}} | {line_content}\n"
266                f"{'':{MAX_LINE_NO_SPACE}} | {f(AnsiStyle.GREEN + AnsiStyle.BRIGHT)}{markers}{f(AnsiStyle.RESET_ALL)}\n"
267            )
268
269    def _render(self, d: DiagBase):
270        self._write(f"{d.format_message(self._color_filter_fn)}\n")
271        if d.loc:
272            self._render_source_location(d.loc)