1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Manage output files.""" 17 18import sys 19from collections import defaultdict 20from collections.abc import Generator 21from contextlib import contextmanager 22from dataclasses import dataclass 23from enum import Enum, auto 24from io import StringIO 25from os import path, sep 26from pathlib import Path 27from types import FrameType, TracebackType 28from typing import TYPE_CHECKING, TextIO 29 30from typing_extensions import Self, override 31 32if TYPE_CHECKING: 33 from taihe.driver.contexts import CompilerInstance 34 35DEFAULT_INDENT = " " # Four spaces 36 37 38class DebugLevel(Enum): 39 """Controls the code-generator debug info. 40 41 When enabled, the generated code would contain comments, representing the 42 location of Python code which generates. 43 """ 44 45 NONE = auto() 46 """Don't print any debug info.""" 47 CONCISE = auto() 48 """Prints function and line number.""" 49 VERBOSE = auto() 50 """Besides CONSICE, also prints code snippet. Could be slow.""" 51 52 53class FileKind(str, Enum): 54 C_HEADER = "c_header" 55 C_SOURCE = "c_source" 56 CPP_HEADER = "cpp_header" 57 CPP_SOURCE = "cpp_source" 58 TEMPLATE = "template" 59 ETS = "ets" 60 OTHER = "other" 61 62 63@dataclass 64class FileDescriptor: 65 relative_path: str # e.g., "include/foo.h" 66 kind: FileKind 67 68 69class BaseWriter: 70 def __init__( 71 self, 72 out: TextIO, 73 *, 74 comment_prefix: str, 75 default_indent: str, 76 debug_level: DebugLevel = DebugLevel.NONE, 77 ): 78 """Initialize a code writer with a writable output stream. 79 80 Args: 81 out: A writable stream object 82 comment_prefix: The prefix for line-comment, for instance, "// " for C++ 83 default_indent: The default indentation string for each level of indentation 84 debug_level: see `DebugLevel` for details 85 """ 86 self._out = out 87 self._default_indent = default_indent 88 self._current_indent = "" 89 self._debug_level = debug_level 90 self._comment_prefix = comment_prefix 91 92 def newline(self): 93 """Writes a newline character.""" 94 self._out.write("\n") 95 96 def writeln( 97 self, 98 line: str = "", 99 ): 100 """Writes a single-line string. 101 102 Args: 103 line: The line to write (must not contain newlines) 104 """ 105 pass 106 if not line: 107 # Don't use indent for empty lines 108 self._out.write("\n") 109 return 110 111 self._out.write(self._current_indent) 112 self._out.write(line) 113 self._out.write("\n") 114 115 def writelns(self, *lines: str): 116 """Writes multiple one-line strings. 117 118 Args: 119 *lines: One or more lines to write 120 """ 121 self._write_debug(skip=2) 122 for line in lines: 123 self.writeln( 124 line, 125 ) 126 127 def write_block(self, text_block: str): 128 """Writes a potentially multi-line text block. 129 130 Args: 131 text_block: The block of text to write 132 """ 133 self.writelns(*text_block.splitlines()) 134 135 def write_comment(self, comment: str): 136 """Writes a comment block, prefixing each line with the comment prefix. 137 138 Indents the comment block according to the current indentation level. 139 Handles multi-line comments by splitting the input string. 140 141 Args: 142 comment: The comment text to write. Can be multi-line. 143 """ 144 for line in comment.splitlines(): 145 self._out.write(self._current_indent) 146 self._out.write(self._comment_prefix) 147 self._out.write(line) 148 self._out.write("\n") 149 150 @contextmanager 151 def indented( 152 self, 153 prologue: str | None, 154 epilogue: str | None, 155 /, 156 *, 157 indent: str | None = None, 158 ) -> Generator[Self, None, None]: 159 """Context manager that indents code within its scope. 160 161 Args: 162 prologue: Optional text to write before indentation 163 epilogue: Optional text to write after indentation 164 indent: Optional string to use for indentation (overrides default) 165 166 Returns: 167 A context manager that yields this BaseWriter 168 """ 169 self._write_debug(skip=3) 170 if prologue is not None: 171 self.writeln( 172 prologue, 173 ) 174 previous_indent = self._current_indent 175 self._current_indent += self._default_indent if indent is None else indent 176 try: 177 yield self 178 finally: 179 self._current_indent = previous_indent 180 if epilogue is not None: 181 self.writeln( 182 epilogue, 183 ) 184 185 def _write_debug(self, *, skip: int): 186 if self._debug_level == DebugLevel.NONE: 187 return 188 self.write_comment(_format_frame(sys._getframe(skip))) # type: ignore 189 190 191class FileWriter(BaseWriter): 192 def __init__( 193 self, 194 om: "OutputManager", 195 relative_path: str, 196 file_kind: FileKind, 197 *, 198 default_indent: str = DEFAULT_INDENT, 199 comment_prefix: str, 200 ): 201 super().__init__( 202 out=StringIO(), 203 default_indent=default_indent, 204 comment_prefix=comment_prefix, 205 debug_level=om.debug_level, 206 ) 207 self._om = om 208 self.desc = FileDescriptor( 209 relative_path=relative_path, 210 kind=file_kind, 211 ) 212 213 def __enter__(self): 214 return self 215 216 def __exit__( 217 self, 218 exc_type: type[BaseException] | None, 219 exc_val: BaseException | None, 220 exc_tb: TracebackType | None, 221 ) -> bool: 222 del exc_val, exc_tb, exc_type 223 self._om.save(self) 224 return False 225 226 def write_body(self, f: TextIO): 227 pass 228 f.write(self._out.getvalue()) 229 230 def write_prologue(self, f: TextIO): 231 del f 232 233 def write_epilogue(self, f: TextIO): 234 del f 235 236 237def _format_frame(f: FrameType) -> str: 238 # For /a/b/c/d/e.py, only keep FILENAME_KEEP directories, resulting "c/d/e.py" 239 FILENAME_KEEP = 3 240 241 file_name = f.f_code.co_filename 242 parts = file_name.split(sep) 243 if len(parts) > FILENAME_KEEP: 244 file_name = path.join(*parts[-FILENAME_KEEP:]) 245 246 base_format = f"CODEGEN-DEBUG: {f.f_code.co_name} in {file_name}:{f.f_lineno}" 247 248 return base_format 249 250 251@dataclass 252class OutputConfig: 253 dst_dir: Path | None = None 254 debug_level: DebugLevel = DebugLevel.NONE 255 256 def construct(self, ci: "CompilerInstance") -> "OutputManager": 257 """Construct an OutputManager based on this configuration.""" 258 return OutputManager( 259 dst_dir=self.dst_dir, 260 debug_level=self.debug_level, 261 ) 262 263 264class OutputManager: 265 """Manages the creation and saving of output files.""" 266 267 files: dict[str, FileDescriptor] 268 files_by_kind: dict[FileKind, list[FileDescriptor]] 269 270 dst_dir: Path | None 271 272 debug_level: DebugLevel 273 274 def __init__( 275 self, 276 dst_dir: Path | None = None, 277 debug_level: DebugLevel = DebugLevel.NONE, 278 ): 279 self.files: dict[str, FileDescriptor] = {} 280 self.files_by_kind: dict[FileKind, list[FileDescriptor]] = defaultdict(list) 281 self.dst_dir = dst_dir 282 self.debug_level = debug_level 283 284 def register(self, desc: FileDescriptor): 285 if (prev := self.files.setdefault(desc.relative_path, desc)) != desc: 286 raise ValueError( 287 f"File {desc.relative_path} is already registered as {prev.kind}, " 288 f"cannot re-register with {desc.kind}." 289 ) 290 self.files_by_kind[desc.kind].append(desc) 291 292 def save(self, writer: FileWriter): 293 """Saves the content of a FileWriter to the output directory.""" 294 self.register(writer.desc) 295 296 if self.dst_dir is None: 297 return 298 299 file_path = self.dst_dir / writer.desc.relative_path 300 file_path.parent.mkdir(exist_ok=True, parents=True) 301 with open(file_path, "w", encoding="utf-8") as dst: 302 writer.write_prologue(dst) 303 writer.write_body(dst) 304 writer.write_epilogue(dst) 305 306 def get_all_files(self) -> list[FileDescriptor]: 307 return list(self.files.values()) 308 309 def get_files_by_kind(self, kind: FileKind) -> list[FileDescriptor]: 310 return self.files_by_kind.get(kind, []) 311 312 def post_generate(self) -> None: 313 pass 314 315 316################################# 317# Cmake code generation related # 318################################# 319 320 321class CMakeWriter(FileWriter): 322 """Represents a CMake file.""" 323 324 @override 325 def __init__( 326 self, 327 om: OutputManager, 328 relative_path: str, 329 file_kind: FileKind, 330 indent_unit: str = DEFAULT_INDENT, 331 ): 332 super().__init__( 333 om, 334 relative_path=relative_path, 335 file_kind=file_kind, 336 default_indent=indent_unit, 337 comment_prefix="# ", 338 ) 339 self.headers: dict[str, None] = {} 340 341 342class CMakeOutputConfig(OutputConfig): 343 runtime_include_dir: Path 344 runtime_src_dir: Path 345 346 def __init__( 347 self, 348 runtime_include_dir: Path, 349 runtime_src_dir: Path, 350 dst_dir: Path | None = None, 351 debug_level: DebugLevel = DebugLevel.NONE, 352 ): 353 super().__init__(dst_dir=dst_dir, debug_level=debug_level) 354 self.runtime_include_dir = runtime_include_dir 355 self.runtime_src_dir = runtime_src_dir 356 357 def construct(self, ci: "CompilerInstance") -> "CMakeOutputManager": 358 return CMakeOutputManager( 359 dst_dir=self.dst_dir, 360 debug_level=self.debug_level, 361 runtime_include_dir=self.runtime_include_dir, 362 runtime_src_dir=self.runtime_src_dir, 363 ) 364 365 366class CMakeOutputManager(OutputManager): 367 """Manages the generation of CMake files for Taihe runtime.""" 368 369 runtime_include_dir: Path 370 runtime_src_files: list[Path] 371 372 def __init__( 373 self, 374 dst_dir: Path | None = None, 375 debug_level: DebugLevel = DebugLevel.NONE, 376 *, 377 runtime_include_dir: Path, 378 runtime_src_dir: Path, 379 ): 380 super().__init__(dst_dir=dst_dir, debug_level=debug_level) 381 self.runtime_include_dir = runtime_include_dir 382 self.runtime_c_src_files = [ 383 p for p in runtime_src_dir.rglob("*.c") if p.is_file() 384 ] 385 self.runtime_cxx_src_files = [ 386 p for p in runtime_src_dir.rglob("*.cpp") if p.is_file() 387 ] 388 389 @override 390 def post_generate(self): 391 with CMakeWriter( 392 self, 393 "TaiheGenerated.cmake", 394 FileKind.OTHER, 395 ) as cmake_target: 396 self.emit_runtime_files_list(cmake_target) 397 self.emit_generated_dir("${CMAKE_CURRENT_LIST_DIR}", cmake_target) 398 self.emit_generated_includes(cmake_target) 399 self.emit_generated_sources(cmake_target) 400 self.emit_generated_ets_files(cmake_target) 401 self.emit_set_cpp_standard(cmake_target) 402 403 def emit_runtime_files_list( 404 self, 405 cmake_target: CMakeWriter, 406 ): 407 with cmake_target.indented( 408 f"if(NOT DEFINED TAIHE_RUNTIME_INCLUDE_INNER)", 409 f"endif()", 410 ): 411 with cmake_target.indented( 412 f"set(TAIHE_RUNTIME_INCLUDE_INNER", 413 f")", 414 ): 415 cmake_target.writelns( 416 f"{self.runtime_include_dir}", 417 ) 418 with cmake_target.indented( 419 f"if(NOT DEFINED TAIHE_RUNTIME_C_SRC_INNER)", 420 f"endif()", 421 ): 422 with cmake_target.indented( 423 f"set(TAIHE_RUNTIME_C_SRC_INNER", 424 f")", 425 ): 426 for runtime_src_file in self.runtime_c_src_files: 427 cmake_target.writelns( 428 f"{runtime_src_file}", 429 ) 430 with cmake_target.indented( 431 f"if(NOT DEFINED TAIHE_RUNTIME_CXX_SRC_INNER)", 432 f"endif()", 433 ): 434 with cmake_target.indented( 435 f"set(TAIHE_RUNTIME_CXX_SRC_INNER", 436 f")", 437 ): 438 for runtime_src_file in self.runtime_cxx_src_files: 439 cmake_target.writelns( 440 f"{runtime_src_file}", 441 ) 442 with cmake_target.indented( 443 f"set(TAIHE_RUNTIME_INCLUDE", 444 f")", 445 ): 446 cmake_target.writelns( 447 f"${{TAIHE_RUNTIME_INCLUDE_INNER}}", 448 ) 449 with cmake_target.indented( 450 f"set(TAIHE_RUNTIME_SRC", 451 f")", 452 ): 453 cmake_target.writelns( 454 f"${{TAIHE_RUNTIME_C_SRC_INNER}}", 455 f"${{TAIHE_RUNTIME_CXX_SRC_INNER}}", 456 ) 457 458 def emit_generated_dir( 459 self, 460 generated_path: str, 461 cmake_target: CMakeWriter, 462 ): 463 with cmake_target.indented( 464 f"if(NOT DEFINED TAIHE_GEN_DIR)", 465 f"endif()", 466 ): 467 with cmake_target.indented( 468 f"set(TAIHE_GEN_DIR", 469 f")", 470 ): 471 cmake_target.writelns( 472 f"{generated_path}", 473 ) 474 475 def emit_generated_includes(self, cmake_target: CMakeWriter): 476 with cmake_target.indented( 477 f"set(TAIHE_GEN_INCLUDE", 478 f")", 479 ): 480 cmake_target.writelns( 481 f"${{TAIHE_GEN_DIR}}/include", 482 ) 483 484 def emit_generated_sources( 485 self, 486 cmake_target: CMakeWriter, 487 ): 488 with cmake_target.indented( 489 f"set(TAIHE_GEN_C_SRC", 490 f")", 491 ): 492 for file in self.get_files_by_kind(FileKind.C_SOURCE): 493 cmake_target.writelns( 494 f"${{TAIHE_GEN_DIR}}/{file.relative_path}", 495 ) 496 with cmake_target.indented( 497 f"set(TAIHE_GEN_CXX_SRC", 498 f")", 499 ): 500 for file in self.get_files_by_kind(FileKind.CPP_SOURCE): 501 cmake_target.writelns( 502 f"${{TAIHE_GEN_DIR}}/{file.relative_path}", 503 ) 504 with cmake_target.indented( 505 f"set(TAIHE_GEN_SRC", 506 f")", 507 ): 508 cmake_target.writelns( 509 f"${{TAIHE_GEN_C_SRC}}", 510 f"${{TAIHE_GEN_CXX_SRC}}", 511 ) 512 513 def emit_generated_ets_files( 514 self, 515 cmake_target: CMakeWriter, 516 ): 517 with cmake_target.indented( 518 f"set(TAIHE_GEN_ETS_FILES", 519 f")", 520 ): 521 for file in self.get_files_by_kind(FileKind.ETS): 522 cmake_target.writelns( 523 f"${{TAIHE_GEN_DIR}}/{file.relative_path}", 524 ) 525 526 def emit_set_cpp_standard( 527 self, 528 cmake_target: CMakeWriter, 529 ): 530 with cmake_target.indented( 531 f"set_source_files_properties(", 532 f")", 533 ): 534 cmake_target.writelns( 535 f"${{TAIHE_GEN_CXX_SRC}}", 536 f"${{TAIHE_RUNTIME_CXX_SRC_INNER}}", 537 # setting 538 f"PROPERTIES", 539 f"LANGUAGE CXX", 540 f'COMPILE_FLAGS "-std=c++17"', 541 )