• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A template to define composite ops."""
15
16# pylint: disable=g-direct-tensorflow-import
17
18import os
19import sys
20
21from absl import app
22from tensorflow.compiler.mlir.tfr.python.composite import Composite
23from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
24from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
25from tensorflow.python.platform import flags
26
27FLAGS = flags.FLAGS
28
29flags.DEFINE_string(
30    'output', None,
31    'Path to write the genereated register op file and MLIR file.')
32
33flags.DEFINE_bool('gen_register_op', True,
34                  'Generate register op cc file or tfr mlir file.')
35
36flags.mark_flag_as_required('output')
37
38
39@Composite('TestRandom', derived_attrs=['T: numbertype'], outputs=['o: T'])
40def _composite_random_op():
41  pass
42
43
44def main(_):
45  if FLAGS.gen_register_op:
46    assert FLAGS.output.endswith('.cc')
47    generated_code = gen_register_op(sys.modules[__name__], '_composite_')
48  else:
49    assert FLAGS.output.endswith('.mlir')
50    generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_')
51
52  dirname = os.path.dirname(FLAGS.output)
53  if not os.path.exists(dirname):
54    os.makedirs(dirname)
55  with open(FLAGS.output, 'w') as f:
56    f.write(generated_code)
57
58
59if __name__ == '__main__':
60  app.run(main=main)
61