• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Manage output files."""
17
18import sys
19from collections import defaultdict
20from collections.abc import Generator
21from contextlib import contextmanager
22from dataclasses import dataclass
23from enum import Enum, auto
24from io import StringIO
25from os import path, sep
26from pathlib import Path
27from types import FrameType, TracebackType
28from typing import TYPE_CHECKING, TextIO
29
30from typing_extensions import Self, override
31
32if TYPE_CHECKING:
33    from taihe.driver.contexts import CompilerInstance
34
35DEFAULT_INDENT = "    "  # Four spaces
36
37
38class DebugLevel(Enum):
39    """Controls the code-generator debug info.
40
41    When enabled, the generated code would contain comments, representing the
42    location of Python code which generates.
43    """
44
45    NONE = auto()
46    """Don't print any debug info."""
47    CONCISE = auto()
48    """Prints function and line number."""
49    VERBOSE = auto()
50    """Besides CONSICE, also prints code snippet. Could be slow."""
51
52
53class FileKind(str, Enum):
54    C_HEADER = "c_header"
55    C_SOURCE = "c_source"
56    CPP_HEADER = "cpp_header"
57    CPP_SOURCE = "cpp_source"
58    TEMPLATE = "template"
59    ETS = "ets"
60    OTHER = "other"
61
62
63@dataclass
64class FileDescriptor:
65    relative_path: str  # e.g., "include/foo.h"
66    kind: FileKind
67
68
69class BaseWriter:
70    def __init__(
71        self,
72        out: TextIO,
73        *,
74        comment_prefix: str,
75        default_indent: str,
76        debug_level: DebugLevel = DebugLevel.NONE,
77    ):
78        """Initialize a code writer with a writable output stream.
79
80        Args:
81            out: A writable stream object
82            comment_prefix: The prefix for line-comment, for instance, "// " for C++
83            default_indent: The default indentation string for each level of indentation
84            debug_level: see `DebugLevel` for details
85        """
86        self._out = out
87        self._default_indent = default_indent
88        self._current_indent = ""
89        self._debug_level = debug_level
90        self._comment_prefix = comment_prefix
91
92    def newline(self):
93        """Writes a newline character."""
94        self._out.write("\n")
95
96    def writeln(
97        self,
98        line: str = "",
99    ):
100        """Writes a single-line string.
101
102        Args:
103            line: The line to write (must not contain newlines)
104        """
105        pass
106        if not line:
107            # Don't use indent for empty lines
108            self._out.write("\n")
109            return
110
111        self._out.write(self._current_indent)
112        self._out.write(line)
113        self._out.write("\n")
114
115    def writelns(self, *lines: str):
116        """Writes multiple one-line strings.
117
118        Args:
119            *lines: One or more lines to write
120        """
121        self._write_debug(skip=2)
122        for line in lines:
123            self.writeln(
124                line,
125            )
126
127    def write_block(self, text_block: str):
128        """Writes a potentially multi-line text block.
129
130        Args:
131            text_block: The block of text to write
132        """
133        self.writelns(*text_block.splitlines())
134
135    def write_comment(self, comment: str):
136        """Writes a comment block, prefixing each line with the comment prefix.
137
138        Indents the comment block according to the current indentation level.
139        Handles multi-line comments by splitting the input string.
140
141        Args:
142            comment: The comment text to write. Can be multi-line.
143        """
144        for line in comment.splitlines():
145            self._out.write(self._current_indent)
146            self._out.write(self._comment_prefix)
147            self._out.write(line)
148            self._out.write("\n")
149
150    @contextmanager
151    def indented(
152        self,
153        prologue: str | None,
154        epilogue: str | None,
155        /,
156        *,
157        indent: str | None = None,
158    ) -> Generator[Self, None, None]:
159        """Context manager that indents code within its scope.
160
161        Args:
162            prologue: Optional text to write before indentation
163            epilogue: Optional text to write after indentation
164            indent: Optional string to use for indentation (overrides default)
165
166        Returns:
167            A context manager that yields this BaseWriter
168        """
169        self._write_debug(skip=3)
170        if prologue is not None:
171            self.writeln(
172                prologue,
173            )
174        previous_indent = self._current_indent
175        self._current_indent += self._default_indent if indent is None else indent
176        try:
177            yield self
178        finally:
179            self._current_indent = previous_indent
180            if epilogue is not None:
181                self.writeln(
182                    epilogue,
183                )
184
185    def _write_debug(self, *, skip: int):
186        if self._debug_level == DebugLevel.NONE:
187            return
188        self.write_comment(_format_frame(sys._getframe(skip)))  # type: ignore
189
190
191class FileWriter(BaseWriter):
192    def __init__(
193        self,
194        om: "OutputManager",
195        relative_path: str,
196        file_kind: FileKind,
197        *,
198        default_indent: str = DEFAULT_INDENT,
199        comment_prefix: str,
200    ):
201        super().__init__(
202            out=StringIO(),
203            default_indent=default_indent,
204            comment_prefix=comment_prefix,
205            debug_level=om.debug_level,
206        )
207        self._om = om
208        self.desc = FileDescriptor(
209            relative_path=relative_path,
210            kind=file_kind,
211        )
212
213    def __enter__(self):
214        return self
215
216    def __exit__(
217        self,
218        exc_type: type[BaseException] | None,
219        exc_val: BaseException | None,
220        exc_tb: TracebackType | None,
221    ) -> bool:
222        del exc_val, exc_tb, exc_type
223        self._om.save(self)
224        return False
225
226    def write_body(self, f: TextIO):
227        pass
228        f.write(self._out.getvalue())
229
230    def write_prologue(self, f: TextIO):
231        del f
232
233    def write_epilogue(self, f: TextIO):
234        del f
235
236
237def _format_frame(f: FrameType) -> str:
238    # For /a/b/c/d/e.py, only keep FILENAME_KEEP directories, resulting "c/d/e.py"
239    FILENAME_KEEP = 3
240
241    file_name = f.f_code.co_filename
242    parts = file_name.split(sep)
243    if len(parts) > FILENAME_KEEP:
244        file_name = path.join(*parts[-FILENAME_KEEP:])
245
246    base_format = f"CODEGEN-DEBUG: {f.f_code.co_name} in {file_name}:{f.f_lineno}"
247
248    return base_format
249
250
251@dataclass
252class OutputConfig:
253    dst_dir: Path | None = None
254    debug_level: DebugLevel = DebugLevel.NONE
255
256    def construct(self, ci: "CompilerInstance") -> "OutputManager":
257        """Construct an OutputManager based on this configuration."""
258        return OutputManager(
259            dst_dir=self.dst_dir,
260            debug_level=self.debug_level,
261        )
262
263
264class OutputManager:
265    """Manages the creation and saving of output files."""
266
267    files: dict[str, FileDescriptor]
268    files_by_kind: dict[FileKind, list[FileDescriptor]]
269
270    dst_dir: Path | None
271
272    debug_level: DebugLevel
273
274    def __init__(
275        self,
276        dst_dir: Path | None = None,
277        debug_level: DebugLevel = DebugLevel.NONE,
278    ):
279        self.files: dict[str, FileDescriptor] = {}
280        self.files_by_kind: dict[FileKind, list[FileDescriptor]] = defaultdict(list)
281        self.dst_dir = dst_dir
282        self.debug_level = debug_level
283
284    def register(self, desc: FileDescriptor):
285        if (prev := self.files.setdefault(desc.relative_path, desc)) != desc:
286            raise ValueError(
287                f"File {desc.relative_path} is already registered as {prev.kind}, "
288                f"cannot re-register with {desc.kind}."
289            )
290        self.files_by_kind[desc.kind].append(desc)
291
292    def save(self, writer: FileWriter):
293        """Saves the content of a FileWriter to the output directory."""
294        self.register(writer.desc)
295
296        if self.dst_dir is None:
297            return
298
299        file_path = self.dst_dir / writer.desc.relative_path
300        file_path.parent.mkdir(exist_ok=True, parents=True)
301        with open(file_path, "w", encoding="utf-8") as dst:
302            writer.write_prologue(dst)
303            writer.write_body(dst)
304            writer.write_epilogue(dst)
305
306    def get_all_files(self) -> list[FileDescriptor]:
307        return list(self.files.values())
308
309    def get_files_by_kind(self, kind: FileKind) -> list[FileDescriptor]:
310        return self.files_by_kind.get(kind, [])
311
312    def post_generate(self) -> None:
313        pass
314
315
316#################################
317# Cmake code generation related #
318#################################
319
320
321class CMakeWriter(FileWriter):
322    """Represents a CMake file."""
323
324    @override
325    def __init__(
326        self,
327        om: OutputManager,
328        relative_path: str,
329        file_kind: FileKind,
330        indent_unit: str = DEFAULT_INDENT,
331    ):
332        super().__init__(
333            om,
334            relative_path=relative_path,
335            file_kind=file_kind,
336            default_indent=indent_unit,
337            comment_prefix="# ",
338        )
339        self.headers: dict[str, None] = {}
340
341
342class CMakeOutputConfig(OutputConfig):
343    runtime_include_dir: Path
344    runtime_src_dir: Path
345
346    def __init__(
347        self,
348        runtime_include_dir: Path,
349        runtime_src_dir: Path,
350        dst_dir: Path | None = None,
351        debug_level: DebugLevel = DebugLevel.NONE,
352    ):
353        super().__init__(dst_dir=dst_dir, debug_level=debug_level)
354        self.runtime_include_dir = runtime_include_dir
355        self.runtime_src_dir = runtime_src_dir
356
357    def construct(self, ci: "CompilerInstance") -> "CMakeOutputManager":
358        return CMakeOutputManager(
359            dst_dir=self.dst_dir,
360            debug_level=self.debug_level,
361            runtime_include_dir=self.runtime_include_dir,
362            runtime_src_dir=self.runtime_src_dir,
363        )
364
365
366class CMakeOutputManager(OutputManager):
367    """Manages the generation of CMake files for Taihe runtime."""
368
369    runtime_include_dir: Path
370    runtime_src_files: list[Path]
371
372    def __init__(
373        self,
374        dst_dir: Path | None = None,
375        debug_level: DebugLevel = DebugLevel.NONE,
376        *,
377        runtime_include_dir: Path,
378        runtime_src_dir: Path,
379    ):
380        super().__init__(dst_dir=dst_dir, debug_level=debug_level)
381        self.runtime_include_dir = runtime_include_dir
382        self.runtime_c_src_files = [
383            p for p in runtime_src_dir.rglob("*.c") if p.is_file()
384        ]
385        self.runtime_cxx_src_files = [
386            p for p in runtime_src_dir.rglob("*.cpp") if p.is_file()
387        ]
388
389    @override
390    def post_generate(self):
391        with CMakeWriter(
392            self,
393            "TaiheGenerated.cmake",
394            FileKind.OTHER,
395        ) as cmake_target:
396            self.emit_runtime_files_list(cmake_target)
397            self.emit_generated_dir("${CMAKE_CURRENT_LIST_DIR}", cmake_target)
398            self.emit_generated_includes(cmake_target)
399            self.emit_generated_sources(cmake_target)
400            self.emit_generated_ets_files(cmake_target)
401            self.emit_set_cpp_standard(cmake_target)
402
403    def emit_runtime_files_list(
404        self,
405        cmake_target: CMakeWriter,
406    ):
407        with cmake_target.indented(
408            f"if(NOT DEFINED TAIHE_RUNTIME_INCLUDE_INNER)",
409            f"endif()",
410        ):
411            with cmake_target.indented(
412                f"set(TAIHE_RUNTIME_INCLUDE_INNER",
413                f")",
414            ):
415                cmake_target.writelns(
416                    f"{self.runtime_include_dir}",
417                )
418        with cmake_target.indented(
419            f"if(NOT DEFINED TAIHE_RUNTIME_C_SRC_INNER)",
420            f"endif()",
421        ):
422            with cmake_target.indented(
423                f"set(TAIHE_RUNTIME_C_SRC_INNER",
424                f")",
425            ):
426                for runtime_src_file in self.runtime_c_src_files:
427                    cmake_target.writelns(
428                        f"{runtime_src_file}",
429                    )
430        with cmake_target.indented(
431            f"if(NOT DEFINED TAIHE_RUNTIME_CXX_SRC_INNER)",
432            f"endif()",
433        ):
434            with cmake_target.indented(
435                f"set(TAIHE_RUNTIME_CXX_SRC_INNER",
436                f")",
437            ):
438                for runtime_src_file in self.runtime_cxx_src_files:
439                    cmake_target.writelns(
440                        f"{runtime_src_file}",
441                    )
442        with cmake_target.indented(
443            f"set(TAIHE_RUNTIME_INCLUDE",
444            f")",
445        ):
446            cmake_target.writelns(
447                f"${{TAIHE_RUNTIME_INCLUDE_INNER}}",
448            )
449        with cmake_target.indented(
450            f"set(TAIHE_RUNTIME_SRC",
451            f")",
452        ):
453            cmake_target.writelns(
454                f"${{TAIHE_RUNTIME_C_SRC_INNER}}",
455                f"${{TAIHE_RUNTIME_CXX_SRC_INNER}}",
456            )
457
458    def emit_generated_dir(
459        self,
460        generated_path: str,
461        cmake_target: CMakeWriter,
462    ):
463        with cmake_target.indented(
464            f"if(NOT DEFINED TAIHE_GEN_DIR)",
465            f"endif()",
466        ):
467            with cmake_target.indented(
468                f"set(TAIHE_GEN_DIR",
469                f")",
470            ):
471                cmake_target.writelns(
472                    f"{generated_path}",
473                )
474
475    def emit_generated_includes(self, cmake_target: CMakeWriter):
476        with cmake_target.indented(
477            f"set(TAIHE_GEN_INCLUDE",
478            f")",
479        ):
480            cmake_target.writelns(
481                f"${{TAIHE_GEN_DIR}}/include",
482            )
483
484    def emit_generated_sources(
485        self,
486        cmake_target: CMakeWriter,
487    ):
488        with cmake_target.indented(
489            f"set(TAIHE_GEN_C_SRC",
490            f")",
491        ):
492            for file in self.get_files_by_kind(FileKind.C_SOURCE):
493                cmake_target.writelns(
494                    f"${{TAIHE_GEN_DIR}}/{file.relative_path}",
495                )
496        with cmake_target.indented(
497            f"set(TAIHE_GEN_CXX_SRC",
498            f")",
499        ):
500            for file in self.get_files_by_kind(FileKind.CPP_SOURCE):
501                cmake_target.writelns(
502                    f"${{TAIHE_GEN_DIR}}/{file.relative_path}",
503                )
504        with cmake_target.indented(
505            f"set(TAIHE_GEN_SRC",
506            f")",
507        ):
508            cmake_target.writelns(
509                f"${{TAIHE_GEN_C_SRC}}",
510                f"${{TAIHE_GEN_CXX_SRC}}",
511            )
512
513    def emit_generated_ets_files(
514        self,
515        cmake_target: CMakeWriter,
516    ):
517        with cmake_target.indented(
518            f"set(TAIHE_GEN_ETS_FILES",
519            f")",
520        ):
521            for file in self.get_files_by_kind(FileKind.ETS):
522                cmake_target.writelns(
523                    f"${{TAIHE_GEN_DIR}}/{file.relative_path}",
524                )
525
526    def emit_set_cpp_standard(
527        self,
528        cmake_target: CMakeWriter,
529    ):
530        with cmake_target.indented(
531            f"set_source_files_properties(",
532            f")",
533        ):
534            cmake_target.writelns(
535                f"${{TAIHE_GEN_CXX_SRC}}",
536                f"${{TAIHE_RUNTIME_CXX_SRC_INNER}}",
537                # setting
538                f"PROPERTIES",
539                f"LANGUAGE CXX",
540                f'COMPILE_FLAGS "-std=c++17"',
541            )