• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tfr_gen: Generate mlir tfr decomposition function from python code."""
16
17# pylint: disable=invalid-name
18# pylint: disable=missing-function-docstring
19# pylint: disable=g-direct-tensorflow-import
20
21import enum
22import os
23import re
24import types
25import gast as ast
26
27from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr
28from tensorflow.core.framework import types_pb2
29from tensorflow.python.autograph.converters import control_flow
30from tensorflow.python.autograph.converters import return_statements
31from tensorflow.python.autograph.impl import api
32from tensorflow.python.autograph.pyct import anno
33from tensorflow.python.autograph.pyct import cfg
34from tensorflow.python.autograph.pyct import qual_names
35from tensorflow.python.autograph.pyct import transformer
36from tensorflow.python.autograph.pyct import transpiler
37from tensorflow.python.autograph.pyct.static_analysis import activity
38from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
39from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
40from tensorflow.python.autograph.pyct.static_analysis import type_inference
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import load_library
43from tensorflow.python.framework import op_def_registry
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import tf_inspect
46
47# TODO(mdan): Use class definitions so that we can mix these with Python types.
48
49
50class TFRTypes(enum.Enum):
51  """All the supported types.
52
53    1-3: tfr types
54    4-99: mlir built-in types
55    100-199: TF related translator internal types
56    200- : Python related translator internal types
57  """
58  TENSOR = 1
59  TENSOR_LIST = 2
60  ATTR = 3
61  NONE = 4
62  SHAPE = 5  # shape -> !shape.shape
63  I1 = 21
64  I8 = 22
65  I16 = 23
66  I32 = 24
67  I64 = 25
68  F32 = 26
69  INDEX = 27
70  AG_UNDEFINED_VAL = 100
71  AG_BUILTIN_FUNC = 101
72  TF_RAW_OP = 102
73  TF_REGION = 103
74  TF_TENSOR_SHAPE_FUNC = 104  # shape.as_list
75  TF_TENSOR_SHAPE_LIST = 105  # shape.as_list()
76  PY_BUILTIN_FUNC = 200
77  TFR_BUILTIN_FUNC = 201
78
79  # As these are not real types, __getattribute__ helps them appear more like
80  # actual types (i.e. class definitions).
81  def __getattribute__(self, name):
82    if name == 'shape' and object.__getattribute__(self, 'value') == 1:
83      return TFRTypes.SHAPE
84    if name == 'as_list' and object.__getattribute__(self, 'value') == 5:
85      return TFRTypes.TF_TENSOR_SHAPE_FUNC
86    return object.__getattribute__(self, name)
87
88  def __str__(self):
89    if self.value < 4:  # pylint: disable=comparison-with-callable
90      return '!tfr.' + self.name.lower()
91    elif self.value < 10:  # pylint: disable=comparison-with-callable
92      return '!shape.' + self.name.lower()
93    else:
94      return self.name.lower()
95
96
97_ATTRIBUTE_TYPES = (
98    TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX,
99    TFRTypes.ATTR
100    )
101
102# TODO(b/203493652): implement the "rename_to" for the customization in
103# tensorflow/core/api_def/base_api/*
104# {op_name: {API's attribute name: OpDef's attribute name}}
105_ATTRIBUTE_RENAMES = {
106    'Mean': {'axis': 'reduction_indices'},
107    'Split': {'axis': 'split_dim'},
108    'SplitV': {'axis': 'split_dim'},
109}
110
111
112def _get_type_from_proto(arg_def=None, attr_def=None):
113  if not arg_def:
114    if attr_def.type == 'bool':
115      return TFRTypes.I1
116    elif attr_def.type == 'int32':
117      return TFRTypes.I32
118    elif attr_def.type == 'int' or attr_def.type == 'int64':
119      return TFRTypes.I64
120    elif attr_def.type == 'float':
121      return TFRTypes.F32
122    else:
123      return TFRTypes.ATTR
124
125  if arg_def.number_attr or arg_def.type_list_attr:
126    return TFRTypes.TENSOR_LIST
127  else:
128    return TFRTypes.TENSOR
129
130
131def _get_type_info_from_proto(arg_def=None, attr_def=None):
132  attr_type = _get_type_from_proto(arg_def, attr_def)
133  if not arg_def:
134    return '{}{{tfr.name="{}",tfr.type="{}"}}'.format(
135        attr_type, attr_def.name, attr_def.type)
136  else:
137    attr_names = []
138    if arg_def.number_attr:
139      attr_names.append(arg_def.number_attr)
140    if arg_def.type_attr:
141      attr_names.append(arg_def.type_attr)
142    if arg_def.type_list_attr:
143      attr_names.append(arg_def.type_list_attr)
144
145    # TODO(fengliuai): currently we don't support backward type inference, so we
146    # have to store these non-derivable type in the signatures, and then they
147    # can be used to cast the values when raising to tf ops.
148    if arg_def.type == types_pb2.DT_FLOAT:
149      attr_names.append('f32_')
150    elif arg_def.type == types_pb2.DT_INT32:
151      attr_names.append('i32_')
152    elif arg_def.type == types_pb2.DT_INT64:
153      attr_names.append('i64_')
154    elif arg_def.type == types_pb2.DT_BOOL:
155      attr_names.append('i1_')
156
157    if not attr_names:
158      return str(attr_type)
159    else:
160      return '{}<{}>'.format(attr_type, ','.join(attr_names))
161
162
163def _get_val_from_proto(attr_type, attr_val):
164  if attr_type == TFRTypes.I1:
165    return 'true' if attr_val.b else 'false'
166  elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64:
167    return attr_val.i
168  elif attr_type == TFRTypes.F32:
169    return attr_val.f
170  elif attr_type == TFRTypes.ATTR:
171    # string
172    if attr_val.HasField('s'):
173      return '"{}"'.format(attr_val.s.decode())
174    # type
175    if attr_val.HasField('type'):
176      if attr_val.type == types_pb2.DT_FLOAT:
177        return 'f32'
178      elif attr_val.type == types_pb2.DT_INT32:
179        return 'i32'
180      elif attr_val.type == types_pb2.DT_INT64:
181        return 'i64'
182      elif attr_val.type == types_pb2.DT_BOOL:
183        return 'i1'
184    # list
185    if attr_val.HasField('list'):
186      if attr_val.list.f:
187        elt_ty = TFRTypes.F32
188        values = attr_val.list.f
189      elif attr_val.list.i:
190        elt_ty = TFRTypes.I64
191        values = attr_val.list.i
192      else:
193        elt_ty = TFRTypes.NONE
194        values = []
195      array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values]
196      return '[{}]'.format(','.join(array_attr_elts))
197  raise NotImplementedError(
198      'Proto AttrValue not recognized. type: {}, value: {}'.format(
199          attr_type, attr_val))
200
201
202def _collect_derived_attrs_from_proto(op_def):
203  derived_attrs = set()
204  for arg in op_def.input_arg:
205    if arg.type_attr:
206      derived_attrs.add(arg.type_attr)
207    if arg.number_attr:
208      derived_attrs.add(arg.number_attr)
209    if arg.type_list_attr:
210      derived_attrs.add(arg.type_list_attr)
211
212    # TODO(fengliuai): currently we don't support backward type inference, so we
213    # have to store these non-derivable type in the signatures, and then they
214    # can be used to cast the values when raising to tf ops.
215    if arg.type == types_pb2.DT_FLOAT:
216      derived_attrs.add('f32_')
217    elif arg.type == types_pb2.DT_INT32:
218      derived_attrs.add('i32_')
219    elif arg.type == types_pb2.DT_INT64:
220      derived_attrs.add('i64_')
221    elif arg.type == types_pb2.DT_BOOL:
222      derived_attrs.add('i1_')
223  return derived_attrs
224
225
226def _require_tensor_list(arg_def):
227  return arg_def.type_list_attr or arg_def.number_attr
228
229
230def _camel_to_snake(name):
231  s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
232  return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
233
234
235class OpDefCache(object):
236  """A Dict to cache the OpDef for the Python function name."""
237
238  def __init__(self):
239    self._op_defs = {}
240
241  def lookup(self, f_name, func_def=None, optional=False):
242    if f_name in self._op_defs:
243      return self._op_defs[f_name]
244
245    if isinstance(func_def, types.FunctionType):
246      if not hasattr(func_def, '_tfr_op_name'):
247        # skip a non-composition function
248        if optional:
249          return (None, None)
250        else:
251          raise KeyError('OpDef does not exist: ' + f_name)
252      op_name = getattr(func_def, '_tfr_op_name')
253    elif not func_def:
254      op_name = f_name
255    else:
256      # TODO(fengliuai): create one utility method to match different APIs.
257      compose_dec = []
258      for dec in func_def.decorator_list:
259        if isinstance(dec, ast.Call):
260          if isinstance(dec.func,
261                        ast.Attribute) and dec.func.attr == 'Composite':
262            compose_dec.append(dec)
263          if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite':
264            compose_dec.append(dec)
265
266      if not compose_dec:
267        # skip a non-composition function
268        if optional:
269          return (None, None)
270        else:
271          raise KeyError('OpDef does not exist: ' + f_name)
272      elif len(compose_dec) > 1:
273        raise KeyError('More than one TF ops decomposes for.')
274      else:
275        op_name = compose_dec[0].args[0].value
276
277    op_def = op_def_registry.get(op_name)
278    if not op_def:
279      raise ValueError('Not a registered op: ' + op_name)
280    derived_attrs = _collect_derived_attrs_from_proto(op_def)
281    self._op_defs[f_name] = (op_def, derived_attrs)
282    return (op_def, derived_attrs)
283
284  def mlir_external_funcs(self):
285    tfr_funcs = set()
286    for _, (op_def, derived_attrs) in sorted(self._op_defs.items()):
287      tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name))
288
289      # tensor inputs
290      inputs = [
291          _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg
292      ]
293
294      # attribute inputs. The attribute with default values are moved backwards.
295      non_derived_attrs = [
296          attr for attr in op_def.attr if attr.name not in derived_attrs
297      ]
298      attrs_no_default = [
299          attr for attr in non_derived_attrs
300          if not attr.HasField('default_value')
301      ]
302      attrs_with_default = [
303          attr for attr in non_derived_attrs if attr.HasField('default_value')
304      ]
305      attr_names = {'f32_', 'i32_', 'i64_', 'i1_'}  # reserved
306      for attr_def in attrs_no_default + attrs_with_default:
307        inputs.append(_get_type_info_from_proto(None, attr_def))
308        attr_names.add(attr_def.name)
309
310      # tensor outputs
311      outputs = [
312          _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg
313      ]
314
315      inputs = ','.join(inputs)
316      outputs = ','.join(outputs)
317      attrs = ','.join(sorted(derived_attrs.union(attr_names)))
318      tfr_funcs.add('{}{}) -> ({}) attributes {{{}}}'.format(
319          tfr_func, inputs, outputs, attrs))
320    return sorted(list(tfr_funcs))
321
322
323_PY_TYPE_TO_TFR = {
324    bool: TFRTypes.I1,
325    int: TFRTypes.I64,
326    float: TFRTypes.F32,
327}
328
329_TF_DTYPE_TO_TFR = {
330    'bool': TFRTypes.I1,
331    'int64': TFRTypes.I64,
332    'int32': TFRTypes.I32,
333    'int16': TFRTypes.I16,
334    'int8': TFRTypes.I8,
335    'float32': TFRTypes.F32,
336}
337
338_AG_FIXED_RETURN_TYPE = {
339    'for_stmt': type(None),
340    'if_stmt': type(None),
341    'Undefined': TFRTypes.AG_UNDEFINED_VAL,
342}
343
344QN = qual_names.QN
345
346# TODO(mdan): Fix this with an importable module.
347AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__']  # pylint:disable=protected-access
348
349# When an item is callable, the signature is (*operand_types) -> result_type(s)
350TFR_BUILTINS = {
351    '_tfr_quant_act_range': (TFRTypes.TENSOR, TFRTypes.TENSOR),
352    '_tfr_quant_rescale': TFRTypes.TENSOR,
353    '_tfr_quant_raw_data': lambda input_type: input_type,
354    '_tfr_quant_qparam': (TFRTypes.TENSOR, TFRTypes.TENSOR),
355    '_tfr_quant_scale_factor': TFRTypes.TENSOR,
356}
357
358
359class TFRTypeResolver(type_inference.Resolver):
360  """Resolve types for the external names, calls and arguments."""
361
362  def __init__(self, op_defs):
363    super(TFRTypeResolver, self).__init__()
364    self._op_defs = op_defs
365
366    # This pattern matching mechanism works with the functional form generated
367    # by autograph:
368    #
369    #   for i in data:
370    #     print(i)
371    #
372    # generates:
373    #
374    #   def loop_body(itr):
375    #     i = itr
376    #     print(i)
377    #   ag__.for_stmt(target)
378    #
379    # The mechanism lets us infer the type of the itr argument based on that of
380    # target.
381    self._for_loop_target_types = {}  # Maps body function name to iterated.
382    self._for_loop_body_fns = {}  # Used only to avoid collisions.
383
384  def res_name(self, ns, types_ns, name):
385    name_str = str(name)
386    if name_str in TFR_BUILTINS:
387      return {TFRTypes.TFR_BUILTIN_FUNC}, name_str
388    if name_str in ns:
389      ns_val = ns[name_str]
390      return {type(ns_val)}, ns_val
391    if name_str in __builtins__:
392      return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str]
393    # This name is not in the namespace because the autograph transformation
394    # is not backloaded into Python.
395    if name_str == 'ag__':
396      return {type(AG_MODULE)}, AG_MODULE
397
398    return None, None
399
400  def res_value(self, ns, value):
401    # resolves the type of the symbol by the metadata in 'value'
402    if value is None:
403      return {TFRTypes.NONE}
404    if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC):
405      # See TFRTypes.__getattribute__.
406      # TODO(mdan): Replacing the enum with classes would avoid this overlap.
407      return {value}
408    # TODO(mdan): Index more efficiently. Could do a name check instead.
409    if any(v is value for v in AG_MODULE.__dict__.values()):
410      return {TFRTypes.AG_BUILTIN_FUNC}
411    if getattr(value, '__name__', None) == 'tensorflow.raw_ops':
412      return {types.ModuleType}
413    if hasattr(value, '__module__'):
414      if isinstance(value, dtypes.DType):
415        return {TFRTypes.ATTR}
416
417      # All the imported operations, which are not autograph built-ins, are
418      # considered to be TF raw ops.
419      # TODO(fengliuai): refine the condition that we only match TensorFlow
420      # ops here.
421      return {TFRTypes.TF_RAW_OP}
422    # TODO(mdan): Is ATTR equivalent to string?
423    return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)}
424
425  def res_call(self, ns, types_ns, node, f_type, args, keywords):
426    # resolves the return type of the function call.
427    name = anno.Basic.QN.of(node.func)
428    if f_type == (TFRTypes.AG_BUILTIN_FUNC,):
429
430      if name == QN(QN('ag__'), attr='if_stmt'):
431        nouts = node.args[6].value
432        # TODO(mdan): Look at the actual types out of if_body.
433        side_effects = {
434            qual_names.QN(n.value): {TFRTypes.TENSOR}
435            for n in node.args[5].elts[:nouts]
436        }
437        return {type(None)}, side_effects
438
439      if name == QN(QN('ag__'), attr='for_stmt'):
440        assert isinstance(node.args[2], ast.Name)
441        body_fn_name = str(anno.Basic.QN.of(node.args[2]))
442        assert body_fn_name not in self._for_loop_body_fns, (
443            'Previously used here: {}. Are you reusing the Resolver across '
444            'transformations?').format(self._for_loop_body_fns[body_fn_name])
445        self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node)
446
447        iterated_type = args[0]
448        assert iterated_type & {
449            TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR
450        }, (
451            iterated_type)
452        self._for_loop_target_types[body_fn_name] = iterated_type
453
454        return {type(None)}, None
455
456      # TODO(mdan): Actually resolve the type here instead.
457      ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None)
458      if ret_type is not None:
459        return {ret_type}, None
460      raise NotImplementedError('return type of {}'.format(name))
461
462    elif f_type == (TFRTypes.TF_RAW_OP,):
463      # This is a TF operation, so it should be found in the op_defs.
464      op_name = name.qn[1]
465      op_def, _ = self._op_defs.lookup(op_name)
466      if len(op_def.output_arg) == 1:
467        return {_get_type_from_proto(op_def.output_arg[0])}, None
468      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
469              None)
470
471    elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
472      assert name.is_simple()
473      if name == QN('range'):
474        return {TFRTypes.ATTR}, None
475
476      if name == QN('len'):
477        return {TFRTypes.INDEX}, None
478
479    elif f_type == (TFRTypes.TFR_BUILTIN_FUNC,):
480      op_name = name.qn[0]
481      if callable(TFR_BUILTINS[op_name]):
482        return {TFR_BUILTINS[op_name](*[list(arg)[0] for arg in args])}, None
483      return {TFR_BUILTINS[op_name]}, None
484
485    elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,):
486      return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None
487
488    elif f_type == (types.FunctionType,):
489      # This is a function call which isn't using tf.raw_op..
490      op_name = name.qn[0]
491
492      # 'new TF operation' produces outputs defined by the composition function.
493      op_def, _ = self._op_defs.lookup(op_name)
494      if len(op_def.output_arg) == 1:
495        return {_get_type_from_proto(op_def.output_arg[0])}, None
496      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
497              None)
498
499    raise NotImplementedError('Function:', name, f_type)
500
501  def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
502    if f_is_local:
503      f_name_str = str(f_name)
504      if f_name_str in self._for_loop_target_types:
505        # See autograph/converters/control_flow.py - the function has a single
506        # argument, the iterate before any expansion.
507        assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR}
508        # Assume all loops are TF loops. Then the iterates are autoboxed into
509        # Tensors.
510        return {TFRTypes.INDEX}
511      else:
512        return None
513
514    func = ns[f_name]
515
516    op_def, derived_attrs = self._op_defs.lookup(f_name, func)
517    if op_def is None:
518      return None
519    pos = tf_inspect.getfullargspec(func).args.index(str(name))
520
521    if pos < len(op_def.input_arg):
522      arg_def = op_def.input_arg[pos]
523      return {_get_type_from_proto(arg_def)}
524    elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs):
525      non_derived_attr_pos = pos - len(op_def.input_arg)
526      for attr_def in op_def.attr:
527        # derived attribute, skip this one and continue to the next one.
528        if attr_def.name in derived_attrs:
529          continue
530        if non_derived_attr_pos == 0:
531          return {_get_type_from_proto(None, attr_def)}
532        non_derived_attr_pos -= 1
533
534    raise ValueError('Argument is not defined in OpDef: ' + str(name))
535
536  def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
537    if not value:
538      return value
539
540    if isinstance(value, set):
541      type_tuple = value.pop()
542      if isinstance(type_tuple, tuple):
543        value = {type_tuple[node_or_slice]}
544      else:
545        value = {type_tuple}
546
547    assert len(value) == 1
548    value, = tuple(value)
549    if value == TFRTypes.TF_TENSOR_SHAPE_LIST:
550      # TODO(mdan): This is not entirely correct for multi-element slices.
551      return {int}
552    elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR):
553      # TODO(mdan): This is not entirely correct for multi-element slices.
554      return {TFRTypes.TENSOR}
555    else:
556      return {value}
557
558  def res_compare(self, ns, types_ns, node, left, right):
559    # TODO(fengliuai): make sure left and right are compatible
560    return {TFRTypes.I1}
561
562  def res_unop(self, ns, types_ns, node, opnd):
563    return opnd
564
565  def res_binop(self, ns, types_ns, node, left, right):
566    # TODO(fengliuai): make sure left and right are compatible
567    return left
568
569  def _coerce_to_more_specific_type(self, elt_types):
570    # TODO(mdan): This needs some type theory study.
571    if TFRTypes.INDEX in elt_types:
572      # Constants collapse to indices.
573      elt_types.discard(TFRTypes.I64)
574    if TFRTypes.TENSOR in elt_types:
575      # Constants collapse to tensors.
576      elt_types.discard(TFRTypes.I64)
577      # Indices collapse to tensors.
578      elt_types.discard(TFRTypes.INDEX)
579    return elt_types
580
581  def res_list_literal(self, ns, elt_types):
582    all_elt_types = set()
583    for t in elt_types:
584      all_elt_types |= t
585
586    if len(all_elt_types) != 1:
587      all_elt_types = self._coerce_to_more_specific_type(all_elt_types)
588
589    if len(all_elt_types) != 1:
590      raise ValueError('ambiguous list element types: {}'.format(elt_types))
591
592    if TFRTypes.TENSOR in all_elt_types:
593      return {TFRTypes.TENSOR_LIST}
594    return {TFRTypes.ATTR}
595
596
597class SymbolTable(object):
598  """Symbol Table for python code."""
599
600  def __init__(self):
601    self.symbols = []
602    self.enter_scope()
603    self.scf_scope = 0
604    # reserved key words
605    self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC)
606
607  def enter_scope(self, scf_scope=False):
608    """Enter a new scope - at function level."""
609    self.symbols.append({'types': {}, 'symbols': {}})
610    self.curr_table = self.symbols[len(self.symbols) - 1]
611    if scf_scope:
612      self.scf_scope += 1
613
614  def insert_symbol(self, name, value, type_):
615    self.curr_table['symbols'][name] = (value, type_)
616    # TODO(mdan): Use the inferred type rather than tracking it here.
617    # The following field is deprecated.
618    self.curr_table['types'][name] = type_
619    return value
620
621  def exit_scope(self):
622    self.symbols.pop()
623    self.curr_table = self.symbols[len(self.symbols) - 1]
624    if self.scf_scope > 0:
625      self.scf_scope -= 1
626
627  def in_scf_scope(self):
628    return self.scf_scope > 0
629
630  def lookup(self, name):
631    curr_idx = len(self.symbols) - 1
632    while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']):
633      curr_idx -= 1
634    if curr_idx < 0:
635      return None
636    return self.symbols[curr_idx]['symbols'][name]
637
638
639class TFRGen(transformer.CodeGenerator):
640  """Visit the AST and generate MLIR TFR functions."""
641
642  def __init__(self, ctx, op_defs):
643    super(TFRGen, self).__init__(ctx)
644    self.ctx = ctx
645    self.symbol_table = SymbolTable()
646    self._op_defs = op_defs
647
648  def _create_mlir_loc(self, loc):
649    """Creates mlir location from autograph ORIGIN value.
650
651    Args:
652      loc: OriginInfo
653
654    Returns:
655      A serialized mlir location string.
656    """
657    if loc is not None and loc.loc.filename:
658      file_name = os.path.basename(loc.loc.filename)
659      return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno,
660                                      loc.loc.col_offset)
661    else:
662      return 'loc(unknown)'
663
664  def _emit_with_loc(self, op_str, node=None):
665    """Emit the mlir operation with the location associated with the node.
666
667    Args:
668      op_str: The mlir operation string to be emitted.
669      node: The node of the AST tree, the mlir operation translated from.
670    """
671    loc = ''
672    if node:
673      loc = self._create_mlir_loc(
674          anno.getanno(node, anno.Basic.ORIGIN, default=None))
675    self.emit(op_str + ' ' + loc)
676
677  def _get_inferred_type(self, node, default=None):
678    """Return single type or a tuple of types if more than one type."""
679    types_ = anno.getanno(node, anno.Static.TYPES, None)
680    if not types_:
681      print('WARN: no Static.TYPES annotation. Fix the type inference pass: ')
682      self.debug_print(node)
683      return default
684
685    if len(types_) == 1:
686      type_, = types_
687    else:
688      type_ = types_
689
690    if default is not None and type_ != default:
691      print('WARN: type annotation {}({}) does not match {}({})'.format(
692          type_, type(type_), default, type(default)))
693      self.debug_print(node)
694
695    return type_
696
697  def _pack_tensor_list(self, value):
698    # This is packing a list of tensors, then the axis is 0.
699    axis = self._ssa_name('zero')
700    self._emit_with_loc('\n{} = arith.constant 0 : i64'.format(axis))
701    casted = self._ssa_name('pack')
702    self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis))
703    self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor')
704    # load the op def of tf.Pack
705    self._op_defs.lookup('Pack')
706    return casted, TFRTypes.TENSOR
707
708  def _index_to_I64(self, value, ty):
709    if ty == TFRTypes.INDEX:
710      casted = self._ssa_name('casted')
711      self._emit_with_loc('\n{} = arith.index_cast {} : index to i64'.format(
712          casted, value))
713      return casted, TFRTypes.I64
714    else:
715      return value, ty
716
717  def _i64_to_index(self, value, ty):
718    if ty == TFRTypes.I64:
719      casted = self._ssa_name('casted')
720      self._emit_with_loc('\n{} = arith.index_cast {} : i64 to index'.format(
721          casted, value))
722      return casted, TFRTypes.INDEX
723    else:
724      return value, ty
725
726  def _value_to_tensor(self, value, ty, node):
727    value, ty = self._index_to_I64(value, ty)
728    cst_tensor = self._ssa_name('cst')
729    self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value))
730    self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node)
731    return cst_tensor, TFRTypes.TENSOR
732
733  def _ssa_name(self, prefix):
734    if isinstance(prefix, qual_names.QN):
735      assert prefix.is_simple(), 'ANF transform should have cleaned this up'
736      prefix = prefix.ssf()
737    return '%' + self.ctx.namer.new_symbol(prefix, set())
738
739  def _op_def(self, op_name):
740    return op_def_registry.get(op_name)
741
742  def visit_block(self, block):
743    return [self.visit(item) for item in block]
744
745  def visit_Pass(self, node):
746    if self.symbol_table.in_scf_scope():
747      self._emit_with_loc('\nscf.yield', node)
748    else:
749      self._emit_with_loc('\ntfr.return', node)
750
751  def visit_Attribute(self, node):
752    node_type = self._get_inferred_type(node, None)
753    if isinstance(node.value, ast.Name):
754      if node.value.id == 'ag__':
755        # some variables are assigned with 'ag__.xxx' method, we should handle
756        # them following the autograph convensions.
757        return (node.attr, TFRTypes.AG_BUILTIN_FUNC)
758
759      if node_type == TFRTypes.TF_RAW_OP:
760        # This branch is used when it is inside tensorflow
761        return (node.attr, TFRTypes.TF_RAW_OP)
762
763      if node_type == TFRTypes.ATTR:
764        attr = self._ssa_name('attr')
765        tfr_type = _TF_DTYPE_TO_TFR.get(node.attr)
766        self._emit_with_loc(
767            '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node)
768        return (attr, TFRTypes.ATTR)
769
770      value, _ = self.visit(node.value)
771      tensor_type = self._get_inferred_type(node.value, None)
772      # TODO(fengliuai): use node_type once it
773      if node_type == TFRTypes.SHAPE:
774        print('TODO: use "node_type"')
775      if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR:
776        ssa_value = self._ssa_name('shape')
777        self._emit_with_loc(
778            '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value),
779            node)
780        return (ssa_value, TFRTypes.SHAPE)
781
782    if isinstance(node.value, ast.Attribute):
783      if isinstance(node.value.value, ast.Name):
784        if node.value.value.id == 'tf' and node.value.attr == 'raw_ops':
785          return (node.attr, TFRTypes.TF_RAW_OP)
786
787      value, ty = self.visit(node.value)
788      # TODO(fengliuai): use node_type once it
789      if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
790        print('TODO: use "node_type"')
791      if ty == TFRTypes.SHAPE and node.attr == 'as_list':
792        return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC)
793
794    raise NotImplementedError('Attribute kind not recognized.')
795
796  def visit_Assign(self, node):
797    values = self.visit(node.value)
798    if isinstance(node.targets[0], ast.Tuple):
799      targets = [elt.id for elt in node.targets[0].elts]
800    elif isinstance(node.targets[0], ast.Name):
801      targets = [node.targets[0].id]
802    else:
803      raise NotImplementedError('Assignment target type not recognized.')
804
805    if isinstance(values, list):
806      if isinstance(node.value, ast.Call):
807        expected = tuple(t for n, t in values)
808        if len(values) == 1:
809          expected = expected[0]
810      elif isinstance(node.value, ast.Tuple):
811        expected = tuple(t for n, t in values)
812      else:
813        raise ValueError('unknown assignment target node', node.value)
814      ty = self._get_inferred_type(node.value, expected)
815
816      if len(targets) == len(values):
817        # TODO(mdan): This should already be a tuple.
818        ty_ = (ty,) if len(values) == 1 else ty
819        for key, value, t in zip(targets, values, ty_):
820          ssa_value, _ = value
821          self.symbol_table.insert_symbol(key, ssa_value, t)
822      elif len(values) == 1:
823        name, tys = values[0]
824        if ty == TFRTypes.TENSOR_LIST:
825          # assign single tensor_list to multiple variables
826          for idx, key in enumerate(targets):
827            idx_name = self._ssa_name('idx')
828            self._emit_with_loc(
829                '\n{} = arith.constant {} : index'.format(idx_name, idx), node)
830            elt_name = self._ssa_name('elt')
831            self.emit('\n{} = tfr.get_element {}[{}]'.format(
832                elt_name, name, idx_name))
833            self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor',
834                                node)
835            self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR)
836        else:
837          # assign single value to multiple targets. This single value is
838          # usually a function return. The return type should be in the tuple of
839          # the value.
840          for idx, key in enumerate(targets):
841            ssa_name = '{}#{}'.format(name, idx)
842            ssa_type = tys[idx]
843            self.symbol_table.insert_symbol(key, ssa_name, ssa_type)
844      elif len(targets) == 1:
845        ssa_names = [n for n, _ in values]
846        self.symbol_table.insert_symbol(targets[0], ssa_names, ty)
847      return
848
849    ty = self._get_inferred_type(node.value, values[1])
850    self.symbol_table.insert_symbol(targets[0], values[0], ty)
851
852  def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty):
853    assert lhs_ty, rhs_ty
854    if isinstance(op, ast.Sub):
855      code = 'arith.sub'
856    elif isinstance(op, ast.Add):
857      code = 'arith.add'
858    elif isinstance(op, ast.Mult):
859      code = 'arith.mul'
860    elif isinstance(op, ast.Div):
861      code = 'arith.div'
862    else:
863      raise NotImplementedError('BinOp operator not recognized' + op)
864
865    if lhs_ty == TFRTypes.I64 or lhs_ty == TFRTypes.I32:
866      suffix = 'i'
867    elif lhs_ty == TFRTypes.F32:
868      suffix = 'f'
869    else:
870      raise NotImplementedError('BinOp operand type not recognized' + op)
871
872    ret = self._ssa_name(code)
873    self._emit_with_loc(
874        '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty),
875        op)
876    return ret, lhs_ty
877
878  def visit_AugAssign(self, node):
879    lhs, lhs_ty = self.visit(node.target)
880    rhs, rhs_ty = self.visit(node.value)
881    ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty)
882    self.symbol_table.insert_symbol(node.target.id, ret, ret_ty)
883
884  def visit_BinOp(self, node):
885    lhs, lhs_ty = self.visit(node.left)
886    rhs, rhs_ty = self.visit(node.right)
887    return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty)
888
889  def visit_BoolOp(self, node):
890    values = [self.visit(value) for value in node.values]
891    # TODO(fengliuai): Handle more ast node types.
892    if isinstance(node.op, ast.Or):
893      raise NotImplementedError('Or operator not recognized')
894    elif isinstance(node.op, ast.And):
895      raise NotImplementedError('And operator not recognized')
896
897  def visit_Call(self, node):
898    func_name, func_type = self.visit(node.func)
899    func_type = self._get_inferred_type(node.func, func_type)
900    if func_type == TFRTypes.AG_BUILTIN_FUNC:
901      if func_name == 'if_stmt':
902        cond, _ = self.visit(node.args[0])
903        body, _ = self.visit(node.args[1])
904        orelse, _ = self.visit(node.args[2])
905        get_state, _ = self.visit(node.args[3])
906        nouts = int(node.args[6].value)
907        out_symbols = []
908        # The out symbols are just a Tuple of names
909        for out in node.args[5].elts[:nouts]:
910          val, ty = self.symbol_table.lookup(out.value)
911          out_symbols.append(out.value)
912        return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols,
913                                   node)
914      elif func_name == 'for_stmt':
915        range_ = self._visit_iter(node.args[0])
916        body, _ = self.visit(node.args[2])
917        get_state, _ = self.visit(node.args[3])
918        loop_carried = [out.value for out in node.args[5].elts]
919        # TODO(fengliuai): opt is not used here.
920        return self._visit_for_stmt(range_, body, get_state, loop_carried, node)
921      elif func_name == 'Undefined':
922        val = self._ssa_name(node.args[0].value)
923        return (val, TFRTypes.AG_UNDEFINED_VAL)
924      elif func_name == 'UndefinedReturnValue':
925        val = self._ssa_name('return_val')
926        return (val, TFRTypes.AG_UNDEFINED_VAL)
927
928    if func_type == TFRTypes.TF_RAW_OP:
929      return self._visit_tf_op(func_name, node.args, node.keywords, node)
930
931    if func_type == TFRTypes.TFR_BUILTIN_FUNC:
932      return self._visit_tfr_builtins(func_name, node.args, node)
933
934    if func_type == types.FunctionType:
935      return self._visit_tf_op(func_name, node.args, node.keywords, node)
936
937    if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
938      return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST)
939
940    if func_type == TFRTypes.PY_BUILTIN_FUNC:
941      if func_name == 'len':
942        arg, ty = self.visit(node.args[0])
943        ty = self._get_inferred_type(node.args[0], ty)
944        if ty == TFRTypes.TF_TENSOR_SHAPE_LIST:
945          len_value = self._ssa_name('len')
946          self._emit_with_loc(
947              '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
948                  len_value, arg), node)
949          size_value = self._ssa_name('len_size')
950          self._emit_with_loc(
951              '\n{} = shape.size_to_index {} : !shape.size'.format(
952                  size_value, len_value), node)
953        elif ty == TFRTypes.TENSOR_LIST:
954          size_value = self._ssa_name('len')
955          self._emit_with_loc(
956              '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node)
957        return (size_value, TFRTypes.INDEX)
958
959    raise NotImplementedError('call operator not recognized: {} {}'.format(
960        func_name, func_type))
961
962  def visit_Compare(self, node):
963    lhs, lhs_ty = self.visit(node.left)
964    for op, right in zip(node.ops, node.comparators):
965      rhs, rhs_ty = self.visit(right)
966      if isinstance(op, ast.Eq):
967        pred = 'eq'
968      elif isinstance(op, ast.Lt):
969        pred = 'ult'
970      elif isinstance(op, ast.LtE):
971        pred = 'ule'
972      elif isinstance(op, ast.Gt):
973        pred = 'ugt'
974      elif isinstance(op, ast.GtE):
975        pred = 'uge'
976      elif isinstance(op, ast.NotEq):
977        pred = 'ne'
978      else:
979        raise NotImplementedError('Compare operator not recognized')
980
981      ret = self._ssa_name(pred)
982      if lhs_ty == TFRTypes.ATTR:
983        self._emit_with_loc(
984            '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node)
985      else:
986        if lhs_ty == TFRTypes.I64:
987          code = 'arith.cmpi'
988        elif lhs_ty == TFRTypes.F32:
989          code = 'arith.cmpf'
990        elif lhs_ty == TFRTypes.INDEX:
991          code = 'arith.cmpi'
992          # TODO(fengliuai): the reverse type inference should solve the issue.
993          rhs, _ = self._i64_to_index(rhs, rhs_ty)
994        else:
995          raise NotImplementedError('Compare operand type not recognized')
996        self._emit_with_loc(
997            '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs,
998                                                 lhs_ty), node)
999
1000      return ret, TFRTypes.I1
1001
1002  def visit_Constant(self, node):
1003    cst_name = self._ssa_name('cst')
1004    if node.value is None:
1005      cst_ty = TFRTypes.NONE
1006    elif isinstance(node.value, bool):
1007      cst_ty = self._get_inferred_type(node)
1008      cst_val = str(node.value).lower()
1009      self._emit_with_loc('\n{} = arith.constant {}'.format(cst_name, cst_val),
1010                          node)
1011    else:
1012      cst_ty = self._get_inferred_type(node)
1013      cst_val = node.value
1014      if cst_ty == TFRTypes.ATTR:
1015        self._emit_with_loc(
1016            '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty),
1017            node)
1018      else:
1019        self._emit_with_loc(
1020            '\n{} = arith.constant {} : {}'.format(cst_name, cst_val, cst_ty),
1021            node)
1022    return cst_name, cst_ty
1023
1024  def visit_FunctionDef(self, node):
1025    op_def, derived_attrs = self._op_defs.lookup(node.name, node, True)
1026    if op_def is None:
1027      # Nested function. Insert it to symbol table for looking up later.
1028      self.symbol_table.insert_symbol(node.name, node, None)
1029      return
1030    op_name = op_def.name
1031    if self.symbol_table.lookup(op_name):
1032      raise LookupError('Composition has not been registered for op: ' +
1033                        op_name)
1034    else:
1035      self.symbol_table.insert_symbol(node.name, None, None)
1036
1037    self.symbol_table.enter_scope()
1038    self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name)))
1039
1040    arg_list = []
1041    idx = 0
1042    max_idx = len(op_def.input_arg) + len(op_def.attr)
1043    for arg in node.args.args:
1044      arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN))
1045      arg_type = anno.getanno(arg, anno.Static.TYPES)[0]
1046
1047      arg_attr = ''
1048      if idx >= len(op_def.input_arg):
1049        attr_def = op_def.attr[idx - len(op_def.input_arg)]
1050        # skip the derived attributes
1051        while attr_def.name in derived_attrs and (idx + 1) < max_idx:
1052          idx += 1
1053          attr_def = op_def.attr[idx - len(op_def.input_arg)]
1054        if idx >= max_idx:
1055          raise ValueError('Argument is not defined in OpDef: ' + arg_name)
1056
1057        arg_attr += '{{tfr.name="{}"'.format(attr_def.name)
1058        if attr_def.HasField('default_value'):
1059          default_val = _get_val_from_proto(arg_type, attr_def.default_value)
1060          arg_attr += ',tfr.default={}'.format(default_val)
1061        arg_attr += '}'
1062
1063      idx += 1
1064      arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr)
1065      arg_list.append(arg_str)
1066      self.symbol_table.insert_symbol(arg.id, arg_name, arg_type)
1067
1068    ret_type_list = []
1069    for ret_def in op_def.output_arg:
1070      if ret_def.number_attr or ret_def.type_list_attr:
1071        ret_type_list.append(str(TFRTypes.TENSOR_LIST))
1072      else:
1073        ret_type_list.append(str(TFRTypes.TENSOR))
1074
1075    self.emit('{}) -> ({}) {{'.format(', '.join(arg_list),
1076                                      ', '.join(ret_type_list)))
1077    self.visit_block(node.body)
1078    self._emit_with_loc('\n}', node)
1079    self.symbol_table.exit_scope()
1080
1081  def visit_arguments(self, node):
1082    # TODO(fengliuai): return ordered the types and names.
1083    # We need to order the arguments to match the assumption in the TFR dialect.
1084    raise NotImplementedError('arguments not supported.')
1085
1086  def visit_Lambda(self, node):
1087    raise NotImplementedError('Lambda not supported.')
1088
1089  def _get_mlir_ssa_values(self, name_prefix, out_types):
1090    """Create MLIR convention SSA values."""
1091    out_ssa_values = []
1092    if not out_types:
1093      return '', out_ssa_values
1094
1095    out_name = self._ssa_name(name_prefix)
1096    if len(out_types) == 1:
1097      out_name_suffix = ''
1098      out_ssa_values.append(out_name)
1099    else:
1100      # For multiple returns, MLIR uses '%s:i' when they are defined and
1101      # '%s#i' when they are used.
1102      out_name_suffix = ':{}'.format(len(out_types))
1103      for idx, _ in enumerate(out_types):
1104        out_ssa_values.append('{}#{}'.format(out_name, idx))
1105
1106    return '{}{}'.format(out_name, out_name_suffix), out_ssa_values
1107
1108  def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols,
1109                     node):
1110    self.emit('\n')
1111    ret_str, ret_ssa_values = self._get_mlir_ssa_values(
1112        'if_stmt', [TFRTypes.TENSOR] * len(out_symbols))
1113    if ret_ssa_values:
1114      self.emit(ret_str + ' = ')
1115
1116    out_types = []
1117    for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
1118      out_types.append(str(TFRTypes.TENSOR))
1119
1120    self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types)))
1121    # Create a new scope in case the local variables are leaked.
1122    self.symbol_table.enter_scope(scf_scope=True)
1123    self.visit_block(body_def.body)
1124    self.visit_block(get_state.body)
1125    self.symbol_table.exit_scope()
1126
1127    self.emit('\n} else {')
1128
1129    # Create a new scope in case the local variables are leaked.
1130    self.symbol_table.enter_scope(scf_scope=True)
1131    self.visit_block(orelse_def.body)
1132    self.visit_block(get_state.body)
1133    self.symbol_table.exit_scope()
1134
1135    # add ssa values to the symbol table
1136    for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
1137      self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
1138
1139    self._emit_with_loc('\n}', node)
1140    return list(zip(ret_ssa_values, out_types))
1141
1142  def _visit_iter(self, node):
1143    if isinstance(node, ast.Call):
1144      f_name = anno.getanno(node.func, anno.Basic.QN)
1145      if f_name == QN('range'):
1146        args = [self.visit(arg) for arg in node.args]
1147        begin = None
1148        step = None
1149        end = None
1150        if len(args) == 1:
1151          end, end_ty = args[0]
1152        elif len(args) == 2:
1153          begin, begin_ty = args[0]
1154          end, end_ty = args[1]
1155        elif len(args) == 3:
1156          begin, begin_ty = args[0]
1157          end, end_ty = args[1]
1158          step, step_ty = args[2]
1159
1160        if begin is None:
1161          begin = self._ssa_name('begin')
1162          self._emit_with_loc('\n{} = arith.constant 0 : index'.format(begin),
1163                              node)
1164        elif begin_ty != TFRTypes.INDEX:
1165          begin_ = self._ssa_name('begin')
1166          self._emit_with_loc(
1167              '\n{} = arith.index_cast {} : {} to index'.format(
1168                  begin_, begin, begin_ty), node)
1169          begin = begin_
1170
1171        if end_ty != TFRTypes.INDEX:
1172          end_ = self._ssa_name('end')
1173          self._emit_with_loc(
1174              '\n{} = arith.index_cast {} : {} to index'.format(
1175                  end_, end, end_ty), node)
1176          end = end_
1177
1178        if step is None:
1179          step = self._ssa_name('step')
1180          self._emit_with_loc('\n{} = arith.constant 1 : index'.format(step),
1181                              node)
1182        elif step_ty != TFRTypes.INDEX:
1183          step_ = self._ssa_name('step')
1184          self._emit_with_loc(
1185              '\n{} = arith.index_cast {} : {} to index'.format(
1186                  step_, step, step_ty), node)
1187          step = step_
1188
1189        return begin, end, step
1190
1191    raise NotImplementedError('Iterator entity not supported.' + node)
1192
1193  def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node):
1194    self.emit('\n')
1195    ret_str, ret_ssa_values = self._get_mlir_ssa_values(
1196        'for_stmt', [TFRTypes.TENSOR] * len(loop_carried))
1197    if ret_ssa_values:
1198      self.emit(ret_str + ' = ')
1199
1200    # Before enter the loop, we use the original ssa values as the initial
1201    # values to the loop iteration arguments. We also create new ssa values as
1202    # the returns of the scf for statements. The symbol table needs to be
1203    # updated to these new ssa values before it enters the scope of the loop.
1204    out_types = []
1205    init_values = []
1206    for symbol, ssa_value in zip(loop_carried, ret_ssa_values):
1207      init, ty = self.symbol_table.lookup(symbol)
1208      self.symbol_table.insert_symbol(symbol, ssa_value, ty)
1209      out_types.append(str(ty))
1210      init_values.append((init, ty))
1211
1212    # Create a new scope in case the local variables are leaked.
1213    self.symbol_table.enter_scope(scf_scope=True)
1214
1215    # Create the iteration variable with index type
1216    assert len(body_def.args.args) == 1
1217    it_name = body_def.args.args[0].id
1218    it = self._ssa_name(it_name)
1219    self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX)
1220
1221    self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1],
1222                                                      range_[2]))
1223    if loop_carried:
1224      iter_args = []
1225      for symbol, init in zip(loop_carried, init_values):
1226        # create new ssa values for the loop carried variables
1227        it_arg = self._ssa_name('it_arg')
1228        self.symbol_table.insert_symbol(symbol, it_arg, init[1])
1229        iter_args.append('{} = {}'.format(it_arg, init[0]))
1230      self.emit('iter_args({}) '.format(', '.join(iter_args)))
1231      self.emit('-> ({}) {{'.format(', '.join(out_types)))
1232    else:
1233      self.emit(' {')
1234    self.visit_block(body_def.body)
1235    self.visit_block(get_state.body)
1236    self.symbol_table.exit_scope()
1237    self._emit_with_loc('\n}', node)
1238    return list(zip(ret_ssa_values, out_types))
1239
1240  def _emit_default_constant_from_proto(self, attr_def):
1241    """emit mlir constant statement from default value of the ArgDef proto."""
1242    name = self._ssa_name('cst')
1243    cst_ty = _get_type_from_proto(None, attr_def)
1244    try:
1245      cst_val = _get_val_from_proto(cst_ty, attr_def.default_value)
1246    except AttributeError:
1247      raise AttributeError(
1248          f'attribute "{attr_def.name}" does not have default_value. If the '
1249          "attribute names from the API and OpDef don't match, please add it "
1250          'to _ATTRIBUTE_RENAMES.')
1251    if cst_ty == TFRTypes.ATTR:
1252      self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format(
1253          name, cst_val, cst_ty))
1254    elif cst_ty == TFRTypes.I1:
1255      self._emit_with_loc('\n{} = arith.constant {}'.format(name, cst_val))
1256    else:
1257      self._emit_with_loc('\n{} = arith.constant {} : {}'.format(
1258          name, cst_val, cst_ty))
1259    return name, cst_ty
1260
1261  def visit_keyword(self, node):
1262    return node.arg, self.visit(node.value)
1263
1264  def _visit_tfr_builtins(self, op_name, args, node):
1265    arg_strs = []
1266    arg_tys = []
1267    for arg in args:
1268      value, ty = self.visit(arg)
1269      arg_strs.append(value)
1270      arg_tys.append(ty)
1271    tfr_op_name = 'tfr.' + op_name[5:]
1272    ret_tys = (
1273        TFR_BUILTINS[op_name](*arg_tys)
1274        if callable(TFR_BUILTINS[op_name]) else TFR_BUILTINS[op_name])
1275    # Convert the tfr builtin returns to a list.
1276    if isinstance(ret_tys, tuple):
1277      ret_tys = list(ret_tys)
1278    else:
1279      ret_tys = [ret_tys]
1280
1281    ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys)
1282
1283    arg_str = ', '.join(arg_strs)
1284    arg_ty_str = ', '.join(str(ty) for ty in arg_tys)
1285    ret_ty_str = ', '.join(str(ty) for ty in ret_tys)
1286    self._emit_with_loc('\n{} = {}({}) : ({}) -> ({})'.format(
1287        ret_str, tfr_op_name, arg_str, arg_ty_str, ret_ty_str), node)
1288    return list(zip(ret_ssa_values, ret_tys))
1289
1290  def _visit_tf_op(self, op_name, args, keywords, node):
1291    op_def, derived_attrs = self._op_defs.lookup(op_name)
1292    ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg]
1293
1294    ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys)
1295
1296    arg_strs = []
1297    ty_strs = []
1298    for arg in args:
1299      value, ty = self.visit(arg)
1300      arg_strs.append(value)
1301      ty_strs.append(str(ty))
1302
1303    input_args = [arg for arg in op_def.input_arg]
1304    attrs_no_default = [
1305        attr for attr in op_def.attr
1306        if not attr.HasField('default_value') and attr.name not in derived_attrs
1307    ]
1308    attrs_with_default = [
1309        attr for attr in op_def.attr
1310        if attr.HasField('default_value') and attr.name not in derived_attrs
1311    ]
1312
1313    kw_args = {}
1314    for arg in keywords:
1315      value, (ssa_name, ty) = self.visit(arg)
1316      ty = self._get_inferred_type(arg.value, ty)
1317
1318      # TODO(b/203493652): see comment on _ATTRIBUTE_RENAMES
1319      if op_name in _ATTRIBUTE_RENAMES and value in _ATTRIBUTE_RENAMES[op_name]:
1320        value = _ATTRIBUTE_RENAMES[op_name][value]
1321
1322      kw_args[value] = (ssa_name, ty)
1323
1324    # tensor arguments and attribute arguments
1325    ordered_args = input_args + attrs_no_default + attrs_with_default
1326    for attr_def in ordered_args[len(args):]:
1327      if attr_def.name in kw_args:
1328        value, ty = kw_args[attr_def.name]
1329        if attr_def in input_args:
1330          if ty in _ATTRIBUTE_TYPES:
1331            # the argument shouldn't be used as tf op calls directly.
1332            value, ty = self._value_to_tensor(value, ty, node)
1333          if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def):
1334            value, ty = self._pack_tensor_list(value)
1335      else:
1336        value, ty = self._emit_default_constant_from_proto(attr_def)
1337      arg_strs.append(value)
1338      ty_strs.append(str(ty))
1339
1340    if ret_ssa_values:
1341      self.emit('\n{} = '.format(ret_str))
1342
1343    self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name)))
1344    arg_str = ', '.join(arg_strs)
1345    arg_ty_str = ', '.join(ty_strs)
1346    ret_ty_str = ', '.join([str(ty) for ty in ret_tys])
1347    self._emit_with_loc(
1348        '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node)
1349    return list(zip(ret_ssa_values, ret_tys))
1350
1351  def visit_If(self, node):
1352    raise NotImplementedError('If not supported.')
1353
1354  def visit_Name(self, node):
1355    val_and_lookup_type = self.symbol_table.lookup(node.id)
1356    if val_and_lookup_type:
1357      (val, lookup_type) = val_and_lookup_type
1358    elif node.id in TFR_BUILTINS:
1359      val = node.id
1360      lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
1361    else:
1362      op_def, _ = self._op_defs.lookup(node.id)
1363      val = op_def.name
1364      lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
1365    type_ = self._get_inferred_type(node, lookup_type)
1366    return val, type_
1367
1368  def visit_Return(self, node):
1369    values = self.visit(node.value)
1370    if self.symbol_table.in_scf_scope():
1371      self.emit('\nscf.yield ')
1372    else:
1373      self.emit('\ntfr.return ')
1374    if not values:
1375      return
1376
1377    if isinstance(values, list):
1378      vals, tys = zip(*values)
1379    else:
1380      vals = values[0]
1381      tys = values[1]
1382
1383    if isinstance(tys, list) or isinstance(tys, tuple):
1384      tys = [str(t) for t in tys]
1385      self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)),
1386                          node)
1387    elif tys != TFRTypes.NONE:
1388      # TODO(fengliuai): scf region yield uses this branch. Fix it.
1389      self._emit_with_loc('{} : {}'.format(vals, tys), node)
1390
1391  def visit_Subscript(self, node):
1392    val, ty = self.visit(node.value)
1393    type_ = self._get_inferred_type(node.value, ty)
1394
1395    # TODO(fengliuai): Here we hardcode the node.slice here to get the index
1396    # type. Use the visit method once the type inference is done.
1397    # slice_val, slice_ty = self.visit(node.slice)
1398    s = node.slice
1399    if not isinstance(s, (ast.Tuple, ast.Slice)):
1400      if isinstance(s, ast.Constant):
1401        # TODO(fengliuai): promote to an assignment
1402        idx_val = self._ssa_name('cst')
1403        self._emit_with_loc(
1404            '\n{} = arith.constant {} : index'.format(idx_val, s.value), node)
1405      else:
1406        idx_val, _ = self.visit(s)
1407    else:
1408      raise NotImplementedError('non-index slice not supported.')
1409
1410    elt = self._ssa_name('elt')
1411    if type_ == TFRTypes.TENSOR_LIST:
1412      self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val))
1413      self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node)
1414      return (elt, TFRTypes.TENSOR)
1415    elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST:
1416      size_ = self._ssa_name('size')
1417      self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val))
1418      self._emit_with_loc(': !shape.shape, index -> !shape.size', node)
1419      self._emit_with_loc(
1420          '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_),
1421          node)
1422      return (elt, TFRTypes.INDEX)
1423
1424  def visit_List(self, node):
1425    out_type = self._get_inferred_type(node)
1426    vals = []
1427    tys = []
1428    for elt in node.elts:
1429      val, ty = self.visit(elt)
1430      ty = self._get_inferred_type(elt, ty)
1431      if ty in _ATTRIBUTE_TYPES and out_type == TFRTypes.TENSOR_LIST:
1432        # This list is a tensor list, then cast all the input values to tensors.
1433        val, ty = self._value_to_tensor(val, ty, node)
1434      else:
1435        # We shouldn't use index type to build the list because list will be use
1436        # as attribute.
1437        val, ty = self._index_to_I64(val, ty)
1438      vals.append(val)
1439      tys.append(str(ty))
1440
1441    list_val = self._ssa_name('list')
1442    self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals)))
1443    self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node)
1444    return (list_val, out_type)
1445
1446  def visit_Tuple(self, node):
1447    return [self.visit(elt) for elt in node.elts]
1448
1449  def visit_UnaryOp(self, node):
1450    value, ty = self.visit(node.operand)
1451    if isinstance(node.op, ast.USub):
1452      zero_value = self._ssa_name('zero')
1453      ssa_value = self._ssa_name('cst')
1454      if ty == TFRTypes.I32 or ty == TFRTypes.I64:
1455        self._emit_with_loc(
1456            '\n{} = arith.constant 0 : {}'.format(zero_value, ty), node)
1457        self._emit_with_loc(
1458            '\n{} = arith.subi {}, {} : {}'.format(ssa_value, zero_value, value,
1459                                                   ty), node)
1460      elif ty == TFRTypes.F32:
1461        self._emit_with_loc(
1462            '\n{} = arith.constant 0.0 : {}'.format(zero_value, ty), node)
1463        self._emit_with_loc(
1464            '\n{} = arith.subf {}, {} : {}'.format(ssa_value, zero_value, value,
1465                                                   ty), node)
1466      else:
1467        raise NotImplementedError('USub type not recognized: ' + str(ty))
1468      return ssa_value, ty
1469    raise NotImplementedError('USub operator not recognized')
1470
1471  def visit_For(self, node):
1472    raise NotImplementedError('For operator not recognized')
1473
1474  def visit_While(self, node):
1475    raise NotImplementedError('While operator not recognized')
1476
1477  def visit_Try(self, node):
1478    # Only handles the body of the try statement.
1479    self.visit_block(node.body)
1480
1481
1482def _apply_py_to_tf_passes(node, ctx):
1483  """Apply transformations from PyToTF to match tf.function tracing."""
1484  # TODO(fengliuai): we don't know which passes are required, thus we evaluate
1485  # each one when the corresponding node is handled.
1486  # copied from PyToTF.transform_ast
1487  node = return_statements.transform(node, ctx, False)
1488  node = control_flow.transform(node, ctx)
1489  return node
1490
1491
1492class TfrGen(transpiler.GenericTranspiler):
1493  """Transforms Python objects into TFR MLIR source code."""
1494
1495  def __init__(self, op_defs):
1496    self._op_defs = op_defs
1497
1498  def transform_ast(self, node, ctx):
1499    node = _apply_py_to_tf_passes(node, ctx)
1500    # TODO(mdan): Enable this.
1501    # node = anf.transform(node, ctx)
1502
1503    graphs = cfg.build(node)
1504    node = qual_names.resolve(node)
1505    node = activity.resolve(node, ctx)
1506    node = reaching_definitions.resolve(node, ctx, graphs)
1507    node = reaching_fndefs.resolve(node, ctx, graphs)
1508    node = type_inference.resolve(node, ctx, graphs,
1509                                  TFRTypeResolver(self._op_defs))
1510
1511    mlir_generator = TFRGen(ctx, self._op_defs)
1512    mlir_generator.visit(node)
1513    return mlir_generator.code_buffer
1514
1515
1516def tfr_gen(func, op_defs):
1517  """Parse a function and emit the TFR functions."""
1518  mlir_code, _ = TfrGen(op_defs).transform(func, None)
1519  assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code)
1520  return mlir_code
1521
1522
1523def tfr_funcs_gen_from_module(source, op_defs, method_prefix=None,
1524                              op_libraries=None):
1525  """Parse the input source module and emit the TFR functions."""
1526
1527  # Load the op library so the op is added to the op registry. This is
1528  # required when the op cc_library couldn't be statically linked in open
1529  # source.
1530  # This is a no op if the op shared library couldn't be found in the same
1531  # directory of the op Python API.
1532  # TODO(fengliuai): make the .so file path configurable.
1533  if op_libraries:
1534    prefix_len = len('gen_')
1535    for m in op_libraries:
1536      lib_dir = os.path.dirname(m.__file__)
1537      lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so')
1538      lib_path = os.path.join(lib_dir, lib_name)
1539      if os.path.exists(lib_path):
1540        logging.info('load file: ' + lib_path)
1541        load_library.load_op_library(lib_path)
1542  else:
1543    # The op library is generated from the source module, then we load all the
1544    # .so file in the directory
1545    lib_dir = os.path.dirname(source.__file__)
1546    for lib_name in os.listdir(lib_dir):
1547      if lib_name.endswith('.so'):
1548        lib_path = os.path.join(lib_dir, lib_name)
1549        logging.info('load file: ' + lib_path)
1550        load_library.load_op_library(lib_path)
1551
1552  py_funcs = [
1553      func
1554      for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
1555      if not method_prefix or name.startswith(method_prefix)
1556  ]
1557  # Sort the methods by the line number, to make sure the definitions are
1558  # processed before the usages.
1559  # TODO(fengliuai): Use type inference resolver to recursively process any
1560  # functions called.
1561  py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno)
1562  mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs]
1563
1564  return mlir_funcs
1565
1566
1567def tfr_gen_from_module(source, method_prefix=None, op_libraries=None,
1568                        op_defs=OpDefCache()):
1569  """Parse the input source module and emit the TFR and external functions."""
1570  mlir_funcs = tfr_funcs_gen_from_module(
1571      source, op_defs, method_prefix, op_libraries)
1572
1573  return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs())
1574