1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Manages diagnostics messages such as semantic errors.""" 17 18from abc import ABC, abstractmethod 19from collections.abc import Callable, Iterable 20from contextlib import contextmanager 21from dataclasses import dataclass, field 22from enum import IntEnum 23from sys import stderr 24from typing import ( 25 ClassVar, 26 TextIO, 27 TypeVar, 28) 29 30from typing_extensions import override 31 32from taihe.utils.logging import AnsiStyle, should_use_color 33from taihe.utils.sources import SourceLocation 34 35T = TypeVar("T") 36 37 38def _passthrough(x: str) -> str: 39 return x 40 41 42def _discard(x: str) -> str: 43 del x 44 return "" 45 46 47FilterT = Callable[[str], str] 48 49 50################### 51# The Basic Types # 52################### 53 54 55class Severity(IntEnum): 56 NOTE = 0 57 WARN = 1 58 ERROR = 2 59 FATAL = 3 60 61 62@dataclass 63class DiagBase(ABC): 64 """The base class for diagnostic messages.""" 65 66 SEVERITY: ClassVar[Severity] 67 SEVERITY_DESC: ClassVar[str] 68 STYLE: ClassVar[str] 69 70 loc: SourceLocation | None = field(kw_only=True) 71 """The source location where the diagnostic refers to.""" 72 73 def __str__(self) -> str: 74 return self.format_message(_discard) 75 76 @abstractmethod 77 def describe(self) -> str: 78 """Concise, human-readable description of the diagnostic message. 79 80 Subclasses must implement this method to explain the specific issue. 81 82 Example: "redefinition of ..." 83 """ 84 85 def notes(self) -> Iterable["DiagNote"]: 86 """Returns an iterable of associated diagnostic notes. 87 88 Notes provide additional context or suggestions related to the main diagnostic. 89 By default, a diagnostic has no associated notes. 90 """ 91 return () 92 93 def format_message(self, f: FilterT) -> str: 94 """Formats the diagnostic message, optionally applying ANSI styling. 95 96 Args: 97 f: A filter for ANSI codes applied to parts of the string for styling. 98 99 Returns: 100 A string representing the formatted diagnostic message, 101 including location, severity, and description. 102 103 Example: 104 "example.taihe:7:20: error: redefinition of ..." 105 """ 106 return ( 107 f"{f(AnsiStyle.BRIGHT)}{self.loc or '???'}: " # "example.taihe:7:20: " 108 f"{f(self.STYLE)}{self.SEVERITY_DESC}{f(AnsiStyle.RESET)}: " # "error: " 109 f"{self.describe()}{f(AnsiStyle.RESET_ALL)}" # "redefinition of ..." 110 ) 111 112 113########################################## 114# Base classes with different severities # 115########################################## 116 117 118@dataclass 119class DiagNote(DiagBase): 120 SEVERITY = Severity.NOTE 121 SEVERITY_DESC = "note" 122 STYLE = AnsiStyle.CYAN 123 124 125@dataclass 126class DiagWarn(DiagBase): 127 SEVERITY = Severity.WARN 128 SEVERITY_DESC = "warning" 129 STYLE = AnsiStyle.MAGENTA 130 131 132@dataclass 133class DiagError(DiagBase, Exception): 134 SEVERITY = Severity.ERROR 135 SEVERITY_DESC = "error" 136 STYLE = AnsiStyle.RED 137 138 139@dataclass 140class DiagFatalError(DiagError): 141 SEVERITY = Severity.FATAL 142 SEVERITY_DESC = "fatal" 143 144 145######################## 146 147 148class DiagnosticsManager(ABC): 149 _max_severity_seen: Severity = Severity.NOTE 150 151 def has_reached_severity(self, severity: Severity): 152 return self._max_severity_seen >= severity 153 154 @property 155 def has_error(self): 156 return self.has_reached_severity(Severity.ERROR) 157 158 @property 159 def has_fatal_error(self): 160 return self.has_reached_severity(Severity.FATAL) 161 162 def reset_severity(self): 163 """Resets the current maximum diagnostic severity.""" 164 self._max_severity_seen = Severity.NOTE 165 166 @abstractmethod 167 def emit(self, diag: DiagBase) -> None: 168 """Emits a new diagnostic message, don't forget to call it in subclasses.""" 169 self._max_severity_seen = max(self._max_severity_seen, diag.SEVERITY) 170 171 @contextmanager 172 def capture_error(self): 173 """Captures "error" and "fatal" diagnostics using context manager. 174 175 Example: 176 ``` 177 # Emit the error and prevent its propogation 178 with diag_mgr.capture_error(): 179 foo(); 180 raise DiagError(...) 181 bar(); 182 183 # Equivalent to: 184 try: 185 foo(); 186 raise DiagError(...) 187 bar(); 188 except DiagError as e: 189 diag_mgr.emit(e) 190 ``` 191 """ 192 try: 193 yield None 194 except DiagError as e: 195 self.emit(e) 196 197 def for_each(self, xs: Iterable[T], cb: Callable[[T], bool | None]) -> bool: 198 """Calls `cb` for each element. Records and recovers from `DiagError`s. 199 200 Returns `True` if no errors are encountered. 201 """ 202 no_error = True 203 for x in xs: 204 try: 205 if cb(x): 206 return True 207 except DiagError as e: 208 self.emit(e) 209 no_error = False 210 return no_error 211 212 213class ConsoleDiagnosticsManager(DiagnosticsManager): 214 """Manages diagnostic messages.""" 215 216 def __init__(self, out: TextIO = stderr): 217 self._out = out 218 if should_use_color(self._out): 219 self._color_filter_fn = _passthrough 220 else: 221 self._color_filter_fn = _discard 222 223 @override 224 def emit(self, diag: DiagBase) -> None: 225 """Emits a new diagnostic message.""" 226 super().emit(diag) 227 self._render(diag) 228 for n in diag.notes(): 229 self._render(n) 230 stderr.flush() 231 232 def _write(self, s: str): 233 self._out.write(s) 234 235 def _flush(self): 236 self._out.flush() 237 238 def _render_source_location(self, loc: SourceLocation): 239 MAX_LINE_NO_SPACE = 5 240 if not loc.span: 241 return 242 243 line_contents = loc.file.read().splitlines() 244 245 if loc.span.start.row < 1 or loc.span.stop.row > len(line_contents): 246 return 247 248 for line, line_content in enumerate(line_contents, 1): 249 if line < loc.span.start.row or line > loc.span.stop.row: 250 continue 251 252 markers = "".join( 253 ( 254 " " 255 if (line == loc.span.start.row and col < loc.span.start.col) 256 or (line == loc.span.stop.row and col > loc.span.stop.col) 257 else "^" 258 ) 259 for col in range(1, len(line_content) + 1) 260 ) 261 262 f = self._color_filter_fn 263 264 self._write( 265 f"{line:{MAX_LINE_NO_SPACE}} | {line_content}\n" 266 f"{'':{MAX_LINE_NO_SPACE}} | {f(AnsiStyle.GREEN + AnsiStyle.BRIGHT)}{markers}{f(AnsiStyle.RESET_ALL)}\n" 267 ) 268 269 def _render(self, d: DiagBase): 270 self._write(f"{d.format_message(self._color_filter_fn)}\n") 271 if d.loc: 272 self._render_source_location(d.loc)