• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper utilities for AOT compilation."""
16
17import collections
18import copy
19import os
20import re
21import shlex
22from typing import List, Tuple
23
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.core.protobuf import meta_graph_pb2
26from tensorflow.python.client import session
27from tensorflow.python.framework import graph_util
28from tensorflow.python.framework import ops as ops_lib
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import versions
31from tensorflow.python.grappler import tf_optimizer
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.ops import array_ops
34from tensorflow.python.platform import sysconfig as sysconfig_lib
35from tensorflow.python.platform import test
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.training import saver as saver_lib
38
39try:
40  from tensorflow.python import _pywrap_tfcompile  # pylint: disable=g-import-not-at-top
41except ImportError as e:
42  _pywrap_tfcompile_import_error = ImportError(
43      'Unable to import _pywrap_tfcompile; you must build TensorFlow '
44      'with XLA.  You may need to build tensorflow with flag '
45      '--define=with_xla_support=true.  Original error: {}'.format(str(e)))
46else:
47  _pywrap_tfcompile_import_error = None
48
49
50_READ_ONLY_VARIABLE_OPS = (
51    'ReadVariableOp',
52    'IsVariableInitializedOp',
53    'ResourceGather',
54    'ResourceGatherNd',
55    'VariableShape',
56)
57
58_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN')
59
60
61def _shlex_quote(s):
62  return shlex.quote(s)
63
64
65def _sysconfig_module():
66  """Load tf.sysconfig if available and working (i.e., inside a pip package)."""
67  try:
68    _ = sysconfig_lib.get_include()
69  except (ImportError, ValueError):
70    # ValueError may come from saved_model_cli_test trying to enable
71    # eager mode twice.
72    return None
73  return sysconfig_lib
74
75
76def _parse_tensor_name(name):
77  """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
78  if ':' in name and not name.endswith(':'):
79    node_name = name[:name.rfind(':')]
80    output_slot = int(name[name.rfind(':') + 1:])
81    return node_name, output_slot
82  else:
83    return name, None
84
85
86_XLA_MAKEFILE_TEMPLATE = """
87INC = -I{tensorflow_includes}
88LIB = -L{compiled_dir}
89CXXFLAGS = {cxx_flags}
90"""
91
92
93def _xla_makefile_string(output_prefix):
94  """Returns a Makefile string with variables for using XLA binary object files.
95
96  Attempts to identify the right include header paths when run from either
97  an installed TensorFlow pip package, or from bazel run.
98
99  Args:
100    output_prefix: A string containing the output prefix for the XLA AOT
101      compiled header + object files.
102
103  Returns:
104    A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
105  """
106  sysconfig = _sysconfig_module()
107  output_dir, _ = os.path.split(output_prefix)
108  if sysconfig:
109    tensorflow_includes = _shlex_quote(sysconfig.get_include())
110  else:
111    # Try hard to find the real source directory if this is a local bazel run.
112    if os.path.islink(__file__):
113      this_file = __file__
114      while os.path.islink(this_file):
115        this_file = os.readlink(this_file)
116      base = os.path.realpath(
117          os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
118    else:
119      try:
120        base = test.test_src_dir_path('')
121      except KeyError:  # Can't find TEST_SRCDIR in environment path.
122        base = os.path.realpath(
123            os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
124    expected_header = os.path.join(
125        base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
126    if not os.path.exists(expected_header):
127      logging.error(
128          'Could not find includes path.  Missing file: {}'
129          .format(expected_header))
130    tensorflow_includes = base
131
132  return _XLA_MAKEFILE_TEMPLATE.format(
133      tensorflow_includes=tensorflow_includes,
134      compiled_dir=_shlex_quote(output_dir),
135      cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
136          versions.CXX11_ABI_FLAG))
137
138
139def _get_variable_nodes_from_graph_def(graph_def):
140  """Get the list of Variable nodes from `graph_def`.
141
142  Args:
143    graph_def: An instance of `GraphDef`.  This GraphDef *must*
144      have already been optimized by Grappler.  In particular, function
145      inlining must have already happened.
146
147  Returns:
148    A dict mapping string names of variables to tuples `(node_def, modified)`,
149    where `node_def` is the `NodeDef` corresponding to variable, and `modified`
150    is a python bool describing whether the variable is modified during runtime.
151  """
152  variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
153  variable_name_map = dict((n.name, n) for n in variables)
154  child_map = collections.defaultdict(lambda: [])
155  for n in graph_def.node:
156    for inp in n.input:
157      if not inp.startswith('^'):
158        child_map[inp].append(n)
159  variables = {}
160  for (v_name, v_node) in variable_name_map.items():
161    queue = list(child_map[v_name])
162    processed = set([])
163    while queue:
164      n_current = queue.pop()
165      if n_current.name in processed:
166        continue
167      processed.add(n_current.name)
168      if n_current.op in _PASS_THROUGH_VARIABLE_OPS:
169        children = child_map.get(n_current.name, [])
170        queue.extend(children)
171      elif n_current.op not in _READ_ONLY_VARIABLE_OPS:
172        variables[v_name] = (v_node, True)
173        queue = []
174    if v_name not in variables:
175      variables[v_name] = (v_node, False)
176
177  return variables
178
179
180def _prune_removed_feed_nodes(signature_def, graph_def):
181  """Identify the inputs in the signature no longer in graph_def, prune them.
182
183  Args:
184    signature_def: A `SignatureDef` instance.
185    graph_def: A `GraphDef` instance.
186
187  Returns:
188    A new pruned `SignatureDef`.
189  """
190  node_names = set([n.name for n in graph_def.node])
191  new_signature_def = meta_graph_pb2.SignatureDef()
192  new_signature_def.CopyFrom(signature_def)
193  for (k, v) in signature_def.inputs.items():
194    tensor_name, _ = _parse_tensor_name(v.name)
195    if tensor_name not in node_names:
196      logging.warn(
197          'Signature input key \'{}\', tensor name \'{}\', has been pruned '
198          'while freezing the graph.  Removing it from the compiled signatures.'
199          .format(k, tensor_name))
200      del new_signature_def.inputs[k]
201  return new_signature_def
202
203
204def freeze_model(checkpoint_path: str,
205                 meta_graph_def: meta_graph_pb2.MetaGraphDef,
206                 output_prefix: str, signature_def_key: str,
207                 variables_to_feed: List[str]) -> Tuple[str, str]:
208  """Freeze a `MetaGraphDef` in preparation for tfcompile`.
209
210  The graph is always optimized with grappler, and optionally (by default)
211  variables are frozen as constants, before compilation happens.
212
213  Args:
214    checkpoint_path: Python string.  Path to checkpoints/variables.
215    meta_graph_def: Instance of `MetaGraphDef`.
216    output_prefix: Python string.  Path prefix for outputs.
217    signature_def_key: String, the signature_def to use in the SavedModel.
218    variables_to_feed: A list of strings, the variables that will be fed by the
219      user; these won't be frozen.  If `None`, then we will extract all the
220      variables in the graph and mark them as to-feed.  The default behavior is
221      an empty tuple: all variables must be frozen.
222  Returns:
223    a pair containing the path to the frozen model and the path to the config.
224  Raises:
225    RuntimeError: If tensorflow was not built with XLA.
226    ImportError: If tensorflow was built with XLA but there was another
227      issue importing the tfcompile python wrapper.
228    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
229      missing or has empty outputs.
230  """
231  if _pywrap_tfcompile_import_error:
232    raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type
233
234  signature_def_map = meta_graph_def.signature_def
235  if signature_def_key not in signature_def_map:
236    raise ValueError(
237        f"Unable to find signature_def_key '{signature_def_key}' in signature "
238        'def map of `meta_graph_def`. Available keys: '
239        f'{list(signature_def_map.keys())}')
240  signature_def = signature_def_map[signature_def_key]
241  if not signature_def.outputs:
242    raise ValueError(
243        f'Signature key {signature_def_key} must have outputs, but saw none:\n'
244        f'{str(signature_def)}')
245
246  file_io.recursive_create_dir(output_prefix)
247  if logging.get_verbosity() >= logging.INFO:
248    original_graph_def_location = os.path.join(output_prefix,
249                                               'original_graph.pb')
250    with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
251      graph_writer.write(meta_graph_def.graph_def.SerializeToString())
252
253  # This updates graph_def in place.
254  _replace_input_placeholders_with_default_values(
255      meta_graph_def.graph_def, signature_def)
256
257  graph_def = _optimize_graph(meta_graph_def, signature_def)
258
259  all_variables = _get_variable_nodes_from_graph_def(graph_def)
260  if variables_to_feed is None:
261    variable_nodes_to_feed = list(all_variables.values())
262  else:
263    not_in_graph = set(variables_to_feed).difference(list(all_variables))
264    if not_in_graph:
265      raise ValueError('Asked to feed variables that were not found in graph: '
266                       f'{not_in_graph}. Variables contained in the graph: '
267                       f'{list(all_variables)}')
268    variable_nodes_to_feed = [
269        all_variables[name] for name in variables_to_feed
270    ]
271
272  if logging.get_verbosity() >= logging.INFO:
273    prefrozen_graph_def_location = os.path.join(output_prefix,
274                                                'prefrozen_graph.pb')
275    with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
276      graph_writer.write(graph_def.SerializeToString())
277
278  # Load the Variables so that we can freeze the graph.
279  with session.Session(graph=ops_lib.Graph()) as sess:
280    restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
281    if restorer is not None:
282      restorer.restore(sess, checkpoint_path)
283    graph_def.CopyFrom(
284        graph_util.convert_variables_to_constants(
285            sess,
286            graph_def,
287            output_node_names=[
288                _parse_tensor_name(n.name)[0]
289                for n in signature_def.outputs.values()
290            ],
291            variable_names_blacklist=[
292                n.name for n, _ in variable_nodes_to_feed
293            ],
294        ))
295
296  signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
297
298  frozen_graph_def_location = os.path.join(output_prefix, 'frozen_graph.pb')
299  config_pbtxt_location = os.path.join(output_prefix, 'config.pbtxt')
300  logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
301  with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
302    graph_writer.write(graph_def.SerializeToString())
303  config = _signature_to_tf2xla_config(
304      signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
305  logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
306  with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
307    config_writer.write(str(config))
308  return frozen_graph_def_location, config_pbtxt_location
309
310
311def aot_compile_cpu_meta_graph_def(checkpoint_path,
312                                   meta_graph_def,
313                                   output_prefix,
314                                   signature_def_key,
315                                   cpp_class,
316                                   target_triple,
317                                   target_cpu,
318                                   variables_to_feed=(),
319                                   multithreading=False):
320  """Compile a `MetaGraphDef` to header+object files in `output_prefix`.
321
322  Use XLA AOT (`tfcompile`) to convert the given meta graph and
323  signature into a header + object files.  Also create an include makefile
324  that helps identify the appropriate necessary include and library paths
325  to incorporate these files into your C++ program.
326
327  Freezing a graph entails restoring the checkpoint and replacing any inputs and
328  variables with constants. If values are feed, those are used, else inputs are
329  replaced with default all-zero constants. Finally, the graph is pruned and
330  then optimized with grappler.
331
332  If the `freeze_graph` is `True`, all variables are embedded as constants
333  into the graph and binary objects.  If it is `False`, then the variable
334  values become inputs and outputs of the compiled class and the C++
335  caller must set these values manually.
336
337  Args:
338    checkpoint_path: Python string.  Path to checkpoints/variables.
339    meta_graph_def: Instance of `MetaGraphDef`.
340    output_prefix: Python string.  Path prefix for outputs.
341    signature_def_key: String, the signature_def to use in the SavedModel.
342    cpp_class: String, Name of output C++ class.
343    target_triple: String, LLVM target triple.
344    target_cpu: String, LLVM target cpu name.
345    variables_to_feed: A list of strings, the variables that will be fed by the
346      user; these won't be frozen.  If `None`, then we will extract all the
347      variables in the graph and mark them as to-feed.  The default behavior is
348      an empty tuple: all variables must be frozen.
349    multithreading: Whether to enable multithreading in the compiled
350      computation.  Note that if using this option, the resulting object files
351      may have external dependencies on multithreading libraries like nsync.
352
353  Raises:
354    RuntimeError: If tensorflow was not built with XLA.
355    ImportError: If tensorflow was built with XLA but there was another
356      issue importing the tfcompile python wrapper.
357    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
358      missing or has empty outputs.
359  """
360  if _pywrap_tfcompile_import_error:
361    raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type
362
363  else:
364    # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
365    # so that we can set these directly instead of relying on env vars.
366    xla_flags = os.environ.get('XLA_FLAGS')
367    if not xla_flags:
368      xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
369          'true' if multithreading else 'false')
370    else:
371      xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format(
372          'true' if multithreading else 'false')
373    os.environ['XLA_FLAGS'] = xla_flags
374
375  temp_dir = test.get_temp_dir()
376  file_io.recursive_create_dir(temp_dir)
377  frozen_graph_def_location, config_pbtxt_location = freeze_model(
378      checkpoint_path=checkpoint_path,
379      meta_graph_def=meta_graph_def,
380      output_prefix=temp_dir,
381      signature_def_key=signature_def_key,
382      variables_to_feed=variables_to_feed)
383  output_dir = os.path.dirname(output_prefix)
384  file_io.recursive_create_dir(output_dir)
385
386  entry_point = re.sub(
387      '[^0-9a-zA-Z]+', '_',
388      '__xla_' + output_prefix + '__' + cpp_class)
389
390  logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
391
392  makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
393  with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
394    makefile_writer.write(_xla_makefile_string(output_prefix))
395
396  output_prefix = _shlex_quote(output_prefix)
397
398  _pywrap_tfcompile.Compile(
399      graph=frozen_graph_def_location,
400      config=config_pbtxt_location,
401      cpp_class=cpp_class,
402      target_triple=target_triple,
403      target_cpu=target_cpu,
404      entry_point=entry_point,
405      out_function_object='{}.o'.format(output_prefix),
406      out_header='{}.h'.format(output_prefix),
407      out_metadata_object='{}_metadata.o'.format(output_prefix),
408      gen_name_to_index=True,
409      # ProgramShape isn't uniquefied by entry_point.
410      gen_program_shape=False)
411
412
413def _optimize_graph(meta_graph_def, signature_def):
414  """Optimize `meta_graph_def` using grappler.  Returns a `GraphDef`."""
415  # We need to add a collection called 'train_op' so that grappler
416  # knows what the outputs are.
417  new_meta_graph_def = copy.deepcopy(meta_graph_def)
418  fetch_collection = meta_graph_pb2.CollectionDef()
419  for tensor_info in (
420      list(signature_def.inputs.values()) +
421      list(signature_def.outputs.values())):
422    fetch_collection.node_list.value.append(tensor_info.name)
423
424  new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
425
426  config = config_pb2.ConfigProto()
427  rewrite_options = config.graph_options.rewrite_options
428  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
429  return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
430
431
432def _replace_input_placeholders_with_default_values(graph_def, signature_def):
433  """Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
434  name_to_node_map = dict((n.name, n) for n in graph_def.node)
435  processed_nodes = set([])
436  for name, input_ in signature_def.inputs.items():
437    tensor_name, _ = _parse_tensor_name(input_.name)
438    if tensor_name in processed_nodes:
439      continue
440    processed_nodes.add(tensor_name)
441    if tensor_name not in name_to_node_map:
442      raise RuntimeError(
443          f"Unable to find input signature tensor '{tensor_name}' in optimized "
444          f'GraphDef. Graph nodes are: {list(name_to_node_map.keys())}')
445    node = name_to_node_map[tensor_name]
446    if node.op not in ('Placeholder', 'PlaceholderV2'):
447      logging.info(
448          'Tried to convert SavedModel input node \'{}\' from a placeholder, '
449          'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
450                                                               node))
451      continue
452    shape = tensor_shape.TensorShape(input_.tensor_shape)
453    if not shape.is_fully_defined():
454      raise ValueError(
455          f"Expected fully defined input shape for signature_def '{name}', "
456          f"tensor name: '{tensor_name}'; but shape is: {shape}.")
457    temp_graph = ops_lib.Graph()
458    with temp_graph.as_default():
459      const = array_ops.zeros(
460          shape, dtype=input_.dtype, name=tensor_name)
461    node.CopyFrom(const.op.node_def)
462    # Sometimes zeros() also creates additional nodes
463    for op in temp_graph.get_operations():
464      if op.name == const.op.name:  # We just inserted this one.
465        continue
466      graph_def.node.append(op.node_def)
467      name_to_node_map[op.node_def.name] = op.node_def
468
469
470def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
471  """Convert `signature_def` to tf2xla config.  Returns a `tf2xla.Config` proto.
472
473  Args:
474    signature_def: Instance of `SignatureDef`.
475    variable_nodes_to_feed: List of tuples of form `(node_def, modified)`
476      corresponding to VarHandleOp, and a boolean `modified` that describes
477      whether the variable was modified during execution.
478
479  Returns:
480    An instance of `tf2xla.Config` proto.
481
482  Raises:
483    RuntimeError: If TensorFlow was not compiled with XLA.
484  """
485  from tensorflow.compiler.tf2xla import tf2xla_pb2  # pylint: disable=g-import-not-at-top
486
487  config = tf2xla_pb2.Config()
488  tensor_id = tf2xla_pb2.TensorId
489
490  for name, input_ in signature_def.inputs.items():
491    name = name.replace('/', '_')
492    name = 'feed_{}'.format(name)
493    (node_name, output_index) = _parse_tensor_name(input_.name)
494    output_index = int(output_index)
495    config.feed.append(
496        tf2xla_pb2.Feed(
497            id=tensor_id(node_name=node_name, output_index=output_index),
498            name=name,
499            type=input_.dtype,
500            shape=input_.tensor_shape))
501  for name, output_ in signature_def.outputs.items():
502    name = name.replace('/', '_')
503    name = 'fetch_{}'.format(name)
504    (node_name, output_index) = _parse_tensor_name(output_.name)
505    output_index = int(output_index)
506    config.fetch.append(
507        tf2xla_pb2.Fetch(
508            id=tensor_id(node_name=node_name, output_index=output_index),
509            name=name,
510            type=output_.dtype,
511            shape=output_.tensor_shape))
512  for (node, modified) in variable_nodes_to_feed:
513    name = node.name.replace('/', '_')
514    name = 'param_{}'.format(name)
515    config.variable.append(
516        tf2xla_pb2.Variable(
517            node_name=node.name,
518            name=name,
519            type=node.attr['dtype'].type,
520            shape=node.attr['shape'].shape,
521            readonly=not modified))
522
523  return config
524