1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tfr_gen: Generate mlir tfr decomposition function from python code.""" 16 17# pylint: disable=invalid-name 18# pylint: disable=missing-function-docstring 19# pylint: disable=g-direct-tensorflow-import 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import enum 26import os 27import re 28import types 29import gast as ast 30 31from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr 32from tensorflow.core.framework import types_pb2 33from tensorflow.python.autograph.converters import control_flow 34from tensorflow.python.autograph.converters import return_statements 35from tensorflow.python.autograph.impl import api 36from tensorflow.python.autograph.pyct import anno 37from tensorflow.python.autograph.pyct import cfg 38from tensorflow.python.autograph.pyct import qual_names 39from tensorflow.python.autograph.pyct import transformer 40from tensorflow.python.autograph.pyct import transpiler 41from tensorflow.python.autograph.pyct.static_analysis import activity 42from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 43from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 44from tensorflow.python.autograph.pyct.static_analysis import type_inference 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import load_library 47from tensorflow.python.framework import op_def_registry 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.util import tf_inspect 50 51# TODO(mdan): Use class definitions so that we can mix these with Python types. 52 53 54class TFRTypes(enum.Enum): 55 """All the supported types. 56 57 1-3: tfr types 58 4-99: mlir built-in types 59 100-199: TF related translator internal types 60 200- : Python related translator internal types 61 """ 62 TENSOR = 1 63 TENSOR_LIST = 2 64 ATTR = 3 65 NONE = 4 66 SHAPE = 5 # shape -> !shape.shape 67 I1 = 21 68 I32 = 22 69 I64 = 23 70 F32 = 24 71 INDEX = 25 72 AG_UNDEFINED_VAL = 100 73 AG_BUILTIN_FUNC = 101 74 TF_RAW_OP = 102 75 TF_REGION = 103 76 TF_TENSOR_SHAPE_FUNC = 104 # shape.as_list 77 TF_TENSOR_SHAPE_LIST = 105 # shape.as_list() 78 PY_BUILTIN_FUNC = 200 79 80 # As these are not real types, __getattribute__ helps them appear more like 81 # actual types (i.e. class definitions). 82 def __getattribute__(self, name): 83 if name == 'shape' and object.__getattribute__(self, 'value') == 1: 84 return TFRTypes.SHAPE 85 if name == 'as_list' and object.__getattribute__(self, 'value') == 5: 86 return TFRTypes.TF_TENSOR_SHAPE_FUNC 87 return object.__getattribute__(self, name) 88 89 def __str__(self): 90 if self.value < 4: # pylint: disable=comparison-with-callable 91 return '!tfr.' + self.name.lower() 92 elif self.value < 10: # pylint: disable=comparison-with-callable 93 return '!shape.' + self.name.lower() 94 else: 95 return self.name.lower() 96 97 98_attribute_types = [ 99 TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX, 100 TFRTypes.ATTR 101] 102 103 104def _get_type_from_proto(arg_def=None, attr_def=None): 105 if not arg_def: 106 if attr_def.type == 'bool': 107 return TFRTypes.I1 108 elif attr_def.type == 'int32': 109 return TFRTypes.I32 110 elif attr_def.type == 'int' or attr_def.type == 'int64': 111 return TFRTypes.I64 112 elif attr_def.type == 'float': 113 return TFRTypes.F32 114 else: 115 return TFRTypes.ATTR 116 117 if arg_def.number_attr or arg_def.type_list_attr: 118 return TFRTypes.TENSOR_LIST 119 else: 120 return TFRTypes.TENSOR 121 122 123def _get_type_info_from_proto(arg_def=None, attr_def=None): 124 attr_type = _get_type_from_proto(arg_def, attr_def) 125 if not arg_def: 126 return '{}{{tfr.name="{}",tfr.type="{}"}}'.format( 127 attr_type, attr_def.name, attr_def.type) 128 else: 129 attr_names = [] 130 if arg_def.number_attr: 131 attr_names.append(arg_def.number_attr) 132 if arg_def.type_attr: 133 attr_names.append(arg_def.type_attr) 134 if arg_def.type_list_attr: 135 attr_names.append(arg_def.type_list_attr) 136 137 # TODO(fengliuai): currently we don't support backward type inference, so we 138 # have to store these non-derivable type in the signatures, and then they 139 # can be used to cast the values when raising to tf ops. 140 if arg_def.type == types_pb2.DT_FLOAT: 141 attr_names.append('f32_') 142 elif arg_def.type == types_pb2.DT_INT32: 143 attr_names.append('i32_') 144 elif arg_def.type == types_pb2.DT_INT64: 145 attr_names.append('i64_') 146 elif arg_def.type == types_pb2.DT_BOOL: 147 attr_names.append('i1_') 148 149 if not attr_names: 150 return str(attr_type) 151 else: 152 return '{}<{}>'.format(attr_type, ','.join(attr_names)) 153 154 155def _get_val_from_proto(attr_type, attr_val): 156 if attr_type == TFRTypes.I1: 157 return 'true' if attr_val.b else 'false' 158 elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64: 159 return attr_val.i 160 elif attr_type == TFRTypes.F32: 161 return attr_val.f 162 elif attr_type == TFRTypes.ATTR: 163 # string 164 if attr_val.HasField('s'): 165 return '"{}"'.format(attr_val.s.decode()) 166 # type 167 if attr_val.HasField('type'): 168 if attr_val.type == types_pb2.DT_FLOAT: 169 return 'f32' 170 elif attr_val.type == types_pb2.DT_INT32: 171 return 'i32' 172 elif attr_val.type == types_pb2.DT_INT64: 173 return 'i64' 174 elif attr_val.type == types_pb2.DT_BOOL: 175 return 'i1' 176 # list 177 if attr_val.HasField('list'): 178 if attr_val.list.f: 179 elt_ty = TFRTypes.F32 180 values = attr_val.list.f 181 elif attr_val.list.i: 182 elt_ty = TFRTypes.I64 183 values = attr_val.list.i 184 else: 185 elt_ty = TFRTypes.NONE 186 values = [] 187 array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values] 188 return '[{}]'.format(','.join(array_attr_elts)) 189 raise NotImplementedError( 190 'Proto AttrValue not recognized. type: {}, value: {}'.format( 191 attr_type, attr_val)) 192 193 194def _collect_derived_attrs_from_proto(op_def): 195 derived_attrs = set() 196 for arg in op_def.input_arg: 197 if arg.type_attr: 198 derived_attrs.add(arg.type_attr) 199 if arg.number_attr: 200 derived_attrs.add(arg.number_attr) 201 if arg.type_list_attr: 202 derived_attrs.add(arg.type_list_attr) 203 204 # TODO(fengliuai): currently we don't support backward type inference, so we 205 # have to store these non-derivable type in the signatures, and then they 206 # can be used to cast the values when raising to tf ops. 207 if arg.type == types_pb2.DT_FLOAT: 208 derived_attrs.add('f32_') 209 elif arg.type == types_pb2.DT_INT32: 210 derived_attrs.add('i32_') 211 elif arg.type == types_pb2.DT_INT64: 212 derived_attrs.add('i64_') 213 elif arg.type == types_pb2.DT_BOOL: 214 derived_attrs.add('i1_') 215 return derived_attrs 216 217 218def _require_tensor_list(arg_def): 219 return arg_def.type_list_attr or arg_def.number_attr 220 221 222def _camel_to_snake(name): 223 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 224 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 225 226 227class OpDefCache(object): 228 """A Dict to cache the OpDef for the Python function name.""" 229 230 def __init__(self): 231 self._op_defs = {} 232 233 def lookup(self, f_name, func_def=None, optional=False): 234 if f_name in self._op_defs.keys(): 235 return self._op_defs[f_name] 236 237 if isinstance(func_def, types.FunctionType): 238 if not hasattr(func_def, '_tfr_op_name'): 239 # skip a non-composition function 240 if optional: 241 return (None, None) 242 else: 243 raise KeyError('OpDef does not exist: ' + f_name) 244 op_name = getattr(func_def, '_tfr_op_name') 245 elif not func_def: 246 op_name = f_name 247 else: 248 # TODO(fengliuai): create one utility method to match different APIs. 249 compose_dec = [] 250 for dec in func_def.decorator_list: 251 if isinstance(dec, ast.Call): 252 if isinstance(dec.func, 253 ast.Attribute) and dec.func.attr == 'Composite': 254 compose_dec.append(dec) 255 if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': 256 compose_dec.append(dec) 257 258 if not compose_dec: 259 # skip a non-composition function 260 if optional: 261 return (None, None) 262 else: 263 raise KeyError('OpDef does not exist: ' + f_name) 264 elif len(compose_dec) > 1: 265 raise KeyError('More than one TF ops decomposes for.') 266 else: 267 op_name = compose_dec[0].args[0].value 268 269 op_def = op_def_registry.get(op_name) 270 if not op_def: 271 raise ValueError('Not a registered op: ' + op_name) 272 derived_attrs = _collect_derived_attrs_from_proto(op_def) 273 self._op_defs[f_name] = (op_def, derived_attrs) 274 return (op_def, derived_attrs) 275 276 def mlir_external_funcs(self): 277 tfr_funcs = [] 278 for _, (op_def, derived_attrs) in sorted(self._op_defs.items()): 279 tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name)) 280 281 # tensor inputs 282 inputs = [ 283 _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg 284 ] 285 286 # attribute inputs. The attribute with default values are moved backwards. 287 non_derived_attrs = [ 288 attr for attr in op_def.attr if attr.name not in derived_attrs 289 ] 290 attrs_no_default = [ 291 attr for attr in non_derived_attrs 292 if not attr.HasField('default_value') 293 ] 294 attrs_with_default = [ 295 attr for attr in non_derived_attrs if attr.HasField('default_value') 296 ] 297 attr_names = set() 298 for attr_def in attrs_no_default + attrs_with_default: 299 inputs.append(_get_type_info_from_proto(None, attr_def)) 300 attr_names.add(attr_def.name) 301 302 # tensor outputs 303 outputs = [ 304 _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg 305 ] 306 307 inputs = ','.join(inputs) 308 outputs = ','.join(outputs) 309 attrs = ','.join(sorted(derived_attrs.union(attr_names))) 310 tfr_funcs.append('{}{}) -> ({}) attributes {{{}}}'.format( 311 tfr_func, inputs, outputs, attrs)) 312 return tfr_funcs 313 314 315_PY_TYPE_TO_TFR = { 316 bool: TFRTypes.I1, 317 int: TFRTypes.I64, 318 float: TFRTypes.F32, 319} 320 321_TF_DTYPE_TO_TFR = { 322 'bool': TFRTypes.I1, 323 'int64': TFRTypes.I64, 324 'int32': TFRTypes.I32, 325 'float32': TFRTypes.F32, 326} 327 328_AG_FIXED_RETURN_TYPE = { 329 'for_stmt': type(None), 330 'if_stmt': type(None), 331 'Undefined': TFRTypes.AG_UNDEFINED_VAL, 332} 333 334QN = qual_names.QN 335 336# TODO(mdan): Fix this with an importable module. 337AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__'] # pylint:disable=protected-access 338 339 340class TFRTypeResolver(type_inference.Resolver): 341 """Resolve types for the external names, calls and arguments.""" 342 343 def __init__(self, op_defs): 344 super(TFRTypeResolver, self).__init__() 345 self._op_defs = op_defs 346 347 # This pattern matching mechanism works with the functional form generated 348 # by autograph: 349 # 350 # for i in data: 351 # print(i) 352 # 353 # generates: 354 # 355 # def loop_body(itr): 356 # i = itr 357 # print(i) 358 # ag__.for_stmt(target) 359 # 360 # The mechanism lets us infer the type of the itr argument based on that of 361 # target. 362 self._for_loop_target_types = {} # Maps body function name to iterated. 363 self._for_loop_body_fns = {} # Used only to avoid collisions. 364 365 def res_name(self, ns, types_ns, name): 366 name_str = str(name) 367 if name_str in ns: 368 ns_val = ns[name_str] 369 return {type(ns_val)}, ns_val 370 if name_str in __builtins__: 371 return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str] 372 # This name is not in the namespace because the autograph transformation 373 # is not backloaded into Python. 374 if name_str == 'ag__': 375 return {type(AG_MODULE)}, AG_MODULE 376 377 return None, None 378 379 def res_value(self, ns, value): 380 if value is None: 381 return {TFRTypes.NONE} 382 if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC): 383 # See TFRTypes.__getattrbute__. 384 # TODO(mdan): Replacing the enum with classes would avoid this overlap. 385 return {value} 386 # TODO(mdan): Index more efficiently. Could do a name check instead. 387 if any(v is value for v in AG_MODULE.__dict__.values()): 388 return {TFRTypes.AG_BUILTIN_FUNC} 389 if getattr(value, '__name__', None) == 'tensorflow.raw_ops': 390 return {types.ModuleType} 391 if hasattr(value, '__module__'): 392 if isinstance(value, dtypes.DType): 393 return {TFRTypes.ATTR} 394 395 # All the imported operations, which are not autograph built-ins, are 396 # considered to be TF raw ops. 397 # TODO(fengliuai): refine the condition so we only match TensorFlow 398 # ops here. 399 return {TFRTypes.TF_RAW_OP} 400 # TODO(mdan): Is ATTR equivalent to string? 401 return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)} 402 403 def res_call(self, ns, types_ns, node, f_type, args, keywords): 404 name = anno.Basic.QN.of(node.func) 405 if f_type == (TFRTypes.AG_BUILTIN_FUNC,): 406 407 if name == QN(QN('ag__'), attr='if_stmt'): 408 nouts = node.args[6].value 409 # TODO(mdan): Look at the actual types out of if_body. 410 side_effects = { 411 qual_names.QN(n.value): {TFRTypes.TENSOR} 412 for n in node.args[5].elts[:nouts] 413 } 414 return {type(None)}, side_effects 415 416 if name == QN(QN('ag__'), attr='for_stmt'): 417 assert isinstance(node.args[2], ast.Name) 418 body_fn_name = str(anno.Basic.QN.of(node.args[2])) 419 assert body_fn_name not in self._for_loop_body_fns, ( 420 'Previously used here: {}. Are you reusing the Resolver across ' 421 'transformations?').format(self._for_loop_body_fns[body_fn_name]) 422 self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node) 423 424 iterated_type = args[0] 425 assert iterated_type & { 426 TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR 427 }, ( 428 iterated_type) 429 self._for_loop_target_types[body_fn_name] = iterated_type 430 431 return {type(None)}, None 432 433 # TODO(mdan): Actually resolve the type here instead. 434 ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None) 435 if ret_type is not None: 436 return {ret_type}, None 437 raise NotImplementedError('return type of {}'.format(name)) 438 439 elif f_type == (TFRTypes.TF_RAW_OP,): 440 op_name = name.qn[1] 441 op_def, _ = self._op_defs.lookup(op_name) 442 if len(op_def.output_arg) == 1: 443 return {_get_type_from_proto(op_def.output_arg[0])}, None 444 return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, 445 None) 446 447 elif f_type == (types.FunctionType,): 448 # A composition Python function name is used directly. 449 op_name = name.qn[0] 450 op_def, _ = self._op_defs.lookup(op_name) 451 if len(op_def.output_arg) == 1: 452 return {_get_type_from_proto(op_def.output_arg[0])}, None 453 return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, 454 None) 455 456 elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): 457 assert name.is_simple() 458 if name == QN('range'): 459 return {TFRTypes.ATTR}, None 460 461 if name == QN('len'): 462 return {TFRTypes.INDEX}, None 463 464 elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,): 465 return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None 466 467 raise NotImplementedError('Function:', name, f_type) 468 469 def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): 470 if f_is_local: 471 f_name_str = str(f_name) 472 if f_name_str in self._for_loop_target_types: 473 # See autograph/converters/control_flow.py - the function has a single 474 # argument, the iterate before any expansion. 475 assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR} 476 # Assume all loops are TF loops. Then the iterates are autoboxed into 477 # Tensors. 478 return {TFRTypes.INDEX} 479 else: 480 return None 481 482 func = ns[f_name] 483 484 op_def, derived_attrs = self._op_defs.lookup(f_name, func) 485 if op_def is None: 486 return None 487 pos = tf_inspect.getfullargspec(func).args.index(str(name)) 488 489 if pos < len(op_def.input_arg): 490 arg_def = op_def.input_arg[pos] 491 return {_get_type_from_proto(arg_def)} 492 elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs): 493 non_derived_attr_pos = pos - len(op_def.input_arg) 494 for attr_def in op_def.attr: 495 # derived attribute, skip this one and continue to the next one. 496 if attr_def.name in derived_attrs: 497 continue 498 if non_derived_attr_pos == 0: 499 return {_get_type_from_proto(None, attr_def)} 500 non_derived_attr_pos -= 1 501 502 raise ValueError('Argument is not defined in OpDef: ' + str(name)) 503 504 def res_slice(self, ns, types_ns, node_or_slice, value, slice_): 505 assert len(value) == 1 506 value, = tuple(value) 507 if value == TFRTypes.TF_TENSOR_SHAPE_LIST: 508 # TODO(mdan): This is not entirely correct for multi-element slices. 509 return {int} 510 elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR): 511 # TODO(mdan): This is not entirely correct for multi-element slices. 512 return {TFRTypes.TENSOR} 513 raise NotImplementedError('slice of {}'.format(value)) 514 515 def res_compare(self, ns, types_ns, node, left, right): 516 # TODO(fengliuai): make sure left and right are compatible 517 return {TFRTypes.I1} 518 519 def res_unop(self, ns, types_ns, node, opnd): 520 return opnd 521 522 def res_binop(self, ns, types_ns, node, left, right): 523 # TODO(fengliuai): make sure left and right are compatible 524 return left 525 526 def _coerce_to_more_specific_type(self, elt_types): 527 # TODO(mdan): This needs some type theory study. 528 if TFRTypes.INDEX in elt_types: 529 # Constants collapse to indices. 530 elt_types.discard(TFRTypes.I64) 531 if TFRTypes.TENSOR in elt_types: 532 # Constants collapse to tensors. 533 elt_types.discard(TFRTypes.I64) 534 # Indices collapse to tensors. 535 elt_types.discard(TFRTypes.INDEX) 536 return elt_types 537 538 def res_list_literal(self, ns, elt_types): 539 all_elt_types = set() 540 for t in elt_types: 541 all_elt_types |= t 542 543 if len(all_elt_types) != 1: 544 all_elt_types = self._coerce_to_more_specific_type(all_elt_types) 545 546 if len(all_elt_types) != 1: 547 raise ValueError('ambiguous list element types: {}'.format(elt_types)) 548 549 if TFRTypes.TENSOR in all_elt_types: 550 return {TFRTypes.TENSOR_LIST} 551 return {TFRTypes.ATTR} 552 553 554class SymbolTable(object): 555 """Symbol Table for python code.""" 556 557 def __init__(self): 558 self.symbols = [] 559 self.enter_scope() 560 self.scf_scope = 0 561 # reserved key words 562 self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC) 563 564 def enter_scope(self, scf_scope=False): 565 """Enter a new scope - at function level.""" 566 self.symbols.append({'types': {}, 'symbols': {}}) 567 self.curr_table = self.symbols[len(self.symbols) - 1] 568 if scf_scope: 569 self.scf_scope += 1 570 571 def insert_symbol(self, name, value, type_): 572 self.curr_table['symbols'][name] = (value, type_) 573 # TODO(mdan): Use the inferred type rather than tracking it here. 574 # The following field is deprecated. 575 self.curr_table['types'][name] = type_ 576 return value 577 578 def exit_scope(self): 579 self.symbols.pop() 580 self.curr_table = self.symbols[len(self.symbols) - 1] 581 if self.scf_scope > 0: 582 self.scf_scope -= 1 583 584 def in_scf_scope(self): 585 return self.scf_scope > 0 586 587 def lookup(self, name): 588 curr_idx = len(self.symbols) - 1 589 while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): 590 curr_idx -= 1 591 if curr_idx < 0: 592 return None 593 return self.symbols[curr_idx]['symbols'][name] 594 595 596class TFRGen(transformer.CodeGenerator): 597 """Visit the AST and generate MLIR TFR functions.""" 598 599 def __init__(self, ctx, op_defs): 600 super(TFRGen, self).__init__(ctx) 601 self.ctx = ctx 602 self.symbol_table = SymbolTable() 603 self._op_defs = op_defs 604 605 def _create_mlir_loc(self, loc): 606 """Creates mlir location from autograph ORIGIN value. 607 608 Args: 609 loc: OriginInfo 610 611 Returns: 612 A serialized mlir location string. 613 """ 614 if loc is not None and loc.loc.filename: 615 file_name = os.path.basename(loc.loc.filename) 616 return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno, 617 loc.loc.col_offset) 618 else: 619 return 'loc(unknown)' 620 621 def _emit_with_loc(self, op_str, node=None): 622 """Emit the mlir operation with the location associated with the node. 623 624 Args: 625 op_str: The mlir operation string to be emitted. 626 node: The node of the AST tree, the mlir operation translated from. 627 """ 628 loc = '' 629 if node: 630 loc = self._create_mlir_loc( 631 anno.getanno(node, anno.Basic.ORIGIN, default=None)) 632 self.emit(op_str + ' ' + loc) 633 634 def _get_inferred_type(self, node, default=None): 635 types_ = anno.getanno(node, anno.Static.TYPES, None) 636 if not types_: 637 print('WARN: no Static.TYPES annotation. Fix the type inference pass: ') 638 self.debug_print(node) 639 return default 640 if types_ and len(types_) > 1: 641 raise ValueError('ambiguous inferred type for "{}": {}'.format( 642 node, types_)) 643 644 type_, = types_ 645 646 if default is not None and type_ != default: 647 print('WARN: type annotation {}({}) does not match {}({})'.format( 648 type_, type(type_), default, type(default))) 649 self.debug_print(node) 650 651 return type_ 652 653 def _pack_tensor_list(self, value): 654 # This is packing a list of tensors, then the axis is 0. 655 axis = self._ssa_name('zero') 656 self._emit_with_loc('\n{} = constant 0 : i64'.format(axis)) 657 casted = self._ssa_name('pack') 658 self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis)) 659 self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor') 660 # load the op def of tf.Pack 661 self._op_defs.lookup('Pack') 662 return casted, TFRTypes.TENSOR 663 664 def _index_to_I64(self, value, ty): 665 if ty == TFRTypes.INDEX: 666 casted = self._ssa_name('casted') 667 self._emit_with_loc('\n{} = index_cast {} : index to i64'.format( 668 casted, value)) 669 return casted, TFRTypes.I64 670 else: 671 return value, ty 672 673 def _i64_to_index(self, value, ty): 674 if ty == TFRTypes.I64: 675 casted = self._ssa_name('casted') 676 self._emit_with_loc('\n{} = index_cast {} : i64 to index'.format( 677 casted, value)) 678 return casted, TFRTypes.INDEX 679 else: 680 return value, ty 681 682 def _value_to_tensor(self, value, ty, node): 683 value, ty = self._index_to_I64(value, ty) 684 cst_tensor = self._ssa_name('cst') 685 self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value)) 686 self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node) 687 return cst_tensor, TFRTypes.TENSOR 688 689 def _ssa_name(self, prefix): 690 if isinstance(prefix, qual_names.QN): 691 assert prefix.is_simple(), 'ANF transform should have cleaned this up' 692 prefix = prefix.ssf() 693 return '%' + self.ctx.namer.new_symbol(prefix, set()) 694 695 def _op_def(self, op_name): 696 return op_def_registry.get(op_name) 697 698 def visit_block(self, block): 699 return [self.visit(item) for item in block] 700 701 def visit_Pass(self, node): 702 if self.symbol_table.in_scf_scope(): 703 self._emit_with_loc('\nscf.yield', node) 704 else: 705 self._emit_with_loc('\ntfr.return', node) 706 707 def visit_Attribute(self, node): 708 node_type = self._get_inferred_type(node, None) 709 if isinstance(node.value, ast.Name): 710 if node.value.id == 'ag__': 711 # some variables are assigned with 'ag__.xxx' method, we should handle 712 # them following the autograph convensions. 713 return (node.attr, TFRTypes.AG_BUILTIN_FUNC) 714 715 if node_type == TFRTypes.TF_RAW_OP: 716 # This branch is used when it is inside tensorflow 717 return (node.attr, TFRTypes.TF_RAW_OP) 718 719 if node_type == TFRTypes.ATTR: 720 attr = self._ssa_name('attr') 721 tfr_type = _TF_DTYPE_TO_TFR.get(node.attr) 722 self._emit_with_loc( 723 '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node) 724 return (attr, TFRTypes.ATTR) 725 726 value, _ = self.visit(node.value) 727 tensor_type = self._get_inferred_type(node.value, None) 728 # TODO(fengliuai): use node_type once it 729 if node_type == TFRTypes.SHAPE: 730 print('TODO: use "node_type"') 731 if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR: 732 ssa_value = self._ssa_name('shape') 733 self._emit_with_loc( 734 '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value), 735 node) 736 return (ssa_value, TFRTypes.SHAPE) 737 738 if isinstance(node.value, ast.Attribute): 739 if isinstance(node.value.value, ast.Name): 740 if node.value.value.id == 'tf' and node.value.attr == 'raw_ops': 741 return (node.attr, TFRTypes.TF_RAW_OP) 742 743 value, ty = self.visit(node.value) 744 # TODO(fengliuai): use node_type once it 745 if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: 746 print('TODO: use "node_type"') 747 if ty == TFRTypes.SHAPE and node.attr == 'as_list': 748 return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC) 749 750 raise NotImplementedError('Attribute kind not recognized.') 751 752 def visit_Assign(self, node): 753 values = self.visit(node.value) 754 if isinstance(node.targets[0], ast.Tuple): 755 targets = [elt.id for elt in node.targets[0].elts] 756 elif isinstance(node.targets[0], ast.Name): 757 targets = [node.targets[0].id] 758 else: 759 raise NotImplementedError('Assignment target type not recognized.') 760 761 if isinstance(values, list): 762 if isinstance(node.value, ast.Call): 763 expected = tuple(t for n, t in values) 764 if len(values) == 1: 765 expected = expected[0] 766 elif isinstance(node.value, ast.Tuple): 767 expected = tuple(t for n, t in values) 768 else: 769 raise ValueError('unknown assignment target node', node.value) 770 ty = self._get_inferred_type(node.value, expected) 771 772 if len(targets) == len(values): 773 # TODO(mdan): This should already be a tuple. 774 ty_ = (ty,) if len(values) == 1 else ty 775 for key, value, t in zip(targets, values, ty_): 776 ssa_value, _ = value 777 self.symbol_table.insert_symbol(key, ssa_value, t) 778 elif len(values) == 1: 779 n, _ = values[0] 780 assert ty == TFRTypes.TENSOR_LIST 781 # assign a tensor_list to multiple variables 782 for idx, key in enumerate(targets): 783 idx_name = self._ssa_name('idx') 784 self._emit_with_loc( 785 '\n{} = constant {} : index'.format(idx_name, idx), node) 786 elt_name = self._ssa_name('elt') 787 self.emit('\n{} = tfr.get_element {}[{}]'.format( 788 elt_name, n, idx_name)) 789 self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor', 790 node) 791 self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR) 792 elif len(targets) == 1: 793 ssa_names = [n for n, _ in values] 794 self.symbol_table.insert_symbol(targets[0], ssa_names, ty) 795 return 796 797 ty = self._get_inferred_type(node.value, values[1]) 798 self.symbol_table.insert_symbol(targets[0], values[0], ty) 799 800 def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty): 801 assert lhs_ty, rhs_ty 802 if isinstance(op, ast.Sub): 803 code = 'sub' 804 elif isinstance(op, ast.Add): 805 code = 'add' 806 else: 807 raise NotImplementedError('BinOp operator not recognized' + op) 808 809 if lhs_ty == TFRTypes.I64: 810 suffix = 'i' 811 elif lhs_ty == TFRTypes.F32: 812 suffix = 'f' 813 else: 814 raise NotImplementedError('BinOp operand type not recognized' + op) 815 816 ret = self._ssa_name(code) 817 self._emit_with_loc( 818 '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty), 819 op) 820 return ret, lhs_ty 821 822 def visit_AugAssign(self, node): 823 lhs, lhs_ty = self.visit(node.target) 824 rhs, rhs_ty = self.visit(node.value) 825 ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) 826 self.symbol_table.insert_symbol(node.target.id, ret, ret_ty) 827 828 def visit_BinOp(self, node): 829 lhs, lhs_ty = self.visit(node.left) 830 rhs, rhs_ty = self.visit(node.right) 831 return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) 832 833 def visit_BoolOp(self, node): 834 values = [self.visit(value) for value in node.values] 835 # TODO(fengliuai): Handle more ast node types. 836 if isinstance(node.op, ast.Or): 837 raise NotImplementedError('Or operator not recognized') 838 elif isinstance(node.op, ast.And): 839 raise NotImplementedError('And operator not recognized') 840 841 def visit_Call(self, node): 842 func_name, func_type = self.visit(node.func) 843 func_type = self._get_inferred_type(node.func, func_type) 844 if func_type == TFRTypes.AG_BUILTIN_FUNC: 845 if func_name == 'if_stmt': 846 cond, _ = self.visit(node.args[0]) 847 body, _ = self.visit(node.args[1]) 848 orelse, _ = self.visit(node.args[2]) 849 get_state, _ = self.visit(node.args[3]) 850 nouts = int(node.args[6].value) 851 out_symbols = [] 852 # The out symbols are just a Tuple of names 853 for out in node.args[5].elts[:nouts]: 854 val, ty = self.symbol_table.lookup(out.value) 855 out_symbols.append(out.value) 856 return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols, 857 node) 858 elif func_name == 'for_stmt': 859 range_ = self._visit_iter(node.args[0]) 860 body, _ = self.visit(node.args[2]) 861 get_state, _ = self.visit(node.args[3]) 862 loop_carried = [out.value for out in node.args[5].elts] 863 # TODO(fengliuai): opt is not used here. 864 return self._visit_for_stmt(range_, body, get_state, loop_carried, node) 865 elif func_name == 'Undefined': 866 val = self._ssa_name(node.args[0].value) 867 return (val, TFRTypes.AG_UNDEFINED_VAL) 868 elif func_name == 'UndefinedReturnValue': 869 val = self._ssa_name('return_val') 870 return (val, TFRTypes.AG_UNDEFINED_VAL) 871 872 if func_type == TFRTypes.TF_RAW_OP: 873 return self._visit_tf_op(func_name, node.args, node.keywords, node) 874 875 if func_type == types.FunctionType: 876 return self._visit_tf_op(func_name, node.args, node.keywords, node) 877 878 if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: 879 return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST) 880 881 if func_type == TFRTypes.PY_BUILTIN_FUNC: 882 if func_name == 'len': 883 arg, ty = self.visit(node.args[0]) 884 ty = self._get_inferred_type(node.args[0], ty) 885 if ty == TFRTypes.TF_TENSOR_SHAPE_LIST: 886 len_value = self._ssa_name('len') 887 self._emit_with_loc( 888 '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format( 889 len_value, arg), node) 890 size_value = self._ssa_name('len_size') 891 self._emit_with_loc( 892 '\n{} = shape.size_to_index {} : !shape.size'.format( 893 size_value, len_value), node) 894 elif ty == TFRTypes.TENSOR_LIST: 895 size_value = self._ssa_name('len') 896 self._emit_with_loc( 897 '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node) 898 return (size_value, TFRTypes.INDEX) 899 900 raise NotImplementedError('call operator not recognized: {} {}'.format( 901 func_name, func_type)) 902 903 def visit_Compare(self, node): 904 lhs, lhs_ty = self.visit(node.left) 905 for op, right in zip(node.ops, node.comparators): 906 rhs, rhs_ty = self.visit(right) 907 if isinstance(op, ast.Eq): 908 pred = 'eq' 909 elif isinstance(op, ast.Lt): 910 pred = 'ult' 911 elif isinstance(op, ast.LtE): 912 pred = 'ule' 913 elif isinstance(op, ast.Gt): 914 pred = 'ugt' 915 elif isinstance(op, ast.GtE): 916 pred = 'uge' 917 elif isinstance(op, ast.NotEq): 918 pred = 'ne' 919 else: 920 raise NotImplementedError('Compare operator not recognized') 921 922 ret = self._ssa_name(pred) 923 if lhs_ty == TFRTypes.ATTR: 924 self._emit_with_loc( 925 '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node) 926 else: 927 if lhs_ty == TFRTypes.I64: 928 code = 'cmpi' 929 elif lhs_ty == TFRTypes.F32: 930 code = 'cmpf' 931 elif lhs_ty == TFRTypes.INDEX: 932 code = 'cmpi' 933 # TODO(fengliuai): the reverse type inference should solve the issue. 934 rhs, _ = self._i64_to_index(rhs, rhs_ty) 935 else: 936 raise NotImplementedError('Compare operand type not recognized') 937 self._emit_with_loc( 938 '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs, 939 lhs_ty), node) 940 941 return ret, TFRTypes.I1 942 943 def visit_Constant(self, node): 944 cst_name = self._ssa_name('cst') 945 if node.value is None: 946 cst_ty = TFRTypes.NONE 947 elif isinstance(node.value, bool): 948 cst_ty = self._get_inferred_type(node) 949 cst_val = str(node.value).lower() 950 self._emit_with_loc('\n{} = constant {}'.format(cst_name, cst_val), node) 951 else: 952 cst_ty = self._get_inferred_type(node) 953 cst_val = node.value 954 if cst_ty == TFRTypes.ATTR: 955 self._emit_with_loc( 956 '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty), 957 node) 958 else: 959 self._emit_with_loc( 960 '\n{} = constant {} : {}'.format(cst_name, cst_val, cst_ty), node) 961 return cst_name, cst_ty 962 963 def visit_FunctionDef(self, node): 964 op_def, derived_attrs = self._op_defs.lookup(node.name, node, True) 965 if op_def is None: 966 # Nested function. Insert it to symbol table for looking up later. 967 self.symbol_table.insert_symbol(node.name, node, None) 968 return 969 op_name = op_def.name 970 if self.symbol_table.lookup(op_name): 971 raise LookupError('Composition has not been registered for op: ' + 972 op_name) 973 else: 974 self.symbol_table.insert_symbol(node.name, None, None) 975 976 self.symbol_table.enter_scope() 977 self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name))) 978 979 arg_list = [] 980 idx = 0 981 max_idx = len(op_def.input_arg) + len(op_def.attr) 982 for arg in node.args.args: 983 arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN)) 984 arg_type = anno.getanno(arg, anno.Static.TYPES)[0] 985 986 arg_attr = '' 987 if idx >= len(op_def.input_arg): 988 attr_def = op_def.attr[idx - len(op_def.input_arg)] 989 # skip the derived attributes 990 while attr_def.name in derived_attrs and (idx + 1) < max_idx: 991 idx += 1 992 attr_def = op_def.attr[idx - len(op_def.input_arg)] 993 if idx >= max_idx: 994 raise ValueError('Argument is not defined in OpDef: ' + arg_name) 995 996 arg_attr += '{{tfr.name="{}"'.format(attr_def.name) 997 if attr_def.HasField('default_value'): 998 default_val = _get_val_from_proto(arg_type, attr_def.default_value) 999 arg_attr += ',tfr.default={}'.format(default_val) 1000 arg_attr += '}' 1001 1002 idx += 1 1003 arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr) 1004 arg_list.append(arg_str) 1005 self.symbol_table.insert_symbol(arg.id, arg_name, arg_type) 1006 1007 ret_type_list = [] 1008 for ret_def in op_def.output_arg: 1009 if ret_def.number_attr or ret_def.type_list_attr: 1010 ret_type_list.append(str(TFRTypes.TENSOR_LIST)) 1011 else: 1012 ret_type_list.append(str(TFRTypes.TENSOR)) 1013 1014 self.emit('{}) -> ({}) {{'.format(', '.join(arg_list), 1015 ', '.join(ret_type_list))) 1016 self.visit_block(node.body) 1017 self._emit_with_loc('\n}', node) 1018 self.symbol_table.exit_scope() 1019 1020 def visit_arguments(self, node): 1021 # TODO(fengliuai): return ordered the types and names. 1022 # We need to order the arguments to match the assumption in the TFR dialect. 1023 raise NotImplementedError('arguments not supported.') 1024 1025 def visit_Lambda(self, node): 1026 raise NotImplementedError('Lambda not supported.') 1027 1028 def _get_mlir_ssa_values(self, name_prefix, out_types): 1029 """Create MLIR convention SSA values.""" 1030 out_ssa_values = [] 1031 if not out_types: 1032 return '', out_ssa_values 1033 1034 out_name = self._ssa_name(name_prefix) 1035 if len(out_types) == 1: 1036 out_name_suffix = '' 1037 out_ssa_values.append(out_name) 1038 else: 1039 # For multiple returns, MLIR uses '%s:i' when they are defined and 1040 # '%s#i' when they are used. 1041 out_name_suffix = ':{}'.format(len(out_types)) 1042 for idx, _ in enumerate(out_types): 1043 out_ssa_values.append('{}#{}'.format(out_name, idx)) 1044 1045 return '{}{}'.format(out_name, out_name_suffix), out_ssa_values 1046 1047 def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols, 1048 node): 1049 self.emit('\n') 1050 ret_str, ret_ssa_values = self._get_mlir_ssa_values( 1051 'if_stmt', [TFRTypes.TENSOR] * len(out_symbols)) 1052 if ret_ssa_values: 1053 self.emit(ret_str + ' = ') 1054 1055 out_types = [] 1056 for symbol, ssa_value in zip(out_symbols, ret_ssa_values): 1057 out_types.append(str(TFRTypes.TENSOR)) 1058 1059 self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types))) 1060 # Create a new scope in case the local variables are leaked. 1061 self.symbol_table.enter_scope(scf_scope=True) 1062 self.visit_block(body_def.body) 1063 self.visit_block(get_state.body) 1064 self.symbol_table.exit_scope() 1065 1066 self.emit('\n} else {') 1067 1068 # Create a new scope in case the local variables are leaked. 1069 self.symbol_table.enter_scope(scf_scope=True) 1070 self.visit_block(orelse_def.body) 1071 self.visit_block(get_state.body) 1072 self.symbol_table.exit_scope() 1073 1074 # add ssa values to the symbol table 1075 for symbol, ssa_value in zip(out_symbols, ret_ssa_values): 1076 self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) 1077 1078 self._emit_with_loc('\n}', node) 1079 return list(zip(ret_ssa_values, out_types)) 1080 1081 def _visit_iter(self, node): 1082 if isinstance(node, ast.Call): 1083 f_name = anno.getanno(node.func, anno.Basic.QN) 1084 if f_name == QN('range'): 1085 args = [self.visit(arg) for arg in node.args] 1086 begin = None 1087 step = None 1088 end = None 1089 if len(args) == 1: 1090 end, end_ty = args[0] 1091 elif len(args) == 2: 1092 begin, begin_ty = args[0] 1093 end, end_ty = args[1] 1094 elif len(args) == 3: 1095 begin, begin_ty = args[0] 1096 end, end_ty = args[1] 1097 step, step_ty = args[2] 1098 1099 if begin is None: 1100 begin = self._ssa_name('begin') 1101 self._emit_with_loc('\n{} = constant 0 : index'.format(begin), node) 1102 elif begin_ty != TFRTypes.INDEX: 1103 begin_ = self._ssa_name('begin') 1104 self._emit_with_loc( 1105 '\n{} = index_cast {} : {} to index'.format( 1106 begin_, begin, begin_ty), node) 1107 begin = begin_ 1108 1109 if end_ty != TFRTypes.INDEX: 1110 end_ = self._ssa_name('end') 1111 self._emit_with_loc( 1112 '\n{} = index_cast {} : {} to index'.format(end_, end, end_ty), 1113 node) 1114 end = end_ 1115 1116 if step is None: 1117 step = self._ssa_name('step') 1118 self._emit_with_loc('\n{} = constant 1 : index'.format(step), node) 1119 elif step_ty != TFRTypes.INDEX: 1120 step_ = self._ssa_name('step') 1121 self._emit_with_loc( 1122 '\n{} = index_cast {} : {} to index'.format(step_, step, step_ty), 1123 node) 1124 step = step_ 1125 1126 return begin, end, step 1127 1128 raise NotImplementedError('Iterator entity not supported.' + node) 1129 1130 def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node): 1131 self.emit('\n') 1132 ret_str, ret_ssa_values = self._get_mlir_ssa_values( 1133 'for_stmt', [TFRTypes.TENSOR] * len(loop_carried)) 1134 if ret_ssa_values: 1135 self.emit(ret_str + ' = ') 1136 1137 # Before enter the loop, we use the original ssa values as the initial 1138 # values to the loop iteration arguments. We also create new ssa values as 1139 # the returns of the scf for statements. The symbol table needs to be 1140 # updated to these new ssa values before it enters the scope of the loop. 1141 out_types = [] 1142 init_values = [] 1143 for symbol, ssa_value in zip(loop_carried, ret_ssa_values): 1144 init, ty = self.symbol_table.lookup(symbol) 1145 self.symbol_table.insert_symbol(symbol, ssa_value, ty) 1146 out_types.append(str(ty)) 1147 init_values.append((init, ty)) 1148 1149 # Create a new scope in case the local variables are leaked. 1150 self.symbol_table.enter_scope(scf_scope=True) 1151 1152 # Create the iteration variable with index type 1153 assert len(body_def.args.args) == 1 1154 it_name = body_def.args.args[0].id 1155 it = self._ssa_name(it_name) 1156 self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX) 1157 1158 self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1], 1159 range_[2])) 1160 if loop_carried: 1161 iter_args = [] 1162 for symbol, init in zip(loop_carried, init_values): 1163 # create new ssa values for the loop carried variables 1164 it_arg = self._ssa_name('it_arg') 1165 self.symbol_table.insert_symbol(symbol, it_arg, init[1]) 1166 iter_args.append('{} = {}'.format(it_arg, init[0])) 1167 self.emit('iter_args({}) '.format(', '.join(iter_args))) 1168 self.emit('-> ({}) {{'.format(', '.join(out_types))) 1169 else: 1170 self.emit(' {') 1171 self.visit_block(body_def.body) 1172 self.visit_block(get_state.body) 1173 self.symbol_table.exit_scope() 1174 self._emit_with_loc('\n}', node) 1175 return list(zip(ret_ssa_values, out_types)) 1176 1177 def _emit_default_constant_from_proto(self, attr_def): 1178 """emit mlir constant statement from default value of the ArgDef proto.""" 1179 name = self._ssa_name('cst') 1180 cst_ty = _get_type_from_proto(None, attr_def) 1181 cst_val = _get_val_from_proto(cst_ty, attr_def.default_value) 1182 if cst_ty == TFRTypes.ATTR: 1183 self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format( 1184 name, cst_val, cst_ty)) 1185 elif cst_ty == TFRTypes.I1: 1186 self._emit_with_loc('\n{} = constant {}'.format(name, cst_val)) 1187 else: 1188 self._emit_with_loc('\n{} = constant {} : {}'.format( 1189 name, cst_val, cst_ty)) 1190 return name, cst_ty 1191 1192 def visit_keyword(self, node): 1193 return node.arg, self.visit(node.value) 1194 1195 def _visit_tf_op(self, op_name, args, keywords, node): 1196 op_def, derived_attrs = self._op_defs.lookup(op_name) 1197 ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg] 1198 1199 ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) 1200 1201 arg_strs = [] 1202 ty_strs = [] 1203 for arg in args: 1204 value, ty = self.visit(arg) 1205 arg_strs.append(value) 1206 ty_strs.append(str(ty)) 1207 1208 input_args = [arg for arg in op_def.input_arg] 1209 attrs_no_default = [ 1210 attr for attr in op_def.attr 1211 if not attr.HasField('default_value') and attr.name not in derived_attrs 1212 ] 1213 attrs_with_default = [ 1214 attr for attr in op_def.attr 1215 if attr.HasField('default_value') and attr.name not in derived_attrs 1216 ] 1217 1218 kw_args = {} 1219 for arg in keywords: 1220 value, (ssa_name, ty) = self.visit(arg) 1221 ty = self._get_inferred_type(arg.value, ty) 1222 1223 # TODO(fengliuai): implement the "rename_to" for the customization in 1224 # tensorflow/core/api_def/base_api/* 1225 if value == 'axis': 1226 value = 'split_dim' 1227 1228 kw_args[value] = (ssa_name, ty) 1229 1230 # tensor arguments and attribute arguments 1231 ordered_args = input_args + attrs_no_default + attrs_with_default 1232 for attr_def in ordered_args[len(args):]: 1233 if attr_def.name in kw_args: 1234 value, ty = kw_args[attr_def.name] 1235 if attr_def in input_args: 1236 if ty in _attribute_types: 1237 # the argument shouldn't be used as tf op calls directly. 1238 value, ty = self._value_to_tensor(value, ty, node) 1239 if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def): 1240 value, ty = self._pack_tensor_list(value) 1241 else: 1242 value, ty = self._emit_default_constant_from_proto(attr_def) 1243 arg_strs.append(value) 1244 ty_strs.append(str(ty)) 1245 1246 if ret_ssa_values: 1247 self.emit('\n{} = '.format(ret_str)) 1248 1249 self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name))) 1250 arg_str = ', '.join(arg_strs) 1251 arg_ty_str = ', '.join(ty_strs) 1252 ret_ty_str = ', '.join([str(ty) for ty in ret_tys]) 1253 self._emit_with_loc( 1254 '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node) 1255 return list(zip(ret_ssa_values, ret_tys)) 1256 1257 def visit_If(self, node): 1258 raise NotImplementedError('If not supported.') 1259 1260 def visit_Name(self, node): 1261 val_and_lookup_type = self.symbol_table.lookup(node.id) 1262 if val_and_lookup_type: 1263 (val, lookup_type) = val_and_lookup_type 1264 else: 1265 op_def, _ = self._op_defs.lookup(node.id) 1266 val = op_def.name 1267 lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType) 1268 type_ = self._get_inferred_type(node, lookup_type) 1269 return val, type_ 1270 1271 def visit_Return(self, node): 1272 values = self.visit(node.value) 1273 if self.symbol_table.in_scf_scope(): 1274 self.emit('\nscf.yield ') 1275 else: 1276 self.emit('\ntfr.return ') 1277 if not values: 1278 return 1279 1280 if isinstance(values, list): 1281 vals, tys = zip(*values) 1282 else: 1283 vals = values[0] 1284 tys = values[1] 1285 1286 if isinstance(tys, list) or isinstance(tys, tuple): 1287 tys = [str(t) for t in tys] 1288 self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)), 1289 node) 1290 elif tys != TFRTypes.NONE: 1291 # TODO(fengliuai): scf region yield uses this branch. Fix it. 1292 self._emit_with_loc('{} : {}'.format(vals, tys), node) 1293 1294 def visit_Subscript(self, node): 1295 val, ty = self.visit(node.value) 1296 type_ = self._get_inferred_type(node.value, ty) 1297 1298 # TODO(fengliuai): Here we hardcode the node.slice here to get the index 1299 # type. Use the visit method once the type inference is done. 1300 # slice_val, slice_ty = self.visit(node.slice) 1301 s = node.slice 1302 if not isinstance(s, (ast.Tuple, ast.Slice)): 1303 if isinstance(s, ast.Constant): 1304 # TODO(fengliuai): promote to an assignment 1305 idx_val = self._ssa_name('cst') 1306 self._emit_with_loc( 1307 '\n{} = constant {} : index'.format(idx_val, s.value), node) 1308 else: 1309 idx_val, _ = self.visit(s) 1310 else: 1311 raise NotImplementedError('non-index slice not supported.') 1312 1313 elt = self._ssa_name('elt') 1314 if type_ == TFRTypes.TENSOR_LIST: 1315 self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val)) 1316 self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node) 1317 return (elt, TFRTypes.TENSOR) 1318 elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST: 1319 size_ = self._ssa_name('size') 1320 self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val)) 1321 self._emit_with_loc(': !shape.shape, index -> !shape.size', node) 1322 self._emit_with_loc( 1323 '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_), 1324 node) 1325 return (elt, TFRTypes.INDEX) 1326 1327 def visit_List(self, node): 1328 out_type = self._get_inferred_type(node) 1329 vals = [] 1330 tys = [] 1331 for elt in node.elts: 1332 val, ty = self.visit(elt) 1333 ty = self._get_inferred_type(elt, ty) 1334 if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST: 1335 # This list is a tensor list, then cast all the input values to tensors. 1336 val, ty = self._value_to_tensor(val, ty, node) 1337 else: 1338 # We shouldn't use index type to build the list because list will be use 1339 # as attribute. 1340 val, ty = self._index_to_I64(val, ty) 1341 vals.append(val) 1342 tys.append(str(ty)) 1343 1344 list_val = self._ssa_name('list') 1345 self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals))) 1346 self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node) 1347 return (list_val, out_type) 1348 1349 def visit_Tuple(self, node): 1350 return [self.visit(elt) for elt in node.elts] 1351 1352 def visit_UnaryOp(self, node): 1353 value, ty = self.visit(node.operand) 1354 if isinstance(node.op, ast.USub): 1355 zero_value = self._ssa_name('zero') 1356 self._emit_with_loc('\n{} = constant 0 : {}'.format(zero_value, ty), node) 1357 ssa_value = self._ssa_name('cst') 1358 if ty == TFRTypes.I32 or ty == TFRTypes.I64: 1359 self._emit_with_loc( 1360 '\n{} = subi {}, {} : {}'.format(ssa_value, zero_value, value, ty), 1361 node) 1362 elif ty == TFRTypes.F32: 1363 self._emit_with_loc( 1364 '\n{} = subf {}, {} : {}'.format(ssa_value, zero_value, value, ty), 1365 node) 1366 else: 1367 raise NotImplementedError('USub type not recognized: ' + str(ty)) 1368 return ssa_value, ty 1369 raise NotImplementedError('USub operator not recognized') 1370 1371 def visit_For(self, node): 1372 raise NotImplementedError('For operator not recognized') 1373 1374 def visit_While(self, node): 1375 raise NotImplementedError('While operator not recognized') 1376 1377 def visit_Try(self, node): 1378 # Only handles the body of the try statement. 1379 self.visit_block(node.body) 1380 1381 1382def _apply_py_to_tf_passes(node, ctx): 1383 """Apply transformations from PyToTF to match tf.function tracing.""" 1384 # TODO(fengliuai): we don't know which passes are required, thus we evaluate 1385 # each one when the corresponding node is handled. 1386 # copied from PyToTF.transform_ast 1387 node = return_statements.transform(node, ctx, False) 1388 node = control_flow.transform(node, ctx) 1389 return node 1390 1391 1392class TfrGen(transpiler.GenericTranspiler): 1393 """Transforms Python objects into TFR MLIR source code.""" 1394 1395 def __init__(self, op_defs): 1396 self._op_defs = op_defs 1397 1398 def transform_ast(self, node, ctx): 1399 node = _apply_py_to_tf_passes(node, ctx) 1400 # TODO(mdan): Enable this. 1401 # node = anf.transform(node, ctx) 1402 1403 graphs = cfg.build(node) 1404 node = qual_names.resolve(node) 1405 node = activity.resolve(node, ctx) 1406 node = reaching_definitions.resolve(node, ctx, graphs) 1407 node = reaching_fndefs.resolve(node, ctx, graphs) 1408 node = type_inference.resolve(node, ctx, graphs, 1409 TFRTypeResolver(self._op_defs)) 1410 1411 mlir_generator = TFRGen(ctx, self._op_defs) 1412 mlir_generator.visit(node) 1413 return mlir_generator.code_buffer 1414 1415 1416def tfr_gen(func, op_defs): 1417 """Parse a function and emit the TFR functions.""" 1418 mlir_code, _ = TfrGen(op_defs).transform(func, None) 1419 assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code) 1420 return mlir_code 1421 1422 1423def tfr_gen_from_module(source, method_prefix=None, op_libraries=None): 1424 """Parse the input source module and emit the TFR functions.""" 1425 op_defs = OpDefCache() 1426 1427 # Load the op library so the op is added to the op registry. This is 1428 # required when the op cc_library couldn't be statically linked in open 1429 # source. 1430 # This is a no op if the op shared library couldn't be found in the same 1431 # directory of the op Python API. 1432 # TODO(fengliuai): make the .so file path configurable. 1433 if op_libraries: 1434 prefix_len = len('gen_') 1435 for m in op_libraries: 1436 lib_dir = os.path.dirname(m.__file__) 1437 lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so') 1438 lib_path = os.path.join(lib_dir, lib_name) 1439 if os.path.exists(lib_path): 1440 logging.info('load file: ' + lib_path) 1441 load_library.load_op_library(lib_path) 1442 else: 1443 # The op library is generated from the source module, then we load all the 1444 # .so file in the directory 1445 lib_dir = os.path.dirname(source.__file__) 1446 for lib_name in os.listdir(lib_dir): 1447 if lib_name.endswith('.so'): 1448 lib_path = os.path.join(lib_dir, lib_name) 1449 logging.info('load file: ' + lib_path) 1450 load_library.load_op_library(lib_path) 1451 1452 py_funcs = [ 1453 func 1454 for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) 1455 if not method_prefix or name.startswith(method_prefix) 1456 ] 1457 # Sort the methods by the line number, to make sure the definitions are 1458 # processed before the usages. 1459 # TODO(fengliuai): Use type inference resolver to recursively process any 1460 # functions called. 1461 py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno) 1462 mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs] 1463 1464 return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs()) 1465