• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper utilities for AOT compilation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import os
24import pipes
25import re
26import shlex
27
28import six
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import meta_graph_pb2
32from tensorflow.python.client import session
33from tensorflow.python.framework import graph_util
34from tensorflow.python.framework import ops as ops_lib
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import versions
37from tensorflow.python.grappler import tf_optimizer
38from tensorflow.python.lib.io import file_io
39from tensorflow.python.ops import array_ops
40from tensorflow.python.platform import sysconfig as sysconfig_lib
41from tensorflow.python.platform import test
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.training import saver as saver_lib
44
45try:
46  from tensorflow.python import _pywrap_tfcompile  # pylint: disable=g-import-not-at-top
47except ImportError as e:
48  _pywrap_tfcompile_import_error = ImportError(
49      'Unable to import _pywrap_tfcompile; you must build TensorFlow '
50      'with XLA.  You may need to build tensorflow with flag '
51      '--define=with_xla_support=true.  Original error: {}'.format(str(e)))
52else:
53  _pywrap_tfcompile_import_error = None
54
55
56_READ_ONLY_VARIABLE_OPS = (
57    'ReadVariableOp',
58    'IsVariableInitializedOp',
59    'ResourceGather',
60    'ResourceGatherNd',
61    'VariableShape',
62)
63
64_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN')
65
66
67def _shlex_quote(s):
68  if six.PY2:
69    return pipes.quote(s)
70  else:
71    return shlex.quote(s)
72
73
74def _sysconfig_module():
75  """Load tf.sysconfig if available and working (i.e., inside a pip package)."""
76  try:
77    _ = sysconfig_lib.get_include()
78  except (ImportError, ValueError):
79    # ValueError may come from saved_model_cli_test trying to enable
80    # eager mode twice.
81    return None
82  return sysconfig_lib
83
84
85def _parse_tensor_name(name):
86  """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
87  if ':' in name and not name.endswith(':'):
88    node_name = name[:name.rfind(':')]
89    output_slot = int(name[name.rfind(':') + 1:])
90    return node_name, output_slot
91  else:
92    return name, None
93
94
95_XLA_MAKEFILE_TEMPLATE = """
96INC = -I{tensorflow_includes}
97LIB = -L{compiled_dir}
98CXXFLAGS = {cxx_flags}
99"""
100
101
102def _xla_makefile_string(output_prefix):
103  """Returns a Makefile string with variables for using XLA binary object files.
104
105  Attempts to identify the right include header paths when run from either
106  an installed TensorFlow pip package, or from bazel run.
107
108  Args:
109    output_prefix: A string containing the output prefix for the XLA AOT
110      compiled header + object files.
111
112  Returns:
113    A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
114  """
115  sysconfig = _sysconfig_module()
116  output_dir, _ = os.path.split(output_prefix)
117  if sysconfig:
118    tensorflow_includes = _shlex_quote(sysconfig.get_include())
119  else:
120    # Try hard to find the real source directory if this is a local bazel run.
121    if os.path.islink(__file__):
122      this_file = __file__
123      while os.path.islink(this_file):
124        this_file = os.readlink(this_file)
125      base = os.path.realpath(
126          os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
127    else:
128      try:
129        base = test.test_src_dir_path('')
130      except KeyError:  # Can't find TEST_SRCDIR in environment path.
131        base = os.path.realpath(
132            os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
133    expected_header = os.path.join(
134        base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
135    if not os.path.exists(expected_header):
136      logging.error(
137          'Could not find includes path.  Missing file: {}'
138          .format(expected_header))
139    tensorflow_includes = base
140
141  return _XLA_MAKEFILE_TEMPLATE.format(
142      tensorflow_includes=tensorflow_includes,
143      compiled_dir=_shlex_quote(output_dir),
144      cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
145          versions.CXX11_ABI_FLAG))
146
147
148def _get_variable_nodes_from_graph_def(graph_def):
149  """Get the list of Variable nodes from `graph_def`.
150
151  Args:
152    graph_def: An instance of `GraphDef`.  This GraphDef *must*
153      have already been optimized by Grappler.  In particular, function
154      inlining must have already happened.
155
156  Returns:
157    A dict mapping string names of variables to tuples `(node_def, modified)`,
158    where `node_def` is the `NodeDef` corresponding to variable, and `modified`
159    is a python bool describing whether the variable is modified during runtime.
160  """
161  variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
162  variable_name_map = dict((n.name, n) for n in variables)
163  child_map = collections.defaultdict(lambda: [])
164  for n in graph_def.node:
165    for inp in n.input:
166      if not inp.startswith('^'):
167        child_map[inp].append(n)
168  variables = {}
169  for (v_name, v_node) in variable_name_map.items():
170    queue = list(child_map[v_name])
171    processed = set([])
172    while queue:
173      n_current = queue.pop()
174      if n_current.name in processed:
175        continue
176      processed.add(n_current.name)
177      if n_current.op in _PASS_THROUGH_VARIABLE_OPS:
178        children = child_map.get(n_current.name, [])
179        queue.extend(children)
180      elif n_current.op not in _READ_ONLY_VARIABLE_OPS:
181        variables[v_name] = (v_node, True)
182        queue = []
183    if v_name not in variables:
184      variables[v_name] = (v_node, False)
185
186  return variables
187
188
189def _prune_removed_feed_nodes(signature_def, graph_def):
190  """Identify the inputs in the signature no longer in graph_def, prune them.
191
192  Args:
193    signature_def: A `SignatureDef` instance.
194    graph_def: A `GraphDef` instance.
195
196  Returns:
197    A new pruned `SignatureDef`.
198  """
199  node_names = set([n.name for n in graph_def.node])
200  new_signature_def = meta_graph_pb2.SignatureDef()
201  new_signature_def.CopyFrom(signature_def)
202  for (k, v) in signature_def.inputs.items():
203    tensor_name, _ = _parse_tensor_name(v.name)
204    if tensor_name not in node_names:
205      logging.warn(
206          'Signature input key \'{}\', tensor name \'{}\', has been pruned '
207          'while freezing the graph.  Removing it from the compiled signatures.'
208          .format(k, tensor_name))
209      del new_signature_def.inputs[k]
210  return new_signature_def
211
212
213def aot_compile_cpu_meta_graph_def(checkpoint_path,
214                                   meta_graph_def,
215                                   output_prefix,
216                                   signature_def_key,
217                                   cpp_class,
218                                   target_triple,
219                                   target_cpu,
220                                   variables_to_feed=(),
221                                   multithreading=False):
222  """Compile a `MetaGraphDef` to header+object files in `output_prefix`.
223
224  Use XLA AOT (`tfcompile`) to convert the given meta graph and
225  signature into a header + object files.  Also create an include makefile
226  that helps identify the appropriate necessary include and library paths
227  to incorporate these files into your C++ program.
228
229  The graph is always optimized with grappler, and optionally (by default)
230  variables are frozen as constants, before compilation happens.
231
232  If the `freeze_graph` is `True`, all variables are embedded as constants
233  into the graph and binary objects.  If it is `False`, then the variable
234  values become inputs and outputs of the compiled class and the C++
235  caller must set these values manually.
236
237  Args:
238    checkpoint_path: Python string.  Path to checkpoints/variables.
239    meta_graph_def: Instance of `MetaGraphDef`.
240    output_prefix: Python string.  Path prefix for outputs.
241    signature_def_key: String, the signature_def to use in the SavedModel.
242    cpp_class: String, Name of output C++ class.
243    target_triple: String, LLVM target triple.
244    target_cpu: String, LLVM target cpu name.
245    variables_to_feed: A list of strings, the variables that will be fed by the
246      user; these won't be frozen.  If `None`, then we will extract all the
247      variables in the graph and mark them as to-feed.  The default behavior is
248      an empty tuple: all variables must be frozen.
249    multithreading: Whether to enable multithreading in the compiled
250      computation.  Note that if using this option, the resulting object files
251      may have external dependencies on multithreading libraries like nsync.
252
253  Raises:
254    RuntimeError: If tensorflow was not built with XLA.
255    ImportError: If tensorflow was built with XLA but there was another
256      issue importing the tfcompile python wrapper.
257    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
258      missing or has empty outputs.
259  """
260  if _pywrap_tfcompile_import_error:
261    raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type
262
263  else:
264    # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
265    # so that we can set these directly instead of relying on env vars.
266    xla_flags = os.environ.get('XLA_FLAGS')
267    if not xla_flags:
268      xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
269          'true' if multithreading else 'false')
270    else:
271      xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format(
272          'true' if multithreading else 'false')
273    os.environ['XLA_FLAGS'] = xla_flags
274
275  signature_def_map = meta_graph_def.signature_def
276  if signature_def_key not in signature_def_map:
277    raise ValueError(
278        'Unable to find signature_def key \'{}\' in signature def map.  '
279        'Available keys: {}'.format(
280            signature_def_key,
281            list(signature_def_map.keys())))
282  signature_def = signature_def_map[signature_def_key]
283  if not signature_def.outputs:
284    raise ValueError(
285        'Signature key {} must have outputs, but saw none:\n{}'.format(
286            signature_def_key, str(signature_def)))
287
288  temp_dir = test.get_temp_dir()
289  file_io.recursive_create_dir(temp_dir)
290  if logging.get_verbosity() >= logging.INFO:
291    original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
292    with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
293      graph_writer.write(meta_graph_def.graph_def.SerializeToString())
294
295  # This updates graph_def in place.
296  _replace_input_placeholders_with_default_values(
297      meta_graph_def.graph_def, signature_def)
298
299  graph_def = _optimize_graph(meta_graph_def, signature_def)
300
301  all_variables = _get_variable_nodes_from_graph_def(graph_def)
302  if variables_to_feed is None:
303    variable_nodes_to_feed = list(all_variables.values())
304  else:
305    not_in_graph = set(variables_to_feed).difference(list(all_variables))
306    if not_in_graph:
307      raise ValueError(
308          'Asked to feed variables that were not found in graph: {}.  '
309          'Variables contained in the graph: {}'.format(
310              not_in_graph, list(all_variables)))
311    variable_nodes_to_feed = [
312        all_variables[name] for name in variables_to_feed
313    ]
314
315  if logging.get_verbosity() >= logging.INFO:
316    prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
317    with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
318      graph_writer.write(graph_def.SerializeToString())
319
320  # Load the Variables so that we can freeze the graph.
321  with session.Session(graph=ops_lib.Graph()) as sess:
322    restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
323    if restorer is not None:
324      restorer.restore(sess, checkpoint_path)
325    graph_def.CopyFrom(
326        graph_util.convert_variables_to_constants(
327            sess,
328            graph_def,
329            output_node_names=[
330                _parse_tensor_name(n.name)[0]
331                for n in signature_def.outputs.values()
332            ],
333            variable_names_blacklist=[
334                n.name for n, _ in variable_nodes_to_feed
335            ],
336        ))
337
338  signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
339
340  frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
341  config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
342  logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
343  with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
344    graph_writer.write(graph_def.SerializeToString())
345  config = _signature_to_tf2xla_config(
346      signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
347  logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
348  with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
349    config_writer.write(str(config))
350
351  output_dir = os.path.dirname(output_prefix)
352  file_io.recursive_create_dir(output_dir)
353
354  entry_point = re.sub(
355      '[^0-9a-zA-Z]+', '_',
356      '__xla_' + output_prefix + '__' + cpp_class)
357
358  logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
359
360  makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
361  with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
362    makefile_writer.write(_xla_makefile_string(output_prefix))
363
364  output_prefix = _shlex_quote(output_prefix)
365
366  _pywrap_tfcompile.Compile(
367      graph=frozen_graph_def_location,
368      config=config_pbtxt_location,
369      cpp_class=cpp_class,
370      target_triple=target_triple,
371      target_cpu=target_cpu,
372      entry_point=entry_point,
373      out_function_object='{}.o'.format(output_prefix),
374      out_header='{}.h'.format(output_prefix),
375      out_metadata_object='{}_metadata.o'.format(output_prefix),
376      gen_name_to_index=True,
377      # ProgramShape isn't uniquefied by entry_point.
378      gen_program_shape=False)
379
380
381def _optimize_graph(meta_graph_def, signature_def):
382  """Optimize `meta_graph_def` using grappler.  Returns a `GraphDef`."""
383  # We need to add a collection called 'train_op' so that grappler
384  # knows what the outputs are.
385  new_meta_graph_def = copy.deepcopy(meta_graph_def)
386  fetch_collection = meta_graph_pb2.CollectionDef()
387  for tensor_info in (
388      list(signature_def.inputs.values()) +
389      list(signature_def.outputs.values())):
390    fetch_collection.node_list.value.append(tensor_info.name)
391
392  new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
393
394  config = config_pb2.ConfigProto()
395  rewrite_options = config.graph_options.rewrite_options
396  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
397  return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
398
399
400def _replace_input_placeholders_with_default_values(graph_def, signature_def):
401  """Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
402  name_to_node_map = dict((n.name, n) for n in graph_def.node)
403  processed_nodes = set([])
404  for name, input_ in signature_def.inputs.items():
405    tensor_name, _ = _parse_tensor_name(input_.name)
406    if tensor_name in processed_nodes:
407      continue
408    processed_nodes.add(tensor_name)
409    if tensor_name not in name_to_node_map:
410      raise RuntimeError(
411          'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
412          'Graph nodes are: {}'.format(tensor_name,
413                                       list(name_to_node_map.keys())))
414    node = name_to_node_map[tensor_name]
415    if node.op not in ('Placeholder', 'PlaceholderV2'):
416      logging.info(
417          'Tried to convert SavedModel input node \'{}\' from a placeholder, '
418          'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
419                                                               node))
420      continue
421    shape = tensor_shape.TensorShape(input_.tensor_shape)
422    if not shape.is_fully_defined():
423      raise ValueError(
424          'Expected fully defined input shape for signature_def \'{}\', '
425          'tensor name: \'{}\'; but shape is: {}.'
426          .format(name, tensor_name, shape))
427    temp_graph = ops_lib.Graph()
428    with temp_graph.as_default():
429      const = array_ops.zeros(
430          shape, dtype=input_.dtype, name=tensor_name)
431    node.CopyFrom(const.op.node_def)
432    # Sometimes zeros() also creates additional nodes
433    for op in temp_graph.get_operations():
434      if op.name == const.op.name:  # We just inserted this one.
435        continue
436      graph_def.node.append(op.node_def)
437      name_to_node_map[op.node_def.name] = op.node_def
438
439
440def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
441  """Convert `signature_def` to tf2xla config.  Returns a `tf2xla.Config` proto.
442
443  Args:
444    signature_def: Instance of `SignatureDef`.
445    variable_nodes_to_feed: List of tuples of form `(node_def, modified)`
446      corresponding to VarHandleOp, and a boolean `modified` that describes
447      whether the variable was modified during execution.
448
449  Returns:
450    An instance of `tf2xla.Config` proto.
451
452  Raises:
453    RuntimeError: If TensorFlow was not compiled with XLA.
454  """
455  from tensorflow.compiler.tf2xla import tf2xla_pb2  # pylint: disable=g-import-not-at-top
456
457  config = tf2xla_pb2.Config()
458  tensor_id = tf2xla_pb2.TensorId
459
460  for name, input_ in signature_def.inputs.items():
461    name = name.replace('/', '_')
462    name = 'feed_{}'.format(name)
463    (node_name, output_index) = _parse_tensor_name(input_.name)
464    output_index = int(output_index)
465    config.feed.append(
466        tf2xla_pb2.Feed(
467            id=tensor_id(node_name=node_name, output_index=output_index),
468            name=name,
469            type=input_.dtype,
470            shape=input_.tensor_shape))
471  for name, output_ in signature_def.outputs.items():
472    name = name.replace('/', '_')
473    name = 'fetch_{}'.format(name)
474    (node_name, output_index) = _parse_tensor_name(output_.name)
475    output_index = int(output_index)
476    config.fetch.append(
477        tf2xla_pb2.Fetch(
478            id=tensor_id(node_name=node_name, output_index=output_index),
479            name=name,
480            type=output_.dtype,
481            shape=output_.tensor_shape))
482  for (node, modified) in variable_nodes_to_feed:
483    name = node.name.replace('/', '_')
484    name = 'param_{}'.format(name)
485    config.variable.append(
486        tf2xla_pb2.Variable(
487            node_name=node.name,
488            name=name,
489            type=node.attr['dtype'].type,
490            shape=node.attr['shape'].shape,
491            readonly=not modified))
492
493  return config
494