• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""mlir_gen: Generate mlir code from python code."""
16
17# pylint: disable=invalid-name
18# pylint: disable=missing-function-docstring
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import gast as ast
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import cfg
27from tensorflow.python.autograph.pyct import inspect_utils
28from tensorflow.python.autograph.pyct import naming
29from tensorflow.python.autograph.pyct import parser
30from tensorflow.python.autograph.pyct import qual_names
31from tensorflow.python.autograph.pyct import transformer
32from tensorflow.python.autograph.pyct.static_analysis import activity
33from tensorflow.python.autograph.pyct.static_analysis import annos
34from tensorflow.python.autograph.pyct.static_analysis import liveness
35from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
36from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
37import tensorflow.python.tf_program.pywrap_tfd as tfp
38from tensorflow.python.types import core
39
40
41class SymbolTable(object):
42  """Symbol Table for python code."""
43
44  def __init__(self):
45    self.symbols = []
46    self.enter_scope()
47
48  def enter_scope(self):
49    """Enter a new scope - at function level."""
50    self.symbols.append({'types': {}, 'symbols': {}})
51    self.curr_table = self.symbols[len(self.symbols) - 1]
52
53  def insert_symbol(self, name, value):
54    self.curr_table['symbols'][name] = value
55    self.curr_table['types'][name] = value.getType()
56    return value
57
58  def insert_type(self, name, type_):
59    self.curr_table['types'][name] = type_
60
61  def exit_scope(self):
62    self.symbols.pop()
63    self.curr_table = self.symbols[len(self.symbols) - 1]
64
65  def lookup(self, name):
66    curr_idx = len(self.symbols) - 1
67    while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']):
68      curr_idx -= 1
69    if curr_idx < 0:
70      return None
71    return self.symbols[curr_idx]['symbols'][name]
72
73  def lookup_type(self, name):
74    curr_idx = len(self.symbols) - 1
75    while curr_idx >= 0 and (name not in self.symbols[curr_idx]['types']):
76      curr_idx -= 1
77    if curr_idx < 0:
78      return None
79    return self.symbols[curr_idx]['types'][name]
80
81  def __repr__(self):
82    s = '\n'.join(
83        ' ' * idx * 2 + str(table) for idx, table in enumerate(self.symbols))
84    return s
85
86
87class ProcessType(ast.NodeVisitor):
88  """Visit a node and return processed type Currently only visits annotations and gives their type.
89  """
90
91  def __init__(self, prog, ctx):
92    self.prog = prog
93    self.ctx = ctx
94
95  def visit_Attribute(self, node):
96    # Supported: core.Tensor
97    value = self.visit(node.value)
98    if value is None or not hasattr(value, node.attr):
99      raise AttributeError(str(type(value)) + ' has no attribute ' + node.attr)
100    attr = getattr(value, node.attr)
101
102    if attr == core.Tensor:
103      return tfp.UnrankedTensorType.get(tfp.IntegerType.get(self.prog.ctx, 32))
104    return attr
105
106  def visit_Name(self, node):
107    if node.id == 'int':
108      return tfp.IntegerType.get(self.prog.ctx, 32)
109    if node.id == 'bool':
110      return tfp.IntegerType.get(self.prog.ctx, 1)
111    if node.id in self.ctx.info.namespace:
112      return self.ctx.info.namespace[node.id]
113
114
115class MLIRGen(ast.NodeVisitor):
116  """Visit the AST and generate MLIR code Requires liveness, reading_definitions.
117  """
118
119  def __init__(self, ctx):
120    self.ctx = ctx
121    self.symbol_table = SymbolTable()
122    self.prog = tfp.TFProgram()
123    self.opbuilder = None
124
125  def visit_block(self, block):
126    return [self.visit(item) for item in block]
127
128  def process_type(self, node):
129    return ProcessType(self.prog, self.ctx).visit(node)
130
131  def visit_Assign(self, node):
132    value = self.visit(node.value)
133    if isinstance(value, tuple):
134      # If it is a tuple of values, assign one to each in targets
135      # TODO: This currently is assuming that all elts in targets[0] are Name
136      # objects. This might not be always True.
137      for key, val in zip(node.targets[0].elts, value):
138        self.symbol_table.insert_symbol(key.id, val)
139    else:
140      self.symbol_table.insert_symbol(node.targets[0].id, value)
141
142  def visit_BinOp(self, node):
143    left = self.visit(node.left)
144    right = self.visit(node.right)
145    if isinstance(node.op, ast.Sub):
146      return tfp.Tf_SubOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
147                                 left, right).getResult(0)
148    if isinstance(node.op, ast.Add):
149      return tfp.Tf_AddV2Op.create(self.opbuilder,
150                                   self.opbuilder.getUnknownLoc(), left,
151                                   right).getResult(0)
152
153  def visit_BoolOp(self, node):
154    values = [self.visit(value) for value in node.values]
155    if isinstance(node.op, ast.Or):
156      return tfp.OrOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
157                             values).getResult(0)
158    if isinstance(node.op, ast.And):
159      return tfp.AndOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
160                              values).getResult(0)
161
162  def visit_Call(self, node):
163    func = self.visit(node.func)
164    args = [self.visit(arg) for arg in node.args]
165    callop = tfp.Tf_LegacyCallOp.create(self.opbuilder,
166                                        self.opbuilder.getUnknownLoc(),
167                                        func.getType().getResults(), args,
168                                        func.getName())
169    if callop.getNumResults() == 1:
170      return callop[0]
171    return tuple(callop.getResult(idx) for idx in range(callop.getNumResults()))
172
173  def visit_Compare(self, node):
174    left = self.visit(node.left)
175    opb = self.opbuilder
176    for op, right in zip(node.ops, node.comparators):
177      if isinstance(op, ast.Eq):
178        left = tfp.Tf_EqualOp.create(opb, opb.getUnknownLoc(), left,
179                                     self.visit(right)).getResult(0)
180      elif isinstance(op, ast.Lt):
181        left = tfp.Tf_LessOp.create(opb, opb.getUnknownLoc(), left,
182                                    self.visit(right)).getResult(0)
183      elif isinstance(op, ast.LtE):
184        left = tfp.Tf_LessEqualOp.create(opb, opb.getUnknownLoc(), left,
185                                         self.visit(right)).getResult(0)
186      elif isinstance(op, ast.Gt):
187        left = tfp.Tf_GreaterOp.create(opb, opb.getUnknownLoc(), left,
188                                       self.visit(right)).getResult(0)
189      elif isinstance(op, ast.GtE):
190        left = tfp.Tf_GreaterEqualOp.create(opb, opb.getUnknownLoc(), left,
191                                            self.visit(right)).getResult(0)
192      elif isinstance(op, ast.NotEq):
193        left = tfp.Tf_NotEqualOp.create(opb, opb.getUnknownLoc(), left,
194                                        self.visit(right)).getResult(0)
195      else:
196        raise NotImplementedError('CompareOp operator not recognized')
197    return left
198
199  def visit_Constant(self, node):
200    opb = self.opbuilder
201    value = None
202    if isinstance(node.value, int):
203      value = tfp.Tf_ConstOp.create(
204          opb, opb.getUnknownLoc(),
205          tfp.IntegerAttr.get(
206              tfp.IntegerType.get(self.prog.ctx, 32), node.value)).getResult(0)
207    return value
208
209  def visit_FunctionDef(self, node):
210    # Cache the current builder
211    cache_builder = self.opbuilder
212    inputs, outputs = [], []
213
214    for arg in node.args.args:
215      inputs.append(self.process_type(arg.annotation))
216
217    if node.returns:
218      outputs = [self.process_type(node.returns)]
219
220    currfunc = self.prog.add_function(
221        self.ctx.namer.new_symbol(node.name, []),
222        self.prog.get_function_type(inputs, outputs))
223
224    # Add the function to symbol table and enter new scope
225    self.symbol_table.insert_symbol(node.name, currfunc)
226    self.symbol_table.enter_scope()
227
228    # Add arguments to symbol table
229    for arg, value in zip(node.args.args, currfunc.getArguments()):
230      self.symbol_table.insert_symbol(arg.id, value)
231    self.opbuilder = tfp.OpBuilder(currfunc.getBody())
232
233    self.visit_block(node.body)
234    self.symbol_table.exit_scope()
235    self.opbuilder = cache_builder
236
237  def visit_If(self, node):
238    cond = self.visit(node.test)
239
240    # Create ifop
241    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
242    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
243    modified_in_cond = list(body_scope.modified | orelse_scope.modified)
244    outputs = [
245        self.symbol_table.lookup_type(str(var)) for var in modified_in_cond
246    ]
247    ifop = tfp.IfOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), cond,
248                           outputs)
249
250    # Cache the builder
251    cache_builder = self.opbuilder
252
253    # Visit body
254    self.opbuilder = tfp.OpBuilder(ifop.getRegion(0))
255    # Enter scope to avoid values generated inside the region to come in symbol
256    # table
257    self.symbol_table.enter_scope()
258    for stmt in node.body:
259      self.visit(stmt)
260    retvals = [
261        self.symbol_table.lookup(str(varname)) for varname in modified_in_cond
262    ]
263    tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals)
264    self.symbol_table.exit_scope()
265
266    # Visit orelse
267    self.opbuilder = tfp.OpBuilder(ifop.getRegion(1))
268    self.symbol_table.enter_scope()
269    for stmt in node.orelse:
270      self.visit(stmt)
271    retvals = [
272        self.symbol_table.lookup(str(varname)) for varname in modified_in_cond
273    ]
274    tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(), retvals)
275    self.symbol_table.exit_scope()
276
277    # Reset builder and enter return values in symbol table
278    self.opbuilder = cache_builder
279    for idx, var in enumerate(modified_in_cond):
280      self.symbol_table.insert_symbol(str(var), ifop.getResult(idx))
281
282    if ifop.getNumResults() == 1:
283      return ifop.getResult(0)
284
285    return tuple(ifop.getResult(i) for i in range(ifop.getNumResults()))
286
287  def visit_Name(self, node):
288    if self.symbol_table.lookup(node.id):
289      return self.symbol_table.lookup(node.id)
290    raise NotImplementedError('Symbol not found' + node.id)
291
292  def visit_Return(self, node):
293    opb = self.opbuilder
294    value = self.visit(node.value)
295    if isinstance(value, tuple):
296      # For more than one return values
297      return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), list(value))
298    return tfp.ReturnOp.create(opb, opb.getUnknownLoc(), [value])
299
300  def visit_Tuple(self, node):
301    return tuple(self.visit(elt) for elt in node.elts)
302
303  def visit_UnaryOp(self, node):
304    operand = self.visit(node.operand)
305    if isinstance(node.op, ast.USub):
306      return tfp.Tf_NegOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
307                                 operand).getResult(0)
308
309  def _get_basic_loop_vars(self, modified, live_in, live_out):
310    # [This is directly from
311    # tensorflow/python/autograph/converters/control_flow.py]
312    # The loop variables corresponding to simple symbols (e.g. `x`).
313    basic_loop_vars = []
314    for s in modified:
315      if s.is_composite():
316        # TODO: Raise an error when this happens for a TF loop.
317        continue
318      # Variables not live into or out of the loop are considered local to the
319      # loop.
320      if s not in live_in and s not in live_out:
321        continue
322      basic_loop_vars.append(s)
323    return frozenset(basic_loop_vars)
324
325  def _get_composite_loop_vars(self, modified, live_in):
326    # [This is directly from
327    # tensorflow/python/autograph/converters/control_flow.py]
328    # The loop variables corresponding to composite symbols (e.g. `self.x`).
329    composite_loop_vars = []
330    for s in modified:
331      if not s.is_composite():
332        continue
333      # Mutations made to objects created inside the loop will appear as writes
334      # to composite symbols. Because these mutations appear as modifications
335      # made to composite symbols, we check whether the composite's parent is
336      # actually live into the loop.
337      # Example:
338      #   while cond:
339      #     x = Foo()
340      #     x.foo = 2 * x.foo  # x.foo is live into the loop, but x is not.
341      #
342      # Note that some parents might not be symbols - for example, in x['foo'],
343      # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
344      # liveness of literals.
345      support_set_symbols = tuple(
346          sss for sss in s.support_set if sss.is_symbol())
347      if not all(sss in live_in for sss in support_set_symbols):
348        continue
349      composite_loop_vars.append(s)
350    return frozenset(composite_loop_vars)
351
352  def _get_loop_vars(self, node, modified):
353    # [This is directly from python/autograph/converters/control_flow.py]
354    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
355    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
356    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
357    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
358    reserved_symbols = body_scope.referenced
359
360    basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out)
361    composite_loop_vars = self._get_composite_loop_vars(modified, live_in)
362    loop_vars = tuple(basic_loop_vars | composite_loop_vars)
363
364    # Variable that are used or defined inside the loop, but not defined
365    # before entering the loop. Only simple variables must be defined. The
366    # composite ones will be implicitly checked at runtime.
367    undefined_lives = basic_loop_vars - defined_in
368
369    return loop_vars, reserved_symbols, undefined_lives
370
371  def visit_While(self, node):
372
373    # Create a new WhileOp
374    # `inputs` are initial values for loop variables
375    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
376    loop_vars, _, _ = self._get_loop_vars(node, body_scope.modified)
377    inputs = [self.symbol_table.lookup(str(name)) for name in loop_vars]
378    types = [input_.getType() for input_ in inputs]
379    while_op = tfp.WhileOp.create(self.opbuilder,
380                                  self.opbuilder.getUnknownLoc(), inputs, types)
381
382    # cache the current builder
383    cache_builder = self.opbuilder
384
385    # Process cond
386    self.symbol_table.enter_scope()
387    for input_, type_ in zip(loop_vars, types):
388      self.symbol_table.insert_symbol(
389          str(input_),
390          while_op.getRegion(0).front().addArgument(type_))
391    self.opbuilder = tfp.OpBuilder(while_op.getRegion(0))
392    tfp.ReturnOp.create(self.opbuilder, self.opbuilder.getUnknownLoc(),
393                        [self.visit(node.test)])
394    self.symbol_table.exit_scope()
395
396    # Process body
397    self.symbol_table.enter_scope()
398    for input_, type_ in zip(loop_vars, types):
399      self.symbol_table.insert_symbol(
400          str(input_),
401          while_op.getRegion(1).front().addArgument(type_))
402    self.opbuilder = tfp.OpBuilder(while_op.getRegion(1))
403    self.visit_block(node.body)
404    tfp.ReturnOp.create(
405        self.opbuilder, self.opbuilder.getUnknownLoc(),
406        [self.symbol_table.lookup(str(name)) for name in loop_vars])
407    self.symbol_table.exit_scope()
408
409    # Enter new values as symbols
410    for idx, var in enumerate(loop_vars):
411      self.symbol_table.insert_symbol(str(var), while_op.getResult(idx))
412
413    # Restore builder
414    self.opbuilder = cache_builder
415
416
417def mlir_gen_internal(node, entity_info):
418  """Returns mlir module for unprocessed node `node`."""
419  namer = naming.Namer({})
420  graphs = cfg.build(node)
421  ctx = transformer.Context(entity_info, namer, None)
422  node = qual_names.resolve(node)
423  node = activity.resolve(node, ctx)
424  node = reaching_definitions.resolve(node, ctx, graphs)
425  node = reaching_fndefs.resolve(node, ctx, graphs)
426  node = liveness.resolve(node, ctx, graphs)
427  mlir_generator = MLIRGen(ctx)
428  mlir_generator.visit(node)
429  return mlir_generator.prog
430
431
432def mlir_gen(func):
433  """Parse a function and return TFProgram."""
434  node, source = parser.parse_entity(func, future_features=())
435  entity_info = transformer.EntityInfo(
436      name=func.__name__,
437      source_code=source,
438      source_file=None,
439      future_features=(),
440      namespace=inspect_utils.getnamespace(func))
441  return mlir_gen_internal(node, entity_info)
442
443
444def mlir_gen_from_source(source=None, src_file=None):
445  """Parse a function as either a string or from a supplied file path and return a TFProgram.
446  """
447  if source is None:
448    source = open(src_file).read()
449  node = ast.parse(source)
450  entity_info = transformer.EntityInfo(
451      name='mlir_module',
452      source_code=source,
453      source_file=None,
454      future_features=(),
455      namespace={})
456  return mlir_gen_internal(node, entity_info)
457