• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""op_reg_gen: Generate op registration code from composite op code."""
16
17# pylint: disable=invalid-name
18# pylint: disable=missing-function-docstring
19# pylint: disable=g-direct-tensorflow-import
20
21import gast as ast
22
23from tensorflow.python.autograph.pyct import transformer
24from tensorflow.python.autograph.pyct import transpiler
25from tensorflow.python.framework import op_def_registry
26from tensorflow.python.util import tf_inspect
27
28_COMPOSITE_ARG_LIST = ['op_name', 'inputs', 'attrs', 'derived_attrs', 'outputs']
29
30
31class OpRegGenImpl(transformer.CodeGenerator):
32  """Visit the AST and generate C++ op registration functions."""
33
34  def __init__(self, ctx):
35    super(OpRegGenImpl, self).__init__(ctx)
36    self.ctx = ctx
37
38  def visit_Name(self, node):
39    return node.id
40
41  def visit_Constant(self, node):
42    return node.value
43
44  def visit_keyword(self, node):
45    return node.arg, self.visit(node.value)
46
47  def visit_List(self, node):
48    return [self.visit(cst) for cst in node.elts]
49
50  def visit_arguments(self, node):
51    return [self.visit(arg) for arg in node.args]
52
53  def visit_FunctionDef(self, node):
54    # TODO(fengliuai): create one utility method to match different apis and
55    # shared it with the tfr_gen.py module.
56    compose_dec = []
57    for dec in node.decorator_list:
58      if isinstance(dec, ast.Call):
59        if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite':
60          compose_dec.append(dec)
61        if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite':
62          compose_dec.append(dec)
63
64    if not compose_dec:
65      # skip a non-composition function
66      return
67    elif len(compose_dec) > 1:
68      raise KeyError('More than one TF ops decomposes for.')
69
70    all_dec_args = {}
71    for arg_name, arg_value in zip(_COMPOSITE_ARG_LIST, compose_dec[0].args):
72      all_dec_args[arg_name] = self.visit(arg_value)
73
74    kw_dec_args = dict([self.visit(kw) for kw in compose_dec[0].keywords])
75
76    if all_dec_args.keys() & kw_dec_args.keys():
77      raise KeyError('More arguments than expected.')
78
79    all_dec_args.update(kw_dec_args)
80
81    op_name = all_dec_args['op_name']
82    op_def = op_def_registry.get(op_name)
83    if op_def:
84      if len(all_dec_args) > 1:
85        # Op has been registered, so it is a user error to specify op def.
86        raise ValueError('Op has been registered: ' + op_name)
87      else:
88        # Op has been registered, then we don't need to generate register code.
89        return
90
91    # Validates the function inputs match what are in the decorator.
92    inputs = all_dec_args.get('inputs', [])
93    attrs = all_dec_args.get('attrs', [])
94    expected_args = [arg.split(':')[0] for arg in inputs + attrs]
95    all_func_args = self.visit(node.args)
96
97    if len(expected_args) != len(all_func_args):
98      raise KeyError(
99          'Composition arguments for {} do not match the registration. {} vs {}'
100          .format(op_name, expected_args, all_func_args))
101
102    cxx_reg_code = ['\nREGISTER_OP("{}")'.format(op_name)]
103    for input_ in inputs:
104      cxx_reg_code.append('.Input("{}")'.format(input_))
105    for attr in attrs:
106      py_str = attr.replace('"', "'")
107      cxx_reg_code.append('.Attr("{}")'.format(py_str))
108    for attr in all_dec_args.get('derived_attrs', []):
109      py_str = attr.replace('"', "'")
110      cxx_reg_code.append('.Attr("{}")'.format(py_str))
111    for output_ in all_dec_args.get('outputs', []):
112      cxx_reg_code.append('.Output("{}")'.format(output_))
113    cxx_reg_code[-1] += ';\n'
114    self.emit('\n    '.join(cxx_reg_code))
115
116
117class OpRegGen(transpiler.GenericTranspiler):
118  """Transforms Python objects into TFR MLIR source code."""
119
120  def transform_ast(self, node, ctx):
121    gen = OpRegGenImpl(ctx)
122    gen.visit(node)
123    return gen.code_buffer
124
125
126def op_reg_gen(func):
127  """Parse a function and emit the TFR functions."""
128  op_reg_code, _ = OpRegGen().transform(func, None)
129  return op_reg_code
130
131
132def gen_register_op(source, method_prefix=None):
133  """Parse a python code and emit the TFR functions from a target class."""
134  mlir_funcs = [
135      op_reg_gen(func)
136      for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
137      if not method_prefix or name.startswith(method_prefix)
138  ]
139  headers = r"""
140#include "tensorflow/core/framework/op.h"
141
142namespace tensorflow {
143  """
144  code = '\n'.join(mlir_funcs)
145  return headers + code + '}  // namespace tensorflow\n'
146