• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2022 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16import xngen
17import xnncommon
18
19parser = argparse.ArgumentParser(
20    description="Generates xnn_operator_type enum.")
21parser.add_argument(
22    "-s",
23    "--spec",
24    metavar="FILE",
25    required=True,
26    help="Specification (YAML) file")
27parser.add_argument(
28    "-o",
29    "--output",
30    metavar="FILE",
31    required=True,
32    help="Output (C source) file")
33parser.add_argument(
34    "-e",
35    "--enum",
36    metavar="FILE",
37    required=True,
38    help="Enum to generate")
39parser.set_defaults(defines=list())
40
41
42def main(args):
43  options = parser.parse_args(args)
44
45  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
46    spec_yaml = yaml.safe_load(spec_file)
47    if not isinstance(spec_yaml, list):
48      raise ValueError("expected a list of operators in the spec")
49
50    output = """\
51// Copyright 2022 Google LLC
52//
53// This source code is licensed under the BSD-style license found in the
54// LICENSE file in the root directory of this source tree.
55//
56// Auto-generated file. Do not edit!
57//   Specification: {specification}
58//   Generator: {generator}
59
60#pragma once
61
62enum xnn_{enum}_type {{
63""".format(
64    specification=options.spec, generator=sys.argv[0], enum=options.enum)
65
66    name = spec_yaml[0]["name"]
67    output += "  " + name + " = 0,\n"
68    for ukernel_spec in spec_yaml[1:]:
69      name = ukernel_spec["name"]
70      output += "  " + name + ",\n"
71
72    output += "};"
73    txt_changed = True
74    if os.path.exists(options.output):
75      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
76        txt_changed = output_file.read() != output
77
78    if txt_changed:
79      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
80        output_file.write(output)
81
82
83if __name__ == "__main__":
84  main(sys.argv[1:])
85