1#!/usr/bin/env python 2# Copyright 2022 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16import xngen 17import xnncommon 18 19parser = argparse.ArgumentParser( 20 description="Generates xnn_operator_type enum.") 21parser.add_argument( 22 "-s", 23 "--spec", 24 metavar="FILE", 25 required=True, 26 help="Specification (YAML) file") 27parser.add_argument( 28 "-o", 29 "--output", 30 metavar="FILE", 31 required=True, 32 help="Output (C source) file") 33parser.add_argument( 34 "-e", 35 "--enum", 36 metavar="FILE", 37 required=True, 38 help="Enum to generate") 39parser.set_defaults(defines=list()) 40 41 42def main(args): 43 options = parser.parse_args(args) 44 45 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 46 spec_yaml = yaml.safe_load(spec_file) 47 if not isinstance(spec_yaml, list): 48 raise ValueError("expected a list of operators in the spec") 49 50 output = """\ 51// Copyright 2022 Google LLC 52// 53// This source code is licensed under the BSD-style license found in the 54// LICENSE file in the root directory of this source tree. 55// 56// Auto-generated file. Do not edit! 57// Specification: {specification} 58// Generator: {generator} 59 60#pragma once 61 62enum xnn_{enum}_type {{ 63""".format( 64 specification=options.spec, generator=sys.argv[0], enum=options.enum) 65 66 name = spec_yaml[0]["name"] 67 output += " " + name + " = 0,\n" 68 for ukernel_spec in spec_yaml[1:]: 69 name = ukernel_spec["name"] 70 output += " " + name + ",\n" 71 72 output += "};" 73 txt_changed = True 74 if os.path.exists(options.output): 75 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 76 txt_changed = output_file.read() != output 77 78 if txt_changed: 79 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 80 output_file.write(output) 81 82 83if __name__ == "__main__": 84 main(sys.argv[1:]) 85