• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tfr_gen: Generate mlir tfr decomposition function from python code."""
16
17# pylint: disable=invalid-name
18# pylint: disable=missing-function-docstring
19# pylint: disable=g-direct-tensorflow-import
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import enum
26import os
27import re
28import types
29import gast as ast
30
31from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr
32from tensorflow.core.framework import types_pb2
33from tensorflow.python.autograph.converters import control_flow
34from tensorflow.python.autograph.converters import return_statements
35from tensorflow.python.autograph.impl import api
36from tensorflow.python.autograph.pyct import anno
37from tensorflow.python.autograph.pyct import cfg
38from tensorflow.python.autograph.pyct import qual_names
39from tensorflow.python.autograph.pyct import transformer
40from tensorflow.python.autograph.pyct import transpiler
41from tensorflow.python.autograph.pyct.static_analysis import activity
42from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
43from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
44from tensorflow.python.autograph.pyct.static_analysis import type_inference
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import load_library
47from tensorflow.python.framework import op_def_registry
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.util import tf_inspect
50
51# TODO(mdan): Use class definitions so that we can mix these with Python types.
52
53
54class TFRTypes(enum.Enum):
55  """All the supported types.
56
57    1-3: tfr types
58    4-99: mlir built-in types
59    100-199: TF related translator internal types
60    200- : Python related translator internal types
61  """
62  TENSOR = 1
63  TENSOR_LIST = 2
64  ATTR = 3
65  NONE = 4
66  SHAPE = 5  # shape -> !shape.shape
67  I1 = 21
68  I32 = 22
69  I64 = 23
70  F32 = 24
71  INDEX = 25
72  AG_UNDEFINED_VAL = 100
73  AG_BUILTIN_FUNC = 101
74  TF_RAW_OP = 102
75  TF_REGION = 103
76  TF_TENSOR_SHAPE_FUNC = 104  # shape.as_list
77  TF_TENSOR_SHAPE_LIST = 105  # shape.as_list()
78  PY_BUILTIN_FUNC = 200
79
80  # As these are not real types, __getattribute__ helps them appear more like
81  # actual types (i.e. class definitions).
82  def __getattribute__(self, name):
83    if name == 'shape' and object.__getattribute__(self, 'value') == 1:
84      return TFRTypes.SHAPE
85    if name == 'as_list' and object.__getattribute__(self, 'value') == 5:
86      return TFRTypes.TF_TENSOR_SHAPE_FUNC
87    return object.__getattribute__(self, name)
88
89  def __str__(self):
90    if self.value < 4:  # pylint: disable=comparison-with-callable
91      return '!tfr.' + self.name.lower()
92    elif self.value < 10:  # pylint: disable=comparison-with-callable
93      return '!shape.' + self.name.lower()
94    else:
95      return self.name.lower()
96
97
98_attribute_types = [
99    TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX,
100    TFRTypes.ATTR
101]
102
103
104def _get_type_from_proto(arg_def=None, attr_def=None):
105  if not arg_def:
106    if attr_def.type == 'bool':
107      return TFRTypes.I1
108    elif attr_def.type == 'int32':
109      return TFRTypes.I32
110    elif attr_def.type == 'int' or attr_def.type == 'int64':
111      return TFRTypes.I64
112    elif attr_def.type == 'float':
113      return TFRTypes.F32
114    else:
115      return TFRTypes.ATTR
116
117  if arg_def.number_attr or arg_def.type_list_attr:
118    return TFRTypes.TENSOR_LIST
119  else:
120    return TFRTypes.TENSOR
121
122
123def _get_type_info_from_proto(arg_def=None, attr_def=None):
124  attr_type = _get_type_from_proto(arg_def, attr_def)
125  if not arg_def:
126    return '{}{{tfr.name="{}",tfr.type="{}"}}'.format(
127        attr_type, attr_def.name, attr_def.type)
128  else:
129    attr_names = []
130    if arg_def.number_attr:
131      attr_names.append(arg_def.number_attr)
132    if arg_def.type_attr:
133      attr_names.append(arg_def.type_attr)
134    if arg_def.type_list_attr:
135      attr_names.append(arg_def.type_list_attr)
136
137    # TODO(fengliuai): currently we don't support backward type inference, so we
138    # have to store these non-derivable type in the signatures, and then they
139    # can be used to cast the values when raising to tf ops.
140    if arg_def.type == types_pb2.DT_FLOAT:
141      attr_names.append('f32_')
142    elif arg_def.type == types_pb2.DT_INT32:
143      attr_names.append('i32_')
144    elif arg_def.type == types_pb2.DT_INT64:
145      attr_names.append('i64_')
146    elif arg_def.type == types_pb2.DT_BOOL:
147      attr_names.append('i1_')
148
149    if not attr_names:
150      return str(attr_type)
151    else:
152      return '{}<{}>'.format(attr_type, ','.join(attr_names))
153
154
155def _get_val_from_proto(attr_type, attr_val):
156  if attr_type == TFRTypes.I1:
157    return 'true' if attr_val.b else 'false'
158  elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64:
159    return attr_val.i
160  elif attr_type == TFRTypes.F32:
161    return attr_val.f
162  elif attr_type == TFRTypes.ATTR:
163    # string
164    if attr_val.HasField('s'):
165      return '"{}"'.format(attr_val.s.decode())
166    # type
167    if attr_val.HasField('type'):
168      if attr_val.type == types_pb2.DT_FLOAT:
169        return 'f32'
170      elif attr_val.type == types_pb2.DT_INT32:
171        return 'i32'
172      elif attr_val.type == types_pb2.DT_INT64:
173        return 'i64'
174      elif attr_val.type == types_pb2.DT_BOOL:
175        return 'i1'
176    # list
177    if attr_val.HasField('list'):
178      if attr_val.list.f:
179        elt_ty = TFRTypes.F32
180        values = attr_val.list.f
181      elif attr_val.list.i:
182        elt_ty = TFRTypes.I64
183        values = attr_val.list.i
184      else:
185        elt_ty = TFRTypes.NONE
186        values = []
187      array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values]
188      return '[{}]'.format(','.join(array_attr_elts))
189  raise NotImplementedError(
190      'Proto AttrValue not recognized. type: {}, value: {}'.format(
191          attr_type, attr_val))
192
193
194def _collect_derived_attrs_from_proto(op_def):
195  derived_attrs = set()
196  for arg in op_def.input_arg:
197    if arg.type_attr:
198      derived_attrs.add(arg.type_attr)
199    if arg.number_attr:
200      derived_attrs.add(arg.number_attr)
201    if arg.type_list_attr:
202      derived_attrs.add(arg.type_list_attr)
203
204    # TODO(fengliuai): currently we don't support backward type inference, so we
205    # have to store these non-derivable type in the signatures, and then they
206    # can be used to cast the values when raising to tf ops.
207    if arg.type == types_pb2.DT_FLOAT:
208      derived_attrs.add('f32_')
209    elif arg.type == types_pb2.DT_INT32:
210      derived_attrs.add('i32_')
211    elif arg.type == types_pb2.DT_INT64:
212      derived_attrs.add('i64_')
213    elif arg.type == types_pb2.DT_BOOL:
214      derived_attrs.add('i1_')
215  return derived_attrs
216
217
218def _require_tensor_list(arg_def):
219  return arg_def.type_list_attr or arg_def.number_attr
220
221
222def _camel_to_snake(name):
223  s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
224  return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
225
226
227class OpDefCache(object):
228  """A Dict to cache the OpDef for the Python function name."""
229
230  def __init__(self):
231    self._op_defs = {}
232
233  def lookup(self, f_name, func_def=None, optional=False):
234    if f_name in self._op_defs.keys():
235      return self._op_defs[f_name]
236
237    if isinstance(func_def, types.FunctionType):
238      if not hasattr(func_def, '_tfr_op_name'):
239        # skip a non-composition function
240        if optional:
241          return (None, None)
242        else:
243          raise KeyError('OpDef does not exist: ' + f_name)
244      op_name = getattr(func_def, '_tfr_op_name')
245    elif not func_def:
246      op_name = f_name
247    else:
248      # TODO(fengliuai): create one utility method to match different APIs.
249      compose_dec = []
250      for dec in func_def.decorator_list:
251        if isinstance(dec, ast.Call):
252          if isinstance(dec.func,
253                        ast.Attribute) and dec.func.attr == 'Composite':
254            compose_dec.append(dec)
255          if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite':
256            compose_dec.append(dec)
257
258      if not compose_dec:
259        # skip a non-composition function
260        if optional:
261          return (None, None)
262        else:
263          raise KeyError('OpDef does not exist: ' + f_name)
264      elif len(compose_dec) > 1:
265        raise KeyError('More than one TF ops decomposes for.')
266      else:
267        op_name = compose_dec[0].args[0].value
268
269    op_def = op_def_registry.get(op_name)
270    if not op_def:
271      raise ValueError('Not a registered op: ' + op_name)
272    derived_attrs = _collect_derived_attrs_from_proto(op_def)
273    self._op_defs[f_name] = (op_def, derived_attrs)
274    return (op_def, derived_attrs)
275
276  def mlir_external_funcs(self):
277    tfr_funcs = []
278    for _, (op_def, derived_attrs) in sorted(self._op_defs.items()):
279      tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name))
280
281      # tensor inputs
282      inputs = [
283          _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg
284      ]
285
286      # attribute inputs. The attribute with default values are moved backwards.
287      non_derived_attrs = [
288          attr for attr in op_def.attr if attr.name not in derived_attrs
289      ]
290      attrs_no_default = [
291          attr for attr in non_derived_attrs
292          if not attr.HasField('default_value')
293      ]
294      attrs_with_default = [
295          attr for attr in non_derived_attrs if attr.HasField('default_value')
296      ]
297      attr_names = set()
298      for attr_def in attrs_no_default + attrs_with_default:
299        inputs.append(_get_type_info_from_proto(None, attr_def))
300        attr_names.add(attr_def.name)
301
302      # tensor outputs
303      outputs = [
304          _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg
305      ]
306
307      inputs = ','.join(inputs)
308      outputs = ','.join(outputs)
309      attrs = ','.join(sorted(derived_attrs.union(attr_names)))
310      tfr_funcs.append('{}{}) -> ({}) attributes {{{}}}'.format(
311          tfr_func, inputs, outputs, attrs))
312    return tfr_funcs
313
314
315_PY_TYPE_TO_TFR = {
316    bool: TFRTypes.I1,
317    int: TFRTypes.I64,
318    float: TFRTypes.F32,
319}
320
321_TF_DTYPE_TO_TFR = {
322    'bool': TFRTypes.I1,
323    'int64': TFRTypes.I64,
324    'int32': TFRTypes.I32,
325    'float32': TFRTypes.F32,
326}
327
328_AG_FIXED_RETURN_TYPE = {
329    'for_stmt': type(None),
330    'if_stmt': type(None),
331    'Undefined': TFRTypes.AG_UNDEFINED_VAL,
332}
333
334QN = qual_names.QN
335
336# TODO(mdan): Fix this with an importable module.
337AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__']  # pylint:disable=protected-access
338
339
340class TFRTypeResolver(type_inference.Resolver):
341  """Resolve types for the external names, calls and arguments."""
342
343  def __init__(self, op_defs):
344    super(TFRTypeResolver, self).__init__()
345    self._op_defs = op_defs
346
347    # This pattern matching mechanism works with the functional form generated
348    # by autograph:
349    #
350    #   for i in data:
351    #     print(i)
352    #
353    # generates:
354    #
355    #   def loop_body(itr):
356    #     i = itr
357    #     print(i)
358    #   ag__.for_stmt(target)
359    #
360    # The mechanism lets us infer the type of the itr argument based on that of
361    # target.
362    self._for_loop_target_types = {}  # Maps body function name to iterated.
363    self._for_loop_body_fns = {}  # Used only to avoid collisions.
364
365  def res_name(self, ns, types_ns, name):
366    name_str = str(name)
367    if name_str in ns:
368      ns_val = ns[name_str]
369      return {type(ns_val)}, ns_val
370    if name_str in __builtins__:
371      return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str]
372    # This name is not in the namespace because the autograph transformation
373    # is not backloaded into Python.
374    if name_str == 'ag__':
375      return {type(AG_MODULE)}, AG_MODULE
376
377    return None, None
378
379  def res_value(self, ns, value):
380    if value is None:
381      return {TFRTypes.NONE}
382    if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC):
383      # See TFRTypes.__getattrbute__.
384      # TODO(mdan): Replacing the enum with classes would avoid this overlap.
385      return {value}
386    # TODO(mdan): Index more efficiently. Could do a name check instead.
387    if any(v is value for v in AG_MODULE.__dict__.values()):
388      return {TFRTypes.AG_BUILTIN_FUNC}
389    if getattr(value, '__name__', None) == 'tensorflow.raw_ops':
390      return {types.ModuleType}
391    if hasattr(value, '__module__'):
392      if isinstance(value, dtypes.DType):
393        return {TFRTypes.ATTR}
394
395      # All the imported operations, which are not autograph built-ins, are
396      # considered to be TF raw ops.
397      # TODO(fengliuai): refine the condition so we only match TensorFlow
398      # ops here.
399      return {TFRTypes.TF_RAW_OP}
400    # TODO(mdan): Is ATTR equivalent to string?
401    return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)}
402
403  def res_call(self, ns, types_ns, node, f_type, args, keywords):
404    name = anno.Basic.QN.of(node.func)
405    if f_type == (TFRTypes.AG_BUILTIN_FUNC,):
406
407      if name == QN(QN('ag__'), attr='if_stmt'):
408        nouts = node.args[6].value
409        # TODO(mdan): Look at the actual types out of if_body.
410        side_effects = {
411            qual_names.QN(n.value): {TFRTypes.TENSOR}
412            for n in node.args[5].elts[:nouts]
413        }
414        return {type(None)}, side_effects
415
416      if name == QN(QN('ag__'), attr='for_stmt'):
417        assert isinstance(node.args[2], ast.Name)
418        body_fn_name = str(anno.Basic.QN.of(node.args[2]))
419        assert body_fn_name not in self._for_loop_body_fns, (
420            'Previously used here: {}. Are you reusing the Resolver across '
421            'transformations?').format(self._for_loop_body_fns[body_fn_name])
422        self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node)
423
424        iterated_type = args[0]
425        assert iterated_type & {
426            TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR
427        }, (
428            iterated_type)
429        self._for_loop_target_types[body_fn_name] = iterated_type
430
431        return {type(None)}, None
432
433      # TODO(mdan): Actually resolve the type here instead.
434      ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None)
435      if ret_type is not None:
436        return {ret_type}, None
437      raise NotImplementedError('return type of {}'.format(name))
438
439    elif f_type == (TFRTypes.TF_RAW_OP,):
440      op_name = name.qn[1]
441      op_def, _ = self._op_defs.lookup(op_name)
442      if len(op_def.output_arg) == 1:
443        return {_get_type_from_proto(op_def.output_arg[0])}, None
444      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
445              None)
446
447    elif f_type == (types.FunctionType,):
448      # A composition Python function name is used directly.
449      op_name = name.qn[0]
450      op_def, _ = self._op_defs.lookup(op_name)
451      if len(op_def.output_arg) == 1:
452        return {_get_type_from_proto(op_def.output_arg[0])}, None
453      return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
454              None)
455
456    elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
457      assert name.is_simple()
458      if name == QN('range'):
459        return {TFRTypes.ATTR}, None
460
461      if name == QN('len'):
462        return {TFRTypes.INDEX}, None
463
464    elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,):
465      return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None
466
467    raise NotImplementedError('Function:', name, f_type)
468
469  def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
470    if f_is_local:
471      f_name_str = str(f_name)
472      if f_name_str in self._for_loop_target_types:
473        # See autograph/converters/control_flow.py - the function has a single
474        # argument, the iterate before any expansion.
475        assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR}
476        # Assume all loops are TF loops. Then the iterates are autoboxed into
477        # Tensors.
478        return {TFRTypes.INDEX}
479      else:
480        return None
481
482    func = ns[f_name]
483
484    op_def, derived_attrs = self._op_defs.lookup(f_name, func)
485    if op_def is None:
486      return None
487    pos = tf_inspect.getfullargspec(func).args.index(str(name))
488
489    if pos < len(op_def.input_arg):
490      arg_def = op_def.input_arg[pos]
491      return {_get_type_from_proto(arg_def)}
492    elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs):
493      non_derived_attr_pos = pos - len(op_def.input_arg)
494      for attr_def in op_def.attr:
495        # derived attribute, skip this one and continue to the next one.
496        if attr_def.name in derived_attrs:
497          continue
498        if non_derived_attr_pos == 0:
499          return {_get_type_from_proto(None, attr_def)}
500        non_derived_attr_pos -= 1
501
502    raise ValueError('Argument is not defined in OpDef: ' + str(name))
503
504  def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
505    assert len(value) == 1
506    value, = tuple(value)
507    if value == TFRTypes.TF_TENSOR_SHAPE_LIST:
508      # TODO(mdan): This is not entirely correct for multi-element slices.
509      return {int}
510    elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR):
511      # TODO(mdan): This is not entirely correct for multi-element slices.
512      return {TFRTypes.TENSOR}
513    raise NotImplementedError('slice of {}'.format(value))
514
515  def res_compare(self, ns, types_ns, node, left, right):
516    # TODO(fengliuai): make sure left and right are compatible
517    return {TFRTypes.I1}
518
519  def res_unop(self, ns, types_ns, node, opnd):
520    return opnd
521
522  def res_binop(self, ns, types_ns, node, left, right):
523    # TODO(fengliuai): make sure left and right are compatible
524    return left
525
526  def _coerce_to_more_specific_type(self, elt_types):
527    # TODO(mdan): This needs some type theory study.
528    if TFRTypes.INDEX in elt_types:
529      # Constants collapse to indices.
530      elt_types.discard(TFRTypes.I64)
531    if TFRTypes.TENSOR in elt_types:
532      # Constants collapse to tensors.
533      elt_types.discard(TFRTypes.I64)
534      # Indices collapse to tensors.
535      elt_types.discard(TFRTypes.INDEX)
536    return elt_types
537
538  def res_list_literal(self, ns, elt_types):
539    all_elt_types = set()
540    for t in elt_types:
541      all_elt_types |= t
542
543    if len(all_elt_types) != 1:
544      all_elt_types = self._coerce_to_more_specific_type(all_elt_types)
545
546    if len(all_elt_types) != 1:
547      raise ValueError('ambiguous list element types: {}'.format(elt_types))
548
549    if TFRTypes.TENSOR in all_elt_types:
550      return {TFRTypes.TENSOR_LIST}
551    return {TFRTypes.ATTR}
552
553
554class SymbolTable(object):
555  """Symbol Table for python code."""
556
557  def __init__(self):
558    self.symbols = []
559    self.enter_scope()
560    self.scf_scope = 0
561    # reserved key words
562    self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC)
563
564  def enter_scope(self, scf_scope=False):
565    """Enter a new scope - at function level."""
566    self.symbols.append({'types': {}, 'symbols': {}})
567    self.curr_table = self.symbols[len(self.symbols) - 1]
568    if scf_scope:
569      self.scf_scope += 1
570
571  def insert_symbol(self, name, value, type_):
572    self.curr_table['symbols'][name] = (value, type_)
573    # TODO(mdan): Use the inferred type rather than tracking it here.
574    # The following field is deprecated.
575    self.curr_table['types'][name] = type_
576    return value
577
578  def exit_scope(self):
579    self.symbols.pop()
580    self.curr_table = self.symbols[len(self.symbols) - 1]
581    if self.scf_scope > 0:
582      self.scf_scope -= 1
583
584  def in_scf_scope(self):
585    return self.scf_scope > 0
586
587  def lookup(self, name):
588    curr_idx = len(self.symbols) - 1
589    while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']):
590      curr_idx -= 1
591    if curr_idx < 0:
592      return None
593    return self.symbols[curr_idx]['symbols'][name]
594
595
596class TFRGen(transformer.CodeGenerator):
597  """Visit the AST and generate MLIR TFR functions."""
598
599  def __init__(self, ctx, op_defs):
600    super(TFRGen, self).__init__(ctx)
601    self.ctx = ctx
602    self.symbol_table = SymbolTable()
603    self._op_defs = op_defs
604
605  def _create_mlir_loc(self, loc):
606    """Creates mlir location from autograph ORIGIN value.
607
608    Args:
609      loc: OriginInfo
610
611    Returns:
612      A serialized mlir location string.
613    """
614    if loc is not None and loc.loc.filename:
615      file_name = os.path.basename(loc.loc.filename)
616      return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno,
617                                      loc.loc.col_offset)
618    else:
619      return 'loc(unknown)'
620
621  def _emit_with_loc(self, op_str, node=None):
622    """Emit the mlir operation with the location associated with the node.
623
624    Args:
625      op_str: The mlir operation string to be emitted.
626      node: The node of the AST tree, the mlir operation translated from.
627    """
628    loc = ''
629    if node:
630      loc = self._create_mlir_loc(
631          anno.getanno(node, anno.Basic.ORIGIN, default=None))
632    self.emit(op_str + ' ' + loc)
633
634  def _get_inferred_type(self, node, default=None):
635    types_ = anno.getanno(node, anno.Static.TYPES, None)
636    if not types_:
637      print('WARN: no Static.TYPES annotation. Fix the type inference pass: ')
638      self.debug_print(node)
639      return default
640    if types_ and len(types_) > 1:
641      raise ValueError('ambiguous inferred type for "{}": {}'.format(
642          node, types_))
643
644    type_, = types_
645
646    if default is not None and type_ != default:
647      print('WARN: type annotation {}({}) does not match {}({})'.format(
648          type_, type(type_), default, type(default)))
649      self.debug_print(node)
650
651    return type_
652
653  def _pack_tensor_list(self, value):
654    # This is packing a list of tensors, then the axis is 0.
655    axis = self._ssa_name('zero')
656    self._emit_with_loc('\n{} = constant 0 : i64'.format(axis))
657    casted = self._ssa_name('pack')
658    self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis))
659    self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor')
660    # load the op def of tf.Pack
661    self._op_defs.lookup('Pack')
662    return casted, TFRTypes.TENSOR
663
664  def _index_to_I64(self, value, ty):
665    if ty == TFRTypes.INDEX:
666      casted = self._ssa_name('casted')
667      self._emit_with_loc('\n{} = index_cast {} : index to i64'.format(
668          casted, value))
669      return casted, TFRTypes.I64
670    else:
671      return value, ty
672
673  def _i64_to_index(self, value, ty):
674    if ty == TFRTypes.I64:
675      casted = self._ssa_name('casted')
676      self._emit_with_loc('\n{} = index_cast {} : i64 to index'.format(
677          casted, value))
678      return casted, TFRTypes.INDEX
679    else:
680      return value, ty
681
682  def _value_to_tensor(self, value, ty, node):
683    value, ty = self._index_to_I64(value, ty)
684    cst_tensor = self._ssa_name('cst')
685    self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value))
686    self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node)
687    return cst_tensor, TFRTypes.TENSOR
688
689  def _ssa_name(self, prefix):
690    if isinstance(prefix, qual_names.QN):
691      assert prefix.is_simple(), 'ANF transform should have cleaned this up'
692      prefix = prefix.ssf()
693    return '%' + self.ctx.namer.new_symbol(prefix, set())
694
695  def _op_def(self, op_name):
696    return op_def_registry.get(op_name)
697
698  def visit_block(self, block):
699    return [self.visit(item) for item in block]
700
701  def visit_Pass(self, node):
702    if self.symbol_table.in_scf_scope():
703      self._emit_with_loc('\nscf.yield', node)
704    else:
705      self._emit_with_loc('\ntfr.return', node)
706
707  def visit_Attribute(self, node):
708    node_type = self._get_inferred_type(node, None)
709    if isinstance(node.value, ast.Name):
710      if node.value.id == 'ag__':
711        # some variables are assigned with 'ag__.xxx' method, we should handle
712        # them following the autograph convensions.
713        return (node.attr, TFRTypes.AG_BUILTIN_FUNC)
714
715      if node_type == TFRTypes.TF_RAW_OP:
716        # This branch is used when it is inside tensorflow
717        return (node.attr, TFRTypes.TF_RAW_OP)
718
719      if node_type == TFRTypes.ATTR:
720        attr = self._ssa_name('attr')
721        tfr_type = _TF_DTYPE_TO_TFR.get(node.attr)
722        self._emit_with_loc(
723            '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node)
724        return (attr, TFRTypes.ATTR)
725
726      value, _ = self.visit(node.value)
727      tensor_type = self._get_inferred_type(node.value, None)
728      # TODO(fengliuai): use node_type once it
729      if node_type == TFRTypes.SHAPE:
730        print('TODO: use "node_type"')
731      if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR:
732        ssa_value = self._ssa_name('shape')
733        self._emit_with_loc(
734            '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value),
735            node)
736        return (ssa_value, TFRTypes.SHAPE)
737
738    if isinstance(node.value, ast.Attribute):
739      if isinstance(node.value.value, ast.Name):
740        if node.value.value.id == 'tf' and node.value.attr == 'raw_ops':
741          return (node.attr, TFRTypes.TF_RAW_OP)
742
743      value, ty = self.visit(node.value)
744      # TODO(fengliuai): use node_type once it
745      if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
746        print('TODO: use "node_type"')
747      if ty == TFRTypes.SHAPE and node.attr == 'as_list':
748        return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC)
749
750    raise NotImplementedError('Attribute kind not recognized.')
751
752  def visit_Assign(self, node):
753    values = self.visit(node.value)
754    if isinstance(node.targets[0], ast.Tuple):
755      targets = [elt.id for elt in node.targets[0].elts]
756    elif isinstance(node.targets[0], ast.Name):
757      targets = [node.targets[0].id]
758    else:
759      raise NotImplementedError('Assignment target type not recognized.')
760
761    if isinstance(values, list):
762      if isinstance(node.value, ast.Call):
763        expected = tuple(t for n, t in values)
764        if len(values) == 1:
765          expected = expected[0]
766      elif isinstance(node.value, ast.Tuple):
767        expected = tuple(t for n, t in values)
768      else:
769        raise ValueError('unknown assignment target node', node.value)
770      ty = self._get_inferred_type(node.value, expected)
771
772      if len(targets) == len(values):
773        # TODO(mdan): This should already be a tuple.
774        ty_ = (ty,) if len(values) == 1 else ty
775        for key, value, t in zip(targets, values, ty_):
776          ssa_value, _ = value
777          self.symbol_table.insert_symbol(key, ssa_value, t)
778      elif len(values) == 1:
779        n, _ = values[0]
780        assert ty == TFRTypes.TENSOR_LIST
781        # assign a tensor_list to multiple variables
782        for idx, key in enumerate(targets):
783          idx_name = self._ssa_name('idx')
784          self._emit_with_loc(
785              '\n{} = constant {} : index'.format(idx_name, idx), node)
786          elt_name = self._ssa_name('elt')
787          self.emit('\n{} = tfr.get_element {}[{}]'.format(
788              elt_name, n, idx_name))
789          self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor',
790                              node)
791          self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR)
792      elif len(targets) == 1:
793        ssa_names = [n for n, _ in values]
794        self.symbol_table.insert_symbol(targets[0], ssa_names, ty)
795      return
796
797    ty = self._get_inferred_type(node.value, values[1])
798    self.symbol_table.insert_symbol(targets[0], values[0], ty)
799
800  def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty):
801    assert lhs_ty, rhs_ty
802    if isinstance(op, ast.Sub):
803      code = 'sub'
804    elif isinstance(op, ast.Add):
805      code = 'add'
806    else:
807      raise NotImplementedError('BinOp operator not recognized' + op)
808
809    if lhs_ty == TFRTypes.I64:
810      suffix = 'i'
811    elif lhs_ty == TFRTypes.F32:
812      suffix = 'f'
813    else:
814      raise NotImplementedError('BinOp operand type not recognized' + op)
815
816    ret = self._ssa_name(code)
817    self._emit_with_loc(
818        '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty),
819        op)
820    return ret, lhs_ty
821
822  def visit_AugAssign(self, node):
823    lhs, lhs_ty = self.visit(node.target)
824    rhs, rhs_ty = self.visit(node.value)
825    ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty)
826    self.symbol_table.insert_symbol(node.target.id, ret, ret_ty)
827
828  def visit_BinOp(self, node):
829    lhs, lhs_ty = self.visit(node.left)
830    rhs, rhs_ty = self.visit(node.right)
831    return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty)
832
833  def visit_BoolOp(self, node):
834    values = [self.visit(value) for value in node.values]
835    # TODO(fengliuai): Handle more ast node types.
836    if isinstance(node.op, ast.Or):
837      raise NotImplementedError('Or operator not recognized')
838    elif isinstance(node.op, ast.And):
839      raise NotImplementedError('And operator not recognized')
840
841  def visit_Call(self, node):
842    func_name, func_type = self.visit(node.func)
843    func_type = self._get_inferred_type(node.func, func_type)
844    if func_type == TFRTypes.AG_BUILTIN_FUNC:
845      if func_name == 'if_stmt':
846        cond, _ = self.visit(node.args[0])
847        body, _ = self.visit(node.args[1])
848        orelse, _ = self.visit(node.args[2])
849        get_state, _ = self.visit(node.args[3])
850        nouts = int(node.args[6].value)
851        out_symbols = []
852        # The out symbols are just a Tuple of names
853        for out in node.args[5].elts[:nouts]:
854          val, ty = self.symbol_table.lookup(out.value)
855          out_symbols.append(out.value)
856        return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols,
857                                   node)
858      elif func_name == 'for_stmt':
859        range_ = self._visit_iter(node.args[0])
860        body, _ = self.visit(node.args[2])
861        get_state, _ = self.visit(node.args[3])
862        loop_carried = [out.value for out in node.args[5].elts]
863        # TODO(fengliuai): opt is not used here.
864        return self._visit_for_stmt(range_, body, get_state, loop_carried, node)
865      elif func_name == 'Undefined':
866        val = self._ssa_name(node.args[0].value)
867        return (val, TFRTypes.AG_UNDEFINED_VAL)
868      elif func_name == 'UndefinedReturnValue':
869        val = self._ssa_name('return_val')
870        return (val, TFRTypes.AG_UNDEFINED_VAL)
871
872    if func_type == TFRTypes.TF_RAW_OP:
873      return self._visit_tf_op(func_name, node.args, node.keywords, node)
874
875    if func_type == types.FunctionType:
876      return self._visit_tf_op(func_name, node.args, node.keywords, node)
877
878    if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
879      return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST)
880
881    if func_type == TFRTypes.PY_BUILTIN_FUNC:
882      if func_name == 'len':
883        arg, ty = self.visit(node.args[0])
884        ty = self._get_inferred_type(node.args[0], ty)
885        if ty == TFRTypes.TF_TENSOR_SHAPE_LIST:
886          len_value = self._ssa_name('len')
887          self._emit_with_loc(
888              '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
889                  len_value, arg), node)
890          size_value = self._ssa_name('len_size')
891          self._emit_with_loc(
892              '\n{} = shape.size_to_index {} : !shape.size'.format(
893                  size_value, len_value), node)
894        elif ty == TFRTypes.TENSOR_LIST:
895          size_value = self._ssa_name('len')
896          self._emit_with_loc(
897              '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node)
898        return (size_value, TFRTypes.INDEX)
899
900    raise NotImplementedError('call operator not recognized: {} {}'.format(
901        func_name, func_type))
902
903  def visit_Compare(self, node):
904    lhs, lhs_ty = self.visit(node.left)
905    for op, right in zip(node.ops, node.comparators):
906      rhs, rhs_ty = self.visit(right)
907      if isinstance(op, ast.Eq):
908        pred = 'eq'
909      elif isinstance(op, ast.Lt):
910        pred = 'ult'
911      elif isinstance(op, ast.LtE):
912        pred = 'ule'
913      elif isinstance(op, ast.Gt):
914        pred = 'ugt'
915      elif isinstance(op, ast.GtE):
916        pred = 'uge'
917      elif isinstance(op, ast.NotEq):
918        pred = 'ne'
919      else:
920        raise NotImplementedError('Compare operator not recognized')
921
922      ret = self._ssa_name(pred)
923      if lhs_ty == TFRTypes.ATTR:
924        self._emit_with_loc(
925            '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node)
926      else:
927        if lhs_ty == TFRTypes.I64:
928          code = 'cmpi'
929        elif lhs_ty == TFRTypes.F32:
930          code = 'cmpf'
931        elif lhs_ty == TFRTypes.INDEX:
932          code = 'cmpi'
933          # TODO(fengliuai): the reverse type inference should solve the issue.
934          rhs, _ = self._i64_to_index(rhs, rhs_ty)
935        else:
936          raise NotImplementedError('Compare operand type not recognized')
937        self._emit_with_loc(
938            '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs,
939                                                 lhs_ty), node)
940
941      return ret, TFRTypes.I1
942
943  def visit_Constant(self, node):
944    cst_name = self._ssa_name('cst')
945    if node.value is None:
946      cst_ty = TFRTypes.NONE
947    elif isinstance(node.value, bool):
948      cst_ty = self._get_inferred_type(node)
949      cst_val = str(node.value).lower()
950      self._emit_with_loc('\n{} = constant {}'.format(cst_name, cst_val), node)
951    else:
952      cst_ty = self._get_inferred_type(node)
953      cst_val = node.value
954      if cst_ty == TFRTypes.ATTR:
955        self._emit_with_loc(
956            '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty),
957            node)
958      else:
959        self._emit_with_loc(
960            '\n{} = constant {} : {}'.format(cst_name, cst_val, cst_ty), node)
961    return cst_name, cst_ty
962
963  def visit_FunctionDef(self, node):
964    op_def, derived_attrs = self._op_defs.lookup(node.name, node, True)
965    if op_def is None:
966      # Nested function. Insert it to symbol table for looking up later.
967      self.symbol_table.insert_symbol(node.name, node, None)
968      return
969    op_name = op_def.name
970    if self.symbol_table.lookup(op_name):
971      raise LookupError('Composition has not been registered for op: ' +
972                        op_name)
973    else:
974      self.symbol_table.insert_symbol(node.name, None, None)
975
976    self.symbol_table.enter_scope()
977    self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name)))
978
979    arg_list = []
980    idx = 0
981    max_idx = len(op_def.input_arg) + len(op_def.attr)
982    for arg in node.args.args:
983      arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN))
984      arg_type = anno.getanno(arg, anno.Static.TYPES)[0]
985
986      arg_attr = ''
987      if idx >= len(op_def.input_arg):
988        attr_def = op_def.attr[idx - len(op_def.input_arg)]
989        # skip the derived attributes
990        while attr_def.name in derived_attrs and (idx + 1) < max_idx:
991          idx += 1
992          attr_def = op_def.attr[idx - len(op_def.input_arg)]
993        if idx >= max_idx:
994          raise ValueError('Argument is not defined in OpDef: ' + arg_name)
995
996        arg_attr += '{{tfr.name="{}"'.format(attr_def.name)
997        if attr_def.HasField('default_value'):
998          default_val = _get_val_from_proto(arg_type, attr_def.default_value)
999          arg_attr += ',tfr.default={}'.format(default_val)
1000        arg_attr += '}'
1001
1002      idx += 1
1003      arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr)
1004      arg_list.append(arg_str)
1005      self.symbol_table.insert_symbol(arg.id, arg_name, arg_type)
1006
1007    ret_type_list = []
1008    for ret_def in op_def.output_arg:
1009      if ret_def.number_attr or ret_def.type_list_attr:
1010        ret_type_list.append(str(TFRTypes.TENSOR_LIST))
1011      else:
1012        ret_type_list.append(str(TFRTypes.TENSOR))
1013
1014    self.emit('{}) -> ({}) {{'.format(', '.join(arg_list),
1015                                      ', '.join(ret_type_list)))
1016    self.visit_block(node.body)
1017    self._emit_with_loc('\n}', node)
1018    self.symbol_table.exit_scope()
1019
1020  def visit_arguments(self, node):
1021    # TODO(fengliuai): return ordered the types and names.
1022    # We need to order the arguments to match the assumption in the TFR dialect.
1023    raise NotImplementedError('arguments not supported.')
1024
1025  def visit_Lambda(self, node):
1026    raise NotImplementedError('Lambda not supported.')
1027
1028  def _get_mlir_ssa_values(self, name_prefix, out_types):
1029    """Create MLIR convention SSA values."""
1030    out_ssa_values = []
1031    if not out_types:
1032      return '', out_ssa_values
1033
1034    out_name = self._ssa_name(name_prefix)
1035    if len(out_types) == 1:
1036      out_name_suffix = ''
1037      out_ssa_values.append(out_name)
1038    else:
1039      # For multiple returns, MLIR uses '%s:i' when they are defined and
1040      # '%s#i' when they are used.
1041      out_name_suffix = ':{}'.format(len(out_types))
1042      for idx, _ in enumerate(out_types):
1043        out_ssa_values.append('{}#{}'.format(out_name, idx))
1044
1045    return '{}{}'.format(out_name, out_name_suffix), out_ssa_values
1046
1047  def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols,
1048                     node):
1049    self.emit('\n')
1050    ret_str, ret_ssa_values = self._get_mlir_ssa_values(
1051        'if_stmt', [TFRTypes.TENSOR] * len(out_symbols))
1052    if ret_ssa_values:
1053      self.emit(ret_str + ' = ')
1054
1055    out_types = []
1056    for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
1057      out_types.append(str(TFRTypes.TENSOR))
1058
1059    self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types)))
1060    # Create a new scope in case the local variables are leaked.
1061    self.symbol_table.enter_scope(scf_scope=True)
1062    self.visit_block(body_def.body)
1063    self.visit_block(get_state.body)
1064    self.symbol_table.exit_scope()
1065
1066    self.emit('\n} else {')
1067
1068    # Create a new scope in case the local variables are leaked.
1069    self.symbol_table.enter_scope(scf_scope=True)
1070    self.visit_block(orelse_def.body)
1071    self.visit_block(get_state.body)
1072    self.symbol_table.exit_scope()
1073
1074    # add ssa values to the symbol table
1075    for symbol, ssa_value in zip(out_symbols, ret_ssa_values):
1076      self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR)
1077
1078    self._emit_with_loc('\n}', node)
1079    return list(zip(ret_ssa_values, out_types))
1080
1081  def _visit_iter(self, node):
1082    if isinstance(node, ast.Call):
1083      f_name = anno.getanno(node.func, anno.Basic.QN)
1084      if f_name == QN('range'):
1085        args = [self.visit(arg) for arg in node.args]
1086        begin = None
1087        step = None
1088        end = None
1089        if len(args) == 1:
1090          end, end_ty = args[0]
1091        elif len(args) == 2:
1092          begin, begin_ty = args[0]
1093          end, end_ty = args[1]
1094        elif len(args) == 3:
1095          begin, begin_ty = args[0]
1096          end, end_ty = args[1]
1097          step, step_ty = args[2]
1098
1099        if begin is None:
1100          begin = self._ssa_name('begin')
1101          self._emit_with_loc('\n{} = constant 0 : index'.format(begin), node)
1102        elif begin_ty != TFRTypes.INDEX:
1103          begin_ = self._ssa_name('begin')
1104          self._emit_with_loc(
1105              '\n{} = index_cast {} : {} to index'.format(
1106                  begin_, begin, begin_ty), node)
1107          begin = begin_
1108
1109        if end_ty != TFRTypes.INDEX:
1110          end_ = self._ssa_name('end')
1111          self._emit_with_loc(
1112              '\n{} = index_cast {} : {} to index'.format(end_, end, end_ty),
1113              node)
1114          end = end_
1115
1116        if step is None:
1117          step = self._ssa_name('step')
1118          self._emit_with_loc('\n{} = constant 1 : index'.format(step), node)
1119        elif step_ty != TFRTypes.INDEX:
1120          step_ = self._ssa_name('step')
1121          self._emit_with_loc(
1122              '\n{} = index_cast {} : {} to index'.format(step_, step, step_ty),
1123              node)
1124          step = step_
1125
1126        return begin, end, step
1127
1128    raise NotImplementedError('Iterator entity not supported.' + node)
1129
1130  def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node):
1131    self.emit('\n')
1132    ret_str, ret_ssa_values = self._get_mlir_ssa_values(
1133        'for_stmt', [TFRTypes.TENSOR] * len(loop_carried))
1134    if ret_ssa_values:
1135      self.emit(ret_str + ' = ')
1136
1137    # Before enter the loop, we use the original ssa values as the initial
1138    # values to the loop iteration arguments. We also create new ssa values as
1139    # the returns of the scf for statements. The symbol table needs to be
1140    # updated to these new ssa values before it enters the scope of the loop.
1141    out_types = []
1142    init_values = []
1143    for symbol, ssa_value in zip(loop_carried, ret_ssa_values):
1144      init, ty = self.symbol_table.lookup(symbol)
1145      self.symbol_table.insert_symbol(symbol, ssa_value, ty)
1146      out_types.append(str(ty))
1147      init_values.append((init, ty))
1148
1149    # Create a new scope in case the local variables are leaked.
1150    self.symbol_table.enter_scope(scf_scope=True)
1151
1152    # Create the iteration variable with index type
1153    assert len(body_def.args.args) == 1
1154    it_name = body_def.args.args[0].id
1155    it = self._ssa_name(it_name)
1156    self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX)
1157
1158    self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1],
1159                                                      range_[2]))
1160    if loop_carried:
1161      iter_args = []
1162      for symbol, init in zip(loop_carried, init_values):
1163        # create new ssa values for the loop carried variables
1164        it_arg = self._ssa_name('it_arg')
1165        self.symbol_table.insert_symbol(symbol, it_arg, init[1])
1166        iter_args.append('{} = {}'.format(it_arg, init[0]))
1167      self.emit('iter_args({}) '.format(', '.join(iter_args)))
1168      self.emit('-> ({}) {{'.format(', '.join(out_types)))
1169    else:
1170      self.emit(' {')
1171    self.visit_block(body_def.body)
1172    self.visit_block(get_state.body)
1173    self.symbol_table.exit_scope()
1174    self._emit_with_loc('\n}', node)
1175    return list(zip(ret_ssa_values, out_types))
1176
1177  def _emit_default_constant_from_proto(self, attr_def):
1178    """emit mlir constant statement from default value of the ArgDef proto."""
1179    name = self._ssa_name('cst')
1180    cst_ty = _get_type_from_proto(None, attr_def)
1181    cst_val = _get_val_from_proto(cst_ty, attr_def.default_value)
1182    if cst_ty == TFRTypes.ATTR:
1183      self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format(
1184          name, cst_val, cst_ty))
1185    elif cst_ty == TFRTypes.I1:
1186      self._emit_with_loc('\n{} = constant {}'.format(name, cst_val))
1187    else:
1188      self._emit_with_loc('\n{} = constant {} : {}'.format(
1189          name, cst_val, cst_ty))
1190    return name, cst_ty
1191
1192  def visit_keyword(self, node):
1193    return node.arg, self.visit(node.value)
1194
1195  def _visit_tf_op(self, op_name, args, keywords, node):
1196    op_def, derived_attrs = self._op_defs.lookup(op_name)
1197    ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg]
1198
1199    ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys)
1200
1201    arg_strs = []
1202    ty_strs = []
1203    for arg in args:
1204      value, ty = self.visit(arg)
1205      arg_strs.append(value)
1206      ty_strs.append(str(ty))
1207
1208    input_args = [arg for arg in op_def.input_arg]
1209    attrs_no_default = [
1210        attr for attr in op_def.attr
1211        if not attr.HasField('default_value') and attr.name not in derived_attrs
1212    ]
1213    attrs_with_default = [
1214        attr for attr in op_def.attr
1215        if attr.HasField('default_value') and attr.name not in derived_attrs
1216    ]
1217
1218    kw_args = {}
1219    for arg in keywords:
1220      value, (ssa_name, ty) = self.visit(arg)
1221      ty = self._get_inferred_type(arg.value, ty)
1222
1223      # TODO(fengliuai): implement the "rename_to" for the customization in
1224      # tensorflow/core/api_def/base_api/*
1225      if value == 'axis':
1226        value = 'split_dim'
1227
1228      kw_args[value] = (ssa_name, ty)
1229
1230    # tensor arguments and attribute arguments
1231    ordered_args = input_args + attrs_no_default + attrs_with_default
1232    for attr_def in ordered_args[len(args):]:
1233      if attr_def.name in kw_args:
1234        value, ty = kw_args[attr_def.name]
1235        if attr_def in input_args:
1236          if ty in _attribute_types:
1237            # the argument shouldn't be used as tf op calls directly.
1238            value, ty = self._value_to_tensor(value, ty, node)
1239          if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def):
1240            value, ty = self._pack_tensor_list(value)
1241      else:
1242        value, ty = self._emit_default_constant_from_proto(attr_def)
1243      arg_strs.append(value)
1244      ty_strs.append(str(ty))
1245
1246    if ret_ssa_values:
1247      self.emit('\n{} = '.format(ret_str))
1248
1249    self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name)))
1250    arg_str = ', '.join(arg_strs)
1251    arg_ty_str = ', '.join(ty_strs)
1252    ret_ty_str = ', '.join([str(ty) for ty in ret_tys])
1253    self._emit_with_loc(
1254        '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node)
1255    return list(zip(ret_ssa_values, ret_tys))
1256
1257  def visit_If(self, node):
1258    raise NotImplementedError('If not supported.')
1259
1260  def visit_Name(self, node):
1261    val_and_lookup_type = self.symbol_table.lookup(node.id)
1262    if val_and_lookup_type:
1263      (val, lookup_type) = val_and_lookup_type
1264    else:
1265      op_def, _ = self._op_defs.lookup(node.id)
1266      val = op_def.name
1267      lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
1268    type_ = self._get_inferred_type(node, lookup_type)
1269    return val, type_
1270
1271  def visit_Return(self, node):
1272    values = self.visit(node.value)
1273    if self.symbol_table.in_scf_scope():
1274      self.emit('\nscf.yield ')
1275    else:
1276      self.emit('\ntfr.return ')
1277    if not values:
1278      return
1279
1280    if isinstance(values, list):
1281      vals, tys = zip(*values)
1282    else:
1283      vals = values[0]
1284      tys = values[1]
1285
1286    if isinstance(tys, list) or isinstance(tys, tuple):
1287      tys = [str(t) for t in tys]
1288      self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)),
1289                          node)
1290    elif tys != TFRTypes.NONE:
1291      # TODO(fengliuai): scf region yield uses this branch. Fix it.
1292      self._emit_with_loc('{} : {}'.format(vals, tys), node)
1293
1294  def visit_Subscript(self, node):
1295    val, ty = self.visit(node.value)
1296    type_ = self._get_inferred_type(node.value, ty)
1297
1298    # TODO(fengliuai): Here we hardcode the node.slice here to get the index
1299    # type. Use the visit method once the type inference is done.
1300    # slice_val, slice_ty = self.visit(node.slice)
1301    s = node.slice
1302    if not isinstance(s, (ast.Tuple, ast.Slice)):
1303      if isinstance(s, ast.Constant):
1304        # TODO(fengliuai): promote to an assignment
1305        idx_val = self._ssa_name('cst')
1306        self._emit_with_loc(
1307            '\n{} = constant {} : index'.format(idx_val, s.value), node)
1308      else:
1309        idx_val, _ = self.visit(s)
1310    else:
1311      raise NotImplementedError('non-index slice not supported.')
1312
1313    elt = self._ssa_name('elt')
1314    if type_ == TFRTypes.TENSOR_LIST:
1315      self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val))
1316      self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node)
1317      return (elt, TFRTypes.TENSOR)
1318    elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST:
1319      size_ = self._ssa_name('size')
1320      self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val))
1321      self._emit_with_loc(': !shape.shape, index -> !shape.size', node)
1322      self._emit_with_loc(
1323          '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_),
1324          node)
1325      return (elt, TFRTypes.INDEX)
1326
1327  def visit_List(self, node):
1328    out_type = self._get_inferred_type(node)
1329    vals = []
1330    tys = []
1331    for elt in node.elts:
1332      val, ty = self.visit(elt)
1333      ty = self._get_inferred_type(elt, ty)
1334      if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST:
1335        # This list is a tensor list, then cast all the input values to tensors.
1336        val, ty = self._value_to_tensor(val, ty, node)
1337      else:
1338        # We shouldn't use index type to build the list because list will be use
1339        # as attribute.
1340        val, ty = self._index_to_I64(val, ty)
1341      vals.append(val)
1342      tys.append(str(ty))
1343
1344    list_val = self._ssa_name('list')
1345    self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals)))
1346    self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node)
1347    return (list_val, out_type)
1348
1349  def visit_Tuple(self, node):
1350    return [self.visit(elt) for elt in node.elts]
1351
1352  def visit_UnaryOp(self, node):
1353    value, ty = self.visit(node.operand)
1354    if isinstance(node.op, ast.USub):
1355      zero_value = self._ssa_name('zero')
1356      self._emit_with_loc('\n{} = constant 0 : {}'.format(zero_value, ty), node)
1357      ssa_value = self._ssa_name('cst')
1358      if ty == TFRTypes.I32 or ty == TFRTypes.I64:
1359        self._emit_with_loc(
1360            '\n{} = subi {}, {} : {}'.format(ssa_value, zero_value, value, ty),
1361            node)
1362      elif ty == TFRTypes.F32:
1363        self._emit_with_loc(
1364            '\n{} = subf {}, {} : {}'.format(ssa_value, zero_value, value, ty),
1365            node)
1366      else:
1367        raise NotImplementedError('USub type not recognized: ' + str(ty))
1368      return ssa_value, ty
1369    raise NotImplementedError('USub operator not recognized')
1370
1371  def visit_For(self, node):
1372    raise NotImplementedError('For operator not recognized')
1373
1374  def visit_While(self, node):
1375    raise NotImplementedError('While operator not recognized')
1376
1377  def visit_Try(self, node):
1378    # Only handles the body of the try statement.
1379    self.visit_block(node.body)
1380
1381
1382def _apply_py_to_tf_passes(node, ctx):
1383  """Apply transformations from PyToTF to match tf.function tracing."""
1384  # TODO(fengliuai): we don't know which passes are required, thus we evaluate
1385  # each one when the corresponding node is handled.
1386  # copied from PyToTF.transform_ast
1387  node = return_statements.transform(node, ctx, False)
1388  node = control_flow.transform(node, ctx)
1389  return node
1390
1391
1392class TfrGen(transpiler.GenericTranspiler):
1393  """Transforms Python objects into TFR MLIR source code."""
1394
1395  def __init__(self, op_defs):
1396    self._op_defs = op_defs
1397
1398  def transform_ast(self, node, ctx):
1399    node = _apply_py_to_tf_passes(node, ctx)
1400    # TODO(mdan): Enable this.
1401    # node = anf.transform(node, ctx)
1402
1403    graphs = cfg.build(node)
1404    node = qual_names.resolve(node)
1405    node = activity.resolve(node, ctx)
1406    node = reaching_definitions.resolve(node, ctx, graphs)
1407    node = reaching_fndefs.resolve(node, ctx, graphs)
1408    node = type_inference.resolve(node, ctx, graphs,
1409                                  TFRTypeResolver(self._op_defs))
1410
1411    mlir_generator = TFRGen(ctx, self._op_defs)
1412    mlir_generator.visit(node)
1413    return mlir_generator.code_buffer
1414
1415
1416def tfr_gen(func, op_defs):
1417  """Parse a function and emit the TFR functions."""
1418  mlir_code, _ = TfrGen(op_defs).transform(func, None)
1419  assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code)
1420  return mlir_code
1421
1422
1423def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
1424  """Parse the input source module and emit the TFR functions."""
1425  op_defs = OpDefCache()
1426
1427  # Load the op library so the op is added to the op registry. This is
1428  # required when the op cc_library couldn't be statically linked in open
1429  # source.
1430  # This is a no op if the op shared library couldn't be found in the same
1431  # directory of the op Python API.
1432  # TODO(fengliuai): make the .so file path configurable.
1433  if op_libraries:
1434    prefix_len = len('gen_')
1435    for m in op_libraries:
1436      lib_dir = os.path.dirname(m.__file__)
1437      lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so')
1438      lib_path = os.path.join(lib_dir, lib_name)
1439      if os.path.exists(lib_path):
1440        logging.info('load file: ' + lib_path)
1441        load_library.load_op_library(lib_path)
1442  else:
1443    # The op library is generated from the source module, then we load all the
1444    # .so file in the directory
1445    lib_dir = os.path.dirname(source.__file__)
1446    for lib_name in os.listdir(lib_dir):
1447      if lib_name.endswith('.so'):
1448        lib_path = os.path.join(lib_dir, lib_name)
1449        logging.info('load file: ' + lib_path)
1450        load_library.load_op_library(lib_path)
1451
1452  py_funcs = [
1453      func
1454      for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
1455      if not method_prefix or name.startswith(method_prefix)
1456  ]
1457  # Sort the methods by the line number, to make sure the definitions are
1458  # processed before the usages.
1459  # TODO(fengliuai): Use type inference resolver to recursively process any
1460  # functions called.
1461  py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno)
1462  mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs]
1463
1464  return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs())
1465