• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tfr_gen: Generate mlir tfr decomposition function from python code."""
16
17# pylint: disable=invalid-name
18# pylint: disable=missing-function-docstring
19# pylint: disable=g-direct-tensorflow-import
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import enum
26import os
27import re
28import types
29import gast as ast
30
31from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr
32from tensorflow.core.framework import types_pb2
33from tensorflow.python.autograph.converters import control_flow
34from tensorflow.python.autograph.converters import return_statements
35from tensorflow.python.autograph.impl import api
36from tensorflow.python.autograph.pyct import anno
37from tensorflow.python.autograph.pyct import cfg
38from tensorflow.python.autograph.pyct import qual_names
39from tensorflow.python.autograph.pyct import transformer
40from tensorflow.python.autograph.pyct import transpiler
41from tensorflow.python.autograph.pyct.static_analysis import activity
42from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
43from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
44from tensorflow.python.autograph.pyct.static_analysis import type_inference
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import load_library
47from tensorflow.python.framework import op_def_registry
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.util import tf_inspect
50
51# TODO(mdan): Use class definitions so that we can mix these with Python types.
52
53
54class TFRTypes(enum.Enum):
55  """All the supported types.
56
57    1-3: tfr types
58    4-99: mlir built-in types
59    100-199: TF related translator internal types
60    200- : Python related translator internal types
61  """
62  TENSOR = 1
63  TENSOR_LIST = 2
64  ATTR = 3
65  NONE = 4
66  SHAPE = 5  # shape -> !shape.shape
67  I1 = 21
68  I8 = 22
69  I16 = 23
70  I32 = 24
71  I64 = 25
72  F32 = 26
73  INDEX = 27
74  AG_UNDEFINED_VAL = 100
75  AG_BUILTIN_FUNC = 101
76  TF_RAW_OP = 102
77  TF_REGION = 103
78  TF_TENSOR_SHAPE_FUNC = 104  # shape.as_list
79  TF_TENSOR_SHAPE_LIST = 105  # shape.as_list()
80  PY_BUILTIN_FUNC = 200
81  TFR_BUILTIN_FUNC = 201
82
83  # As these are not real types, __getattribute__ helps them appear more like
84  # actual types (i.e. class definitions).
85  def __getattribute__(self, name):
86    if name == 'shape' and object.__getattribute__(self, 'value') == 1:
87      return TFRTypes.SHAPE
88    if name == 'as_list' and object.__getattribute__(self, 'value') == 5:
89      return TFRTypes.TF_TENSOR_SHAPE_FUNC
90    return object.__getattribute__(self, name)
91
92  def __str__(self):
93    if self.value < 4:  # pylint: disable=comparison-with-callable
94      return '!tfr.' + self.name.lower()
95    elif self.value < 10:  # pylint: disable=comparison-with-callable
96      return '!shape.' + self.name.lower()
97    else:
98      return self.name.lower()
99
100
101_attribute_types = [
102    TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX,
103    TFRTypes.ATTR
104]
105
106
107def _get_type_from_proto(arg_def=None, attr_def=None):
108  if not arg_def:
109    if attr_def.type == 'bool':
110      return TFRTypes.I1
111    elif attr_def.type == 'int32':
112      return TFRTypes.I32
113    elif attr_def.type == 'int' or attr_def.type == 'int64':
114      return TFRTypes.I64
115    elif attr_def.type == 'float':
116      return TFRTypes.F32
117    else:
118      return TFRTypes.ATTR
119
120  if arg_def.number_attr or arg_def.type_list_attr:
121    return TFRTypes.TENSOR_LIST
122  else:
123    return TFRTypes.TENSOR
124
125
126def _get_type_info_from_proto(arg_def=None, attr_def=None):
127  attr_type = _get_type_from_proto(arg_def, attr_def)
128  if not arg_def:
129    return '{}{{tfr.name="{}",tfr.type="{}"}}'.format(
130        attr_type, attr_def.name, attr_def.type)
131  else:
132    attr_names = []
133    if arg_def.number_attr:
134      attr_names.append(arg_def.number_attr)
135    if arg_def.type_attr:
136      attr_names.append(arg_def.type_attr)
137    if arg_def.type_list_attr:
138      attr_names.append(arg_def.type_list_attr)
139
140    # TODO(fengliuai): currently we don't support backward type inference, so we
141    # have to store these non-derivable type in the signatures, and then they
142    # can be used to cast the values when raising to tf ops.
143    if arg_def.type == types_pb2.DT_FLOAT:
144      attr_names.append('f32_')
145    elif arg_def.type == types_pb2.DT_INT32:
146      attr_names.append('i32_')
147    elif arg_def.type == types_pb2.DT_INT64:
148      attr_names.append('i64_')
149    elif arg_def.type == types_pb2.DT_BOOL:
150      attr_names.append('i1_')
151
152    if not attr_names:
153      return str(attr_type)
154    else:
155      return '{}<{}>'.format(attr_type, ','.join(attr_names))
156
157
158def _get_val_from_proto(attr_type, attr_val):
159  if attr_type == TFRTypes.I1:
160    return 'true' if attr_val.b else 'false'
161  elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64:
162    return attr_val.i
163  elif attr_type == TFRTypes.F32:
164    return attr_val.f
165  elif attr_type == TFRTypes.ATTR:
166    # string
167    if attr_val.HasField('s'):
168      return '"{}"'.format(attr_val.s.decode())
169    # type
170    if attr_val.HasField('type'):
171      if attr_val.type == types_pb2.DT_FLOAT:
172        return 'f32'
173      elif attr_val.type == types_pb2.DT_INT32:
174        return 'i32'
175      elif attr_val.type == types_pb2.DT_INT64:
176        return 'i64'
177      elif attr_val.type == types_pb2.DT_BOOL:
178        return 'i1'
179    # list
180    if attr_val.HasField('list'):
181      if attr_val.list.f:
182        elt_ty = TFRTypes.F32
183        values = attr_val.list.f
184      elif attr_val.list.i:
185        elt_ty = TFRTypes.I64
186        values = attr_val.list.i
187      else:
188        elt_ty = TFRTypes.NONE
189        values = []
190      array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values]
191      return '[{}]'.format(','.join(array_attr_elts))
192  raise NotImplementedError(
193      'Proto AttrValue not recognized. type: {}, value: {}'.format(
194          attr_type, attr_val))
195
196
197def _collect_derived_attrs_from_proto(op_def):
198  derived_attrs = set()
199  for arg in op_def.input_arg:
200    if arg.type_attr:
201      derived_attrs.add(arg.type_attr)
202    if arg.number_attr:
203      derived_attrs.add(arg.number_attr)
204    if arg.type_list_attr:
205      derived_attrs.add(arg.type_list_attr)
206
207    # TODO(fengliuai): currently we don't support backward type inference, so we
208    # have to store these non-derivable type in the signatures, and then they
209    # can be used to cast the values when raising to tf ops.
210    if arg.type == types_pb2.DT_FLOAT:
211      derived_attrs.add('f32_')
212    elif arg.type == types_pb2.DT_INT32:
213      derived_attrs.add('i32_')
214    elif arg.type == types_pb2.DT_INT64:
215      derived_attrs.add('i64_')
216    elif arg.type == types_pb2.DT_BOOL:
217      derived_attrs.add('i1_')
218  return derived_attrs
219
220
221def _require_tensor_list(arg_def):
222  return arg_def.type_list_attr or arg_def.number_attr
223
224
225def _camel_to_snake(name):
226  s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
227  return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
228
229
230class OpDefCache(object):
231  """A Dict to cache the OpDef for the Python function name."""
232
233  def __init__(self):
234    self._op_defs = {}
235
236  def lookup(self, f_name, func_def=None, optional=False):
237    if f_name in self._op_defs:
238      return self._op_defs[f_name]
239
240    if isinstance(func_def, types.FunctionType):
241      if not hasattr(func_def, '_tfr_op_name'):
242        # skip a non-composition function
243        if optional:
244          return (None, None)
245        else:
246          raise KeyError('OpDef does not exist: ' + f_name)
247      op_name = getattr(func_def, '_tfr_op_name')
248    elif not func_def:
249      op_name = f_name
250    else:
251      # TODO(fengliuai): create one utility method to match different APIs.
252      compose_dec = []
253      for dec in func_def.decorator_list:
254        if isinstance(dec, ast.Call):
255          if isinstance(dec.func,
256                        ast.Attribute) and dec.func.attr == 'Composite':
257            compose_dec.append(dec)
258          if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite':
259            compose_dec.append(dec)
260
261      if not compose_dec:
262        # skip a non-composition function
263        if optional:
264          return (None, None)
265        else:
266          raise KeyError('OpDef does not exist: ' + f_name)
267      elif len(compose_dec) > 1:
268        raise KeyError('More than one TF ops decomposes for.')
269      else:
270        op_name = compose_dec[0].args[0].value
271
272    op_def = op_def_registry.get(op_name)
273    if not op_def:
274      raise ValueError('Not a registered op: ' + op_name)
275    derived_attrs = _collect_derived_attrs_from_proto(op_def)
276    self._op_defs[f_name] = (op_def, derived_attrs)
277    return (op_def, derived_attrs)
278
279  def mlir_external_funcs(self):
280    tfr_funcs = set()
281    for _, (op_def, derived_attrs) in sorted(self._op_defs.items()):
282      tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name))
283
284      # tensor inputs
285      inputs = [
286          _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg
287      ]
288
289      # attribute inputs. The attribute with default values are moved backwards.
290      non_derived_attrs = [
291          attr for attr in op_def.attr if attr.name not in derived_attrs
292      ]
293      attrs_no_default = [
294          attr for attr in non_derived_attrs
295          if not attr.HasField('default_value')
296      ]
297      attrs_with_default = [
298          attr for attr in non_derived_attrs if attr.HasField('default_value')
299      ]
300      attr_names = {'f32_', 'i32_', 'i64_', 'i1_'}  # reserved
301      for attr_def in attrs_no_default + attrs_with_default:
302        inputs.append(_get_type_info_from_proto(None, attr_def))
303        attr_names.add(attr_def.name)
304
305      # tensor outputs
306      outputs = [
307          _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg
308      ]
309
310      inputs = ','.join(inputs)
311      outputs = ','.join(outputs)
312      attrs = ','.join(sorted(derived_attrs.union(attr_names)))
313      tfr_funcs.add('{}{}) -> ({}) attributes {{{}}}'.format(
314          tfr_func, inputs, outputs, attrs))
315    return sorted(list(tfr_funcs))
316
317
318_PY_TYPE_TO_TFR = {
319    bool: TFRTypes.I1,
320    int: TFRTypes.I64,
321    float: TFRTypes.F32,
322}
323
324_TF_DTYPE_TO_TFR = {
325    'bool': TFRTypes.I1,
326    'int64': TFRTypes.I64,
327    'int32': TFRTypes.I32,
328    'int16': TFRTypes.I16,
329    'int8': TFRTypes.I8,
330    'float32': TFRTypes.F32,
331}
332
333_AG_FIXED_RETURN_TYPE = {
334    'for_stmt': type(None),
335    'if_stmt': type(None),
336    'Undefined': TFRTypes.AG_UNDEFINED_VAL,
337}
338
339QN = qual_names.QN
340
341# TODO(mdan): Fix this with an importable module.
342AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__']  # pylint:disable=protected-access
343
344# When an item is callable, the signature is (*operand_types) -> result_type(s)
345TFR_BUILTINS = {
346    '_tfr_quant_act_range': (TFRTypes.TENSOR, TFRTypes.TENSOR),
347    '_tfr_quant_rescale': TFRTypes.TENSOR,
348    '_tfr_quant_raw_data': lambda input_type: input_type,
349    '_tfr_quant_qparam': (TFRTypes.TENSOR, TFRTypes.TENSOR),
350    '_tfr_quant_scale_factor': TFRTypes.TENSOR,
351}
352
353
354class TFRTypeResolver(type_inference.Resolver):
355  """Resolve types for the external names, calls and arguments."""
356
357  def __init__(self, op_defs):
358    super(TFRTypeResolver, self).__init__()
359    self._op_defs = op_defs
360
361    # This pattern matching mechanism works with the functional form generated
362    # by autograph:
363    #
364    #   for i in data:
365    #     print(i)
366    #
367    # generates:
368    #
369    #   def loop_body(itr):
370    #     i = itr
371    #     print(i)
372    #   ag__.for_stmt(target)
373    #
374    # The mechanism lets us infer the type of the itr argument based on that of
375    # target.
376    self._for_loop_target_types = {}  # Maps body function name to iterated.
377    self._for_loop_body_fns = {}  # Used only to avoid collisions.
378
379  def res_name(self, ns, types_ns, name):
380    name_str = str(name)
381    if name_str in TFR_BUILTINS:
382      return {TFRTypes.TFR_BUILTIN_FUNC}, name_str
383    if name_str in ns:
384      ns_val = ns[name_str]
385      return {type(ns_val)}, ns_val
386    if name_str in __builtins__:
387      return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str]
388    # This name is not in the namespace because the autograph transformation
389    # is not backloaded into Python.
390    if name_str == 'ag__':
391      return {type(AG_MODULE)}, AG_MODULE
392
393    return None, None
394
395  def res_value(self, ns, value):
396    # resolves the type of the symbol by the metadata in 'value'
397    if value is None:
398      return {TFRTypes.NONE}
399    if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC):
400      # See TFRTypes.__getattrbute__.
401      # TODO(mdan): Replacing the enum with classes would avoid this overlap.
402      return {value}
403    # TODO(mdan): Index more efficiently. Could do a name check instead.
404    if any(v is value for v in AG_MODULE.__dict__.values()):
405      return {TFRTypes.AG_BUILTIN_FUNC}
406    if getattr(value, '__name__', None) == 'tensorflow.raw_ops':
407      return {types.ModuleType}
408    if hasattr(value, '__module__'):
409      if isinstance(value, dtypes.DType):
410        return {TFRTypes.ATTR}
411
412      # All the imported operations, which are not autograph built-ins, are
413      # considered to be TF raw ops.
414      # TODO(fengliuai): refine the condition that we only match TensorFlow
415      # ops here.
416      return {TFRTypes.TF_RAW_OP}
417    # TODO(mdan): Is ATTR equivalent to string?
418    return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)}
419
420  def res_call(self, ns, types_ns, node, f_type, args, keywords):
421    # resolves the return type of the function call.
422    name = anno.Basic.QN.of(node.func)
423    if f_type == (TFRTypes.AG_BUILTIN_FUNC,):
424
425      if name == QN(QN('ag__'), attr='if_stmt'):
426        nouts = node.args[6].value
427        # TODO(mdan): Look at the actual types out of if_body.
428        side_effects = {
429            qual_names.QN(n.value): {TFRTypes.TENSOR}
430            for n in node.args[5].elts[:nouts]
431        }
432        return {type(None)}, side_effects
433
434      if name == QN(QN('ag__'), attr='for_stmt'):
435        assert isinstance(node.args[2], ast.Name)
436        body_fn_name = str(anno.Basic.QN.of(node.args[2]))
437        assert body_fn_name not in self._for_loop_body_fns, (
438            'Previously used here: {}. Are you reusing the Resolver across '
439            'transformations?').format(self._for_loop_body_fns[body_fn_name])
440        self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node)
441
442        iterated_type = args[0]
443        assert iterated_type & {
444            TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR
445        }, (
446            iterated_type)
447        self._for_loop_target_types[body_fn_name] = iterated_type
448
449        return {type(None)}, None
450
451      # TODO(mdan): Actually resolve the type here instead.
452      ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None)
453      if ret_type is not None:
454        return {ret_type}, None
455      raise NotImplementedError('return type of {}'.format(name))
456
457    elif f_type == (TFRTypes.TF_RAW_OP,):
458      # This is a TF operation, so it should be found in the op_defs.
459      op_name = name.qn[1]
460      op_def, _ = self._op_defs.lookup(op_name)
461      if len(op_def.output_arg) == 1:
462        return {_get_type_from_proto(op_def.output_arg[0])}, None
463      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
464              None)
465
466    elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
467      assert name.is_simple()
468      if name == QN('range'):
469        return {TFRTypes.ATTR}, None
470
471      if name == QN('len'):
472        return {TFRTypes.INDEX}, None
473
474    elif f_type == (TFRTypes.TFR_BUILTIN_FUNC,):
475      op_name = name.qn[0]
476      if callable(TFR_BUILTINS[op_name]):
477        return {TFR_BUILTINS[op_name](*[list(arg)[0] for arg in args])}, None
478      return {TFR_BUILTINS[op_name]}, None
479
480    elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,):
481      return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None
482
483    elif f_type == (types.FunctionType,):
484      # This is a function call which isn't using tf.raw_op..
485      op_name = name.qn[0]
486
487      # 'new TF operation' produces outputs defined by the composition function.
488      op_def, _ = self._op_defs.lookup(op_name)
489      if len(op_def.output_arg) == 1:
490        return {_get_type_from_proto(op_def.output_arg[0])}, None
491      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
492              None)
493
494    raise NotImplementedError('Function:', name, f_type)
495
496  def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
497    if f_is_local:
498      f_name_str = str(f_name)
499      if f_name_str in self._for_loop_target_types:
500        # See autograph/converters/control_flow.py - the function has a single
501        # argument, the iterate before any expansion.
502        assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR}
503        # Assume all loops are TF loops. Then the iterates are autoboxed into
504        # Tensors.
505        return {TFRTypes.INDEX}
506      else:
507        return None
508
509    func = ns[f_name]
510
511    op_def, derived_attrs = self._op_defs.lookup(f_name, func)
512    if op_def is None:
513      return None
514    pos = tf_inspect.getfullargspec(func).args.index(str(name))
515
516    if pos < len(op_def.input_arg):
517      arg_def = op_def.input_arg[pos]
518      return {_get_type_from_proto(arg_def)}
519    elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs):
520      non_derived_attr_pos = pos - len(op_def.input_arg)
521      for attr_def in op_def.attr:
522        # derived attribute, skip this one and continue to the next one.
523        if attr_def.name in derived_attrs:
524          continue
525        if non_derived_attr_pos == 0:
526          return {_get_type_from_proto(None, attr_def)}
527        non_derived_attr_pos -= 1
528
529    raise ValueError('Argument is not defined in OpDef: ' + str(name))
530
531  def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
532    if not value:
533      return value
534
535    if isinstance(value, set):
536      type_tuple = value.pop()
537      if isinstance(type_tuple, tuple):
538        value = {type_tuple[node_or_slice]}
539      else:
540        value = {type_tuple}
541
542    assert len(value) == 1
543    value, = tuple(value)
544    if value == TFRTypes.TF_TENSOR_SHAPE_LIST:
545      # TODO(mdan): This is not entirely correct for multi-element slices.
546      return {int}
547    elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR):
548      # TODO(mdan): This is not entirely correct for multi-element slices.
549      return {TFRTypes.TENSOR}
550    else:
551      return {value}
552
553  def res_compare(self, ns, types_ns, node, left, right):
554    # TODO(fengliuai): make sure left and right are compatible
555    return {TFRTypes.I1}
556
557  def res_unop(self, ns, types_ns, node, opnd):
558    return opnd
559
560  def res_binop(self, ns, types_ns, node, left, right):
561    # TODO(fengliuai): make sure left and right are compatible
562    return left
563
564  def _coerce_to_more_specific_type(self, elt_types):
565    # TODO(mdan): This needs some type theory study.
566    if TFRTypes.INDEX in elt_types:
567      # Constants collapse to indices.
568      elt_types.discard(TFRTypes.I64)
569    if TFRTypes.TENSOR in elt_types:
570      # Constants collapse to tensors.
571      elt_types.discard(TFRTypes.I64)
572      # Indices collapse to tensors.
573      elt_types.discard(TFRTypes.INDEX)
574    return elt_types
575
576  def res_list_literal(self, ns, elt_types):
577    all_elt_types = set()
578    for t in elt_types:
579      all_elt_types |= t
580
581    if len(all_elt_types) != 1:
582      all_elt_types = self._coerce_to_more_specific_type(all_elt_types)
583
584    if len(all_elt_types) != 1:
585      raise ValueError('ambiguous list element types: {}'.format(elt_types))
586
587    if TFRTypes.TENSOR in all_elt_types:
588      return {TFRTypes.TENSOR_LIST}
589    return {TFRTypes.ATTR}
590
591
592class SymbolTable(object):
593  """Symbol Table for python code."""
594
595  def __init__(self):
596    self.symbols = []
597    self.enter_scope()
598    self.scf_scope = 0
599    # reserved key words
600    self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC)
601
602  def enter_scope(self, scf_scope=False):
603    """Enter a new scope - at function level."""
604    self.symbols.append({'types': {}, 'symbols': {}})
605    self.curr_table = self.symbols[len(self.symbols) - 1]
606    if scf_scope:
607      self.scf_scope += 1
608
609  def insert_symbol(self, name, value, type_):
610    self.curr_table['symbols'][name] = (value, type_)
611    # TODO(mdan): Use the inferred type rather than tracking it here.
612    # The following field is deprecated.
613    self.curr_table['types'][name] = type_
614    return value
615
616  def exit_scope(self):
617    self.symbols.pop()
618    self.curr_table = self.symbols[len(self.symbols) - 1]
619    if self.scf_scope > 0:
620      self.scf_scope -= 1
621
622  def in_scf_scope(self):
623    return self.scf_scope > 0
624
625  def lookup(self, name):
626    curr_idx = len(self.symbols) - 1
627    while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']):
628      curr_idx -= 1
629    if curr_idx < 0:
630      return None
631    return self.symbols[curr_idx]['symbols'][name]
632
633
634class TFRGen(transformer.CodeGenerator):
635  """Visit the AST and generate MLIR TFR functions."""
636
637  def __init__(self, ctx, op_defs):
638    super(TFRGen, self).__init__(ctx)
639    self.ctx = ctx
640    self.symbol_table = SymbolTable()
641    self._op_defs = op_defs
642
643  def _create_mlir_loc(self, loc):
644    """Creates mlir location from autograph ORIGIN value.
645
646    Args:
647      loc: OriginInfo
648
649    Returns:
650      A serialized mlir location string.
651    """
652    if loc is not None and loc.loc.filename:
653      file_name = os.path.basename(loc.loc.filename)
654      return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno,
655                                      loc.loc.col_offset)
656    else:
657      return 'loc(unknown)'
658
659  def _emit_with_loc(self, op_str, node=None):
660    """Emit the mlir operation with the location associated with the node.
661
662    Args:
663      op_str: The mlir operation string to be emitted.
664      node: The node of the AST tree, the mlir operation translated from.
665    """
666    loc = ''
667    if node:
668      loc = self._create_mlir_loc(
669          anno.getanno(node, anno.Basic.ORIGIN, default=None))
670    self.emit(op_str + ' ' + loc)
671
672  def _get_inferred_type(self, node, default=None):
673    """Return single type or a tuple of types if more than one type."""
674    types_ = anno.getanno(node, anno.Static.TYPES, None)
675    if not types_:
676      print('WARN: no Static.TYPES annotation. Fix the type inference pass: ')
677      self.debug_print(node)
678      return default
679
680    if len(types_) == 1:
681      type_, = types_
682    else:
683      type_ = types_
684
685    if default is not None and type_ != default:
686      print('WARN: type annotation {}({}) does not match {}({})'.format(
687          type_, type(type_), default, type(default)))
688      self.debug_print(node)
689
690    return type_
691
692  def _pack_tensor_list(self, value):
693    # This is packing a list of tensors, then the axis is 0.
694    axis = self._ssa_name('zero')
695    self._emit_with_loc('\n{} = constant 0 : i64'.format(axis))
696    casted = self._ssa_name('pack')
697    self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis))
698    self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor')
699    # load the op def of tf.Pack
700    self._op_defs.lookup('Pack')
701    return casted, TFRTypes.TENSOR
702
703  def _index_to_I64(self, value, ty):
704    if ty == TFRTypes.INDEX:
705      casted = self._ssa_name('casted')
706      self._emit_with_loc('\n{} = index_cast {} : index to i64'.format(
707          casted, value))
708      return casted, TFRTypes.I64
709    else:
710      return value, ty
711
712  def _i64_to_index(self, value, ty):
713    if ty == TFRTypes.I64:
714      casted = self._ssa_name('casted')
715      self._emit_with_loc('\n{} = index_cast {} : i64 to index'.format(
716          casted, value))
717      return casted, TFRTypes.INDEX
718    else:
719      return value, ty
720
721  def _value_to_tensor(self, value, ty, node):
722    value, ty = self._index_to_I64(value, ty)
723    cst_tensor = self._ssa_name('cst')
724    self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value))
725    self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node)
726    return cst_tensor, TFRTypes.TENSOR
727
728  def _ssa_name(self, prefix):
729    if isinstance(prefix, qual_names.QN):
730      assert prefix.is_simple(), 'ANF transform should have cleaned this up'
731      prefix = prefix.ssf()
732    return '%' + self.ctx.namer.new_symbol(prefix, set())
733
734  def _op_def(self, op_name):
735    return op_def_registry.get(op_name)
736
737  def visit_block(self, block):
738    return [self.visit(item) for item in block]
739
740  def visit_Pass(self, node):
741    if self.symbol_table.in_scf_scope():
742      self._emit_with_loc('\nscf.yield', node)
743    else:
744      self._emit_with_loc('\ntfr.return', node)
745
746  def visit_Attribute(self, node):
747    node_type = self._get_inferred_type(node, None)
748    if isinstance(node.value, ast.Name):
749      if node.value.id == 'ag__':
750        # some variables are assigned with 'ag__.xxx' method, we should handle
751        # them following the autograph convensions.
752        return (node.attr, TFRTypes.AG_BUILTIN_FUNC)
753
754      if node_type == TFRTypes.TF_RAW_OP:
755        # This branch is used when it is inside tensorflow
756        return (node.attr, TFRTypes.TF_RAW_OP)
757
758      if node_type == TFRTypes.ATTR:
759        attr = self._ssa_name('attr')
760        tfr_type = _TF_DTYPE_TO_TFR.get(node.attr)
761        self._emit_with_loc(
762            '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node)
763        return (attr, TFRTypes.ATTR)
764
765      value, _ = self.visit(node.value)
766      tensor_type = self._get_inferred_type(node.value, None)
767      # TODO(fengliuai): use node_type once it
768      if node_type == TFRTypes.SHAPE:
769        print('TODO: use "node_type"')
770      if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR:
771        ssa_value = self._ssa_name('shape')
772        self._emit_with_loc(
773            '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value),
774            node)
775        return (ssa_value, TFRTypes.SHAPE)
776
777    if isinstance(node.value, ast.Attribute):
778      if isinstance(node.value.value, ast.Name):
779        if node.value.value.id == 'tf' and node.value.attr == 'raw_ops':
780          return (node.attr, TFRTypes.TF_RAW_OP)
781
782      value, ty = self.visit(node.value)
783      # TODO(fengliuai): use node_type once it
784      if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
785        print('TODO: use "node_type"')
786      if ty == TFRTypes.SHAPE and node.attr == 'as_list':
787        return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC)
788
789    raise NotImplementedError('Attribute kind not recognized.')
790
791  def visit_Assign(self, node):
792    values = self.visit(node.value)
793    if isinstance(node.targets[0], ast.Tuple):
794      targets = [elt.id for elt in node.targets[0].elts]
795    elif isinstance(node.targets[0], ast.Name):
796      targets = [node.targets[0].id]
797    else:
798      raise NotImplementedError('Assignment target type not recognized.')
799
800    if isinstance(values, list):
801      if isinstance(node.value, ast.Call):
802        expected = tuple(t for n, t in values)
803        if len(values) == 1:
804          expected = expected[0]
805      elif isinstance(node.value, ast.Tuple):
806        expected = tuple(t for n, t in values)
807      else:
808        raise ValueError('unknown assignment target node', node.value)
809      ty = self._get_inferred_type(node.value, expected)
810
811      if len(targets) == len(values):
812        # TODO(mdan): This should already be a tuple.
813        ty_ = (ty,) if len(values) == 1 else ty
814        for key, value, t in zip(targets, values, ty_):
815          ssa_value, _ = value
816          self.symbol_table.insert_symbol(key, ssa_value, t)
817      elif len(values) == 1:
818        name, tys = values[0]
819        if ty == TFRTypes.TENSOR_LIST:
820          # assign single tensor_list to multiple variables
821          for idx, key in enumerate(targets):
822            idx_name = self._ssa_name('idx')
823            self._emit_with_loc(
824                '\n{} = constant {} : index'.format(idx_name, idx), node)
825            elt_name = self._ssa_name('elt')
826            self.emit('\n{} = tfr.get_element {}[{}]'.format(
827                elt_name, name, idx_name))
828            self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor',
829                                node)
830            self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR)
831        else:
832          # assign single value to multiple targets. This single value is
833          # usually a function return. The return type should be in the tuple of
834          # the value.
835          for idx, key in enumerate(targets):
836            ssa_name = '{}#{}'.format(name, idx)
837            ssa_type = tys[idx]
838            self.symbol_table.insert_symbol(key, ssa_name, ssa_type)
839      elif len(targets) == 1:
840        ssa_names = [n for n, _ in values]
841        self.symbol_table.insert_symbol(targets[0], ssa_names, ty)
842      return
843
844    ty = self._get_inferred_type(node.value, values[1])
845    self.symbol_table.insert_symbol(targets[0], values[0], ty)
846
847  def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty):
848    assert lhs_ty, rhs_ty
849    if isinstance(op, ast.Sub):
850      code = 'sub'
851    elif isinstance(op, ast.Add):
852      code = 'add'
853    elif isinstance(op, ast.Mult):
854      code = 'mul'
855    elif isinstance(op, ast.Div):
856      code = 'div'
857    else:
858      raise NotImplementedError('BinOp operator not recognized' + op)
859
860    if lhs_ty == TFRTypes.I64 or lhs_ty == TFRTypes.I32:
861      suffix = 'i'
862    elif lhs_ty == TFRTypes.F32:
863      suffix = 'f'
864    else:
865      raise NotImplementedError('BinOp operand type not recognized' + op)
866
867    ret = self._ssa_name(code)
868    self._emit_with_loc(
869        '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty),
870        op)
871    return ret, lhs_ty
872
873  def visit_AugAssign(self, node):
874    lhs, lhs_ty = self.visit(node.target)
875    rhs, rhs_ty = self.visit(node.value)
876    ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty)
877    self.symbol_table.insert_symbol(node.target.id, ret, ret_ty)
878
879  def visit_BinOp(self, node):
880    lhs, lhs_ty = self.visit(node.left)
881    rhs, rhs_ty = self.visit(node.right)
882    return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty)
883
884  def visit_BoolOp(self, node):
885    values = [self.visit(value) for value in node.values]
886    # TODO(fengliuai): Handle more ast node types.
887    if isinstance(node.op, ast.Or):
888      raise NotImplementedError('Or operator not recognized')
889    elif isinstance(node.op, ast.And):
890      raise NotImplementedError('And operator not recognized')
891
892  def visit_Call(self, node):
893    func_name, func_type = self.visit(node.func)
894    func_type = self._get_inferred_type(node.func, func_type)
895    if func_type == TFRTypes.AG_BUILTIN_FUNC:
896      if func_name == 'if_stmt':
897        cond, _ = self.visit(node.args[0])
898        body, _ = self.visit(node.args[1])
899        orelse, _ = self.visit(node.args[2])
900        get_state, _ = self.visit(node.args[3])
901        nouts = int(node.args[6].value)
902        out_symbols = []
903        # The out symbols are just a Tuple of names
904        for out in node.args[5].elts[:nouts]:
905          val, ty = self.symbol_table.lookup(out.value)
906          out_symbols.append(out.value)
907        return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols,
908                                   node)
909      elif func_name == 'for_stmt':
910        range_ = self._visit_iter(node.args[0])
911        body, _ = self.visit(node.args[2])
912        get_state, _ = self.visit(node.args[3])
913        loop_carried = [out.value for out in node.args[5].elts]
914        # TODO(fengliuai): opt is not used here.
915        return self._visit_for_stmt(range_, body, get_state, loop_carried, node)
916      elif func_name == 'Undefined':
917        val = self._ssa_name(node.args[0].value)
918        return (val, TFRTypes.AG_UNDEFINED_VAL)
919      elif func_name == 'UndefinedReturnValue':
920        val = self._ssa_name('return_val')
921        return (val, TFRTypes.AG_UNDEFINED_VAL)
922
923    if func_type == TFRTypes.TF_RAW_OP:
924      return self._visit_tf_op(func_name, node.args, node.keywords, node)
925
926    if func_type == TFRTypes.TFR_BUILTIN_FUNC:
927      return self._visit_tfr_builtins(func_name, node.args, node)
928
929    if func_type == types.FunctionType:
930      return self._visit_tf_op(func_name, node.args, node.keywords, node)
931
932    if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
933      return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST)
934
935    if func_type == TFRTypes.PY_BUILTIN_FUNC:
936      if func_name == 'len':
937        arg, ty = self.visit(node.args[0])
938        ty = self._get_inferred_type(node.args[0], ty)
939        if ty == TFRTypes.TF_TENSOR_SHAPE_LIST:
940          len_value = self._ssa_name('len')
941          self._emit_with_loc(
942              '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
943                  len_value, arg), node)
944          size_value = self._ssa_name('len_size')
945          self._emit_with_loc(
946              '\n{} = shape.size_to_index {} : !shape.size'.format(
947                  size_value, len_value), node)
948        elif ty == TFRTypes.TENSOR_LIST:
949          size_value = self._ssa_name('len')
950          self._emit_with_loc(
951              '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node)
952        return (size_value, TFRTypes.INDEX)
953
954    raise NotImplementedError('call operator not recognized: {} {}'.format(
955        func_name, func_type))
956
957  def visit_Compare(self, node):
958    lhs, lhs_ty = self.visit(node.left)
959    for op, right in zip(node.ops, node.comparators):
960      rhs, rhs_ty = self.visit(right)
961      if isinstance(op, ast.Eq):
962        pred = 'eq'
963      elif isinstance(op, ast.Lt):
964        pred = 'ult'
965      elif isinstance(op, ast.LtE):
966        pred = 'ule'
967      elif isinstance(op, ast.Gt):
968        pred = 'ugt'
969      elif isinstance(op, ast.GtE):
970        pred = 'uge'
971      elif isinstance(op, ast.NotEq):
972        pred = 'ne'
973      else:
974        raise NotImplementedError('Compare operator not recognized')
975
976      ret = self._ssa_name(pred)
977      if lhs_ty == TFRTypes.ATTR:
978        self._emit_with_loc(
979            '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node)
980      else:
981        if lhs_ty == TFRTypes.I64:
982          code = 'cmpi'
983        elif lhs_ty == TFRTypes.F32:
984          code = 'cmpf'
985        elif lhs_ty == TFRTypes.INDEX:
986          code = 'cmpi'
987          # TODO(fengliuai): the reverse type inference should solve the issue.
988          rhs, _ = self._i64_to_index(rhs, rhs_ty)
989        else:
990          raise NotImplementedError('Compare operand type not recognized')
991        self._emit_with_loc(
992            '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs,
993                                                 lhs_ty), node)
994
995      return ret, TFRTypes.I1
996
997  def visit_Constant(self, node):
998    cst_name = self._ssa_name('cst')
999    if node.value is None:
1000      cst_ty = TFRTypes.NONE
1001    elif isinstance(node.value, bool):
1002      cst_ty = self._get_inferred_type(node)
1003      cst_val = str(node.value).lower()
1004      self._emit_with_loc('\n{} = constant {}'.format(cst_name, cst_val), node)
1005    else:
1006      cst_ty = self._get_inferred_type(node)
1007      cst_val = node.value
1008      if cst_ty == TFRTypes.ATTR:
1009        self._emit_with_loc(
1010            '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty),
1011            node)
1012      else:
1013        self._emit_with_loc(
1014            '\n{} = constant {} : {}'.format(cst_name, cst_val, cst_ty), node)
1015    return cst_name, cst_ty
1016
1017  def visit_FunctionDef(self, node):
1018    op_def, derived_attrs = self._op_defs.lookup(node.name, node, True)
1019    if op_def is None:
1020      # Nested function. Insert it to symbol table for looking up later.
1021      self.symbol_table.insert_symbol(node.name, node, None)
1022      return
1023    op_name = op_def.name
1024    if self.symbol_table.lookup(op_name):
1025      raise LookupError('Composition has not been registered for op: ' +
1026                        op_name)
1027    else:
1028      self.symbol_table.insert_symbol(node.name, None, None)
1029
1030    self.symbol_table.enter_scope()
1031    self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name)))
1032
1033    arg_list = []
1034    idx = 0
1035    max_idx = len(op_def.input_arg) + len(op_def.attr)
1036    for arg in node.args.args:
1037      arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN))
1038      arg_type = anno.getanno(arg, anno.Static.TYPES)[0]
1039
1040      arg_attr = ''
1041      if idx >= len(op_def.input_arg):
1042        attr_def = op_def.attr[idx - len(op_def.input_arg)]
1043        # skip the derived attributes
1044        while attr_def.name in derived_attrs and (idx + 1) < max_idx:
1045          idx += 1
1046          attr_def = op_def.attr[idx - len(op_def.input_arg)]
1047        if idx >= max_idx:
1048          raise ValueError('Argument is not defined in OpDef: ' + arg_name)
1049
1050        arg_attr += '{{tfr.name="{}"'.format(attr_def.name)
1051        if attr_def.HasField('default_value'):
1052          default_val = _get_val_from_proto(arg_type, attr_def.default_value)
1053          arg_attr += ',tfr.default={}'.format(default_val)
1054        arg_attr += '}'
1055
1056      idx += 1
1057      arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr)
1058      arg_list.append(arg_str)
1059      self.symbol_table.insert_symbol(arg.id, arg_name, arg_type)
1060
1061    ret_type_list = []
1062    for ret_def in op_def.output_arg:
1063      if ret_def.number_attr or ret_def.type_list_attr:
1064        ret_type_list.append(str(TFRTypes.TENSOR_LIST))
1065      else:
1066        ret_type_list.append(str(TFRTypes.TENSOR))
1067
1068    self.emit('{}) -> ({}) {{'.format(', '.join(arg_list),
1069                                      ', '.join(ret_type_list)))
1070    self.visit_block(node.body)
1071    self._emit_with_loc('\n}', node)
1072    self.symbol_table.exit_scope()
1073
1074  def visit_arguments(self, node):
1075    # TODO(fengliuai): return ordered the types and names.
1076    # We need to order the arguments to match the assumption in the TFR dialect.
1077    raise NotImplementedError('arguments not supported.')
1078
1079  def visit_Lambda(self, node):
1080    raise NotImplementedError('Lambda not supported.')
1081
1082  def _get_mlir_ssa_values(self, name_prefix, out_types):
1083    """Create MLIR convention SSA values."""
1084    out_ssa_values = []
1085    if not out_types:
1086      return '', out_ssa_values
1087
1088    out_name = self._ssa_name(name_prefix)
1089    if len(out_types) == 1:
1090      out_name_suffix = ''
1091      out_ssa_values.append(out_name)
1092    else:
1093      # For multiple returns, MLIR uses '%s:i' when they are defined and
1094      # '%s#i' when they are used.
1095      out_name_suffix = ':{}'.format(len(out_types))
1096      for idx, _ in enumerate(out_types):
1097        out_ssa_values.append('{}#{}'.format(out_name, idx))
1098
1099    return '{}{}'.format(out_name, out_name_suffix), out_ssa_values
1100
1101  def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols,
1102                     node):
1103    self.emit('\n')
1104    ret_str, ret_ssa_values = self._get_mlir_ssa_values(
1105        'if_stmt', [TFRTypes.TENSOR] * len(out_symbols))
1106    if ret_ssa_values:
1107      self.emit(ret_str + ' = ')
1108
1109    out_types = []
1110    for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
1111      out_types.append(str(TFRTypes.TENSOR))
1112
1113    self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types)))
1114    # Create a new scope in case the local variables are leaked.
1115    self.symbol_table.enter_scope(scf_scope=True)
1116    self.visit_block(body_def.body)
1117    self.visit_block(get_state.body)
1118    self.symbol_table.exit_scope()
1119
1120    self.emit('\n} else {')
1121
1122    # Create a new scope in case the local variables are leaked.
1123    self.symbol_table.enter_scope(scf_scope=True)
1124    self.visit_block(orelse_def.body)
1125    self.visit_block(get_state.body)
1126    self.symbol_table.exit_scope()
1127
1128    # add ssa values to the symbol table
1129    for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
1130      self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
1131
1132    self._emit_with_loc('\n}', node)
1133    return list(zip(ret_ssa_values, out_types))
1134
1135  def _visit_iter(self, node):
1136    if isinstance(node, ast.Call):
1137      f_name = anno.getanno(node.func, anno.Basic.QN)
1138      if f_name == QN('range'):
1139        args = [self.visit(arg) for arg in node.args]
1140        begin = None
1141        step = None
1142        end = None
1143        if len(args) == 1:
1144          end, end_ty = args[0]
1145        elif len(args) == 2:
1146          begin, begin_ty = args[0]
1147          end, end_ty = args[1]
1148        elif len(args) == 3:
1149          begin, begin_ty = args[0]
1150          end, end_ty = args[1]
1151          step, step_ty = args[2]
1152
1153        if begin is None:
1154          begin = self._ssa_name('begin')
1155          self._emit_with_loc('\n{} = constant 0 : index'.format(begin), node)
1156        elif begin_ty != TFRTypes.INDEX:
1157          begin_ = self._ssa_name('begin')
1158          self._emit_with_loc(
1159              '\n{} = index_cast {} : {} to index'.format(
1160                  begin_, begin, begin_ty), node)
1161          begin = begin_
1162
1163        if end_ty != TFRTypes.INDEX:
1164          end_ = self._ssa_name('end')
1165          self._emit_with_loc(
1166              '\n{} = index_cast {} : {} to index'.format(end_, end, end_ty),
1167              node)
1168          end = end_
1169
1170        if step is None:
1171          step = self._ssa_name('step')
1172          self._emit_with_loc('\n{} = constant 1 : index'.format(step), node)
1173        elif step_ty != TFRTypes.INDEX:
1174          step_ = self._ssa_name('step')
1175          self._emit_with_loc(
1176              '\n{} = index_cast {} : {} to index'.format(step_, step, step_ty),
1177              node)
1178          step = step_
1179
1180        return begin, end, step
1181
1182    raise NotImplementedError('Iterator entity not supported.' + node)
1183
1184  def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node):
1185    self.emit('\n')
1186    ret_str, ret_ssa_values = self._get_mlir_ssa_values(
1187        'for_stmt', [TFRTypes.TENSOR] * len(loop_carried))
1188    if ret_ssa_values:
1189      self.emit(ret_str + ' = ')
1190
1191    # Before enter the loop, we use the original ssa values as the initial
1192    # values to the loop iteration arguments. We also create new ssa values as
1193    # the returns of the scf for statements. The symbol table needs to be
1194    # updated to these new ssa values before it enters the scope of the loop.
1195    out_types = []
1196    init_values = []
1197    for symbol, ssa_value in zip(loop_carried, ret_ssa_values):
1198      init, ty = self.symbol_table.lookup(symbol)
1199      self.symbol_table.insert_symbol(symbol, ssa_value, ty)
1200      out_types.append(str(ty))
1201      init_values.append((init, ty))
1202
1203    # Create a new scope in case the local variables are leaked.
1204    self.symbol_table.enter_scope(scf_scope=True)
1205
1206    # Create the iteration variable with index type
1207    assert len(body_def.args.args) == 1
1208    it_name = body_def.args.args[0].id
1209    it = self._ssa_name(it_name)
1210    self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX)
1211
1212    self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1],
1213                                                      range_[2]))
1214    if loop_carried:
1215      iter_args = []
1216      for symbol, init in zip(loop_carried, init_values):
1217        # create new ssa values for the loop carried variables
1218        it_arg = self._ssa_name('it_arg')
1219        self.symbol_table.insert_symbol(symbol, it_arg, init[1])
1220        iter_args.append('{} = {}'.format(it_arg, init[0]))
1221      self.emit('iter_args({}) '.format(', '.join(iter_args)))
1222      self.emit('-> ({}) {{'.format(', '.join(out_types)))
1223    else:
1224      self.emit(' {')
1225    self.visit_block(body_def.body)
1226    self.visit_block(get_state.body)
1227    self.symbol_table.exit_scope()
1228    self._emit_with_loc('\n}', node)
1229    return list(zip(ret_ssa_values, out_types))
1230
1231  def _emit_default_constant_from_proto(self, attr_def):
1232    """emit mlir constant statement from default value of the ArgDef proto."""
1233    name = self._ssa_name('cst')
1234    cst_ty = _get_type_from_proto(None, attr_def)
1235    try:
1236      cst_val = _get_val_from_proto(cst_ty, attr_def.default_value)
1237    except AttributeError:
1238      raise AttributeError(
1239          f'attribute "{attr_def.name}" does not have default_value')
1240    if cst_ty == TFRTypes.ATTR:
1241      self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format(
1242          name, cst_val, cst_ty))
1243    elif cst_ty == TFRTypes.I1:
1244      self._emit_with_loc('\n{} = constant {}'.format(name, cst_val))
1245    else:
1246      self._emit_with_loc('\n{} = constant {} : {}'.format(
1247          name, cst_val, cst_ty))
1248    return name, cst_ty
1249
1250  def visit_keyword(self, node):
1251    return node.arg, self.visit(node.value)
1252
1253  def _visit_tfr_builtins(self, op_name, args, node):
1254    arg_strs = []
1255    arg_tys = []
1256    for arg in args:
1257      value, ty = self.visit(arg)
1258      arg_strs.append(value)
1259      arg_tys.append(ty)
1260    tfr_op_name = 'tfr.' + op_name[5:]
1261    ret_tys = (
1262        TFR_BUILTINS[op_name](*arg_tys)
1263        if callable(TFR_BUILTINS[op_name]) else TFR_BUILTINS[op_name])
1264    # Convert the tfr builtin returns to a list.
1265    if isinstance(ret_tys, tuple):
1266      ret_tys = list(ret_tys)
1267    else:
1268      ret_tys = [ret_tys]
1269
1270    ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys)
1271
1272    arg_str = ', '.join(arg_strs)
1273    arg_ty_str = ', '.join(str(ty) for ty in arg_tys)
1274    ret_ty_str = ', '.join(str(ty) for ty in ret_tys)
1275    self._emit_with_loc('\n{} = {}({}) : ({}) -> ({})'.format(
1276        ret_str, tfr_op_name, arg_str, arg_ty_str, ret_ty_str), node)
1277    return list(zip(ret_ssa_values, ret_tys))
1278
1279  def _visit_tf_op(self, op_name, args, keywords, node):
1280    op_def, derived_attrs = self._op_defs.lookup(op_name)
1281    ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg]
1282
1283    ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys)
1284
1285    arg_strs = []
1286    ty_strs = []
1287    for arg in args:
1288      value, ty = self.visit(arg)
1289      arg_strs.append(value)
1290      ty_strs.append(str(ty))
1291
1292    input_args = [arg for arg in op_def.input_arg]
1293    attrs_no_default = [
1294        attr for attr in op_def.attr
1295        if not attr.HasField('default_value') and attr.name not in derived_attrs
1296    ]
1297    attrs_with_default = [
1298        attr for attr in op_def.attr
1299        if attr.HasField('default_value') and attr.name not in derived_attrs
1300    ]
1301
1302    kw_args = {}
1303    for arg in keywords:
1304      value, (ssa_name, ty) = self.visit(arg)
1305      ty = self._get_inferred_type(arg.value, ty)
1306
1307      # TODO(fengliuai): implement the "rename_to" for the customization in
1308      # tensorflow/core/api_def/base_api/*
1309      if value == 'axis':
1310        value = 'split_dim'
1311
1312      kw_args[value] = (ssa_name, ty)
1313
1314    # tensor arguments and attribute arguments
1315    ordered_args = input_args + attrs_no_default + attrs_with_default
1316    for attr_def in ordered_args[len(args):]:
1317      if attr_def.name in kw_args:
1318        value, ty = kw_args[attr_def.name]
1319        if attr_def in input_args:
1320          if ty in _attribute_types:
1321            # the argument shouldn't be used as tf op calls directly.
1322            value, ty = self._value_to_tensor(value, ty, node)
1323          if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def):
1324            value, ty = self._pack_tensor_list(value)
1325      else:
1326        value, ty = self._emit_default_constant_from_proto(attr_def)
1327      arg_strs.append(value)
1328      ty_strs.append(str(ty))
1329
1330    if ret_ssa_values:
1331      self.emit('\n{} = '.format(ret_str))
1332
1333    self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name)))
1334    arg_str = ', '.join(arg_strs)
1335    arg_ty_str = ', '.join(ty_strs)
1336    ret_ty_str = ', '.join([str(ty) for ty in ret_tys])
1337    self._emit_with_loc(
1338        '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node)
1339    return list(zip(ret_ssa_values, ret_tys))
1340
1341  def visit_If(self, node):
1342    raise NotImplementedError('If not supported.')
1343
1344  def visit_Name(self, node):
1345    val_and_lookup_type = self.symbol_table.lookup(node.id)
1346    if val_and_lookup_type:
1347      (val, lookup_type) = val_and_lookup_type
1348    elif node.id in TFR_BUILTINS:
1349      val = node.id
1350      lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
1351    else:
1352      op_def, _ = self._op_defs.lookup(node.id)
1353      val = op_def.name
1354      lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
1355    type_ = self._get_inferred_type(node, lookup_type)
1356    return val, type_
1357
1358  def visit_Return(self, node):
1359    values = self.visit(node.value)
1360    if self.symbol_table.in_scf_scope():
1361      self.emit('\nscf.yield ')
1362    else:
1363      self.emit('\ntfr.return ')
1364    if not values:
1365      return
1366
1367    if isinstance(values, list):
1368      vals, tys = zip(*values)
1369    else:
1370      vals = values[0]
1371      tys = values[1]
1372
1373    if isinstance(tys, list) or isinstance(tys, tuple):
1374      tys = [str(t) for t in tys]
1375      self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)),
1376                          node)
1377    elif tys != TFRTypes.NONE:
1378      # TODO(fengliuai): scf region yield uses this branch. Fix it.
1379      self._emit_with_loc('{} : {}'.format(vals, tys), node)
1380
1381  def visit_Subscript(self, node):
1382    val, ty = self.visit(node.value)
1383    type_ = self._get_inferred_type(node.value, ty)
1384
1385    # TODO(fengliuai): Here we hardcode the node.slice here to get the index
1386    # type. Use the visit method once the type inference is done.
1387    # slice_val, slice_ty = self.visit(node.slice)
1388    s = node.slice
1389    if not isinstance(s, (ast.Tuple, ast.Slice)):
1390      if isinstance(s, ast.Constant):
1391        # TODO(fengliuai): promote to an assignment
1392        idx_val = self._ssa_name('cst')
1393        self._emit_with_loc(
1394            '\n{} = constant {} : index'.format(idx_val, s.value), node)
1395      else:
1396        idx_val, _ = self.visit(s)
1397    else:
1398      raise NotImplementedError('non-index slice not supported.')
1399
1400    elt = self._ssa_name('elt')
1401    if type_ == TFRTypes.TENSOR_LIST:
1402      self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val))
1403      self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node)
1404      return (elt, TFRTypes.TENSOR)
1405    elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST:
1406      size_ = self._ssa_name('size')
1407      self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val))
1408      self._emit_with_loc(': !shape.shape, index -> !shape.size', node)
1409      self._emit_with_loc(
1410          '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_),
1411          node)
1412      return (elt, TFRTypes.INDEX)
1413
1414  def visit_List(self, node):
1415    out_type = self._get_inferred_type(node)
1416    vals = []
1417    tys = []
1418    for elt in node.elts:
1419      val, ty = self.visit(elt)
1420      ty = self._get_inferred_type(elt, ty)
1421      if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST:
1422        # This list is a tensor list, then cast all the input values to tensors.
1423        val, ty = self._value_to_tensor(val, ty, node)
1424      else:
1425        # We shouldn't use index type to build the list because list will be use
1426        # as attribute.
1427        val, ty = self._index_to_I64(val, ty)
1428      vals.append(val)
1429      tys.append(str(ty))
1430
1431    list_val = self._ssa_name('list')
1432    self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals)))
1433    self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node)
1434    return (list_val, out_type)
1435
1436  def visit_Tuple(self, node):
1437    return [self.visit(elt) for elt in node.elts]
1438
1439  def visit_UnaryOp(self, node):
1440    value, ty = self.visit(node.operand)
1441    if isinstance(node.op, ast.USub):
1442      zero_value = self._ssa_name('zero')
1443      ssa_value = self._ssa_name('cst')
1444      if ty == TFRTypes.I32 or ty == TFRTypes.I64:
1445        self._emit_with_loc(
1446            '\n{} = constant 0 : {}'.format(zero_value, ty), node)
1447        self._emit_with_loc(
1448            '\n{} = subi {}, {} : {}'.format(ssa_value, zero_value, value, ty),
1449            node)
1450      elif ty == TFRTypes.F32:
1451        self._emit_with_loc(
1452            '\n{} = constant 0.0 : {}'.format(zero_value, ty), node)
1453        self._emit_with_loc(
1454            '\n{} = subf {}, {} : {}'.format(ssa_value, zero_value, value, ty),
1455            node)
1456      else:
1457        raise NotImplementedError('USub type not recognized: ' + str(ty))
1458      return ssa_value, ty
1459    raise NotImplementedError('USub operator not recognized')
1460
1461  def visit_For(self, node):
1462    raise NotImplementedError('For operator not recognized')
1463
1464  def visit_While(self, node):
1465    raise NotImplementedError('While operator not recognized')
1466
1467  def visit_Try(self, node):
1468    # Only handles the body of the try statement.
1469    self.visit_block(node.body)
1470
1471
1472def _apply_py_to_tf_passes(node, ctx):
1473  """Apply transformations from PyToTF to match tf.function tracing."""
1474  # TODO(fengliuai): we don't know which passes are required, thus we evaluate
1475  # each one when the corresponding node is handled.
1476  # copied from PyToTF.transform_ast
1477  node = return_statements.transform(node, ctx, False)
1478  node = control_flow.transform(node, ctx)
1479  return node
1480
1481
1482class TfrGen(transpiler.GenericTranspiler):
1483  """Transforms Python objects into TFR MLIR source code."""
1484
1485  def __init__(self, op_defs):
1486    self._op_defs = op_defs
1487
1488  def transform_ast(self, node, ctx):
1489    node = _apply_py_to_tf_passes(node, ctx)
1490    # TODO(mdan): Enable this.
1491    # node = anf.transform(node, ctx)
1492
1493    graphs = cfg.build(node)
1494    node = qual_names.resolve(node)
1495    node = activity.resolve(node, ctx)
1496    node = reaching_definitions.resolve(node, ctx, graphs)
1497    node = reaching_fndefs.resolve(node, ctx, graphs)
1498    node = type_inference.resolve(node, ctx, graphs,
1499                                  TFRTypeResolver(self._op_defs))
1500
1501    mlir_generator = TFRGen(ctx, self._op_defs)
1502    mlir_generator.visit(node)
1503    return mlir_generator.code_buffer
1504
1505
1506def tfr_gen(func, op_defs):
1507  """Parse a function and emit the TFR functions."""
1508  mlir_code, _ = TfrGen(op_defs).transform(func, None)
1509  assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code)
1510  return mlir_code
1511
1512
1513def tfr_funcs_gen_from_module(source, op_defs, method_prefix=None,
1514                              op_libraries=None):
1515  """Parse the input source module and emit the TFR functions."""
1516
1517  # Load the op library so the op is added to the op registry. This is
1518  # required when the op cc_library couldn't be statically linked in open
1519  # source.
1520  # This is a no op if the op shared library couldn't be found in the same
1521  # directory of the op Python API.
1522  # TODO(fengliuai): make the .so file path configurable.
1523  if op_libraries:
1524    prefix_len = len('gen_')
1525    for m in op_libraries:
1526      lib_dir = os.path.dirname(m.__file__)
1527      lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so')
1528      lib_path = os.path.join(lib_dir, lib_name)
1529      if os.path.exists(lib_path):
1530        logging.info('load file: ' + lib_path)
1531        load_library.load_op_library(lib_path)
1532  else:
1533    # The op library is generated from the source module, then we load all the
1534    # .so file in the directory
1535    lib_dir = os.path.dirname(source.__file__)
1536    for lib_name in os.listdir(lib_dir):
1537      if lib_name.endswith('.so'):
1538        lib_path = os.path.join(lib_dir, lib_name)
1539        logging.info('load file: ' + lib_path)
1540        load_library.load_op_library(lib_path)
1541
1542  py_funcs = [
1543      func
1544      for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
1545      if not method_prefix or name.startswith(method_prefix)
1546  ]
1547  # Sort the methods by the line number, to make sure the definitions are
1548  # processed before the usages.
1549  # TODO(fengliuai): Use type inference resolver to recursively process any
1550  # functions called.
1551  py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno)
1552  mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs]
1553
1554  return mlir_funcs
1555
1556
1557def tfr_gen_from_module(source, method_prefix=None, op_libraries=None,
1558                        op_defs=OpDefCache()):
1559  """Parse the input source module and emit the TFR and external functions."""
1560  mlir_funcs = tfr_funcs_gen_from_module(
1561      source, op_defs, method_prefix, op_libraries)
1562
1563  return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs())
1564