1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""mlir_gen: Generate mlir code from python code.""" 16 17# pylint: disable=invalid-name 18# pylint: disable=missing-function-docstring 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import gast as ast 25from tensorflow.python.autograph.pyct import anno 26from tensorflow.python.autograph.pyct import cfg 27from tensorflow.python.autograph.pyct import inspect_utils 28from tensorflow.python.autograph.pyct import naming 29from tensorflow.python.autograph.pyct import parser 30from tensorflow.python.autograph.pyct import qual_names 31from tensorflow.python.autograph.pyct import transformer 32from tensorflow.python.autograph.pyct.static_analysis import activity 33from tensorflow.python.autograph.pyct.static_analysis import annos 34from tensorflow.python.autograph.pyct.static_analysis import liveness 35from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 36from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 37import tensorflow.python.tf_program.pywrap_tfd as tfp 38from tensorflow.python.types import core 39 40 41class SymbolTable(object): 42 """Symbol Table for python code.""" 43 44 def __init__(self): 45 self.symbols = [] 46 self.enter_scope() 47 48 def enter_scope(self): 49 """Enter a new scope - at function level.""" 50 self.symbols.append({'types': {}, 'symbols': {}}) 51 self.curr_table = self.symbols[len(self.symbols) - 1] 52 53 def insert_symbol(self, name, value): 54 self.curr_table['symbols'][name] = value 55 self.curr_table['types'][name] = value.getType() 56 return value 57 58 def insert_type(self, name, type_): 59 self.curr_table['types'][name] = type_ 60 61 def exit_scope(self): 62 self.symbols.pop() 63 self.curr_table = self.symbols[len(self.symbols) - 1] 64 65 def lookup(self, name): 66 curr_idx = len(self.symbols) - 1 67 while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): 68 curr_idx -= 1 69 if curr_idx < 0: 70 return None 71 return self.symbols[curr_idx]['symbols'][name] 72 73 def lookup_type(self, name): 74 curr_idx = len(self.symbols) - 1 75 while curr_idx >= 0 and (name not in self.symbols[curr_idx]['types']): 76 curr_idx -= 1 77 if curr_idx < 0: 78 return None 79 return self.symbols[curr_idx]['types'][name] 80 81 def __repr__(self): 82 s = '\n'.join( 83 ' ' * idx * 2 + str(table) for idx, table in enumerate(self.symbols)) 84 return s 85 86 87class ProcessType(ast.NodeVisitor): 88 """Visit a node and return processed type Currently only visits annotations and gives their type. 89 """ 90 91 def __init__(self, prog, ctx): 92 self.prog = prog 93 self.ctx = ctx 94 95 def visit_Attribute(self, node): 96 # Supported: core.Tensor 97 value = self.visit(node.value) 98 if value is None or not hasattr(value, node.attr): 99 raise AttributeError(str(type(value)) + ' has no attribute ' + node.attr) 100 attr = getattr(value, node.attr) 101 102 if attr == core.Tensor: 103 return tfp.UnrankedTensorType.get(tfp.IntegerType.get(self.prog.ctx, 32)) 104 return attr 105 106 def visit_Name(self, node): 107 if node.id == 'int': 108 return tfp.IntegerType.get(self.prog.ctx, 32) 109 if node.id == 'bool': 110 return tfp.IntegerType.get(self.prog.ctx, 1) 111 if node.id in self.ctx.info.namespace: 112 return self.ctx.info.namespace[node.id] 113 114 115class MLIRGen(ast.NodeVisitor): 116 """Visit the AST and generate MLIR code Requires liveness, reading_definitions. 117 """ 118 119 def __init__(self, ctx): 120 self.ctx = ctx 121 self.symbol_table = SymbolTable() 122 self.prog = tfp.TFProgram() 123 self.opbuilder = None 124 125 def visit_block(self, block): 126 return [self.visit(item) for item in block] 127 128 def process_type(self, node): 129 return ProcessType(self.prog, self.ctx).visit(node) 130 131 def visit_Assign(self, node): 132 value = self.visit(node.value) 133 if isinstance(value, tuple): 134 # If it is a tuple of values, assign one to each in targets 135 # TODO: This currently is assuming that all elts in targets[0] are Name 136 # objects. This might not be always True. 137 for key, val in zip(node.targets[0].elts, value): 138 self.symbol_table.insert_symbol(key.id, val) 139 else: 140 self.symbol_table.insert_symbol(node.targets[0].id, value) 141 142 def visit_BinOp(self, node): 143 left = self.visit(node.left) 144 right = self.visit(node.right) 145 if isinstance(node.op, ast.Sub): 146 return tfp.Tf_SubOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), 147 left, right).getResult(0) 148 if isinstance(node.op, ast.Add): 149 return tfp.Tf_AddV2Op.create(self.opbuilder, 150 self.opbuilder.getUnknownLoc(), left, 151 right).getResult(0) 152 153 def visit_BoolOp(self, node): 154 values = [self.visit(value) for value in node.values] 155 if isinstance(node.op, ast.Or): 156 return tfp.OrOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), 157 values).getResult(0) 158 if isinstance(node.op, ast.And): 159 return tfp.AndOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), 160 values).getResult(0) 161 162 def visit_Call(self, node): 163 func = self.visit(node.func) 164 args = [self.visit(arg) for arg in node.args] 165 callop = tfp.Tf_LegacyCallOp.create(self.opbuilder, 166 self.opbuilder.getUnknownLoc(), 167 func.getType().getResults(), args, 168 func.getName()) 169 if callop.getNumResults() == 1: 170 return callop[0] 171 return tuple(callop.getResult(idx) for idx in range(callop.getNumResults())) 172 173 def visit_Compare(self, node): 174 left = self.visit(node.left) 175 opb = self.opbuilder 176 for op, right in zip(node.ops, node.comparators): 177 if isinstance(op, ast.Eq): 178 left = tfp.Tf_EqualOp.create(opb, opb.getUnknownLoc(), left, 179 self.visit(right)).getResult(0) 180 elif isinstance(op, ast.Lt): 181 left = tfp.Tf_LessOp.create(opb, opb.getUnknownLoc(), left, 182 self.visit(right)).getResult(0) 183 elif isinstance(op, ast.LtE): 184 left = tfp.Tf_LessEqualOp.create(opb, opb.getUnknownLoc(), left, 185 self.visit(right)).getResult(0) 186 elif isinstance(op, ast.Gt): 187 left = tfp.Tf_GreaterOp.create(opb, opb.getUnknownLoc(), left, 188 self.visit(right)).getResult(0) 189 elif isinstance(op, ast.GtE): 190 left = tfp.Tf_GreaterEqualOp.create(opb, opb.getUnknownLoc(), left, 191 self.visit(right)).getResult(0) 192 elif isinstance(op, ast.NotEq): 193 left = tfp.Tf_NotEqualOp.create(opb, opb.getUnknownLoc(), left, 194 self.visit(right)).getResult(0) 195 else: 196 raise NotImplementedError('CompareOp operator not recognized') 197 return left 198 199 def visit_Constant(self, node): 200 opb = self.opbuilder 201 value = None 202 if isinstance(node.value, int): 203 value = tfp.Tf_ConstOp.create( 204 opb, opb.getUnknownLoc(), 205 tfp.IntegerAttr.get( 206 tfp.IntegerType.get(self.prog.ctx, 32), node.value)).getResult(0) 207 return value 208 209 def visit_FunctionDef(self, node): 210 # Cache the current builder 211 cache_builder = self.opbuilder 212 inputs, outputs = [], [] 213 214 for arg in node.args.args: 215 inputs.append(self.process_type(arg.annotation)) 216 217 if node.returns: 218 outputs = [self.process_type(node.returns)] 219 220 currfunc = self.prog.add_function( 221 self.ctx.namer.new_symbol(node.name, []), 222 self.prog.get_function_type(inputs, outputs)) 223 224 # Add the function to symbol table and enter new scope 225 self.symbol_table.insert_symbol(node.name, currfunc) 226 self.symbol_table.enter_scope() 227 228 # Add arguments to symbol table 229 for arg, value in zip(node.args.args, currfunc.getArguments()): 230 self.symbol_table.insert_symbol(arg.id, value) 231 self.opbuilder = tfp.OpBuilder(currfunc.getBody()) 232 233 self.visit_block(node.body) 234 self.symbol_table.exit_scope() 235 self.opbuilder = cache_builder 236 237 def visit_If(self, node): 238 cond = self.visit(node.test) 239 240 # Create ifop 241 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 242 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 243 modified_in_cond = list(body_scope.modified | orelse_scope.modified) 244 outputs = [ 245 self.symbol_table.lookup_type(str(var)) for var in modified_in_cond 246 ] 247 ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), cond, 248 outputs) 249 250 # Cache the builder 251 cache_builder = self.opbuilder 252 253 # Visit body 254 self.opbuilder = tfp.OpBuilder(ifop.getRegion(0)) 255 # Enter scope to avoid values generated inside the region to come in symbol 256 # table 257 self.symbol_table.enter_scope() 258 for stmt in node.body: 259 self.visit(stmt) 260 retvals = [ 261 self.symbol_table.lookup(str(varname)) for varname in modified_in_cond 262 ] 263 tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals) 264 self.symbol_table.exit_scope() 265 266 # Visit orelse 267 self.opbuilder = tfp.OpBuilder(ifop.getRegion(1)) 268 self.symbol_table.enter_scope() 269 for stmt in node.orelse: 270 self.visit(stmt) 271 retvals = [ 272 self.symbol_table.lookup(str(varname)) for varname in modified_in_cond 273 ] 274 tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals) 275 self.symbol_table.exit_scope() 276 277 # Reset builder and enter return values in symbol table 278 self.opbuilder = cache_builder 279 for idx, var in enumerate(modified_in_cond): 280 self.symbol_table.insert_symbol(str(var), ifop.getResult(idx)) 281 282 if ifop.getNumResults() == 1: 283 return ifop.getResult(0) 284 285 return tuple(ifop.getResult(i) for i in range(ifop.getNumResults())) 286 287 def visit_Name(self, node): 288 if self.symbol_table.lookup(node.id): 289 return self.symbol_table.lookup(node.id) 290 raise NotImplementedError('Symbol not found' + node.id) 291 292 def visit_Return(self, node): 293 opb = self.opbuilder 294 value = self.visit(node.value) 295 if isinstance(value, tuple): 296 # For more than one return values 297 return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), list(value)) 298 return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), [value]) 299 300 def visit_Tuple(self, node): 301 return tuple(self.visit(elt) for elt in node.elts) 302 303 def visit_UnaryOp(self, node): 304 operand = self.visit(node.operand) 305 if isinstance(node.op, ast.USub): 306 return tfp.Tf_NegOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), 307 operand).getResult(0) 308 309 def _get_basic_loop_vars(self, modified, live_in, live_out): 310 # [This is directly from 311 # tensorflow/python/autograph/converters/control_flow.py] 312 # The loop variables corresponding to simple symbols (e.g. `x`). 313 basic_loop_vars = [] 314 for s in modified: 315 if s.is_composite(): 316 # TODO: Raise an error when this happens for a TF loop. 317 continue 318 # Variables not live into or out of the loop are considered local to the 319 # loop. 320 if s not in live_in and s not in live_out: 321 continue 322 basic_loop_vars.append(s) 323 return frozenset(basic_loop_vars) 324 325 def _get_composite_loop_vars(self, modified, live_in): 326 # [This is directly from 327 # tensorflow/python/autograph/converters/control_flow.py] 328 # The loop variables corresponding to composite symbols (e.g. `self.x`). 329 composite_loop_vars = [] 330 for s in modified: 331 if not s.is_composite(): 332 continue 333 # Mutations made to objects created inside the loop will appear as writes 334 # to composite symbols. Because these mutations appear as modifications 335 # made to composite symbols, we check whether the composite's parent is 336 # actually live into the loop. 337 # Example: 338 # while cond: 339 # x = Foo() 340 # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. 341 # 342 # Note that some parents might not be symbols - for example, in x['foo'], 343 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the 344 # liveness of literals. 345 support_set_symbols = tuple( 346 sss for sss in s.support_set if sss.is_symbol()) 347 if not all(sss in live_in for sss in support_set_symbols): 348 continue 349 composite_loop_vars.append(s) 350 return frozenset(composite_loop_vars) 351 352 def _get_loop_vars(self, node, modified): 353 # [This is directly from python/autograph/converters/control_flow.py] 354 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 355 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 356 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 357 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 358 reserved_symbols = body_scope.referenced 359 360 basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out) 361 composite_loop_vars = self._get_composite_loop_vars(modified, live_in) 362 loop_vars = tuple(basic_loop_vars | composite_loop_vars) 363 364 # Variable that are used or defined inside the loop, but not defined 365 # before entering the loop. Only simple variables must be defined. The 366 # composite ones will be implicitly checked at runtime. 367 undefined_lives = basic_loop_vars - defined_in 368 369 return loop_vars, reserved_symbols, undefined_lives 370 371 def visit_While(self, node): 372 373 # Create a new WhileOp 374 # `inputs` are initial values for loop variables 375 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 376 loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified) 377 inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars] 378 types = [input_.getType() for input_ in inputs] 379 while_op = tfp.WhileOp.create(self.opbuilder, 380 self.opbuilder.getUnknownLoc(), inputs, types) 381 382 # cache the current builder 383 cache_builder = self.opbuilder 384 385 # Process cond 386 self.symbol_table.enter_scope() 387 for input_, type_ in zip(loop_vars, types): 388 self.symbol_table.insert_symbol( 389 str(input_), 390 while_op.getRegion(0).front().addArgument(type_)) 391 self.opbuilder = tfp.OpBuilder(while_op.getRegion(0)) 392 tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), 393 [self.visit(node.test)]) 394 self.symbol_table.exit_scope() 395 396 # Process body 397 self.symbol_table.enter_scope() 398 for input_, type_ in zip(loop_vars, types): 399 self.symbol_table.insert_symbol( 400 str(input_), 401 while_op.getRegion(1).front().addArgument(type_)) 402 self.opbuilder = tfp.OpBuilder(while_op.getRegion(1)) 403 self.visit_block(node.body) 404 tfp.ReturnOp.create( 405 self.opbuilder, self.opbuilder.getUnknownLoc(), 406 [self.symbol_table.lookup(str(name)) for name in loop_vars]) 407 self.symbol_table.exit_scope() 408 409 # Enter new values as symbols 410 for idx, var in enumerate(loop_vars): 411 self.symbol_table.insert_symbol(str(var), while_op.getResult(idx)) 412 413 # Restore builder 414 self.opbuilder = cache_builder 415 416 417def mlir_gen_internal(node, entity_info): 418 """Returns mlir module for unprocessed node `node`.""" 419 namer = naming.Namer({}) 420 graphs = cfg.build(node) 421 ctx = transformer.Context(entity_info, namer, None) 422 node = qual_names.resolve(node) 423 node = activity.resolve(node, ctx) 424 node = reaching_definitions.resolve(node, ctx, graphs) 425 node = reaching_fndefs.resolve(node, ctx, graphs) 426 node = liveness.resolve(node, ctx, graphs) 427 mlir_generator = MLIRGen(ctx) 428 mlir_generator.visit(node) 429 return mlir_generator.prog 430 431 432def mlir_gen(func): 433 """Parse a function and return TFProgram.""" 434 node, source = parser.parse_entity(func, future_features=()) 435 entity_info = transformer.EntityInfo( 436 name=func.__name__, 437 source_code=source, 438 source_file=None, 439 future_features=(), 440 namespace=inspect_utils.getnamespace(func)) 441 return mlir_gen_internal(node, entity_info) 442 443 444def mlir_gen_from_source(source=None, src_file=None): 445 """Parse a function as either a string or from a supplied file path and return a TFProgram. 446 """ 447 if source is None: 448 source = open(src_file).read() 449 node = ast.parse(source) 450 entity_info = transformer.EntityInfo( 451 name='mlir_module', 452 source_code=source, 453 source_file=None, 454 future_features=(), 455 namespace={}) 456 return mlir_gen_internal(node, entity_info) 457