1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tfr_gen: Generate mlir tfr decomposition function from python code.""" 16 17# pylint: disable=invalid-name 18# pylint: disable=missing-function-docstring 19# pylint: disable=g-direct-tensorflow-import 20 21import enum 22import os 23import re 24import types 25import gast as ast 26 27from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr 28from tensorflow.core.framework import types_pb2 29from tensorflow.python.autograph.converters import control_flow 30from tensorflow.python.autograph.converters import return_statements 31from tensorflow.python.autograph.impl import api 32from tensorflow.python.autograph.pyct import anno 33from tensorflow.python.autograph.pyct import cfg 34from tensorflow.python.autograph.pyct import qual_names 35from tensorflow.python.autograph.pyct import transformer 36from tensorflow.python.autograph.pyct import transpiler 37from tensorflow.python.autograph.pyct.static_analysis import activity 38from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 39from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 40from tensorflow.python.autograph.pyct.static_analysis import type_inference 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import load_library 43from tensorflow.python.framework import op_def_registry 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import tf_inspect 46 47# TODO(mdan): Use class definitions so that we can mix these with Python types. 48 49 50class TFRTypes(enum.Enum): 51 """All the supported types. 52 53 1-3: tfr types 54 4-99: mlir built-in types 55 100-199: TF related translator internal types 56 200- : Python related translator internal types 57 """ 58 TENSOR = 1 59 TENSOR_LIST = 2 60 ATTR = 3 61 NONE = 4 62 SHAPE = 5 # shape -> !shape.shape 63 I1 = 21 64 I8 = 22 65 I16 = 23 66 I32 = 24 67 I64 = 25 68 F32 = 26 69 INDEX = 27 70 AG_UNDEFINED_VAL = 100 71 AG_BUILTIN_FUNC = 101 72 TF_RAW_OP = 102 73 TF_REGION = 103 74 TF_TENSOR_SHAPE_FUNC = 104 # shape.as_list 75 TF_TENSOR_SHAPE_LIST = 105 # shape.as_list() 76 PY_BUILTIN_FUNC = 200 77 TFR_BUILTIN_FUNC = 201 78 79 # As these are not real types, __getattribute__ helps them appear more like 80 # actual types (i.e. class definitions). 81 def __getattribute__(self, name): 82 if name == 'shape' and object.__getattribute__(self, 'value') == 1: 83 return TFRTypes.SHAPE 84 if name == 'as_list' and object.__getattribute__(self, 'value') == 5: 85 return TFRTypes.TF_TENSOR_SHAPE_FUNC 86 return object.__getattribute__(self, name) 87 88 def __str__(self): 89 if self.value < 4: # pylint: disable=comparison-with-callable 90 return '!tfr.' + self.name.lower() 91 elif self.value < 10: # pylint: disable=comparison-with-callable 92 return '!shape.' + self.name.lower() 93 else: 94 return self.name.lower() 95 96 97_ATTRIBUTE_TYPES = ( 98 TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX, 99 TFRTypes.ATTR 100 ) 101 102# TODO(b/203493652): implement the "rename_to" for the customization in 103# tensorflow/core/api_def/base_api/* 104# {op_name: {API's attribute name: OpDef's attribute name}} 105_ATTRIBUTE_RENAMES = { 106 'Mean': {'axis': 'reduction_indices'}, 107 'Split': {'axis': 'split_dim'}, 108 'SplitV': {'axis': 'split_dim'}, 109} 110 111 112def _get_type_from_proto(arg_def=None, attr_def=None): 113 if not arg_def: 114 if attr_def.type == 'bool': 115 return TFRTypes.I1 116 elif attr_def.type == 'int32': 117 return TFRTypes.I32 118 elif attr_def.type == 'int' or attr_def.type == 'int64': 119 return TFRTypes.I64 120 elif attr_def.type == 'float': 121 return TFRTypes.F32 122 else: 123 return TFRTypes.ATTR 124 125 if arg_def.number_attr or arg_def.type_list_attr: 126 return TFRTypes.TENSOR_LIST 127 else: 128 return TFRTypes.TENSOR 129 130 131def _get_type_info_from_proto(arg_def=None, attr_def=None): 132 attr_type = _get_type_from_proto(arg_def, attr_def) 133 if not arg_def: 134 return '{}{{tfr.name="{}",tfr.type="{}"}}'.format( 135 attr_type, attr_def.name, attr_def.type) 136 else: 137 attr_names = [] 138 if arg_def.number_attr: 139 attr_names.append(arg_def.number_attr) 140 if arg_def.type_attr: 141 attr_names.append(arg_def.type_attr) 142 if arg_def.type_list_attr: 143 attr_names.append(arg_def.type_list_attr) 144 145 # TODO(fengliuai): currently we don't support backward type inference, so we 146 # have to store these non-derivable type in the signatures, and then they 147 # can be used to cast the values when raising to tf ops. 148 if arg_def.type == types_pb2.DT_FLOAT: 149 attr_names.append('f32_') 150 elif arg_def.type == types_pb2.DT_INT32: 151 attr_names.append('i32_') 152 elif arg_def.type == types_pb2.DT_INT64: 153 attr_names.append('i64_') 154 elif arg_def.type == types_pb2.DT_BOOL: 155 attr_names.append('i1_') 156 157 if not attr_names: 158 return str(attr_type) 159 else: 160 return '{}<{}>'.format(attr_type, ','.join(attr_names)) 161 162 163def _get_val_from_proto(attr_type, attr_val): 164 if attr_type == TFRTypes.I1: 165 return 'true' if attr_val.b else 'false' 166 elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64: 167 return attr_val.i 168 elif attr_type == TFRTypes.F32: 169 return attr_val.f 170 elif attr_type == TFRTypes.ATTR: 171 # string 172 if attr_val.HasField('s'): 173 return '"{}"'.format(attr_val.s.decode()) 174 # type 175 if attr_val.HasField('type'): 176 if attr_val.type == types_pb2.DT_FLOAT: 177 return 'f32' 178 elif attr_val.type == types_pb2.DT_INT32: 179 return 'i32' 180 elif attr_val.type == types_pb2.DT_INT64: 181 return 'i64' 182 elif attr_val.type == types_pb2.DT_BOOL: 183 return 'i1' 184 # list 185 if attr_val.HasField('list'): 186 if attr_val.list.f: 187 elt_ty = TFRTypes.F32 188 values = attr_val.list.f 189 elif attr_val.list.i: 190 elt_ty = TFRTypes.I64 191 values = attr_val.list.i 192 else: 193 elt_ty = TFRTypes.NONE 194 values = [] 195 array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values] 196 return '[{}]'.format(','.join(array_attr_elts)) 197 raise NotImplementedError( 198 'Proto AttrValue not recognized. type: {}, value: {}'.format( 199 attr_type, attr_val)) 200 201 202def _collect_derived_attrs_from_proto(op_def): 203 derived_attrs = set() 204 for arg in op_def.input_arg: 205 if arg.type_attr: 206 derived_attrs.add(arg.type_attr) 207 if arg.number_attr: 208 derived_attrs.add(arg.number_attr) 209 if arg.type_list_attr: 210 derived_attrs.add(arg.type_list_attr) 211 212 # TODO(fengliuai): currently we don't support backward type inference, so we 213 # have to store these non-derivable type in the signatures, and then they 214 # can be used to cast the values when raising to tf ops. 215 if arg.type == types_pb2.DT_FLOAT: 216 derived_attrs.add('f32_') 217 elif arg.type == types_pb2.DT_INT32: 218 derived_attrs.add('i32_') 219 elif arg.type == types_pb2.DT_INT64: 220 derived_attrs.add('i64_') 221 elif arg.type == types_pb2.DT_BOOL: 222 derived_attrs.add('i1_') 223 return derived_attrs 224 225 226def _require_tensor_list(arg_def): 227 return arg_def.type_list_attr or arg_def.number_attr 228 229 230def _camel_to_snake(name): 231 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 232 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 233 234 235class OpDefCache(object): 236 """A Dict to cache the OpDef for the Python function name.""" 237 238 def __init__(self): 239 self._op_defs = {} 240 241 def lookup(self, f_name, func_def=None, optional=False): 242 if f_name in self._op_defs: 243 return self._op_defs[f_name] 244 245 if isinstance(func_def, types.FunctionType): 246 if not hasattr(func_def, '_tfr_op_name'): 247 # skip a non-composition function 248 if optional: 249 return (None, None) 250 else: 251 raise KeyError('OpDef does not exist: ' + f_name) 252 op_name = getattr(func_def, '_tfr_op_name') 253 elif not func_def: 254 op_name = f_name 255 else: 256 # TODO(fengliuai): create one utility method to match different APIs. 257 compose_dec = [] 258 for dec in func_def.decorator_list: 259 if isinstance(dec, ast.Call): 260 if isinstance(dec.func, 261 ast.Attribute) and dec.func.attr == 'Composite': 262 compose_dec.append(dec) 263 if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': 264 compose_dec.append(dec) 265 266 if not compose_dec: 267 # skip a non-composition function 268 if optional: 269 return (None, None) 270 else: 271 raise KeyError('OpDef does not exist: ' + f_name) 272 elif len(compose_dec) > 1: 273 raise KeyError('More than one TF ops decomposes for.') 274 else: 275 op_name = compose_dec[0].args[0].value 276 277 op_def = op_def_registry.get(op_name) 278 if not op_def: 279 raise ValueError('Not a registered op: ' + op_name) 280 derived_attrs = _collect_derived_attrs_from_proto(op_def) 281 self._op_defs[f_name] = (op_def, derived_attrs) 282 return (op_def, derived_attrs) 283 284 def mlir_external_funcs(self): 285 tfr_funcs = set() 286 for _, (op_def, derived_attrs) in sorted(self._op_defs.items()): 287 tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name)) 288 289 # tensor inputs 290 inputs = [ 291 _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg 292 ] 293 294 # attribute inputs. The attribute with default values are moved backwards. 295 non_derived_attrs = [ 296 attr for attr in op_def.attr if attr.name not in derived_attrs 297 ] 298 attrs_no_default = [ 299 attr for attr in non_derived_attrs 300 if not attr.HasField('default_value') 301 ] 302 attrs_with_default = [ 303 attr for attr in non_derived_attrs if attr.HasField('default_value') 304 ] 305 attr_names = {'f32_', 'i32_', 'i64_', 'i1_'} # reserved 306 for attr_def in attrs_no_default + attrs_with_default: 307 inputs.append(_get_type_info_from_proto(None, attr_def)) 308 attr_names.add(attr_def.name) 309 310 # tensor outputs 311 outputs = [ 312 _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg 313 ] 314 315 inputs = ','.join(inputs) 316 outputs = ','.join(outputs) 317 attrs = ','.join(sorted(derived_attrs.union(attr_names))) 318 tfr_funcs.add('{}{}) -> ({}) attributes {{{}}}'.format( 319 tfr_func, inputs, outputs, attrs)) 320 return sorted(list(tfr_funcs)) 321 322 323_PY_TYPE_TO_TFR = { 324 bool: TFRTypes.I1, 325 int: TFRTypes.I64, 326 float: TFRTypes.F32, 327} 328 329_TF_DTYPE_TO_TFR = { 330 'bool': TFRTypes.I1, 331 'int64': TFRTypes.I64, 332 'int32': TFRTypes.I32, 333 'int16': TFRTypes.I16, 334 'int8': TFRTypes.I8, 335 'float32': TFRTypes.F32, 336} 337 338_AG_FIXED_RETURN_TYPE = { 339 'for_stmt': type(None), 340 'if_stmt': type(None), 341 'Undefined': TFRTypes.AG_UNDEFINED_VAL, 342} 343 344QN = qual_names.QN 345 346# TODO(mdan): Fix this with an importable module. 347AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__'] # pylint:disable=protected-access 348 349# When an item is callable, the signature is (*operand_types) -> result_type(s) 350TFR_BUILTINS = { 351 '_tfr_quant_act_range': (TFRTypes.TENSOR, TFRTypes.TENSOR), 352 '_tfr_quant_rescale': TFRTypes.TENSOR, 353 '_tfr_quant_raw_data': lambda input_type: input_type, 354 '_tfr_quant_qparam': (TFRTypes.TENSOR, TFRTypes.TENSOR), 355 '_tfr_quant_scale_factor': TFRTypes.TENSOR, 356} 357 358 359class TFRTypeResolver(type_inference.Resolver): 360 """Resolve types for the external names, calls and arguments.""" 361 362 def __init__(self, op_defs): 363 super(TFRTypeResolver, self).__init__() 364 self._op_defs = op_defs 365 366 # This pattern matching mechanism works with the functional form generated 367 # by autograph: 368 # 369 # for i in data: 370 # print(i) 371 # 372 # generates: 373 # 374 # def loop_body(itr): 375 # i = itr 376 # print(i) 377 # ag__.for_stmt(target) 378 # 379 # The mechanism lets us infer the type of the itr argument based on that of 380 # target. 381 self._for_loop_target_types = {} # Maps body function name to iterated. 382 self._for_loop_body_fns = {} # Used only to avoid collisions. 383 384 def res_name(self, ns, types_ns, name): 385 name_str = str(name) 386 if name_str in TFR_BUILTINS: 387 return {TFRTypes.TFR_BUILTIN_FUNC}, name_str 388 if name_str in ns: 389 ns_val = ns[name_str] 390 return {type(ns_val)}, ns_val 391 if name_str in __builtins__: 392 return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str] 393 # This name is not in the namespace because the autograph transformation 394 # is not backloaded into Python. 395 if name_str == 'ag__': 396 return {type(AG_MODULE)}, AG_MODULE 397 398 return None, None 399 400 def res_value(self, ns, value): 401 # resolves the type of the symbol by the metadata in 'value' 402 if value is None: 403 return {TFRTypes.NONE} 404 if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC): 405 # See TFRTypes.__getattribute__. 406 # TODO(mdan): Replacing the enum with classes would avoid this overlap. 407 return {value} 408 # TODO(mdan): Index more efficiently. Could do a name check instead. 409 if any(v is value for v in AG_MODULE.__dict__.values()): 410 return {TFRTypes.AG_BUILTIN_FUNC} 411 if getattr(value, '__name__', None) == 'tensorflow.raw_ops': 412 return {types.ModuleType} 413 if hasattr(value, '__module__'): 414 if isinstance(value, dtypes.DType): 415 return {TFRTypes.ATTR} 416 417 # All the imported operations, which are not autograph built-ins, are 418 # considered to be TF raw ops. 419 # TODO(fengliuai): refine the condition that we only match TensorFlow 420 # ops here. 421 return {TFRTypes.TF_RAW_OP} 422 # TODO(mdan): Is ATTR equivalent to string? 423 return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)} 424 425 def res_call(self, ns, types_ns, node, f_type, args, keywords): 426 # resolves the return type of the function call. 427 name = anno.Basic.QN.of(node.func) 428 if f_type == (TFRTypes.AG_BUILTIN_FUNC,): 429 430 if name == QN(QN('ag__'), attr='if_stmt'): 431 nouts = node.args[6].value 432 # TODO(mdan): Look at the actual types out of if_body. 433 side_effects = { 434 qual_names.QN(n.value): {TFRTypes.TENSOR} 435 for n in node.args[5].elts[:nouts] 436 } 437 return {type(None)}, side_effects 438 439 if name == QN(QN('ag__'), attr='for_stmt'): 440 assert isinstance(node.args[2], ast.Name) 441 body_fn_name = str(anno.Basic.QN.of(node.args[2])) 442 assert body_fn_name not in self._for_loop_body_fns, ( 443 'Previously used here: {}. Are you reusing the Resolver across ' 444 'transformations?').format(self._for_loop_body_fns[body_fn_name]) 445 self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node) 446 447 iterated_type = args[0] 448 assert iterated_type & { 449 TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR 450 }, ( 451 iterated_type) 452 self._for_loop_target_types[body_fn_name] = iterated_type 453 454 return {type(None)}, None 455 456 # TODO(mdan): Actually resolve the type here instead. 457 ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None) 458 if ret_type is not None: 459 return {ret_type}, None 460 raise NotImplementedError('return type of {}'.format(name)) 461 462 elif f_type == (TFRTypes.TF_RAW_OP,): 463 # This is a TF operation, so it should be found in the op_defs. 464 op_name = name.qn[1] 465 op_def, _ = self._op_defs.lookup(op_name) 466 if len(op_def.output_arg) == 1: 467 return {_get_type_from_proto(op_def.output_arg[0])}, None 468 return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, 469 None) 470 471 elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): 472 assert name.is_simple() 473 if name == QN('range'): 474 return {TFRTypes.ATTR}, None 475 476 if name == QN('len'): 477 return {TFRTypes.INDEX}, None 478 479 elif f_type == (TFRTypes.TFR_BUILTIN_FUNC,): 480 op_name = name.qn[0] 481 if callable(TFR_BUILTINS[op_name]): 482 return {TFR_BUILTINS[op_name](*[list(arg)[0] for arg in args])}, None 483 return {TFR_BUILTINS[op_name]}, None 484 485 elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,): 486 return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None 487 488 elif f_type == (types.FunctionType,): 489 # This is a function call which isn't using tf.raw_op.. 490 op_name = name.qn[0] 491 492 # 'new TF operation' produces outputs defined by the composition function. 493 op_def, _ = self._op_defs.lookup(op_name) 494 if len(op_def.output_arg) == 1: 495 return {_get_type_from_proto(op_def.output_arg[0])}, None 496 return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, 497 None) 498 499 raise NotImplementedError('Function:', name, f_type) 500 501 def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): 502 if f_is_local: 503 f_name_str = str(f_name) 504 if f_name_str in self._for_loop_target_types: 505 # See autograph/converters/control_flow.py - the function has a single 506 # argument, the iterate before any expansion. 507 assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR} 508 # Assume all loops are TF loops. Then the iterates are autoboxed into 509 # Tensors. 510 return {TFRTypes.INDEX} 511 else: 512 return None 513 514 func = ns[f_name] 515 516 op_def, derived_attrs = self._op_defs.lookup(f_name, func) 517 if op_def is None: 518 return None 519 pos = tf_inspect.getfullargspec(func).args.index(str(name)) 520 521 if pos < len(op_def.input_arg): 522 arg_def = op_def.input_arg[pos] 523 return {_get_type_from_proto(arg_def)} 524 elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs): 525 non_derived_attr_pos = pos - len(op_def.input_arg) 526 for attr_def in op_def.attr: 527 # derived attribute, skip this one and continue to the next one. 528 if attr_def.name in derived_attrs: 529 continue 530 if non_derived_attr_pos == 0: 531 return {_get_type_from_proto(None, attr_def)} 532 non_derived_attr_pos -= 1 533 534 raise ValueError('Argument is not defined in OpDef: ' + str(name)) 535 536 def res_slice(self, ns, types_ns, node_or_slice, value, slice_): 537 if not value: 538 return value 539 540 if isinstance(value, set): 541 type_tuple = value.pop() 542 if isinstance(type_tuple, tuple): 543 value = {type_tuple[node_or_slice]} 544 else: 545 value = {type_tuple} 546 547 assert len(value) == 1 548 value, = tuple(value) 549 if value == TFRTypes.TF_TENSOR_SHAPE_LIST: 550 # TODO(mdan): This is not entirely correct for multi-element slices. 551 return {int} 552 elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR): 553 # TODO(mdan): This is not entirely correct for multi-element slices. 554 return {TFRTypes.TENSOR} 555 else: 556 return {value} 557 558 def res_compare(self, ns, types_ns, node, left, right): 559 # TODO(fengliuai): make sure left and right are compatible 560 return {TFRTypes.I1} 561 562 def res_unop(self, ns, types_ns, node, opnd): 563 return opnd 564 565 def res_binop(self, ns, types_ns, node, left, right): 566 # TODO(fengliuai): make sure left and right are compatible 567 return left 568 569 def _coerce_to_more_specific_type(self, elt_types): 570 # TODO(mdan): This needs some type theory study. 571 if TFRTypes.INDEX in elt_types: 572 # Constants collapse to indices. 573 elt_types.discard(TFRTypes.I64) 574 if TFRTypes.TENSOR in elt_types: 575 # Constants collapse to tensors. 576 elt_types.discard(TFRTypes.I64) 577 # Indices collapse to tensors. 578 elt_types.discard(TFRTypes.INDEX) 579 return elt_types 580 581 def res_list_literal(self, ns, elt_types): 582 all_elt_types = set() 583 for t in elt_types: 584 all_elt_types |= t 585 586 if len(all_elt_types) != 1: 587 all_elt_types = self._coerce_to_more_specific_type(all_elt_types) 588 589 if len(all_elt_types) != 1: 590 raise ValueError('ambiguous list element types: {}'.format(elt_types)) 591 592 if TFRTypes.TENSOR in all_elt_types: 593 return {TFRTypes.TENSOR_LIST} 594 return {TFRTypes.ATTR} 595 596 597class SymbolTable(object): 598 """Symbol Table for python code.""" 599 600 def __init__(self): 601 self.symbols = [] 602 self.enter_scope() 603 self.scf_scope = 0 604 # reserved key words 605 self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC) 606 607 def enter_scope(self, scf_scope=False): 608 """Enter a new scope - at function level.""" 609 self.symbols.append({'types': {}, 'symbols': {}}) 610 self.curr_table = self.symbols[len(self.symbols) - 1] 611 if scf_scope: 612 self.scf_scope += 1 613 614 def insert_symbol(self, name, value, type_): 615 self.curr_table['symbols'][name] = (value, type_) 616 # TODO(mdan): Use the inferred type rather than tracking it here. 617 # The following field is deprecated. 618 self.curr_table['types'][name] = type_ 619 return value 620 621 def exit_scope(self): 622 self.symbols.pop() 623 self.curr_table = self.symbols[len(self.symbols) - 1] 624 if self.scf_scope > 0: 625 self.scf_scope -= 1 626 627 def in_scf_scope(self): 628 return self.scf_scope > 0 629 630 def lookup(self, name): 631 curr_idx = len(self.symbols) - 1 632 while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): 633 curr_idx -= 1 634 if curr_idx < 0: 635 return None 636 return self.symbols[curr_idx]['symbols'][name] 637 638 639class TFRGen(transformer.CodeGenerator): 640 """Visit the AST and generate MLIR TFR functions.""" 641 642 def __init__(self, ctx, op_defs): 643 super(TFRGen, self).__init__(ctx) 644 self.ctx = ctx 645 self.symbol_table = SymbolTable() 646 self._op_defs = op_defs 647 648 def _create_mlir_loc(self, loc): 649 """Creates mlir location from autograph ORIGIN value. 650 651 Args: 652 loc: OriginInfo 653 654 Returns: 655 A serialized mlir location string. 656 """ 657 if loc is not None and loc.loc.filename: 658 file_name = os.path.basename(loc.loc.filename) 659 return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno, 660 loc.loc.col_offset) 661 else: 662 return 'loc(unknown)' 663 664 def _emit_with_loc(self, op_str, node=None): 665 """Emit the mlir operation with the location associated with the node. 666 667 Args: 668 op_str: The mlir operation string to be emitted. 669 node: The node of the AST tree, the mlir operation translated from. 670 """ 671 loc = '' 672 if node: 673 loc = self._create_mlir_loc( 674 anno.getanno(node, anno.Basic.ORIGIN, default=None)) 675 self.emit(op_str + ' ' + loc) 676 677 def _get_inferred_type(self, node, default=None): 678 """Return single type or a tuple of types if more than one type.""" 679 types_ = anno.getanno(node, anno.Static.TYPES, None) 680 if not types_: 681 print('WARN: no Static.TYPES annotation. Fix the type inference pass: ') 682 self.debug_print(node) 683 return default 684 685 if len(types_) == 1: 686 type_, = types_ 687 else: 688 type_ = types_ 689 690 if default is not None and type_ != default: 691 print('WARN: type annotation {}({}) does not match {}({})'.format( 692 type_, type(type_), default, type(default))) 693 self.debug_print(node) 694 695 return type_ 696 697 def _pack_tensor_list(self, value): 698 # This is packing a list of tensors, then the axis is 0. 699 axis = self._ssa_name('zero') 700 self._emit_with_loc('\n{} = arith.constant 0 : i64'.format(axis)) 701 casted = self._ssa_name('pack') 702 self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis)) 703 self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor') 704 # load the op def of tf.Pack 705 self._op_defs.lookup('Pack') 706 return casted, TFRTypes.TENSOR 707 708 def _index_to_I64(self, value, ty): 709 if ty == TFRTypes.INDEX: 710 casted = self._ssa_name('casted') 711 self._emit_with_loc('\n{} = arith.index_cast {} : index to i64'.format( 712 casted, value)) 713 return casted, TFRTypes.I64 714 else: 715 return value, ty 716 717 def _i64_to_index(self, value, ty): 718 if ty == TFRTypes.I64: 719 casted = self._ssa_name('casted') 720 self._emit_with_loc('\n{} = arith.index_cast {} : i64 to index'.format( 721 casted, value)) 722 return casted, TFRTypes.INDEX 723 else: 724 return value, ty 725 726 def _value_to_tensor(self, value, ty, node): 727 value, ty = self._index_to_I64(value, ty) 728 cst_tensor = self._ssa_name('cst') 729 self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value)) 730 self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node) 731 return cst_tensor, TFRTypes.TENSOR 732 733 def _ssa_name(self, prefix): 734 if isinstance(prefix, qual_names.QN): 735 assert prefix.is_simple(), 'ANF transform should have cleaned this up' 736 prefix = prefix.ssf() 737 return '%' + self.ctx.namer.new_symbol(prefix, set()) 738 739 def _op_def(self, op_name): 740 return op_def_registry.get(op_name) 741 742 def visit_block(self, block): 743 return [self.visit(item) for item in block] 744 745 def visit_Pass(self, node): 746 if self.symbol_table.in_scf_scope(): 747 self._emit_with_loc('\nscf.yield', node) 748 else: 749 self._emit_with_loc('\ntfr.return', node) 750 751 def visit_Attribute(self, node): 752 node_type = self._get_inferred_type(node, None) 753 if isinstance(node.value, ast.Name): 754 if node.value.id == 'ag__': 755 # some variables are assigned with 'ag__.xxx' method, we should handle 756 # them following the autograph convensions. 757 return (node.attr, TFRTypes.AG_BUILTIN_FUNC) 758 759 if node_type == TFRTypes.TF_RAW_OP: 760 # This branch is used when it is inside tensorflow 761 return (node.attr, TFRTypes.TF_RAW_OP) 762 763 if node_type == TFRTypes.ATTR: 764 attr = self._ssa_name('attr') 765 tfr_type = _TF_DTYPE_TO_TFR.get(node.attr) 766 self._emit_with_loc( 767 '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node) 768 return (attr, TFRTypes.ATTR) 769 770 value, _ = self.visit(node.value) 771 tensor_type = self._get_inferred_type(node.value, None) 772 # TODO(fengliuai): use node_type once it 773 if node_type == TFRTypes.SHAPE: 774 print('TODO: use "node_type"') 775 if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR: 776 ssa_value = self._ssa_name('shape') 777 self._emit_with_loc( 778 '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value), 779 node) 780 return (ssa_value, TFRTypes.SHAPE) 781 782 if isinstance(node.value, ast.Attribute): 783 if isinstance(node.value.value, ast.Name): 784 if node.value.value.id == 'tf' and node.value.attr == 'raw_ops': 785 return (node.attr, TFRTypes.TF_RAW_OP) 786 787 value, ty = self.visit(node.value) 788 # TODO(fengliuai): use node_type once it 789 if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: 790 print('TODO: use "node_type"') 791 if ty == TFRTypes.SHAPE and node.attr == 'as_list': 792 return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC) 793 794 raise NotImplementedError('Attribute kind not recognized.') 795 796 def visit_Assign(self, node): 797 values = self.visit(node.value) 798 if isinstance(node.targets[0], ast.Tuple): 799 targets = [elt.id for elt in node.targets[0].elts] 800 elif isinstance(node.targets[0], ast.Name): 801 targets = [node.targets[0].id] 802 else: 803 raise NotImplementedError('Assignment target type not recognized.') 804 805 if isinstance(values, list): 806 if isinstance(node.value, ast.Call): 807 expected = tuple(t for n, t in values) 808 if len(values) == 1: 809 expected = expected[0] 810 elif isinstance(node.value, ast.Tuple): 811 expected = tuple(t for n, t in values) 812 else: 813 raise ValueError('unknown assignment target node', node.value) 814 ty = self._get_inferred_type(node.value, expected) 815 816 if len(targets) == len(values): 817 # TODO(mdan): This should already be a tuple. 818 ty_ = (ty,) if len(values) == 1 else ty 819 for key, value, t in zip(targets, values, ty_): 820 ssa_value, _ = value 821 self.symbol_table.insert_symbol(key, ssa_value, t) 822 elif len(values) == 1: 823 name, tys = values[0] 824 if ty == TFRTypes.TENSOR_LIST: 825 # assign single tensor_list to multiple variables 826 for idx, key in enumerate(targets): 827 idx_name = self._ssa_name('idx') 828 self._emit_with_loc( 829 '\n{} = arith.constant {} : index'.format(idx_name, idx), node) 830 elt_name = self._ssa_name('elt') 831 self.emit('\n{} = tfr.get_element {}[{}]'.format( 832 elt_name, name, idx_name)) 833 self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor', 834 node) 835 self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR) 836 else: 837 # assign single value to multiple targets. This single value is 838 # usually a function return. The return type should be in the tuple of 839 # the value. 840 for idx, key in enumerate(targets): 841 ssa_name = '{}#{}'.format(name, idx) 842 ssa_type = tys[idx] 843 self.symbol_table.insert_symbol(key, ssa_name, ssa_type) 844 elif len(targets) == 1: 845 ssa_names = [n for n, _ in values] 846 self.symbol_table.insert_symbol(targets[0], ssa_names, ty) 847 return 848 849 ty = self._get_inferred_type(node.value, values[1]) 850 self.symbol_table.insert_symbol(targets[0], values[0], ty) 851 852 def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty): 853 assert lhs_ty, rhs_ty 854 if isinstance(op, ast.Sub): 855 code = 'arith.sub' 856 elif isinstance(op, ast.Add): 857 code = 'arith.add' 858 elif isinstance(op, ast.Mult): 859 code = 'arith.mul' 860 elif isinstance(op, ast.Div): 861 code = 'arith.div' 862 else: 863 raise NotImplementedError('BinOp operator not recognized' + op) 864 865 if lhs_ty == TFRTypes.I64 or lhs_ty == TFRTypes.I32: 866 suffix = 'i' 867 elif lhs_ty == TFRTypes.F32: 868 suffix = 'f' 869 else: 870 raise NotImplementedError('BinOp operand type not recognized' + op) 871 872 ret = self._ssa_name(code) 873 self._emit_with_loc( 874 '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty), 875 op) 876 return ret, lhs_ty 877 878 def visit_AugAssign(self, node): 879 lhs, lhs_ty = self.visit(node.target) 880 rhs, rhs_ty = self.visit(node.value) 881 ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) 882 self.symbol_table.insert_symbol(node.target.id, ret, ret_ty) 883 884 def visit_BinOp(self, node): 885 lhs, lhs_ty = self.visit(node.left) 886 rhs, rhs_ty = self.visit(node.right) 887 return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) 888 889 def visit_BoolOp(self, node): 890 values = [self.visit(value) for value in node.values] 891 # TODO(fengliuai): Handle more ast node types. 892 if isinstance(node.op, ast.Or): 893 raise NotImplementedError('Or operator not recognized') 894 elif isinstance(node.op, ast.And): 895 raise NotImplementedError('And operator not recognized') 896 897 def visit_Call(self, node): 898 func_name, func_type = self.visit(node.func) 899 func_type = self._get_inferred_type(node.func, func_type) 900 if func_type == TFRTypes.AG_BUILTIN_FUNC: 901 if func_name == 'if_stmt': 902 cond, _ = self.visit(node.args[0]) 903 body, _ = self.visit(node.args[1]) 904 orelse, _ = self.visit(node.args[2]) 905 get_state, _ = self.visit(node.args[3]) 906 nouts = int(node.args[6].value) 907 out_symbols = [] 908 # The out symbols are just a Tuple of names 909 for out in node.args[5].elts[:nouts]: 910 val, ty = self.symbol_table.lookup(out.value) 911 out_symbols.append(out.value) 912 return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols, 913 node) 914 elif func_name == 'for_stmt': 915 range_ = self._visit_iter(node.args[0]) 916 body, _ = self.visit(node.args[2]) 917 get_state, _ = self.visit(node.args[3]) 918 loop_carried = [out.value for out in node.args[5].elts] 919 # TODO(fengliuai): opt is not used here. 920 return self._visit_for_stmt(range_, body, get_state, loop_carried, node) 921 elif func_name == 'Undefined': 922 val = self._ssa_name(node.args[0].value) 923 return (val, TFRTypes.AG_UNDEFINED_VAL) 924 elif func_name == 'UndefinedReturnValue': 925 val = self._ssa_name('return_val') 926 return (val, TFRTypes.AG_UNDEFINED_VAL) 927 928 if func_type == TFRTypes.TF_RAW_OP: 929 return self._visit_tf_op(func_name, node.args, node.keywords, node) 930 931 if func_type == TFRTypes.TFR_BUILTIN_FUNC: 932 return self._visit_tfr_builtins(func_name, node.args, node) 933 934 if func_type == types.FunctionType: 935 return self._visit_tf_op(func_name, node.args, node.keywords, node) 936 937 if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: 938 return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST) 939 940 if func_type == TFRTypes.PY_BUILTIN_FUNC: 941 if func_name == 'len': 942 arg, ty = self.visit(node.args[0]) 943 ty = self._get_inferred_type(node.args[0], ty) 944 if ty == TFRTypes.TF_TENSOR_SHAPE_LIST: 945 len_value = self._ssa_name('len') 946 self._emit_with_loc( 947 '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format( 948 len_value, arg), node) 949 size_value = self._ssa_name('len_size') 950 self._emit_with_loc( 951 '\n{} = shape.size_to_index {} : !shape.size'.format( 952 size_value, len_value), node) 953 elif ty == TFRTypes.TENSOR_LIST: 954 size_value = self._ssa_name('len') 955 self._emit_with_loc( 956 '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node) 957 return (size_value, TFRTypes.INDEX) 958 959 raise NotImplementedError('call operator not recognized: {} {}'.format( 960 func_name, func_type)) 961 962 def visit_Compare(self, node): 963 lhs, lhs_ty = self.visit(node.left) 964 for op, right in zip(node.ops, node.comparators): 965 rhs, rhs_ty = self.visit(right) 966 if isinstance(op, ast.Eq): 967 pred = 'eq' 968 elif isinstance(op, ast.Lt): 969 pred = 'ult' 970 elif isinstance(op, ast.LtE): 971 pred = 'ule' 972 elif isinstance(op, ast.Gt): 973 pred = 'ugt' 974 elif isinstance(op, ast.GtE): 975 pred = 'uge' 976 elif isinstance(op, ast.NotEq): 977 pred = 'ne' 978 else: 979 raise NotImplementedError('Compare operator not recognized') 980 981 ret = self._ssa_name(pred) 982 if lhs_ty == TFRTypes.ATTR: 983 self._emit_with_loc( 984 '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node) 985 else: 986 if lhs_ty == TFRTypes.I64: 987 code = 'arith.cmpi' 988 elif lhs_ty == TFRTypes.F32: 989 code = 'arith.cmpf' 990 elif lhs_ty == TFRTypes.INDEX: 991 code = 'arith.cmpi' 992 # TODO(fengliuai): the reverse type inference should solve the issue. 993 rhs, _ = self._i64_to_index(rhs, rhs_ty) 994 else: 995 raise NotImplementedError('Compare operand type not recognized') 996 self._emit_with_loc( 997 '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs, 998 lhs_ty), node) 999 1000 return ret, TFRTypes.I1 1001 1002 def visit_Constant(self, node): 1003 cst_name = self._ssa_name('cst') 1004 if node.value is None: 1005 cst_ty = TFRTypes.NONE 1006 elif isinstance(node.value, bool): 1007 cst_ty = self._get_inferred_type(node) 1008 cst_val = str(node.value).lower() 1009 self._emit_with_loc('\n{} = arith.constant {}'.format(cst_name, cst_val), 1010 node) 1011 else: 1012 cst_ty = self._get_inferred_type(node) 1013 cst_val = node.value 1014 if cst_ty == TFRTypes.ATTR: 1015 self._emit_with_loc( 1016 '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty), 1017 node) 1018 else: 1019 self._emit_with_loc( 1020 '\n{} = arith.constant {} : {}'.format(cst_name, cst_val, cst_ty), 1021 node) 1022 return cst_name, cst_ty 1023 1024 def visit_FunctionDef(self, node): 1025 op_def, derived_attrs = self._op_defs.lookup(node.name, node, True) 1026 if op_def is None: 1027 # Nested function. Insert it to symbol table for looking up later. 1028 self.symbol_table.insert_symbol(node.name, node, None) 1029 return 1030 op_name = op_def.name 1031 if self.symbol_table.lookup(op_name): 1032 raise LookupError('Composition has not been registered for op: ' + 1033 op_name) 1034 else: 1035 self.symbol_table.insert_symbol(node.name, None, None) 1036 1037 self.symbol_table.enter_scope() 1038 self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name))) 1039 1040 arg_list = [] 1041 idx = 0 1042 max_idx = len(op_def.input_arg) + len(op_def.attr) 1043 for arg in node.args.args: 1044 arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN)) 1045 arg_type = anno.getanno(arg, anno.Static.TYPES)[0] 1046 1047 arg_attr = '' 1048 if idx >= len(op_def.input_arg): 1049 attr_def = op_def.attr[idx - len(op_def.input_arg)] 1050 # skip the derived attributes 1051 while attr_def.name in derived_attrs and (idx + 1) < max_idx: 1052 idx += 1 1053 attr_def = op_def.attr[idx - len(op_def.input_arg)] 1054 if idx >= max_idx: 1055 raise ValueError('Argument is not defined in OpDef: ' + arg_name) 1056 1057 arg_attr += '{{tfr.name="{}"'.format(attr_def.name) 1058 if attr_def.HasField('default_value'): 1059 default_val = _get_val_from_proto(arg_type, attr_def.default_value) 1060 arg_attr += ',tfr.default={}'.format(default_val) 1061 arg_attr += '}' 1062 1063 idx += 1 1064 arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr) 1065 arg_list.append(arg_str) 1066 self.symbol_table.insert_symbol(arg.id, arg_name, arg_type) 1067 1068 ret_type_list = [] 1069 for ret_def in op_def.output_arg: 1070 if ret_def.number_attr or ret_def.type_list_attr: 1071 ret_type_list.append(str(TFRTypes.TENSOR_LIST)) 1072 else: 1073 ret_type_list.append(str(TFRTypes.TENSOR)) 1074 1075 self.emit('{}) -> ({}) {{'.format(', '.join(arg_list), 1076 ', '.join(ret_type_list))) 1077 self.visit_block(node.body) 1078 self._emit_with_loc('\n}', node) 1079 self.symbol_table.exit_scope() 1080 1081 def visit_arguments(self, node): 1082 # TODO(fengliuai): return ordered the types and names. 1083 # We need to order the arguments to match the assumption in the TFR dialect. 1084 raise NotImplementedError('arguments not supported.') 1085 1086 def visit_Lambda(self, node): 1087 raise NotImplementedError('Lambda not supported.') 1088 1089 def _get_mlir_ssa_values(self, name_prefix, out_types): 1090 """Create MLIR convention SSA values.""" 1091 out_ssa_values = [] 1092 if not out_types: 1093 return '', out_ssa_values 1094 1095 out_name = self._ssa_name(name_prefix) 1096 if len(out_types) == 1: 1097 out_name_suffix = '' 1098 out_ssa_values.append(out_name) 1099 else: 1100 # For multiple returns, MLIR uses '%s:i' when they are defined and 1101 # '%s#i' when they are used. 1102 out_name_suffix = ':{}'.format(len(out_types)) 1103 for idx, _ in enumerate(out_types): 1104 out_ssa_values.append('{}#{}'.format(out_name, idx)) 1105 1106 return '{}{}'.format(out_name, out_name_suffix), out_ssa_values 1107 1108 def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols, 1109 node): 1110 self.emit('\n') 1111 ret_str, ret_ssa_values = self._get_mlir_ssa_values( 1112 'if_stmt', [TFRTypes.TENSOR] * len(out_symbols)) 1113 if ret_ssa_values: 1114 self.emit(ret_str + ' = ') 1115 1116 out_types = [] 1117 for symbol, ssa_value in zip(out_symbols, ret_ssa_values): 1118 out_types.append(str(TFRTypes.TENSOR)) 1119 1120 self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types))) 1121 # Create a new scope in case the local variables are leaked. 1122 self.symbol_table.enter_scope(scf_scope=True) 1123 self.visit_block(body_def.body) 1124 self.visit_block(get_state.body) 1125 self.symbol_table.exit_scope() 1126 1127 self.emit('\n} else {') 1128 1129 # Create a new scope in case the local variables are leaked. 1130 self.symbol_table.enter_scope(scf_scope=True) 1131 self.visit_block(orelse_def.body) 1132 self.visit_block(get_state.body) 1133 self.symbol_table.exit_scope() 1134 1135 # add ssa values to the symbol table 1136 for symbol, ssa_value in zip(out_symbols, ret_ssa_values): 1137 self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) 1138 1139 self._emit_with_loc('\n}', node) 1140 return list(zip(ret_ssa_values, out_types)) 1141 1142 def _visit_iter(self, node): 1143 if isinstance(node, ast.Call): 1144 f_name = anno.getanno(node.func, anno.Basic.QN) 1145 if f_name == QN('range'): 1146 args = [self.visit(arg) for arg in node.args] 1147 begin = None 1148 step = None 1149 end = None 1150 if len(args) == 1: 1151 end, end_ty = args[0] 1152 elif len(args) == 2: 1153 begin, begin_ty = args[0] 1154 end, end_ty = args[1] 1155 elif len(args) == 3: 1156 begin, begin_ty = args[0] 1157 end, end_ty = args[1] 1158 step, step_ty = args[2] 1159 1160 if begin is None: 1161 begin = self._ssa_name('begin') 1162 self._emit_with_loc('\n{} = arith.constant 0 : index'.format(begin), 1163 node) 1164 elif begin_ty != TFRTypes.INDEX: 1165 begin_ = self._ssa_name('begin') 1166 self._emit_with_loc( 1167 '\n{} = arith.index_cast {} : {} to index'.format( 1168 begin_, begin, begin_ty), node) 1169 begin = begin_ 1170 1171 if end_ty != TFRTypes.INDEX: 1172 end_ = self._ssa_name('end') 1173 self._emit_with_loc( 1174 '\n{} = arith.index_cast {} : {} to index'.format( 1175 end_, end, end_ty), node) 1176 end = end_ 1177 1178 if step is None: 1179 step = self._ssa_name('step') 1180 self._emit_with_loc('\n{} = arith.constant 1 : index'.format(step), 1181 node) 1182 elif step_ty != TFRTypes.INDEX: 1183 step_ = self._ssa_name('step') 1184 self._emit_with_loc( 1185 '\n{} = arith.index_cast {} : {} to index'.format( 1186 step_, step, step_ty), node) 1187 step = step_ 1188 1189 return begin, end, step 1190 1191 raise NotImplementedError('Iterator entity not supported.' + node) 1192 1193 def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node): 1194 self.emit('\n') 1195 ret_str, ret_ssa_values = self._get_mlir_ssa_values( 1196 'for_stmt', [TFRTypes.TENSOR] * len(loop_carried)) 1197 if ret_ssa_values: 1198 self.emit(ret_str + ' = ') 1199 1200 # Before enter the loop, we use the original ssa values as the initial 1201 # values to the loop iteration arguments. We also create new ssa values as 1202 # the returns of the scf for statements. The symbol table needs to be 1203 # updated to these new ssa values before it enters the scope of the loop. 1204 out_types = [] 1205 init_values = [] 1206 for symbol, ssa_value in zip(loop_carried, ret_ssa_values): 1207 init, ty = self.symbol_table.lookup(symbol) 1208 self.symbol_table.insert_symbol(symbol, ssa_value, ty) 1209 out_types.append(str(ty)) 1210 init_values.append((init, ty)) 1211 1212 # Create a new scope in case the local variables are leaked. 1213 self.symbol_table.enter_scope(scf_scope=True) 1214 1215 # Create the iteration variable with index type 1216 assert len(body_def.args.args) == 1 1217 it_name = body_def.args.args[0].id 1218 it = self._ssa_name(it_name) 1219 self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX) 1220 1221 self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1], 1222 range_[2])) 1223 if loop_carried: 1224 iter_args = [] 1225 for symbol, init in zip(loop_carried, init_values): 1226 # create new ssa values for the loop carried variables 1227 it_arg = self._ssa_name('it_arg') 1228 self.symbol_table.insert_symbol(symbol, it_arg, init[1]) 1229 iter_args.append('{} = {}'.format(it_arg, init[0])) 1230 self.emit('iter_args({}) '.format(', '.join(iter_args))) 1231 self.emit('-> ({}) {{'.format(', '.join(out_types))) 1232 else: 1233 self.emit(' {') 1234 self.visit_block(body_def.body) 1235 self.visit_block(get_state.body) 1236 self.symbol_table.exit_scope() 1237 self._emit_with_loc('\n}', node) 1238 return list(zip(ret_ssa_values, out_types)) 1239 1240 def _emit_default_constant_from_proto(self, attr_def): 1241 """emit mlir constant statement from default value of the ArgDef proto.""" 1242 name = self._ssa_name('cst') 1243 cst_ty = _get_type_from_proto(None, attr_def) 1244 try: 1245 cst_val = _get_val_from_proto(cst_ty, attr_def.default_value) 1246 except AttributeError: 1247 raise AttributeError( 1248 f'attribute "{attr_def.name}" does not have default_value. If the ' 1249 "attribute names from the API and OpDef don't match, please add it " 1250 'to _ATTRIBUTE_RENAMES.') 1251 if cst_ty == TFRTypes.ATTR: 1252 self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format( 1253 name, cst_val, cst_ty)) 1254 elif cst_ty == TFRTypes.I1: 1255 self._emit_with_loc('\n{} = arith.constant {}'.format(name, cst_val)) 1256 else: 1257 self._emit_with_loc('\n{} = arith.constant {} : {}'.format( 1258 name, cst_val, cst_ty)) 1259 return name, cst_ty 1260 1261 def visit_keyword(self, node): 1262 return node.arg, self.visit(node.value) 1263 1264 def _visit_tfr_builtins(self, op_name, args, node): 1265 arg_strs = [] 1266 arg_tys = [] 1267 for arg in args: 1268 value, ty = self.visit(arg) 1269 arg_strs.append(value) 1270 arg_tys.append(ty) 1271 tfr_op_name = 'tfr.' + op_name[5:] 1272 ret_tys = ( 1273 TFR_BUILTINS[op_name](*arg_tys) 1274 if callable(TFR_BUILTINS[op_name]) else TFR_BUILTINS[op_name]) 1275 # Convert the tfr builtin returns to a list. 1276 if isinstance(ret_tys, tuple): 1277 ret_tys = list(ret_tys) 1278 else: 1279 ret_tys = [ret_tys] 1280 1281 ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) 1282 1283 arg_str = ', '.join(arg_strs) 1284 arg_ty_str = ', '.join(str(ty) for ty in arg_tys) 1285 ret_ty_str = ', '.join(str(ty) for ty in ret_tys) 1286 self._emit_with_loc('\n{} = {}({}) : ({}) -> ({})'.format( 1287 ret_str, tfr_op_name, arg_str, arg_ty_str, ret_ty_str), node) 1288 return list(zip(ret_ssa_values, ret_tys)) 1289 1290 def _visit_tf_op(self, op_name, args, keywords, node): 1291 op_def, derived_attrs = self._op_defs.lookup(op_name) 1292 ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg] 1293 1294 ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) 1295 1296 arg_strs = [] 1297 ty_strs = [] 1298 for arg in args: 1299 value, ty = self.visit(arg) 1300 arg_strs.append(value) 1301 ty_strs.append(str(ty)) 1302 1303 input_args = [arg for arg in op_def.input_arg] 1304 attrs_no_default = [ 1305 attr for attr in op_def.attr 1306 if not attr.HasField('default_value') and attr.name not in derived_attrs 1307 ] 1308 attrs_with_default = [ 1309 attr for attr in op_def.attr 1310 if attr.HasField('default_value') and attr.name not in derived_attrs 1311 ] 1312 1313 kw_args = {} 1314 for arg in keywords: 1315 value, (ssa_name, ty) = self.visit(arg) 1316 ty = self._get_inferred_type(arg.value, ty) 1317 1318 # TODO(b/203493652): see comment on _ATTRIBUTE_RENAMES 1319 if op_name in _ATTRIBUTE_RENAMES and value in _ATTRIBUTE_RENAMES[op_name]: 1320 value = _ATTRIBUTE_RENAMES[op_name][value] 1321 1322 kw_args[value] = (ssa_name, ty) 1323 1324 # tensor arguments and attribute arguments 1325 ordered_args = input_args + attrs_no_default + attrs_with_default 1326 for attr_def in ordered_args[len(args):]: 1327 if attr_def.name in kw_args: 1328 value, ty = kw_args[attr_def.name] 1329 if attr_def in input_args: 1330 if ty in _ATTRIBUTE_TYPES: 1331 # the argument shouldn't be used as tf op calls directly. 1332 value, ty = self._value_to_tensor(value, ty, node) 1333 if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def): 1334 value, ty = self._pack_tensor_list(value) 1335 else: 1336 value, ty = self._emit_default_constant_from_proto(attr_def) 1337 arg_strs.append(value) 1338 ty_strs.append(str(ty)) 1339 1340 if ret_ssa_values: 1341 self.emit('\n{} = '.format(ret_str)) 1342 1343 self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name))) 1344 arg_str = ', '.join(arg_strs) 1345 arg_ty_str = ', '.join(ty_strs) 1346 ret_ty_str = ', '.join([str(ty) for ty in ret_tys]) 1347 self._emit_with_loc( 1348 '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node) 1349 return list(zip(ret_ssa_values, ret_tys)) 1350 1351 def visit_If(self, node): 1352 raise NotImplementedError('If not supported.') 1353 1354 def visit_Name(self, node): 1355 val_and_lookup_type = self.symbol_table.lookup(node.id) 1356 if val_and_lookup_type: 1357 (val, lookup_type) = val_and_lookup_type 1358 elif node.id in TFR_BUILTINS: 1359 val = node.id 1360 lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType) 1361 else: 1362 op_def, _ = self._op_defs.lookup(node.id) 1363 val = op_def.name 1364 lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType) 1365 type_ = self._get_inferred_type(node, lookup_type) 1366 return val, type_ 1367 1368 def visit_Return(self, node): 1369 values = self.visit(node.value) 1370 if self.symbol_table.in_scf_scope(): 1371 self.emit('\nscf.yield ') 1372 else: 1373 self.emit('\ntfr.return ') 1374 if not values: 1375 return 1376 1377 if isinstance(values, list): 1378 vals, tys = zip(*values) 1379 else: 1380 vals = values[0] 1381 tys = values[1] 1382 1383 if isinstance(tys, list) or isinstance(tys, tuple): 1384 tys = [str(t) for t in tys] 1385 self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)), 1386 node) 1387 elif tys != TFRTypes.NONE: 1388 # TODO(fengliuai): scf region yield uses this branch. Fix it. 1389 self._emit_with_loc('{} : {}'.format(vals, tys), node) 1390 1391 def visit_Subscript(self, node): 1392 val, ty = self.visit(node.value) 1393 type_ = self._get_inferred_type(node.value, ty) 1394 1395 # TODO(fengliuai): Here we hardcode the node.slice here to get the index 1396 # type. Use the visit method once the type inference is done. 1397 # slice_val, slice_ty = self.visit(node.slice) 1398 s = node.slice 1399 if not isinstance(s, (ast.Tuple, ast.Slice)): 1400 if isinstance(s, ast.Constant): 1401 # TODO(fengliuai): promote to an assignment 1402 idx_val = self._ssa_name('cst') 1403 self._emit_with_loc( 1404 '\n{} = arith.constant {} : index'.format(idx_val, s.value), node) 1405 else: 1406 idx_val, _ = self.visit(s) 1407 else: 1408 raise NotImplementedError('non-index slice not supported.') 1409 1410 elt = self._ssa_name('elt') 1411 if type_ == TFRTypes.TENSOR_LIST: 1412 self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val)) 1413 self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node) 1414 return (elt, TFRTypes.TENSOR) 1415 elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST: 1416 size_ = self._ssa_name('size') 1417 self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val)) 1418 self._emit_with_loc(': !shape.shape, index -> !shape.size', node) 1419 self._emit_with_loc( 1420 '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_), 1421 node) 1422 return (elt, TFRTypes.INDEX) 1423 1424 def visit_List(self, node): 1425 out_type = self._get_inferred_type(node) 1426 vals = [] 1427 tys = [] 1428 for elt in node.elts: 1429 val, ty = self.visit(elt) 1430 ty = self._get_inferred_type(elt, ty) 1431 if ty in _ATTRIBUTE_TYPES and out_type == TFRTypes.TENSOR_LIST: 1432 # This list is a tensor list, then cast all the input values to tensors. 1433 val, ty = self._value_to_tensor(val, ty, node) 1434 else: 1435 # We shouldn't use index type to build the list because list will be use 1436 # as attribute. 1437 val, ty = self._index_to_I64(val, ty) 1438 vals.append(val) 1439 tys.append(str(ty)) 1440 1441 list_val = self._ssa_name('list') 1442 self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals))) 1443 self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node) 1444 return (list_val, out_type) 1445 1446 def visit_Tuple(self, node): 1447 return [self.visit(elt) for elt in node.elts] 1448 1449 def visit_UnaryOp(self, node): 1450 value, ty = self.visit(node.operand) 1451 if isinstance(node.op, ast.USub): 1452 zero_value = self._ssa_name('zero') 1453 ssa_value = self._ssa_name('cst') 1454 if ty == TFRTypes.I32 or ty == TFRTypes.I64: 1455 self._emit_with_loc( 1456 '\n{} = arith.constant 0 : {}'.format(zero_value, ty), node) 1457 self._emit_with_loc( 1458 '\n{} = arith.subi {}, {} : {}'.format(ssa_value, zero_value, value, 1459 ty), node) 1460 elif ty == TFRTypes.F32: 1461 self._emit_with_loc( 1462 '\n{} = arith.constant 0.0 : {}'.format(zero_value, ty), node) 1463 self._emit_with_loc( 1464 '\n{} = arith.subf {}, {} : {}'.format(ssa_value, zero_value, value, 1465 ty), node) 1466 else: 1467 raise NotImplementedError('USub type not recognized: ' + str(ty)) 1468 return ssa_value, ty 1469 raise NotImplementedError('USub operator not recognized') 1470 1471 def visit_For(self, node): 1472 raise NotImplementedError('For operator not recognized') 1473 1474 def visit_While(self, node): 1475 raise NotImplementedError('While operator not recognized') 1476 1477 def visit_Try(self, node): 1478 # Only handles the body of the try statement. 1479 self.visit_block(node.body) 1480 1481 1482def _apply_py_to_tf_passes(node, ctx): 1483 """Apply transformations from PyToTF to match tf.function tracing.""" 1484 # TODO(fengliuai): we don't know which passes are required, thus we evaluate 1485 # each one when the corresponding node is handled. 1486 # copied from PyToTF.transform_ast 1487 node = return_statements.transform(node, ctx, False) 1488 node = control_flow.transform(node, ctx) 1489 return node 1490 1491 1492class TfrGen(transpiler.GenericTranspiler): 1493 """Transforms Python objects into TFR MLIR source code.""" 1494 1495 def __init__(self, op_defs): 1496 self._op_defs = op_defs 1497 1498 def transform_ast(self, node, ctx): 1499 node = _apply_py_to_tf_passes(node, ctx) 1500 # TODO(mdan): Enable this. 1501 # node = anf.transform(node, ctx) 1502 1503 graphs = cfg.build(node) 1504 node = qual_names.resolve(node) 1505 node = activity.resolve(node, ctx) 1506 node = reaching_definitions.resolve(node, ctx, graphs) 1507 node = reaching_fndefs.resolve(node, ctx, graphs) 1508 node = type_inference.resolve(node, ctx, graphs, 1509 TFRTypeResolver(self._op_defs)) 1510 1511 mlir_generator = TFRGen(ctx, self._op_defs) 1512 mlir_generator.visit(node) 1513 return mlir_generator.code_buffer 1514 1515 1516def tfr_gen(func, op_defs): 1517 """Parse a function and emit the TFR functions.""" 1518 mlir_code, _ = TfrGen(op_defs).transform(func, None) 1519 assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code) 1520 return mlir_code 1521 1522 1523def tfr_funcs_gen_from_module(source, op_defs, method_prefix=None, 1524 op_libraries=None): 1525 """Parse the input source module and emit the TFR functions.""" 1526 1527 # Load the op library so the op is added to the op registry. This is 1528 # required when the op cc_library couldn't be statically linked in open 1529 # source. 1530 # This is a no op if the op shared library couldn't be found in the same 1531 # directory of the op Python API. 1532 # TODO(fengliuai): make the .so file path configurable. 1533 if op_libraries: 1534 prefix_len = len('gen_') 1535 for m in op_libraries: 1536 lib_dir = os.path.dirname(m.__file__) 1537 lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so') 1538 lib_path = os.path.join(lib_dir, lib_name) 1539 if os.path.exists(lib_path): 1540 logging.info('load file: ' + lib_path) 1541 load_library.load_op_library(lib_path) 1542 else: 1543 # The op library is generated from the source module, then we load all the 1544 # .so file in the directory 1545 lib_dir = os.path.dirname(source.__file__) 1546 for lib_name in os.listdir(lib_dir): 1547 if lib_name.endswith('.so'): 1548 lib_path = os.path.join(lib_dir, lib_name) 1549 logging.info('load file: ' + lib_path) 1550 load_library.load_op_library(lib_path) 1551 1552 py_funcs = [ 1553 func 1554 for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) 1555 if not method_prefix or name.startswith(method_prefix) 1556 ] 1557 # Sort the methods by the line number, to make sure the definitions are 1558 # processed before the usages. 1559 # TODO(fengliuai): Use type inference resolver to recursively process any 1560 # functions called. 1561 py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno) 1562 mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs] 1563 1564 return mlir_funcs 1565 1566 1567def tfr_gen_from_module(source, method_prefix=None, op_libraries=None, 1568 op_defs=OpDefCache()): 1569 """Parse the input source module and emit the TFR and external functions.""" 1570 mlir_funcs = tfr_funcs_gen_from_module( 1571 source, op_defs, method_prefix, op_libraries) 1572 1573 return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs()) 1574