1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""A template to define composite ops.""" 15 16# pylint: disable=g-direct-tensorflow-import 17 18import os 19import sys 20 21from absl import app 22from tensorflow.compiler.mlir.tfr.python.composite import Composite 23from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op 24from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module 25from tensorflow.python.platform import flags 26 27FLAGS = flags.FLAGS 28 29flags.DEFINE_string( 30 'output', None, 31 'Path to write the genereated register op file and MLIR file.') 32 33flags.DEFINE_bool('gen_register_op', True, 34 'Generate register op cc file or tfr mlir file.') 35 36flags.mark_flag_as_required('output') 37 38 39@Composite('TestRandom', derived_attrs=['T: numbertype'], outputs=['o: T']) 40def _composite_random_op(): 41 pass 42 43 44def main(_): 45 if FLAGS.gen_register_op: 46 assert FLAGS.output.endswith('.cc') 47 generated_code = gen_register_op(sys.modules[__name__], '_composite_') 48 else: 49 assert FLAGS.output.endswith('.mlir') 50 generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_') 51 52 dirname = os.path.dirname(FLAGS.output) 53 if not os.path.exists(dirname): 54 os.makedirs(dirname) 55 with open(FLAGS.output, 'w') as f: 56 f.write(generated_code) 57 58 59if __name__ == '__main__': 60 app.run(main=main) 61