1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tfr_gen: Generate mlir tfr decomposition function from python code.""" 16 17# pylint: disable=invalid-name 18# pylint: disable=missing-function-docstring 19# pylint: disable=g-direct-tensorflow-import 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import enum 26import os 27import re 28import types 29import gast as ast 30 31from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr 32from tensorflow.core.framework import types_pb2 33from tensorflow.python.autograph.converters import control_flow 34from tensorflow.python.autograph.converters import return_statements 35from tensorflow.python.autograph.impl import api 36from tensorflow.python.autograph.pyct import anno 37from tensorflow.python.autograph.pyct import cfg 38from tensorflow.python.autograph.pyct import qual_names 39from tensorflow.python.autograph.pyct import transformer 40from tensorflow.python.autograph.pyct import transpiler 41from tensorflow.python.autograph.pyct.static_analysis import activity 42from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 43from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 44from tensorflow.python.autograph.pyct.static_analysis import type_inference 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import load_library 47from tensorflow.python.framework import op_def_registry 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.util import tf_inspect 50 51# TODO(mdan): Use class definitions so that we can mix these with Python types. 52 53 54class TFRTypes(enum.Enum): 55 """All the supported types. 56 57 1-3: tfr types 58 4-99: mlir built-in types 59 100-199: TF related translator internal types 60 200- : Python related translator internal types 61 """ 62 TENSOR = 1 63 TENSOR_LIST = 2 64 ATTR = 3 65 NONE = 4 66 SHAPE = 5 # shape -> !shape.shape 67 I1 = 21 68 I8 = 22 69 I16 = 23 70 I32 = 24 71 I64 = 25 72 F32 = 26 73 INDEX = 27 74 AG_UNDEFINED_VAL = 100 75 AG_BUILTIN_FUNC = 101 76 TF_RAW_OP = 102 77 TF_REGION = 103 78 TF_TENSOR_SHAPE_FUNC = 104 # shape.as_list 79 TF_TENSOR_SHAPE_LIST = 105 # shape.as_list() 80 PY_BUILTIN_FUNC = 200 81 TFR_BUILTIN_FUNC = 201 82 83 # As these are not real types, __getattribute__ helps them appear more like 84 # actual types (i.e. class definitions). 85 def __getattribute__(self, name): 86 if name == 'shape' and object.__getattribute__(self, 'value') == 1: 87 return TFRTypes.SHAPE 88 if name == 'as_list' and object.__getattribute__(self, 'value') == 5: 89 return TFRTypes.TF_TENSOR_SHAPE_FUNC 90 return object.__getattribute__(self, name) 91 92 def __str__(self): 93 if self.value < 4: # pylint: disable=comparison-with-callable 94 return '!tfr.' + self.name.lower() 95 elif self.value < 10: # pylint: disable=comparison-with-callable 96 return '!shape.' + self.name.lower() 97 else: 98 return self.name.lower() 99 100 101_attribute_types = [ 102 TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX, 103 TFRTypes.ATTR 104] 105 106 107def _get_type_from_proto(arg_def=None, attr_def=None): 108 if not arg_def: 109 if attr_def.type == 'bool': 110 return TFRTypes.I1 111 elif attr_def.type == 'int32': 112 return TFRTypes.I32 113 elif attr_def.type == 'int' or attr_def.type == 'int64': 114 return TFRTypes.I64 115 elif attr_def.type == 'float': 116 return TFRTypes.F32 117 else: 118 return TFRTypes.ATTR 119 120 if arg_def.number_attr or arg_def.type_list_attr: 121 return TFRTypes.TENSOR_LIST 122 else: 123 return TFRTypes.TENSOR 124 125 126def _get_type_info_from_proto(arg_def=None, attr_def=None): 127 attr_type = _get_type_from_proto(arg_def, attr_def) 128 if not arg_def: 129 return '{}{{tfr.name="{}",tfr.type="{}"}}'.format( 130 attr_type, attr_def.name, attr_def.type) 131 else: 132 attr_names = [] 133 if arg_def.number_attr: 134 attr_names.append(arg_def.number_attr) 135 if arg_def.type_attr: 136 attr_names.append(arg_def.type_attr) 137 if arg_def.type_list_attr: 138 attr_names.append(arg_def.type_list_attr) 139 140 # TODO(fengliuai): currently we don't support backward type inference, so we 141 # have to store these non-derivable type in the signatures, and then they 142 # can be used to cast the values when raising to tf ops. 143 if arg_def.type == types_pb2.DT_FLOAT: 144 attr_names.append('f32_') 145 elif arg_def.type == types_pb2.DT_INT32: 146 attr_names.append('i32_') 147 elif arg_def.type == types_pb2.DT_INT64: 148 attr_names.append('i64_') 149 elif arg_def.type == types_pb2.DT_BOOL: 150 attr_names.append('i1_') 151 152 if not attr_names: 153 return str(attr_type) 154 else: 155 return '{}<{}>'.format(attr_type, ','.join(attr_names)) 156 157 158def _get_val_from_proto(attr_type, attr_val): 159 if attr_type == TFRTypes.I1: 160 return 'true' if attr_val.b else 'false' 161 elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64: 162 return attr_val.i 163 elif attr_type == TFRTypes.F32: 164 return attr_val.f 165 elif attr_type == TFRTypes.ATTR: 166 # string 167 if attr_val.HasField('s'): 168 return '"{}"'.format(attr_val.s.decode()) 169 # type 170 if attr_val.HasField('type'): 171 if attr_val.type == types_pb2.DT_FLOAT: 172 return 'f32' 173 elif attr_val.type == types_pb2.DT_INT32: 174 return 'i32' 175 elif attr_val.type == types_pb2.DT_INT64: 176 return 'i64' 177 elif attr_val.type == types_pb2.DT_BOOL: 178 return 'i1' 179 # list 180 if attr_val.HasField('list'): 181 if attr_val.list.f: 182 elt_ty = TFRTypes.F32 183 values = attr_val.list.f 184 elif attr_val.list.i: 185 elt_ty = TFRTypes.I64 186 values = attr_val.list.i 187 else: 188 elt_ty = TFRTypes.NONE 189 values = [] 190 array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values] 191 return '[{}]'.format(','.join(array_attr_elts)) 192 raise NotImplementedError( 193 'Proto AttrValue not recognized. type: {}, value: {}'.format( 194 attr_type, attr_val)) 195 196 197def _collect_derived_attrs_from_proto(op_def): 198 derived_attrs = set() 199 for arg in op_def.input_arg: 200 if arg.type_attr: 201 derived_attrs.add(arg.type_attr) 202 if arg.number_attr: 203 derived_attrs.add(arg.number_attr) 204 if arg.type_list_attr: 205 derived_attrs.add(arg.type_list_attr) 206 207 # TODO(fengliuai): currently we don't support backward type inference, so we 208 # have to store these non-derivable type in the signatures, and then they 209 # can be used to cast the values when raising to tf ops. 210 if arg.type == types_pb2.DT_FLOAT: 211 derived_attrs.add('f32_') 212 elif arg.type == types_pb2.DT_INT32: 213 derived_attrs.add('i32_') 214 elif arg.type == types_pb2.DT_INT64: 215 derived_attrs.add('i64_') 216 elif arg.type == types_pb2.DT_BOOL: 217 derived_attrs.add('i1_') 218 return derived_attrs 219 220 221def _require_tensor_list(arg_def): 222 return arg_def.type_list_attr or arg_def.number_attr 223 224 225def _camel_to_snake(name): 226 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 227 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 228 229 230class OpDefCache(object): 231 """A Dict to cache the OpDef for the Python function name.""" 232 233 def __init__(self): 234 self._op_defs = {} 235 236 def lookup(self, f_name, func_def=None, optional=False): 237 if f_name in self._op_defs: 238 return self._op_defs[f_name] 239 240 if isinstance(func_def, types.FunctionType): 241 if not hasattr(func_def, '_tfr_op_name'): 242 # skip a non-composition function 243 if optional: 244 return (None, None) 245 else: 246 raise KeyError('OpDef does not exist: ' + f_name) 247 op_name = getattr(func_def, '_tfr_op_name') 248 elif not func_def: 249 op_name = f_name 250 else: 251 # TODO(fengliuai): create one utility method to match different APIs. 252 compose_dec = [] 253 for dec in func_def.decorator_list: 254 if isinstance(dec, ast.Call): 255 if isinstance(dec.func, 256 ast.Attribute) and dec.func.attr == 'Composite': 257 compose_dec.append(dec) 258 if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': 259 compose_dec.append(dec) 260 261 if not compose_dec: 262 # skip a non-composition function 263 if optional: 264 return (None, None) 265 else: 266 raise KeyError('OpDef does not exist: ' + f_name) 267 elif len(compose_dec) > 1: 268 raise KeyError('More than one TF ops decomposes for.') 269 else: 270 op_name = compose_dec[0].args[0].value 271 272 op_def = op_def_registry.get(op_name) 273 if not op_def: 274 raise ValueError('Not a registered op: ' + op_name) 275 derived_attrs = _collect_derived_attrs_from_proto(op_def) 276 self._op_defs[f_name] = (op_def, derived_attrs) 277 return (op_def, derived_attrs) 278 279 def mlir_external_funcs(self): 280 tfr_funcs = set() 281 for _, (op_def, derived_attrs) in sorted(self._op_defs.items()): 282 tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name)) 283 284 # tensor inputs 285 inputs = [ 286 _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg 287 ] 288 289 # attribute inputs. The attribute with default values are moved backwards. 290 non_derived_attrs = [ 291 attr for attr in op_def.attr if attr.name not in derived_attrs 292 ] 293 attrs_no_default = [ 294 attr for attr in non_derived_attrs 295 if not attr.HasField('default_value') 296 ] 297 attrs_with_default = [ 298 attr for attr in non_derived_attrs if attr.HasField('default_value') 299 ] 300 attr_names = {'f32_', 'i32_', 'i64_', 'i1_'} # reserved 301 for attr_def in attrs_no_default + attrs_with_default: 302 inputs.append(_get_type_info_from_proto(None, attr_def)) 303 attr_names.add(attr_def.name) 304 305 # tensor outputs 306 outputs = [ 307 _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg 308 ] 309 310 inputs = ','.join(inputs) 311 outputs = ','.join(outputs) 312 attrs = ','.join(sorted(derived_attrs.union(attr_names))) 313 tfr_funcs.add('{}{}) -> ({}) attributes {{{}}}'.format( 314 tfr_func, inputs, outputs, attrs)) 315 return sorted(list(tfr_funcs)) 316 317 318_PY_TYPE_TO_TFR = { 319 bool: TFRTypes.I1, 320 int: TFRTypes.I64, 321 float: TFRTypes.F32, 322} 323 324_TF_DTYPE_TO_TFR = { 325 'bool': TFRTypes.I1, 326 'int64': TFRTypes.I64, 327 'int32': TFRTypes.I32, 328 'int16': TFRTypes.I16, 329 'int8': TFRTypes.I8, 330 'float32': TFRTypes.F32, 331} 332 333_AG_FIXED_RETURN_TYPE = { 334 'for_stmt': type(None), 335 'if_stmt': type(None), 336 'Undefined': TFRTypes.AG_UNDEFINED_VAL, 337} 338 339QN = qual_names.QN 340 341# TODO(mdan): Fix this with an importable module. 342AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__'] # pylint:disable=protected-access 343 344# When an item is callable, the signature is (*operand_types) -> result_type(s) 345TFR_BUILTINS = { 346 '_tfr_quant_act_range': (TFRTypes.TENSOR, TFRTypes.TENSOR), 347 '_tfr_quant_rescale': TFRTypes.TENSOR, 348 '_tfr_quant_raw_data': lambda input_type: input_type, 349 '_tfr_quant_qparam': (TFRTypes.TENSOR, TFRTypes.TENSOR), 350 '_tfr_quant_scale_factor': TFRTypes.TENSOR, 351} 352 353 354class TFRTypeResolver(type_inference.Resolver): 355 """Resolve types for the external names, calls and arguments.""" 356 357 def __init__(self, op_defs): 358 super(TFRTypeResolver, self).__init__() 359 self._op_defs = op_defs 360 361 # This pattern matching mechanism works with the functional form generated 362 # by autograph: 363 # 364 # for i in data: 365 # print(i) 366 # 367 # generates: 368 # 369 # def loop_body(itr): 370 # i = itr 371 # print(i) 372 # ag__.for_stmt(target) 373 # 374 # The mechanism lets us infer the type of the itr argument based on that of 375 # target. 376 self._for_loop_target_types = {} # Maps body function name to iterated. 377 self._for_loop_body_fns = {} # Used only to avoid collisions. 378 379 def res_name(self, ns, types_ns, name): 380 name_str = str(name) 381 if name_str in TFR_BUILTINS: 382 return {TFRTypes.TFR_BUILTIN_FUNC}, name_str 383 if name_str in ns: 384 ns_val = ns[name_str] 385 return {type(ns_val)}, ns_val 386 if name_str in __builtins__: 387 return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str] 388 # This name is not in the namespace because the autograph transformation 389 # is not backloaded into Python. 390 if name_str == 'ag__': 391 return {type(AG_MODULE)}, AG_MODULE 392 393 return None, None 394 395 def res_value(self, ns, value): 396 # resolves the type of the symbol by the metadata in 'value' 397 if value is None: 398 return {TFRTypes.NONE} 399 if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC): 400 # See TFRTypes.__getattrbute__. 401 # TODO(mdan): Replacing the enum with classes would avoid this overlap. 402 return {value} 403 # TODO(mdan): Index more efficiently. Could do a name check instead. 404 if any(v is value for v in AG_MODULE.__dict__.values()): 405 return {TFRTypes.AG_BUILTIN_FUNC} 406 if getattr(value, '__name__', None) == 'tensorflow.raw_ops': 407 return {types.ModuleType} 408 if hasattr(value, '__module__'): 409 if isinstance(value, dtypes.DType): 410 return {TFRTypes.ATTR} 411 412 # All the imported operations, which are not autograph built-ins, are 413 # considered to be TF raw ops. 414 # TODO(fengliuai): refine the condition that we only match TensorFlow 415 # ops here. 416 return {TFRTypes.TF_RAW_OP} 417 # TODO(mdan): Is ATTR equivalent to string? 418 return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)} 419 420 def res_call(self, ns, types_ns, node, f_type, args, keywords): 421 # resolves the return type of the function call. 422 name = anno.Basic.QN.of(node.func) 423 if f_type == (TFRTypes.AG_BUILTIN_FUNC,): 424 425 if name == QN(QN('ag__'), attr='if_stmt'): 426 nouts = node.args[6].value 427 # TODO(mdan): Look at the actual types out of if_body. 428 side_effects = { 429 qual_names.QN(n.value): {TFRTypes.TENSOR} 430 for n in node.args[5].elts[:nouts] 431 } 432 return {type(None)}, side_effects 433 434 if name == QN(QN('ag__'), attr='for_stmt'): 435 assert isinstance(node.args[2], ast.Name) 436 body_fn_name = str(anno.Basic.QN.of(node.args[2])) 437 assert body_fn_name not in self._for_loop_body_fns, ( 438 'Previously used here: {}. Are you reusing the Resolver across ' 439 'transformations?').format(self._for_loop_body_fns[body_fn_name]) 440 self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node) 441 442 iterated_type = args[0] 443 assert iterated_type & { 444 TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR 445 }, ( 446 iterated_type) 447 self._for_loop_target_types[body_fn_name] = iterated_type 448 449 return {type(None)}, None 450 451 # TODO(mdan): Actually resolve the type here instead. 452 ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None) 453 if ret_type is not None: 454 return {ret_type}, None 455 raise NotImplementedError('return type of {}'.format(name)) 456 457 elif f_type == (TFRTypes.TF_RAW_OP,): 458 # This is a TF operation, so it should be found in the op_defs. 459 op_name = name.qn[1] 460 op_def, _ = self._op_defs.lookup(op_name) 461 if len(op_def.output_arg) == 1: 462 return {_get_type_from_proto(op_def.output_arg[0])}, None 463 return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, 464 None) 465 466 elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): 467 assert name.is_simple() 468 if name == QN('range'): 469 return {TFRTypes.ATTR}, None 470 471 if name == QN('len'): 472 return {TFRTypes.INDEX}, None 473 474 elif f_type == (TFRTypes.TFR_BUILTIN_FUNC,): 475 op_name = name.qn[0] 476 if callable(TFR_BUILTINS[op_name]): 477 return {TFR_BUILTINS[op_name](*[list(arg)[0] for arg in args])}, None 478 return {TFR_BUILTINS[op_name]}, None 479 480 elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,): 481 return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None 482 483 elif f_type == (types.FunctionType,): 484 # This is a function call which isn't using tf.raw_op.. 485 op_name = name.qn[0] 486 487 # 'new TF operation' produces outputs defined by the composition function. 488 op_def, _ = self._op_defs.lookup(op_name) 489 if len(op_def.output_arg) == 1: 490 return {_get_type_from_proto(op_def.output_arg[0])}, None 491 return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, 492 None) 493 494 raise NotImplementedError('Function:', name, f_type) 495 496 def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): 497 if f_is_local: 498 f_name_str = str(f_name) 499 if f_name_str in self._for_loop_target_types: 500 # See autograph/converters/control_flow.py - the function has a single 501 # argument, the iterate before any expansion. 502 assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR} 503 # Assume all loops are TF loops. Then the iterates are autoboxed into 504 # Tensors. 505 return {TFRTypes.INDEX} 506 else: 507 return None 508 509 func = ns[f_name] 510 511 op_def, derived_attrs = self._op_defs.lookup(f_name, func) 512 if op_def is None: 513 return None 514 pos = tf_inspect.getfullargspec(func).args.index(str(name)) 515 516 if pos < len(op_def.input_arg): 517 arg_def = op_def.input_arg[pos] 518 return {_get_type_from_proto(arg_def)} 519 elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs): 520 non_derived_attr_pos = pos - len(op_def.input_arg) 521 for attr_def in op_def.attr: 522 # derived attribute, skip this one and continue to the next one. 523 if attr_def.name in derived_attrs: 524 continue 525 if non_derived_attr_pos == 0: 526 return {_get_type_from_proto(None, attr_def)} 527 non_derived_attr_pos -= 1 528 529 raise ValueError('Argument is not defined in OpDef: ' + str(name)) 530 531 def res_slice(self, ns, types_ns, node_or_slice, value, slice_): 532 if not value: 533 return value 534 535 if isinstance(value, set): 536 type_tuple = value.pop() 537 if isinstance(type_tuple, tuple): 538 value = {type_tuple[node_or_slice]} 539 else: 540 value = {type_tuple} 541 542 assert len(value) == 1 543 value, = tuple(value) 544 if value == TFRTypes.TF_TENSOR_SHAPE_LIST: 545 # TODO(mdan): This is not entirely correct for multi-element slices. 546 return {int} 547 elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR): 548 # TODO(mdan): This is not entirely correct for multi-element slices. 549 return {TFRTypes.TENSOR} 550 else: 551 return {value} 552 553 def res_compare(self, ns, types_ns, node, left, right): 554 # TODO(fengliuai): make sure left and right are compatible 555 return {TFRTypes.I1} 556 557 def res_unop(self, ns, types_ns, node, opnd): 558 return opnd 559 560 def res_binop(self, ns, types_ns, node, left, right): 561 # TODO(fengliuai): make sure left and right are compatible 562 return left 563 564 def _coerce_to_more_specific_type(self, elt_types): 565 # TODO(mdan): This needs some type theory study. 566 if TFRTypes.INDEX in elt_types: 567 # Constants collapse to indices. 568 elt_types.discard(TFRTypes.I64) 569 if TFRTypes.TENSOR in elt_types: 570 # Constants collapse to tensors. 571 elt_types.discard(TFRTypes.I64) 572 # Indices collapse to tensors. 573 elt_types.discard(TFRTypes.INDEX) 574 return elt_types 575 576 def res_list_literal(self, ns, elt_types): 577 all_elt_types = set() 578 for t in elt_types: 579 all_elt_types |= t 580 581 if len(all_elt_types) != 1: 582 all_elt_types = self._coerce_to_more_specific_type(all_elt_types) 583 584 if len(all_elt_types) != 1: 585 raise ValueError('ambiguous list element types: {}'.format(elt_types)) 586 587 if TFRTypes.TENSOR in all_elt_types: 588 return {TFRTypes.TENSOR_LIST} 589 return {TFRTypes.ATTR} 590 591 592class SymbolTable(object): 593 """Symbol Table for python code.""" 594 595 def __init__(self): 596 self.symbols = [] 597 self.enter_scope() 598 self.scf_scope = 0 599 # reserved key words 600 self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC) 601 602 def enter_scope(self, scf_scope=False): 603 """Enter a new scope - at function level.""" 604 self.symbols.append({'types': {}, 'symbols': {}}) 605 self.curr_table = self.symbols[len(self.symbols) - 1] 606 if scf_scope: 607 self.scf_scope += 1 608 609 def insert_symbol(self, name, value, type_): 610 self.curr_table['symbols'][name] = (value, type_) 611 # TODO(mdan): Use the inferred type rather than tracking it here. 612 # The following field is deprecated. 613 self.curr_table['types'][name] = type_ 614 return value 615 616 def exit_scope(self): 617 self.symbols.pop() 618 self.curr_table = self.symbols[len(self.symbols) - 1] 619 if self.scf_scope > 0: 620 self.scf_scope -= 1 621 622 def in_scf_scope(self): 623 return self.scf_scope > 0 624 625 def lookup(self, name): 626 curr_idx = len(self.symbols) - 1 627 while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): 628 curr_idx -= 1 629 if curr_idx < 0: 630 return None 631 return self.symbols[curr_idx]['symbols'][name] 632 633 634class TFRGen(transformer.CodeGenerator): 635 """Visit the AST and generate MLIR TFR functions.""" 636 637 def __init__(self, ctx, op_defs): 638 super(TFRGen, self).__init__(ctx) 639 self.ctx = ctx 640 self.symbol_table = SymbolTable() 641 self._op_defs = op_defs 642 643 def _create_mlir_loc(self, loc): 644 """Creates mlir location from autograph ORIGIN value. 645 646 Args: 647 loc: OriginInfo 648 649 Returns: 650 A serialized mlir location string. 651 """ 652 if loc is not None and loc.loc.filename: 653 file_name = os.path.basename(loc.loc.filename) 654 return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno, 655 loc.loc.col_offset) 656 else: 657 return 'loc(unknown)' 658 659 def _emit_with_loc(self, op_str, node=None): 660 """Emit the mlir operation with the location associated with the node. 661 662 Args: 663 op_str: The mlir operation string to be emitted. 664 node: The node of the AST tree, the mlir operation translated from. 665 """ 666 loc = '' 667 if node: 668 loc = self._create_mlir_loc( 669 anno.getanno(node, anno.Basic.ORIGIN, default=None)) 670 self.emit(op_str + ' ' + loc) 671 672 def _get_inferred_type(self, node, default=None): 673 """Return single type or a tuple of types if more than one type.""" 674 types_ = anno.getanno(node, anno.Static.TYPES, None) 675 if not types_: 676 print('WARN: no Static.TYPES annotation. Fix the type inference pass: ') 677 self.debug_print(node) 678 return default 679 680 if len(types_) == 1: 681 type_, = types_ 682 else: 683 type_ = types_ 684 685 if default is not None and type_ != default: 686 print('WARN: type annotation {}({}) does not match {}({})'.format( 687 type_, type(type_), default, type(default))) 688 self.debug_print(node) 689 690 return type_ 691 692 def _pack_tensor_list(self, value): 693 # This is packing a list of tensors, then the axis is 0. 694 axis = self._ssa_name('zero') 695 self._emit_with_loc('\n{} = constant 0 : i64'.format(axis)) 696 casted = self._ssa_name('pack') 697 self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis)) 698 self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor') 699 # load the op def of tf.Pack 700 self._op_defs.lookup('Pack') 701 return casted, TFRTypes.TENSOR 702 703 def _index_to_I64(self, value, ty): 704 if ty == TFRTypes.INDEX: 705 casted = self._ssa_name('casted') 706 self._emit_with_loc('\n{} = index_cast {} : index to i64'.format( 707 casted, value)) 708 return casted, TFRTypes.I64 709 else: 710 return value, ty 711 712 def _i64_to_index(self, value, ty): 713 if ty == TFRTypes.I64: 714 casted = self._ssa_name('casted') 715 self._emit_with_loc('\n{} = index_cast {} : i64 to index'.format( 716 casted, value)) 717 return casted, TFRTypes.INDEX 718 else: 719 return value, ty 720 721 def _value_to_tensor(self, value, ty, node): 722 value, ty = self._index_to_I64(value, ty) 723 cst_tensor = self._ssa_name('cst') 724 self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value)) 725 self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node) 726 return cst_tensor, TFRTypes.TENSOR 727 728 def _ssa_name(self, prefix): 729 if isinstance(prefix, qual_names.QN): 730 assert prefix.is_simple(), 'ANF transform should have cleaned this up' 731 prefix = prefix.ssf() 732 return '%' + self.ctx.namer.new_symbol(prefix, set()) 733 734 def _op_def(self, op_name): 735 return op_def_registry.get(op_name) 736 737 def visit_block(self, block): 738 return [self.visit(item) for item in block] 739 740 def visit_Pass(self, node): 741 if self.symbol_table.in_scf_scope(): 742 self._emit_with_loc('\nscf.yield', node) 743 else: 744 self._emit_with_loc('\ntfr.return', node) 745 746 def visit_Attribute(self, node): 747 node_type = self._get_inferred_type(node, None) 748 if isinstance(node.value, ast.Name): 749 if node.value.id == 'ag__': 750 # some variables are assigned with 'ag__.xxx' method, we should handle 751 # them following the autograph convensions. 752 return (node.attr, TFRTypes.AG_BUILTIN_FUNC) 753 754 if node_type == TFRTypes.TF_RAW_OP: 755 # This branch is used when it is inside tensorflow 756 return (node.attr, TFRTypes.TF_RAW_OP) 757 758 if node_type == TFRTypes.ATTR: 759 attr = self._ssa_name('attr') 760 tfr_type = _TF_DTYPE_TO_TFR.get(node.attr) 761 self._emit_with_loc( 762 '\n{} = tfr.constant {} -> !tfr.attr'.format(attr, tfr_type), node) 763 return (attr, TFRTypes.ATTR) 764 765 value, _ = self.visit(node.value) 766 tensor_type = self._get_inferred_type(node.value, None) 767 # TODO(fengliuai): use node_type once it 768 if node_type == TFRTypes.SHAPE: 769 print('TODO: use "node_type"') 770 if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR: 771 ssa_value = self._ssa_name('shape') 772 self._emit_with_loc( 773 '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value), 774 node) 775 return (ssa_value, TFRTypes.SHAPE) 776 777 if isinstance(node.value, ast.Attribute): 778 if isinstance(node.value.value, ast.Name): 779 if node.value.value.id == 'tf' and node.value.attr == 'raw_ops': 780 return (node.attr, TFRTypes.TF_RAW_OP) 781 782 value, ty = self.visit(node.value) 783 # TODO(fengliuai): use node_type once it 784 if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: 785 print('TODO: use "node_type"') 786 if ty == TFRTypes.SHAPE and node.attr == 'as_list': 787 return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC) 788 789 raise NotImplementedError('Attribute kind not recognized.') 790 791 def visit_Assign(self, node): 792 values = self.visit(node.value) 793 if isinstance(node.targets[0], ast.Tuple): 794 targets = [elt.id for elt in node.targets[0].elts] 795 elif isinstance(node.targets[0], ast.Name): 796 targets = [node.targets[0].id] 797 else: 798 raise NotImplementedError('Assignment target type not recognized.') 799 800 if isinstance(values, list): 801 if isinstance(node.value, ast.Call): 802 expected = tuple(t for n, t in values) 803 if len(values) == 1: 804 expected = expected[0] 805 elif isinstance(node.value, ast.Tuple): 806 expected = tuple(t for n, t in values) 807 else: 808 raise ValueError('unknown assignment target node', node.value) 809 ty = self._get_inferred_type(node.value, expected) 810 811 if len(targets) == len(values): 812 # TODO(mdan): This should already be a tuple. 813 ty_ = (ty,) if len(values) == 1 else ty 814 for key, value, t in zip(targets, values, ty_): 815 ssa_value, _ = value 816 self.symbol_table.insert_symbol(key, ssa_value, t) 817 elif len(values) == 1: 818 name, tys = values[0] 819 if ty == TFRTypes.TENSOR_LIST: 820 # assign single tensor_list to multiple variables 821 for idx, key in enumerate(targets): 822 idx_name = self._ssa_name('idx') 823 self._emit_with_loc( 824 '\n{} = constant {} : index'.format(idx_name, idx), node) 825 elt_name = self._ssa_name('elt') 826 self.emit('\n{} = tfr.get_element {}[{}]'.format( 827 elt_name, name, idx_name)) 828 self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor', 829 node) 830 self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR) 831 else: 832 # assign single value to multiple targets. This single value is 833 # usually a function return. The return type should be in the tuple of 834 # the value. 835 for idx, key in enumerate(targets): 836 ssa_name = '{}#{}'.format(name, idx) 837 ssa_type = tys[idx] 838 self.symbol_table.insert_symbol(key, ssa_name, ssa_type) 839 elif len(targets) == 1: 840 ssa_names = [n for n, _ in values] 841 self.symbol_table.insert_symbol(targets[0], ssa_names, ty) 842 return 843 844 ty = self._get_inferred_type(node.value, values[1]) 845 self.symbol_table.insert_symbol(targets[0], values[0], ty) 846 847 def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty): 848 assert lhs_ty, rhs_ty 849 if isinstance(op, ast.Sub): 850 code = 'sub' 851 elif isinstance(op, ast.Add): 852 code = 'add' 853 elif isinstance(op, ast.Mult): 854 code = 'mul' 855 elif isinstance(op, ast.Div): 856 code = 'div' 857 else: 858 raise NotImplementedError('BinOp operator not recognized' + op) 859 860 if lhs_ty == TFRTypes.I64 or lhs_ty == TFRTypes.I32: 861 suffix = 'i' 862 elif lhs_ty == TFRTypes.F32: 863 suffix = 'f' 864 else: 865 raise NotImplementedError('BinOp operand type not recognized' + op) 866 867 ret = self._ssa_name(code) 868 self._emit_with_loc( 869 '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty), 870 op) 871 return ret, lhs_ty 872 873 def visit_AugAssign(self, node): 874 lhs, lhs_ty = self.visit(node.target) 875 rhs, rhs_ty = self.visit(node.value) 876 ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) 877 self.symbol_table.insert_symbol(node.target.id, ret, ret_ty) 878 879 def visit_BinOp(self, node): 880 lhs, lhs_ty = self.visit(node.left) 881 rhs, rhs_ty = self.visit(node.right) 882 return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) 883 884 def visit_BoolOp(self, node): 885 values = [self.visit(value) for value in node.values] 886 # TODO(fengliuai): Handle more ast node types. 887 if isinstance(node.op, ast.Or): 888 raise NotImplementedError('Or operator not recognized') 889 elif isinstance(node.op, ast.And): 890 raise NotImplementedError('And operator not recognized') 891 892 def visit_Call(self, node): 893 func_name, func_type = self.visit(node.func) 894 func_type = self._get_inferred_type(node.func, func_type) 895 if func_type == TFRTypes.AG_BUILTIN_FUNC: 896 if func_name == 'if_stmt': 897 cond, _ = self.visit(node.args[0]) 898 body, _ = self.visit(node.args[1]) 899 orelse, _ = self.visit(node.args[2]) 900 get_state, _ = self.visit(node.args[3]) 901 nouts = int(node.args[6].value) 902 out_symbols = [] 903 # The out symbols are just a Tuple of names 904 for out in node.args[5].elts[:nouts]: 905 val, ty = self.symbol_table.lookup(out.value) 906 out_symbols.append(out.value) 907 return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols, 908 node) 909 elif func_name == 'for_stmt': 910 range_ = self._visit_iter(node.args[0]) 911 body, _ = self.visit(node.args[2]) 912 get_state, _ = self.visit(node.args[3]) 913 loop_carried = [out.value for out in node.args[5].elts] 914 # TODO(fengliuai): opt is not used here. 915 return self._visit_for_stmt(range_, body, get_state, loop_carried, node) 916 elif func_name == 'Undefined': 917 val = self._ssa_name(node.args[0].value) 918 return (val, TFRTypes.AG_UNDEFINED_VAL) 919 elif func_name == 'UndefinedReturnValue': 920 val = self._ssa_name('return_val') 921 return (val, TFRTypes.AG_UNDEFINED_VAL) 922 923 if func_type == TFRTypes.TF_RAW_OP: 924 return self._visit_tf_op(func_name, node.args, node.keywords, node) 925 926 if func_type == TFRTypes.TFR_BUILTIN_FUNC: 927 return self._visit_tfr_builtins(func_name, node.args, node) 928 929 if func_type == types.FunctionType: 930 return self._visit_tf_op(func_name, node.args, node.keywords, node) 931 932 if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: 933 return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST) 934 935 if func_type == TFRTypes.PY_BUILTIN_FUNC: 936 if func_name == 'len': 937 arg, ty = self.visit(node.args[0]) 938 ty = self._get_inferred_type(node.args[0], ty) 939 if ty == TFRTypes.TF_TENSOR_SHAPE_LIST: 940 len_value = self._ssa_name('len') 941 self._emit_with_loc( 942 '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format( 943 len_value, arg), node) 944 size_value = self._ssa_name('len_size') 945 self._emit_with_loc( 946 '\n{} = shape.size_to_index {} : !shape.size'.format( 947 size_value, len_value), node) 948 elif ty == TFRTypes.TENSOR_LIST: 949 size_value = self._ssa_name('len') 950 self._emit_with_loc( 951 '\n{} = tfr.get_length {} -> index'.format(size_value, arg), node) 952 return (size_value, TFRTypes.INDEX) 953 954 raise NotImplementedError('call operator not recognized: {} {}'.format( 955 func_name, func_type)) 956 957 def visit_Compare(self, node): 958 lhs, lhs_ty = self.visit(node.left) 959 for op, right in zip(node.ops, node.comparators): 960 rhs, rhs_ty = self.visit(right) 961 if isinstance(op, ast.Eq): 962 pred = 'eq' 963 elif isinstance(op, ast.Lt): 964 pred = 'ult' 965 elif isinstance(op, ast.LtE): 966 pred = 'ule' 967 elif isinstance(op, ast.Gt): 968 pred = 'ugt' 969 elif isinstance(op, ast.GtE): 970 pred = 'uge' 971 elif isinstance(op, ast.NotEq): 972 pred = 'ne' 973 else: 974 raise NotImplementedError('Compare operator not recognized') 975 976 ret = self._ssa_name(pred) 977 if lhs_ty == TFRTypes.ATTR: 978 self._emit_with_loc( 979 '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node) 980 else: 981 if lhs_ty == TFRTypes.I64: 982 code = 'cmpi' 983 elif lhs_ty == TFRTypes.F32: 984 code = 'cmpf' 985 elif lhs_ty == TFRTypes.INDEX: 986 code = 'cmpi' 987 # TODO(fengliuai): the reverse type inference should solve the issue. 988 rhs, _ = self._i64_to_index(rhs, rhs_ty) 989 else: 990 raise NotImplementedError('Compare operand type not recognized') 991 self._emit_with_loc( 992 '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs, 993 lhs_ty), node) 994 995 return ret, TFRTypes.I1 996 997 def visit_Constant(self, node): 998 cst_name = self._ssa_name('cst') 999 if node.value is None: 1000 cst_ty = TFRTypes.NONE 1001 elif isinstance(node.value, bool): 1002 cst_ty = self._get_inferred_type(node) 1003 cst_val = str(node.value).lower() 1004 self._emit_with_loc('\n{} = constant {}'.format(cst_name, cst_val), node) 1005 else: 1006 cst_ty = self._get_inferred_type(node) 1007 cst_val = node.value 1008 if cst_ty == TFRTypes.ATTR: 1009 self._emit_with_loc( 1010 '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty), 1011 node) 1012 else: 1013 self._emit_with_loc( 1014 '\n{} = constant {} : {}'.format(cst_name, cst_val, cst_ty), node) 1015 return cst_name, cst_ty 1016 1017 def visit_FunctionDef(self, node): 1018 op_def, derived_attrs = self._op_defs.lookup(node.name, node, True) 1019 if op_def is None: 1020 # Nested function. Insert it to symbol table for looking up later. 1021 self.symbol_table.insert_symbol(node.name, node, None) 1022 return 1023 op_name = op_def.name 1024 if self.symbol_table.lookup(op_name): 1025 raise LookupError('Composition has not been registered for op: ' + 1026 op_name) 1027 else: 1028 self.symbol_table.insert_symbol(node.name, None, None) 1029 1030 self.symbol_table.enter_scope() 1031 self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name))) 1032 1033 arg_list = [] 1034 idx = 0 1035 max_idx = len(op_def.input_arg) + len(op_def.attr) 1036 for arg in node.args.args: 1037 arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN)) 1038 arg_type = anno.getanno(arg, anno.Static.TYPES)[0] 1039 1040 arg_attr = '' 1041 if idx >= len(op_def.input_arg): 1042 attr_def = op_def.attr[idx - len(op_def.input_arg)] 1043 # skip the derived attributes 1044 while attr_def.name in derived_attrs and (idx + 1) < max_idx: 1045 idx += 1 1046 attr_def = op_def.attr[idx - len(op_def.input_arg)] 1047 if idx >= max_idx: 1048 raise ValueError('Argument is not defined in OpDef: ' + arg_name) 1049 1050 arg_attr += '{{tfr.name="{}"'.format(attr_def.name) 1051 if attr_def.HasField('default_value'): 1052 default_val = _get_val_from_proto(arg_type, attr_def.default_value) 1053 arg_attr += ',tfr.default={}'.format(default_val) 1054 arg_attr += '}' 1055 1056 idx += 1 1057 arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr) 1058 arg_list.append(arg_str) 1059 self.symbol_table.insert_symbol(arg.id, arg_name, arg_type) 1060 1061 ret_type_list = [] 1062 for ret_def in op_def.output_arg: 1063 if ret_def.number_attr or ret_def.type_list_attr: 1064 ret_type_list.append(str(TFRTypes.TENSOR_LIST)) 1065 else: 1066 ret_type_list.append(str(TFRTypes.TENSOR)) 1067 1068 self.emit('{}) -> ({}) {{'.format(', '.join(arg_list), 1069 ', '.join(ret_type_list))) 1070 self.visit_block(node.body) 1071 self._emit_with_loc('\n}', node) 1072 self.symbol_table.exit_scope() 1073 1074 def visit_arguments(self, node): 1075 # TODO(fengliuai): return ordered the types and names. 1076 # We need to order the arguments to match the assumption in the TFR dialect. 1077 raise NotImplementedError('arguments not supported.') 1078 1079 def visit_Lambda(self, node): 1080 raise NotImplementedError('Lambda not supported.') 1081 1082 def _get_mlir_ssa_values(self, name_prefix, out_types): 1083 """Create MLIR convention SSA values.""" 1084 out_ssa_values = [] 1085 if not out_types: 1086 return '', out_ssa_values 1087 1088 out_name = self._ssa_name(name_prefix) 1089 if len(out_types) == 1: 1090 out_name_suffix = '' 1091 out_ssa_values.append(out_name) 1092 else: 1093 # For multiple returns, MLIR uses '%s:i' when they are defined and 1094 # '%s#i' when they are used. 1095 out_name_suffix = ':{}'.format(len(out_types)) 1096 for idx, _ in enumerate(out_types): 1097 out_ssa_values.append('{}#{}'.format(out_name, idx)) 1098 1099 return '{}{}'.format(out_name, out_name_suffix), out_ssa_values 1100 1101 def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols, 1102 node): 1103 self.emit('\n') 1104 ret_str, ret_ssa_values = self._get_mlir_ssa_values( 1105 'if_stmt', [TFRTypes.TENSOR] * len(out_symbols)) 1106 if ret_ssa_values: 1107 self.emit(ret_str + ' = ') 1108 1109 out_types = [] 1110 for symbol, ssa_value in zip(out_symbols, ret_ssa_values): 1111 out_types.append(str(TFRTypes.TENSOR)) 1112 1113 self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types))) 1114 # Create a new scope in case the local variables are leaked. 1115 self.symbol_table.enter_scope(scf_scope=True) 1116 self.visit_block(body_def.body) 1117 self.visit_block(get_state.body) 1118 self.symbol_table.exit_scope() 1119 1120 self.emit('\n} else {') 1121 1122 # Create a new scope in case the local variables are leaked. 1123 self.symbol_table.enter_scope(scf_scope=True) 1124 self.visit_block(orelse_def.body) 1125 self.visit_block(get_state.body) 1126 self.symbol_table.exit_scope() 1127 1128 # add ssa values to the symbol table 1129 for symbol, ssa_value in zip(out_symbols, ret_ssa_values): 1130 self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) 1131 1132 self._emit_with_loc('\n}', node) 1133 return list(zip(ret_ssa_values, out_types)) 1134 1135 def _visit_iter(self, node): 1136 if isinstance(node, ast.Call): 1137 f_name = anno.getanno(node.func, anno.Basic.QN) 1138 if f_name == QN('range'): 1139 args = [self.visit(arg) for arg in node.args] 1140 begin = None 1141 step = None 1142 end = None 1143 if len(args) == 1: 1144 end, end_ty = args[0] 1145 elif len(args) == 2: 1146 begin, begin_ty = args[0] 1147 end, end_ty = args[1] 1148 elif len(args) == 3: 1149 begin, begin_ty = args[0] 1150 end, end_ty = args[1] 1151 step, step_ty = args[2] 1152 1153 if begin is None: 1154 begin = self._ssa_name('begin') 1155 self._emit_with_loc('\n{} = constant 0 : index'.format(begin), node) 1156 elif begin_ty != TFRTypes.INDEX: 1157 begin_ = self._ssa_name('begin') 1158 self._emit_with_loc( 1159 '\n{} = index_cast {} : {} to index'.format( 1160 begin_, begin, begin_ty), node) 1161 begin = begin_ 1162 1163 if end_ty != TFRTypes.INDEX: 1164 end_ = self._ssa_name('end') 1165 self._emit_with_loc( 1166 '\n{} = index_cast {} : {} to index'.format(end_, end, end_ty), 1167 node) 1168 end = end_ 1169 1170 if step is None: 1171 step = self._ssa_name('step') 1172 self._emit_with_loc('\n{} = constant 1 : index'.format(step), node) 1173 elif step_ty != TFRTypes.INDEX: 1174 step_ = self._ssa_name('step') 1175 self._emit_with_loc( 1176 '\n{} = index_cast {} : {} to index'.format(step_, step, step_ty), 1177 node) 1178 step = step_ 1179 1180 return begin, end, step 1181 1182 raise NotImplementedError('Iterator entity not supported.' + node) 1183 1184 def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node): 1185 self.emit('\n') 1186 ret_str, ret_ssa_values = self._get_mlir_ssa_values( 1187 'for_stmt', [TFRTypes.TENSOR] * len(loop_carried)) 1188 if ret_ssa_values: 1189 self.emit(ret_str + ' = ') 1190 1191 # Before enter the loop, we use the original ssa values as the initial 1192 # values to the loop iteration arguments. We also create new ssa values as 1193 # the returns of the scf for statements. The symbol table needs to be 1194 # updated to these new ssa values before it enters the scope of the loop. 1195 out_types = [] 1196 init_values = [] 1197 for symbol, ssa_value in zip(loop_carried, ret_ssa_values): 1198 init, ty = self.symbol_table.lookup(symbol) 1199 self.symbol_table.insert_symbol(symbol, ssa_value, ty) 1200 out_types.append(str(ty)) 1201 init_values.append((init, ty)) 1202 1203 # Create a new scope in case the local variables are leaked. 1204 self.symbol_table.enter_scope(scf_scope=True) 1205 1206 # Create the iteration variable with index type 1207 assert len(body_def.args.args) == 1 1208 it_name = body_def.args.args[0].id 1209 it = self._ssa_name(it_name) 1210 self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX) 1211 1212 self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1], 1213 range_[2])) 1214 if loop_carried: 1215 iter_args = [] 1216 for symbol, init in zip(loop_carried, init_values): 1217 # create new ssa values for the loop carried variables 1218 it_arg = self._ssa_name('it_arg') 1219 self.symbol_table.insert_symbol(symbol, it_arg, init[1]) 1220 iter_args.append('{} = {}'.format(it_arg, init[0])) 1221 self.emit('iter_args({}) '.format(', '.join(iter_args))) 1222 self.emit('-> ({}) {{'.format(', '.join(out_types))) 1223 else: 1224 self.emit(' {') 1225 self.visit_block(body_def.body) 1226 self.visit_block(get_state.body) 1227 self.symbol_table.exit_scope() 1228 self._emit_with_loc('\n}', node) 1229 return list(zip(ret_ssa_values, out_types)) 1230 1231 def _emit_default_constant_from_proto(self, attr_def): 1232 """emit mlir constant statement from default value of the ArgDef proto.""" 1233 name = self._ssa_name('cst') 1234 cst_ty = _get_type_from_proto(None, attr_def) 1235 try: 1236 cst_val = _get_val_from_proto(cst_ty, attr_def.default_value) 1237 except AttributeError: 1238 raise AttributeError( 1239 f'attribute "{attr_def.name}" does not have default_value') 1240 if cst_ty == TFRTypes.ATTR: 1241 self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format( 1242 name, cst_val, cst_ty)) 1243 elif cst_ty == TFRTypes.I1: 1244 self._emit_with_loc('\n{} = constant {}'.format(name, cst_val)) 1245 else: 1246 self._emit_with_loc('\n{} = constant {} : {}'.format( 1247 name, cst_val, cst_ty)) 1248 return name, cst_ty 1249 1250 def visit_keyword(self, node): 1251 return node.arg, self.visit(node.value) 1252 1253 def _visit_tfr_builtins(self, op_name, args, node): 1254 arg_strs = [] 1255 arg_tys = [] 1256 for arg in args: 1257 value, ty = self.visit(arg) 1258 arg_strs.append(value) 1259 arg_tys.append(ty) 1260 tfr_op_name = 'tfr.' + op_name[5:] 1261 ret_tys = ( 1262 TFR_BUILTINS[op_name](*arg_tys) 1263 if callable(TFR_BUILTINS[op_name]) else TFR_BUILTINS[op_name]) 1264 # Convert the tfr builtin returns to a list. 1265 if isinstance(ret_tys, tuple): 1266 ret_tys = list(ret_tys) 1267 else: 1268 ret_tys = [ret_tys] 1269 1270 ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) 1271 1272 arg_str = ', '.join(arg_strs) 1273 arg_ty_str = ', '.join(str(ty) for ty in arg_tys) 1274 ret_ty_str = ', '.join(str(ty) for ty in ret_tys) 1275 self._emit_with_loc('\n{} = {}({}) : ({}) -> ({})'.format( 1276 ret_str, tfr_op_name, arg_str, arg_ty_str, ret_ty_str), node) 1277 return list(zip(ret_ssa_values, ret_tys)) 1278 1279 def _visit_tf_op(self, op_name, args, keywords, node): 1280 op_def, derived_attrs = self._op_defs.lookup(op_name) 1281 ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg] 1282 1283 ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) 1284 1285 arg_strs = [] 1286 ty_strs = [] 1287 for arg in args: 1288 value, ty = self.visit(arg) 1289 arg_strs.append(value) 1290 ty_strs.append(str(ty)) 1291 1292 input_args = [arg for arg in op_def.input_arg] 1293 attrs_no_default = [ 1294 attr for attr in op_def.attr 1295 if not attr.HasField('default_value') and attr.name not in derived_attrs 1296 ] 1297 attrs_with_default = [ 1298 attr for attr in op_def.attr 1299 if attr.HasField('default_value') and attr.name not in derived_attrs 1300 ] 1301 1302 kw_args = {} 1303 for arg in keywords: 1304 value, (ssa_name, ty) = self.visit(arg) 1305 ty = self._get_inferred_type(arg.value, ty) 1306 1307 # TODO(fengliuai): implement the "rename_to" for the customization in 1308 # tensorflow/core/api_def/base_api/* 1309 if value == 'axis': 1310 value = 'split_dim' 1311 1312 kw_args[value] = (ssa_name, ty) 1313 1314 # tensor arguments and attribute arguments 1315 ordered_args = input_args + attrs_no_default + attrs_with_default 1316 for attr_def in ordered_args[len(args):]: 1317 if attr_def.name in kw_args: 1318 value, ty = kw_args[attr_def.name] 1319 if attr_def in input_args: 1320 if ty in _attribute_types: 1321 # the argument shouldn't be used as tf op calls directly. 1322 value, ty = self._value_to_tensor(value, ty, node) 1323 if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def): 1324 value, ty = self._pack_tensor_list(value) 1325 else: 1326 value, ty = self._emit_default_constant_from_proto(attr_def) 1327 arg_strs.append(value) 1328 ty_strs.append(str(ty)) 1329 1330 if ret_ssa_values: 1331 self.emit('\n{} = '.format(ret_str)) 1332 1333 self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name))) 1334 arg_str = ', '.join(arg_strs) 1335 arg_ty_str = ', '.join(ty_strs) 1336 ret_ty_str = ', '.join([str(ty) for ty in ret_tys]) 1337 self._emit_with_loc( 1338 '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node) 1339 return list(zip(ret_ssa_values, ret_tys)) 1340 1341 def visit_If(self, node): 1342 raise NotImplementedError('If not supported.') 1343 1344 def visit_Name(self, node): 1345 val_and_lookup_type = self.symbol_table.lookup(node.id) 1346 if val_and_lookup_type: 1347 (val, lookup_type) = val_and_lookup_type 1348 elif node.id in TFR_BUILTINS: 1349 val = node.id 1350 lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType) 1351 else: 1352 op_def, _ = self._op_defs.lookup(node.id) 1353 val = op_def.name 1354 lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType) 1355 type_ = self._get_inferred_type(node, lookup_type) 1356 return val, type_ 1357 1358 def visit_Return(self, node): 1359 values = self.visit(node.value) 1360 if self.symbol_table.in_scf_scope(): 1361 self.emit('\nscf.yield ') 1362 else: 1363 self.emit('\ntfr.return ') 1364 if not values: 1365 return 1366 1367 if isinstance(values, list): 1368 vals, tys = zip(*values) 1369 else: 1370 vals = values[0] 1371 tys = values[1] 1372 1373 if isinstance(tys, list) or isinstance(tys, tuple): 1374 tys = [str(t) for t in tys] 1375 self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)), 1376 node) 1377 elif tys != TFRTypes.NONE: 1378 # TODO(fengliuai): scf region yield uses this branch. Fix it. 1379 self._emit_with_loc('{} : {}'.format(vals, tys), node) 1380 1381 def visit_Subscript(self, node): 1382 val, ty = self.visit(node.value) 1383 type_ = self._get_inferred_type(node.value, ty) 1384 1385 # TODO(fengliuai): Here we hardcode the node.slice here to get the index 1386 # type. Use the visit method once the type inference is done. 1387 # slice_val, slice_ty = self.visit(node.slice) 1388 s = node.slice 1389 if not isinstance(s, (ast.Tuple, ast.Slice)): 1390 if isinstance(s, ast.Constant): 1391 # TODO(fengliuai): promote to an assignment 1392 idx_val = self._ssa_name('cst') 1393 self._emit_with_loc( 1394 '\n{} = constant {} : index'.format(idx_val, s.value), node) 1395 else: 1396 idx_val, _ = self.visit(s) 1397 else: 1398 raise NotImplementedError('non-index slice not supported.') 1399 1400 elt = self._ssa_name('elt') 1401 if type_ == TFRTypes.TENSOR_LIST: 1402 self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val)) 1403 self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node) 1404 return (elt, TFRTypes.TENSOR) 1405 elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST: 1406 size_ = self._ssa_name('size') 1407 self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val)) 1408 self._emit_with_loc(': !shape.shape, index -> !shape.size', node) 1409 self._emit_with_loc( 1410 '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_), 1411 node) 1412 return (elt, TFRTypes.INDEX) 1413 1414 def visit_List(self, node): 1415 out_type = self._get_inferred_type(node) 1416 vals = [] 1417 tys = [] 1418 for elt in node.elts: 1419 val, ty = self.visit(elt) 1420 ty = self._get_inferred_type(elt, ty) 1421 if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST: 1422 # This list is a tensor list, then cast all the input values to tensors. 1423 val, ty = self._value_to_tensor(val, ty, node) 1424 else: 1425 # We shouldn't use index type to build the list because list will be use 1426 # as attribute. 1427 val, ty = self._index_to_I64(val, ty) 1428 vals.append(val) 1429 tys.append(str(ty)) 1430 1431 list_val = self._ssa_name('list') 1432 self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals))) 1433 self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node) 1434 return (list_val, out_type) 1435 1436 def visit_Tuple(self, node): 1437 return [self.visit(elt) for elt in node.elts] 1438 1439 def visit_UnaryOp(self, node): 1440 value, ty = self.visit(node.operand) 1441 if isinstance(node.op, ast.USub): 1442 zero_value = self._ssa_name('zero') 1443 ssa_value = self._ssa_name('cst') 1444 if ty == TFRTypes.I32 or ty == TFRTypes.I64: 1445 self._emit_with_loc( 1446 '\n{} = constant 0 : {}'.format(zero_value, ty), node) 1447 self._emit_with_loc( 1448 '\n{} = subi {}, {} : {}'.format(ssa_value, zero_value, value, ty), 1449 node) 1450 elif ty == TFRTypes.F32: 1451 self._emit_with_loc( 1452 '\n{} = constant 0.0 : {}'.format(zero_value, ty), node) 1453 self._emit_with_loc( 1454 '\n{} = subf {}, {} : {}'.format(ssa_value, zero_value, value, ty), 1455 node) 1456 else: 1457 raise NotImplementedError('USub type not recognized: ' + str(ty)) 1458 return ssa_value, ty 1459 raise NotImplementedError('USub operator not recognized') 1460 1461 def visit_For(self, node): 1462 raise NotImplementedError('For operator not recognized') 1463 1464 def visit_While(self, node): 1465 raise NotImplementedError('While operator not recognized') 1466 1467 def visit_Try(self, node): 1468 # Only handles the body of the try statement. 1469 self.visit_block(node.body) 1470 1471 1472def _apply_py_to_tf_passes(node, ctx): 1473 """Apply transformations from PyToTF to match tf.function tracing.""" 1474 # TODO(fengliuai): we don't know which passes are required, thus we evaluate 1475 # each one when the corresponding node is handled. 1476 # copied from PyToTF.transform_ast 1477 node = return_statements.transform(node, ctx, False) 1478 node = control_flow.transform(node, ctx) 1479 return node 1480 1481 1482class TfrGen(transpiler.GenericTranspiler): 1483 """Transforms Python objects into TFR MLIR source code.""" 1484 1485 def __init__(self, op_defs): 1486 self._op_defs = op_defs 1487 1488 def transform_ast(self, node, ctx): 1489 node = _apply_py_to_tf_passes(node, ctx) 1490 # TODO(mdan): Enable this. 1491 # node = anf.transform(node, ctx) 1492 1493 graphs = cfg.build(node) 1494 node = qual_names.resolve(node) 1495 node = activity.resolve(node, ctx) 1496 node = reaching_definitions.resolve(node, ctx, graphs) 1497 node = reaching_fndefs.resolve(node, ctx, graphs) 1498 node = type_inference.resolve(node, ctx, graphs, 1499 TFRTypeResolver(self._op_defs)) 1500 1501 mlir_generator = TFRGen(ctx, self._op_defs) 1502 mlir_generator.visit(node) 1503 return mlir_generator.code_buffer 1504 1505 1506def tfr_gen(func, op_defs): 1507 """Parse a function and emit the TFR functions.""" 1508 mlir_code, _ = TfrGen(op_defs).transform(func, None) 1509 assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code) 1510 return mlir_code 1511 1512 1513def tfr_funcs_gen_from_module(source, op_defs, method_prefix=None, 1514 op_libraries=None): 1515 """Parse the input source module and emit the TFR functions.""" 1516 1517 # Load the op library so the op is added to the op registry. This is 1518 # required when the op cc_library couldn't be statically linked in open 1519 # source. 1520 # This is a no op if the op shared library couldn't be found in the same 1521 # directory of the op Python API. 1522 # TODO(fengliuai): make the .so file path configurable. 1523 if op_libraries: 1524 prefix_len = len('gen_') 1525 for m in op_libraries: 1526 lib_dir = os.path.dirname(m.__file__) 1527 lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so') 1528 lib_path = os.path.join(lib_dir, lib_name) 1529 if os.path.exists(lib_path): 1530 logging.info('load file: ' + lib_path) 1531 load_library.load_op_library(lib_path) 1532 else: 1533 # The op library is generated from the source module, then we load all the 1534 # .so file in the directory 1535 lib_dir = os.path.dirname(source.__file__) 1536 for lib_name in os.listdir(lib_dir): 1537 if lib_name.endswith('.so'): 1538 lib_path = os.path.join(lib_dir, lib_name) 1539 logging.info('load file: ' + lib_path) 1540 load_library.load_op_library(lib_path) 1541 1542 py_funcs = [ 1543 func 1544 for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) 1545 if not method_prefix or name.startswith(method_prefix) 1546 ] 1547 # Sort the methods by the line number, to make sure the definitions are 1548 # processed before the usages. 1549 # TODO(fengliuai): Use type inference resolver to recursively process any 1550 # functions called. 1551 py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno) 1552 mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs] 1553 1554 return mlir_funcs 1555 1556 1557def tfr_gen_from_module(source, method_prefix=None, op_libraries=None, 1558 op_defs=OpDefCache()): 1559 """Parse the input source module and emit the TFR and external functions.""" 1560 mlir_funcs = tfr_funcs_gen_from_module( 1561 source, op_defs, method_prefix, op_libraries) 1562 1563 return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs()) 1564