• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper utilities for AOT compilation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import os
24import pipes
25import re
26import shlex
27
28import six
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import meta_graph_pb2
32from tensorflow.python.client import session
33from tensorflow.python.framework import graph_util
34from tensorflow.python.framework import ops as ops_lib
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import versions
37from tensorflow.python.grappler import tf_optimizer
38from tensorflow.python.lib.io import file_io
39from tensorflow.python.ops import array_ops
40from tensorflow.python.platform import sysconfig as sysconfig_lib
41from tensorflow.python.platform import test
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.training import saver as saver_lib
44
45try:
46  from tensorflow.python import _pywrap_tfcompile  # pylint: disable=g-import-not-at-top
47except ImportError as e:
48  _pywrap_tfcompile_import_error = ImportError(
49      'Unable to import _pywrap_tfcompile; you must build TensorFlow '
50      'with XLA.  You may need to build tensorflow with flag '
51      '--define=with_xla_support=true.  Original error: {}'.format(str(e)))
52else:
53  _pywrap_tfcompile_import_error = None
54
55
56_READ_ONLY_VARIABLE_OPS = (
57    'ReadVariableOp',
58    'IsVariableInitializedOp',
59    'ResourceGather',
60    'ResourceGatherNd',
61    'VariableShape',
62)
63
64_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN')
65
66
67def _shlex_quote(s):
68  if six.PY2:
69    return pipes.quote(s)
70  else:
71    return shlex.quote(s)
72
73
74def _sysconfig_module():
75  """Load tf.sysconfig if available and working (i.e., inside a pip package)."""
76  try:
77    _ = sysconfig_lib.get_include()
78  except ImportError:
79    return None
80  return sysconfig_lib
81
82
83def _parse_tensor_name(name):
84  """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
85  if ':' in name and not name.endswith(':'):
86    node_name = name[:name.rfind(':')]
87    output_slot = int(name[name.rfind(':') + 1:])
88    return node_name, output_slot
89  else:
90    return name, None
91
92
93_XLA_MAKEFILE_TEMPLATE = """
94INC = -I{tensorflow_includes}
95LIB = -L{compiled_dir}
96CXXFLAGS = {cxx_flags}
97"""
98
99
100def _xla_makefile_string(output_prefix):
101  """Returns a Makefile string with variables for using XLA binary object files.
102
103  Attempts to identify the right include header paths when run from either
104  an installed TensorFlow pip package, or from bazel run.
105
106  Args:
107    output_prefix: A string containing the output prefix for the XLA AOT
108      compiled header + object files.
109
110  Returns:
111    A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
112  """
113  sysconfig = _sysconfig_module()
114  output_dir, _ = os.path.split(output_prefix)
115  if sysconfig:
116    tensorflow_includes = _shlex_quote(sysconfig.get_include())
117  else:
118    # Try hard to find the real source directory if this is a local bazel run.
119    if os.path.islink(__file__):
120      this_file = __file__
121      while os.path.islink(this_file):
122        this_file = os.readlink(this_file)
123      base = os.path.realpath(
124          os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
125    else:
126      try:
127        base = test.test_src_dir_path('')
128      except KeyError:  # Can't find TEST_SRCDIR in environment path.
129        base = os.path.realpath(
130            os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
131    expected_header = os.path.join(
132        base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
133    if not os.path.exists(expected_header):
134      logging.error(
135          'Could not find includes path.  Missing file: {}'
136          .format(expected_header))
137    tensorflow_includes = base
138
139  return _XLA_MAKEFILE_TEMPLATE.format(
140      tensorflow_includes=tensorflow_includes,
141      compiled_dir=_shlex_quote(output_dir),
142      cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
143          versions.CXX11_ABI_FLAG))
144
145
146def _get_variable_nodes_from_graph_def(graph_def):
147  """Get the list of Variable nodes from `graph_def`.
148
149  Args:
150    graph_def: An instance of `GraphDef`.  This GraphDef *must*
151      have already been optimized by Grappler.  In particular, function
152      inlining must have already happened.
153
154  Returns:
155    A dict mapping string names of variables to tuples `(node_def, modified)`,
156    where `node_def` is the `NodeDef` corresponding to variable, and `modified`
157    is a python bool describing whether the variable is modified during runtime.
158  """
159  variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
160  variable_name_map = dict((n.name, n) for n in variables)
161  child_map = collections.defaultdict(lambda: [])
162  for n in graph_def.node:
163    for inp in n.input:
164      if not inp.startswith('^'):
165        child_map[inp].append(n)
166  variables = {}
167  for (v_name, v_node) in variable_name_map.items():
168    queue = list(child_map[v_name])
169    processed = set([])
170    while queue:
171      n_current = queue.pop()
172      if n_current.name in processed:
173        continue
174      processed.add(n_current.name)
175      if n_current.op in _PASS_THROUGH_VARIABLE_OPS:
176        children = child_map.get(n_current.name, [])
177        queue.extend(children)
178      elif n_current.op not in _READ_ONLY_VARIABLE_OPS:
179        variables[v_name] = (v_node, True)
180        queue = []
181    if v_name not in variables:
182      variables[v_name] = (v_node, False)
183
184  return variables
185
186
187def _prune_removed_feed_nodes(signature_def, graph_def):
188  """Identify the inputs in the signature no longer in graph_def, prune them.
189
190  Args:
191    signature_def: A `SignatureDef` instance.
192    graph_def: A `GraphDef` instance.
193
194  Returns:
195    A new pruned `SignatureDef`.
196  """
197  node_names = set([n.name for n in graph_def.node])
198  new_signature_def = meta_graph_pb2.SignatureDef()
199  new_signature_def.CopyFrom(signature_def)
200  for (k, v) in signature_def.inputs.items():
201    tensor_name, _ = _parse_tensor_name(v.name)
202    if tensor_name not in node_names:
203      logging.warn(
204          'Signature input key \'{}\', tensor name \'{}\', has been pruned '
205          'while freezing the graph.  Removing it from the compiled signatures.'
206          .format(k, tensor_name))
207      del new_signature_def.inputs[k]
208  return new_signature_def
209
210
211def aot_compile_cpu_meta_graph_def(checkpoint_path,
212                                   meta_graph_def,
213                                   output_prefix,
214                                   signature_def_key,
215                                   cpp_class,
216                                   target_triple,
217                                   target_cpu,
218                                   variables_to_feed=(),
219                                   multithreading=False):
220  """Compile a `MetaGraphDef` to header+object files in `output_prefix`.
221
222  Use XLA AOT (`tfcompile`) to convert the given meta graph and
223  signature into a header + object files.  Also create an include makefile
224  that helps identify the appropriate necessary include and library paths
225  to incorporate these files into your C++ program.
226
227  The graph is always optimized with grappler, and optionally (by default)
228  variables are frozen as constants, before compilation happens.
229
230  If the `freeze_graph` is `True`, all variables are embedded as constants
231  into the graph and binary objects.  If it is `False`, then the variable
232  values become inputs and outputs of the compiled class and the C++
233  caller must set these values manually.
234
235  Args:
236    checkpoint_path: Python string.  Path to checkpoints/variables.
237    meta_graph_def: Instance of `MetaGraphDef`.
238    output_prefix: Python string.  Path prefix for outputs.
239    signature_def_key: String, the signature_def to use in the SavedModel.
240    cpp_class: String, Name of output C++ class.
241    target_triple: String, LLVM target triple.
242    target_cpu: String, LLVM target cpu name.
243    variables_to_feed: A list of strings, the variables that will be fed by the
244      user; these won't be frozen.  If `None`, then we will extract all the
245      variables in the graph and mark them as to-feed.  The default behavior is
246      an empty tuple: all variables must be frozen.
247    multithreading: Whether to enable multithreading in the compiled
248      computation.  Note that if using this option, the resulting object files
249      may have external dependencies on multithreading libraries like nsync.
250
251  Raises:
252    RuntimeError: If tensorflow was not built with XLA.
253    ImportError: If tensorflow was built with XLA but there was another
254      issue importing the tfcompile python wrapper.
255    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
256      missing or has empty outputs.
257  """
258  if _pywrap_tfcompile_import_error:
259    raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type
260
261  else:
262    # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
263    # so that we can set these directly instead of relying on env vars.
264    xla_flags = os.environ.get('XLA_FLAGS')
265    if not xla_flags:
266      xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
267          'true' if multithreading else 'false')
268    else:
269      xla_flags += ',--xla_cpu_multi_thread_eigen={}'.format(
270          'true' if multithreading else 'false')
271    os.environ['XLA_FLAGS'] = xla_flags
272
273  signature_def_map = meta_graph_def.signature_def
274  if signature_def_key not in signature_def_map:
275    raise ValueError(
276        'Unable to find signature_def key \'{}\' in signature def map.  '
277        'Available keys: {}'.format(
278            signature_def_key,
279            list(signature_def_map.keys())))
280  signature_def = signature_def_map[signature_def_key]
281  if not signature_def.outputs:
282    raise ValueError(
283        'Signature key {} must have outputs, but saw none:\n{}'.format(
284            signature_def_key, str(signature_def)))
285
286  temp_dir = test.get_temp_dir()
287  file_io.recursive_create_dir(temp_dir)
288  if logging.get_verbosity() >= logging.INFO:
289    original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
290    with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
291      graph_writer.write(meta_graph_def.graph_def.SerializeToString())
292
293  # This updates graph_def in place.
294  _replace_input_placeholders_with_default_values(
295      meta_graph_def.graph_def, signature_def)
296
297  graph_def = _optimize_graph(meta_graph_def, signature_def)
298
299  all_variables = _get_variable_nodes_from_graph_def(graph_def)
300  if variables_to_feed is None:
301    variable_nodes_to_feed = list(all_variables.values())
302  else:
303    not_in_graph = set(variables_to_feed).difference(list(all_variables))
304    if not_in_graph:
305      raise ValueError(
306          'Asked to feed variables that were not found in graph: {}.  '
307          'Variables contained in the graph: {}'.format(
308              not_in_graph, list(all_variables)))
309    variable_nodes_to_feed = [
310        all_variables[name] for name in variables_to_feed
311    ]
312
313  if logging.get_verbosity() >= logging.INFO:
314    prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
315    with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
316      graph_writer.write(graph_def.SerializeToString())
317
318  # Load the Variables so that we can freeze the graph.
319  with session.Session(graph=ops_lib.Graph()) as sess:
320    restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
321    if restorer is not None:
322      restorer.restore(sess, checkpoint_path)
323    graph_def.CopyFrom(
324        graph_util.convert_variables_to_constants(
325            sess,
326            graph_def,
327            output_node_names=[
328                _parse_tensor_name(n.name)[0]
329                for n in signature_def.outputs.values()
330            ],
331            variable_names_blacklist=[
332                n.name for n, _ in variable_nodes_to_feed
333            ],
334        ))
335
336  signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
337
338  frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
339  config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
340  logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
341  with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
342    graph_writer.write(graph_def.SerializeToString())
343  config = _signature_to_tf2xla_config(
344      signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
345  logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
346  with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
347    config_writer.write(str(config))
348
349  output_dir = os.path.dirname(output_prefix)
350  file_io.recursive_create_dir(output_dir)
351
352  entry_point = re.sub(
353      '[^0-9a-zA-Z]+', '_',
354      '__xla_' + output_prefix + '__' + cpp_class)
355
356  logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
357
358  makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
359  with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
360    makefile_writer.write(_xla_makefile_string(output_prefix))
361
362  output_prefix = _shlex_quote(output_prefix)
363
364  _pywrap_tfcompile.Compile(
365      graph=frozen_graph_def_location,
366      config=config_pbtxt_location,
367      cpp_class=cpp_class,
368      target_triple=target_triple,
369      target_cpu=target_cpu,
370      entry_point=entry_point,
371      out_function_object='{}.o'.format(output_prefix),
372      out_header='{}.h'.format(output_prefix),
373      out_metadata_object='{}_metadata.o'.format(output_prefix),
374      gen_name_to_index=True,
375      # ProgramShape isn't uniquefied by entry_point.
376      gen_program_shape=False)
377
378
379def _optimize_graph(meta_graph_def, signature_def):
380  """Optimize `meta_graph_def` using grappler.  Returns a `GraphDef`."""
381  # We need to add a collection called 'train_op' so that grappler
382  # knows what the outputs are.
383  new_meta_graph_def = copy.deepcopy(meta_graph_def)
384  fetch_collection = meta_graph_pb2.CollectionDef()
385  for tensor_info in (
386      list(signature_def.inputs.values()) +
387      list(signature_def.outputs.values())):
388    fetch_collection.node_list.value.append(tensor_info.name)
389
390  new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
391
392  config = config_pb2.ConfigProto()
393  rewrite_options = config.graph_options.rewrite_options
394  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
395  return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
396
397
398def _replace_input_placeholders_with_default_values(graph_def, signature_def):
399  """Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
400  name_to_node_map = dict((n.name, n) for n in graph_def.node)
401  processed_nodes = set([])
402  for name, input_ in signature_def.inputs.items():
403    tensor_name, _ = _parse_tensor_name(input_.name)
404    if tensor_name in processed_nodes:
405      continue
406    processed_nodes.add(tensor_name)
407    if tensor_name not in name_to_node_map:
408      raise RuntimeError(
409          'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
410          'Graph nodes are: {}'.format(tensor_name,
411                                       list(name_to_node_map.keys())))
412    node = name_to_node_map[tensor_name]
413    if node.op not in ('Placeholder', 'PlaceholderV2'):
414      logging.info(
415          'Tried to convert SavedModel input node \'{}\' from a placeholder, '
416          'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
417                                                               node))
418      continue
419    shape = tensor_shape.TensorShape(input_.tensor_shape)
420    if not shape.is_fully_defined():
421      raise ValueError(
422          'Expected fully defined input shape for signature_def \'{}\', '
423          'tensor name: \'{}\'; but shape is: {}.'
424          .format(name, tensor_name, shape))
425    temp_graph = ops_lib.Graph()
426    with temp_graph.as_default():
427      const = array_ops.zeros(
428          shape, dtype=input_.dtype, name=tensor_name)
429    node.CopyFrom(const.op.node_def)
430    # Sometimes zeros() also creates additional nodes
431    for op in temp_graph.get_operations():
432      if op.name == const.op.name:  # We just inserted this one.
433        continue
434      graph_def.node.append(op.node_def)
435      name_to_node_map[op.node_def.name] = op.node_def
436
437
438def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
439  """Convert `signature_def` to tf2xla config.  Returns a `tf2xla.Config` proto.
440
441  Args:
442    signature_def: Instance of `SignatureDef`.
443    variable_nodes_to_feed: List of tuples of form `(node_def, modified)`
444      corresponding to VarHandleOp, and a boolean `modified` that describes
445      whether the variable was modified during execution.
446
447  Returns:
448    An instance of `tf2xla.Config` proto.
449
450  Raises:
451    RuntimeError: If TensorFlow was not compiled with XLA.
452  """
453  from tensorflow.compiler.tf2xla import tf2xla_pb2  # pylint: disable=g-import-not-at-top
454
455  config = tf2xla_pb2.Config()
456  tensor_id = tf2xla_pb2.TensorId
457
458  for name, input_ in signature_def.inputs.items():
459    name = name.replace('/', '_')
460    name = 'feed_{}'.format(name)
461    (node_name, output_index) = _parse_tensor_name(input_.name)
462    output_index = int(output_index)
463    config.feed.append(
464        tf2xla_pb2.Feed(
465            id=tensor_id(node_name=node_name, output_index=output_index),
466            name=name,
467            type=input_.dtype,
468            shape=input_.tensor_shape))
469  for name, output_ in signature_def.outputs.items():
470    name = name.replace('/', '_')
471    name = 'fetch_{}'.format(name)
472    (node_name, output_index) = _parse_tensor_name(output_.name)
473    output_index = int(output_index)
474    config.fetch.append(
475        tf2xla_pb2.Fetch(
476            id=tensor_id(node_name=node_name, output_index=output_index),
477            name=name,
478            type=output_.dtype,
479            shape=output_.tensor_shape))
480  for (node, modified) in variable_nodes_to_feed:
481    name = node.name.replace('/', '_')
482    name = 'param_{}'.format(name)
483    config.variable.append(
484        tf2xla_pb2.Variable(
485            node_name=node.name,
486            name=name,
487            type=node.attr['dtype'].type,
488            shape=node.attr['shape'].shape,
489            readonly=not modified))
490
491  return config
492