1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""op_reg_gen: Generate op registration code from composite op code.""" 16 17# pylint: disable=invalid-name 18# pylint: disable=missing-function-docstring 19# pylint: disable=g-direct-tensorflow-import 20 21import gast as ast 22 23from tensorflow.python.autograph.pyct import transformer 24from tensorflow.python.autograph.pyct import transpiler 25from tensorflow.python.framework import op_def_registry 26from tensorflow.python.util import tf_inspect 27 28_COMPOSITE_ARG_LIST = ['op_name', 'inputs', 'attrs', 'derived_attrs', 'outputs'] 29 30 31class OpRegGenImpl(transformer.CodeGenerator): 32 """Visit the AST and generate C++ op registration functions.""" 33 34 def __init__(self, ctx): 35 super(OpRegGenImpl, self).__init__(ctx) 36 self.ctx = ctx 37 38 def visit_Name(self, node): 39 return node.id 40 41 def visit_Constant(self, node): 42 return node.value 43 44 def visit_keyword(self, node): 45 return node.arg, self.visit(node.value) 46 47 def visit_List(self, node): 48 return [self.visit(cst) for cst in node.elts] 49 50 def visit_arguments(self, node): 51 return [self.visit(arg) for arg in node.args] 52 53 def visit_FunctionDef(self, node): 54 # TODO(fengliuai): create one utility method to match different apis and 55 # shared it with the tfr_gen.py module. 56 compose_dec = [] 57 for dec in node.decorator_list: 58 if isinstance(dec, ast.Call): 59 if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite': 60 compose_dec.append(dec) 61 if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': 62 compose_dec.append(dec) 63 64 if not compose_dec: 65 # skip a non-composition function 66 return 67 elif len(compose_dec) > 1: 68 raise KeyError('More than one TF ops decomposes for.') 69 70 all_dec_args = {} 71 for arg_name, arg_value in zip(_COMPOSITE_ARG_LIST, compose_dec[0].args): 72 all_dec_args[arg_name] = self.visit(arg_value) 73 74 kw_dec_args = dict([self.visit(kw) for kw in compose_dec[0].keywords]) 75 76 if all_dec_args.keys() & kw_dec_args.keys(): 77 raise KeyError('More arguments than expected.') 78 79 all_dec_args.update(kw_dec_args) 80 81 op_name = all_dec_args['op_name'] 82 op_def = op_def_registry.get(op_name) 83 if op_def: 84 if len(all_dec_args) > 1: 85 # Op has been registered, so it is a user error to specify op def. 86 raise ValueError('Op has been registered: ' + op_name) 87 else: 88 # Op has been registered, then we don't need to generate register code. 89 return 90 91 # Validates the function inputs match what are in the decorator. 92 inputs = all_dec_args.get('inputs', []) 93 attrs = all_dec_args.get('attrs', []) 94 expected_args = [arg.split(':')[0] for arg in inputs + attrs] 95 all_func_args = self.visit(node.args) 96 97 if len(expected_args) != len(all_func_args): 98 raise KeyError( 99 'Composition arguments for {} do not match the registration. {} vs {}' 100 .format(op_name, expected_args, all_func_args)) 101 102 cxx_reg_code = ['\nREGISTER_OP("{}")'.format(op_name)] 103 for input_ in inputs: 104 cxx_reg_code.append('.Input("{}")'.format(input_)) 105 for attr in attrs: 106 py_str = attr.replace('"', "'") 107 cxx_reg_code.append('.Attr("{}")'.format(py_str)) 108 for attr in all_dec_args.get('derived_attrs', []): 109 py_str = attr.replace('"', "'") 110 cxx_reg_code.append('.Attr("{}")'.format(py_str)) 111 for output_ in all_dec_args.get('outputs', []): 112 cxx_reg_code.append('.Output("{}")'.format(output_)) 113 cxx_reg_code[-1] += ';\n' 114 self.emit('\n '.join(cxx_reg_code)) 115 116 117class OpRegGen(transpiler.GenericTranspiler): 118 """Transforms Python objects into TFR MLIR source code.""" 119 120 def transform_ast(self, node, ctx): 121 gen = OpRegGenImpl(ctx) 122 gen.visit(node) 123 return gen.code_buffer 124 125 126def op_reg_gen(func): 127 """Parse a function and emit the TFR functions.""" 128 op_reg_code, _ = OpRegGen().transform(func, None) 129 return op_reg_code 130 131 132def gen_register_op(source, method_prefix=None): 133 """Parse a python code and emit the TFR functions from a target class.""" 134 mlir_funcs = [ 135 op_reg_gen(func) 136 for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) 137 if not method_prefix or name.startswith(method_prefix) 138 ] 139 headers = r""" 140#include "tensorflow/core/framework/op.h" 141 142namespace tensorflow { 143 """ 144 code = '\n'.join(mlir_funcs) 145 return headers + code + '} // namespace tensorflow\n' 146