• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# ===- Generate headers for libc functions  -------------------*- python -*--==#
4#
5# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6# See https://llvm.org/LICENSE.txt for license information.
7# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8#
9# ==-------------------------------------------------------------------------==#
10
11import yaml
12import argparse
13from pathlib import Path
14from header import HeaderFile
15from gpu_headers import GpuHeaderFile as GpuHeader
16from class_implementation.classes.macro import Macro
17from class_implementation.classes.type import Type
18from class_implementation.classes.function import Function
19from class_implementation.classes.enumeration import Enumeration
20from class_implementation.classes.object import Object
21
22
23def yaml_to_classes(yaml_data, header_class, entry_points=None):
24    """
25    Convert YAML data to header classes.
26
27    Args:
28        yaml_data: The YAML data containing header specifications.
29        header_class: The class to use for creating the header.
30        entry_points: A list of specific function names to include in the header.
31
32    Returns:
33        HeaderFile: An instance of HeaderFile populated with the data.
34    """
35    header_name = yaml_data.get("header")
36    header = header_class(header_name)
37
38    for macro_data in yaml_data.get("macros", []):
39        header.add_macro(Macro(macro_data["macro_name"], macro_data["macro_value"]))
40
41    types = yaml_data.get("types", [])
42    sorted_types = sorted(types, key=lambda x: x["type_name"])
43    for type_data in sorted_types:
44        header.add_type(Type(type_data["type_name"]))
45
46    for enum_data in yaml_data.get("enums", []):
47        header.add_enumeration(
48            Enumeration(enum_data["name"], enum_data.get("value", None))
49        )
50
51    functions = yaml_data.get("functions", [])
52    if entry_points:
53        entry_points_set = set(entry_points)
54        functions = [f for f in functions if f["name"] in entry_points_set]
55    sorted_functions = sorted(functions, key=lambda x: x["name"])
56    guards = []
57    guarded_function_dict = {}
58    for function_data in sorted_functions:
59        guard = function_data.get("guard", None)
60        if guard is None:
61            arguments = [arg["type"] for arg in function_data["arguments"]]
62            attributes = function_data.get("attributes", None)
63            standards = function_data.get("standards", None)
64            header.add_function(
65                Function(
66                    function_data["return_type"],
67                    function_data["name"],
68                    arguments,
69                    standards,
70                    guard,
71                    attributes,
72                )
73            )
74        else:
75            if guard not in guards:
76                guards.append(guard)
77                guarded_function_dict[guard] = []
78                guarded_function_dict[guard].append(function_data)
79            else:
80                guarded_function_dict[guard].append(function_data)
81    sorted_guards = sorted(guards)
82    for guard in sorted_guards:
83        for function_data in guarded_function_dict[guard]:
84            arguments = [arg["type"] for arg in function_data["arguments"]]
85            attributes = function_data.get("attributes", None)
86            standards = function_data.get("standards", None)
87            header.add_function(
88                Function(
89                    function_data["return_type"],
90                    function_data["name"],
91                    arguments,
92                    standards,
93                    guard,
94                    attributes,
95                )
96            )
97
98    objects = yaml_data.get("objects", [])
99    sorted_objects = sorted(objects, key=lambda x: x["object_name"])
100    for object_data in sorted_objects:
101        header.add_object(
102            Object(object_data["object_name"], object_data["object_type"])
103        )
104
105    return header
106
107
108def load_yaml_file(yaml_file, header_class, entry_points):
109    """
110    Load YAML file and convert it to header classes.
111
112    Args:
113        yaml_file: Path to the YAML file.
114        header_class: The class to use for creating the header (HeaderFile or GpuHeader).
115        entry_points: A list of specific function names to include in the header.
116
117    Returns:
118        HeaderFile: An instance of HeaderFile populated with the data.
119    """
120    with open(yaml_file, "r") as f:
121        yaml_data = yaml.safe_load(f)
122    return yaml_to_classes(yaml_data, header_class, entry_points)
123
124
125def fill_public_api(header_str, h_def_content):
126    """
127    Replace the %%public_api() placeholder in the .h.def content with the generated header content.
128
129    Args:
130        header_str: The generated header string.
131        h_def_content: The content of the .h.def file.
132
133    Returns:
134        The final header content with the public API filled in.
135    """
136    header_str = header_str.strip()
137    return h_def_content.replace("%%public_api()", header_str, 1)
138
139
140def parse_function_details(details):
141    """
142    Parse function details from a list of strings and return a Function object.
143
144    Args:
145        details: A list containing function details
146
147    Returns:
148        Function: An instance of Function initialized with the details.
149    """
150    return_type, name, arguments, standards, guard, attributes = details
151    standards = standards.split(",") if standards != "null" else []
152    arguments = [arg.strip() for arg in arguments.split(",")]
153    attributes = attributes.split(",") if attributes != "null" else []
154
155    return Function(
156        return_type=return_type,
157        name=name,
158        arguments=arguments,
159        standards=standards,
160        guard=guard if guard != "null" else None,
161        attributes=attributes if attributes else [],
162    )
163
164
165def add_function_to_yaml(yaml_file, function_details):
166    """
167    Add a function to the YAML file.
168
169    Args:
170        yaml_file: The path to the YAML file.
171        function_details: A list containing function details (return_type, name, arguments, standards, guard, attributes).
172    """
173    new_function = parse_function_details(function_details)
174
175    with open(yaml_file, "r") as f:
176        yaml_data = yaml.safe_load(f)
177    if "functions" not in yaml_data:
178        yaml_data["functions"] = []
179
180    function_dict = {
181        "name": new_function.name,
182        "standards": new_function.standards,
183        "return_type": new_function.return_type,
184        "arguments": [{"type": arg} for arg in new_function.arguments],
185    }
186
187    if new_function.guard:
188        function_dict["guard"] = new_function.guard
189
190    if new_function.attributes:
191        function_dict["attributes"] = new_function.attributes
192
193    insert_index = 0
194    for i, func in enumerate(yaml_data["functions"]):
195        if func["name"] > new_function.name:
196            insert_index = i
197            break
198    else:
199        insert_index = len(yaml_data["functions"])
200
201    yaml_data["functions"].insert(insert_index, function_dict)
202
203    class IndentYamlListDumper(yaml.Dumper):
204        def increase_indent(self, flow=False, indentless=False):
205            return super(IndentYamlListDumper, self).increase_indent(flow, False)
206
207    with open(yaml_file, "w") as f:
208        yaml.dump(
209            yaml_data,
210            f,
211            Dumper=IndentYamlListDumper,
212            default_flow_style=False,
213            sort_keys=False,
214        )
215
216    print(f"Added function {new_function.name} to {yaml_file}")
217
218
219def main():
220    parser = argparse.ArgumentParser(description="Generate header files from YAML")
221    parser.add_argument(
222        "yaml_file", help="Path to the YAML file containing header specification"
223    )
224    parser.add_argument(
225        "--output_dir",
226        help="Directory to output the generated header file",
227    )
228    parser.add_argument(
229        "--h_def_file",
230        help="Path to the .h.def template file (required if not using --export_decls)",
231    )
232    parser.add_argument(
233        "--add_function",
234        nargs=6,
235        metavar=(
236            "RETURN_TYPE",
237            "NAME",
238            "ARGUMENTS",
239            "STANDARDS",
240            "GUARD",
241            "ATTRIBUTES",
242        ),
243        help="Add a function to the YAML file",
244    )
245    parser.add_argument(
246        "--e", action="append", help="Entry point to include", dest="entry_points"
247    )
248    parser.add_argument(
249        "--export-decls",
250        action="store_true",
251        help="Flag to use GpuHeader for exporting declarations",
252    )
253    args = parser.parse_args()
254
255    if args.add_function:
256        add_function_to_yaml(args.yaml_file, args.add_function)
257
258    header_class = GpuHeader if args.export_decls else HeaderFile
259    header = load_yaml_file(args.yaml_file, header_class, args.entry_points)
260
261    header_str = str(header)
262
263    if args.output_dir:
264        output_file_path = Path(args.output_dir)
265        if output_file_path.is_dir():
266            output_file_path /= f"{Path(args.yaml_file).stem}.h"
267    else:
268        output_file_path = Path(f"{Path(args.yaml_file).stem}.h")
269
270    if not args.export_decls and args.h_def_file:
271        with open(args.h_def_file, "r") as f:
272            h_def_content = f.read()
273        final_header_content = fill_public_api(header_str, h_def_content)
274        with open(output_file_path, "w") as f:
275            f.write(final_header_content)
276    else:
277        with open(output_file_path, "w") as f:
278            f.write(header_str)
279
280
281if __name__ == "__main__":
282    main()
283