1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper utilities for AOT compilation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import os 24import pipes 25import re 26import shlex 27 28import six 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import meta_graph_pb2 32from tensorflow.python.client import session 33from tensorflow.python.framework import graph_util 34from tensorflow.python.framework import ops as ops_lib 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import versions 37from tensorflow.python.grappler import tf_optimizer 38from tensorflow.python.lib.io import file_io 39from tensorflow.python.ops import array_ops 40from tensorflow.python.platform import sysconfig as sysconfig_lib 41from tensorflow.python.platform import test 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training import saver as saver_lib 44 45try: 46 from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top 47except ImportError as e: 48 _pywrap_tfcompile_import_error = ImportError( 49 'Unable to import _pywrap_tfcompile; you must build TensorFlow ' 50 'with XLA. You may need to build tensorflow with flag ' 51 '--define=with_xla_support=true. Original error: {}'.format(str(e))) 52else: 53 _pywrap_tfcompile_import_error = None 54 55 56_READ_ONLY_VARIABLE_OPS = ( 57 'ReadVariableOp', 58 'IsVariableInitializedOp', 59 'ResourceGather', 60 'ResourceGatherNd', 61 'VariableShape', 62) 63 64_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN') 65 66 67def _shlex_quote(s): 68 if six.PY2: 69 return pipes.quote(s) 70 else: 71 return shlex.quote(s) 72 73 74def _sysconfig_module(): 75 """Load tf.sysconfig if available and working (i.e., inside a pip package).""" 76 try: 77 _ = sysconfig_lib.get_include() 78 except (ImportError, ValueError): 79 # ValueError may come from saved_model_cli_test trying to enable 80 # eager mode twice. 81 return None 82 return sysconfig_lib 83 84 85def _parse_tensor_name(name): 86 """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0).""" 87 if ':' in name and not name.endswith(':'): 88 node_name = name[:name.rfind(':')] 89 output_slot = int(name[name.rfind(':') + 1:]) 90 return node_name, output_slot 91 else: 92 return name, None 93 94 95_XLA_MAKEFILE_TEMPLATE = """ 96INC = -I{tensorflow_includes} 97LIB = -L{compiled_dir} 98CXXFLAGS = {cxx_flags} 99""" 100 101 102def _xla_makefile_string(output_prefix): 103 """Returns a Makefile string with variables for using XLA binary object files. 104 105 Attempts to identify the right include header paths when run from either 106 an installed TensorFlow pip package, or from bazel run. 107 108 Args: 109 output_prefix: A string containing the output prefix for the XLA AOT 110 compiled header + object files. 111 112 Returns: 113 A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`. 114 """ 115 sysconfig = _sysconfig_module() 116 output_dir, _ = os.path.split(output_prefix) 117 if sysconfig: 118 tensorflow_includes = _shlex_quote(sysconfig.get_include()) 119 else: 120 # Try hard to find the real source directory if this is a local bazel run. 121 if os.path.islink(__file__): 122 this_file = __file__ 123 while os.path.islink(this_file): 124 this_file = os.readlink(this_file) 125 base = os.path.realpath( 126 os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) 127 else: 128 try: 129 base = test.test_src_dir_path('') 130 except KeyError: # Can't find TEST_SRCDIR in environment path. 131 base = os.path.realpath( 132 os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3))) 133 expected_header = os.path.join( 134 base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') 135 if not os.path.exists(expected_header): 136 logging.error( 137 'Could not find includes path. Missing file: {}' 138 .format(expected_header)) 139 tensorflow_includes = base 140 141 return _XLA_MAKEFILE_TEMPLATE.format( 142 tensorflow_includes=tensorflow_includes, 143 compiled_dir=_shlex_quote(output_dir), 144 cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format( 145 versions.CXX11_ABI_FLAG)) 146 147 148def _get_variable_nodes_from_graph_def(graph_def): 149 """Get the list of Variable nodes from `graph_def`. 150 151 Args: 152 graph_def: An instance of `GraphDef`. This GraphDef *must* 153 have already been optimized by Grappler. In particular, function 154 inlining must have already happened. 155 156 Returns: 157 A dict mapping string names of variables to tuples `(node_def, modified)`, 158 where `node_def` is the `NodeDef` corresponding to variable, and `modified` 159 is a python bool describing whether the variable is modified during runtime. 160 """ 161 variables = [n for n in graph_def.node if n.op == 'VarHandleOp'] 162 variable_name_map = dict((n.name, n) for n in variables) 163 child_map = collections.defaultdict(lambda: []) 164 for n in graph_def.node: 165 for inp in n.input: 166 if not inp.startswith('^'): 167 child_map[inp].append(n) 168 variables = {} 169 for (v_name, v_node) in variable_name_map.items(): 170 queue = list(child_map[v_name]) 171 processed = set([]) 172 while queue: 173 n_current = queue.pop() 174 if n_current.name in processed: 175 continue 176 processed.add(n_current.name) 177 if n_current.op in _PASS_THROUGH_VARIABLE_OPS: 178 children = child_map.get(n_current.name, []) 179 queue.extend(children) 180 elif n_current.op not in _READ_ONLY_VARIABLE_OPS: 181 variables[v_name] = (v_node, True) 182 queue = [] 183 if v_name not in variables: 184 variables[v_name] = (v_node, False) 185 186 return variables 187 188 189def _prune_removed_feed_nodes(signature_def, graph_def): 190 """Identify the inputs in the signature no longer in graph_def, prune them. 191 192 Args: 193 signature_def: A `SignatureDef` instance. 194 graph_def: A `GraphDef` instance. 195 196 Returns: 197 A new pruned `SignatureDef`. 198 """ 199 node_names = set([n.name for n in graph_def.node]) 200 new_signature_def = meta_graph_pb2.SignatureDef() 201 new_signature_def.CopyFrom(signature_def) 202 for (k, v) in signature_def.inputs.items(): 203 tensor_name, _ = _parse_tensor_name(v.name) 204 if tensor_name not in node_names: 205 logging.warn( 206 'Signature input key \'{}\', tensor name \'{}\', has been pruned ' 207 'while freezing the graph. Removing it from the compiled signatures.' 208 .format(k, tensor_name)) 209 del new_signature_def.inputs[k] 210 return new_signature_def 211 212 213def aot_compile_cpu_meta_graph_def(checkpoint_path, 214 meta_graph_def, 215 output_prefix, 216 signature_def_key, 217 cpp_class, 218 target_triple, 219 target_cpu, 220 variables_to_feed=(), 221 multithreading=False): 222 """Compile a `MetaGraphDef` to header+object files in `output_prefix`. 223 224 Use XLA AOT (`tfcompile`) to convert the given meta graph and 225 signature into a header + object files. Also create an include makefile 226 that helps identify the appropriate necessary include and library paths 227 to incorporate these files into your C++ program. 228 229 The graph is always optimized with grappler, and optionally (by default) 230 variables are frozen as constants, before compilation happens. 231 232 If the `freeze_graph` is `True`, all variables are embedded as constants 233 into the graph and binary objects. If it is `False`, then the variable 234 values become inputs and outputs of the compiled class and the C++ 235 caller must set these values manually. 236 237 Args: 238 checkpoint_path: Python string. Path to checkpoints/variables. 239 meta_graph_def: Instance of `MetaGraphDef`. 240 output_prefix: Python string. Path prefix for outputs. 241 signature_def_key: String, the signature_def to use in the SavedModel. 242 cpp_class: String, Name of output C++ class. 243 target_triple: String, LLVM target triple. 244 target_cpu: String, LLVM target cpu name. 245 variables_to_feed: A list of strings, the variables that will be fed by the 246 user; these won't be frozen. If `None`, then we will extract all the 247 variables in the graph and mark them as to-feed. The default behavior is 248 an empty tuple: all variables must be frozen. 249 multithreading: Whether to enable multithreading in the compiled 250 computation. Note that if using this option, the resulting object files 251 may have external dependencies on multithreading libraries like nsync. 252 253 Raises: 254 RuntimeError: If tensorflow was not built with XLA. 255 ImportError: If tensorflow was built with XLA but there was another 256 issue importing the tfcompile python wrapper. 257 ValueError: If `meta_graph_def.signature_def[signature_def_key]` is 258 missing or has empty outputs. 259 """ 260 if _pywrap_tfcompile_import_error: 261 raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type 262 263 else: 264 # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap 265 # so that we can set these directly instead of relying on env vars. 266 xla_flags = os.environ.get('XLA_FLAGS') 267 if not xla_flags: 268 xla_flags = '--xla_cpu_multi_thread_eigen={}'.format( 269 'true' if multithreading else 'false') 270 else: 271 xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format( 272 'true' if multithreading else 'false') 273 os.environ['XLA_FLAGS'] = xla_flags 274 275 signature_def_map = meta_graph_def.signature_def 276 if signature_def_key not in signature_def_map: 277 raise ValueError( 278 'Unable to find signature_def key \'{}\' in signature def map. ' 279 'Available keys: {}'.format( 280 signature_def_key, 281 list(signature_def_map.keys()))) 282 signature_def = signature_def_map[signature_def_key] 283 if not signature_def.outputs: 284 raise ValueError( 285 'Signature key {} must have outputs, but saw none:\n{}'.format( 286 signature_def_key, str(signature_def))) 287 288 temp_dir = test.get_temp_dir() 289 file_io.recursive_create_dir(temp_dir) 290 if logging.get_verbosity() >= logging.INFO: 291 original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') 292 with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: 293 graph_writer.write(meta_graph_def.graph_def.SerializeToString()) 294 295 # This updates graph_def in place. 296 _replace_input_placeholders_with_default_values( 297 meta_graph_def.graph_def, signature_def) 298 299 graph_def = _optimize_graph(meta_graph_def, signature_def) 300 301 all_variables = _get_variable_nodes_from_graph_def(graph_def) 302 if variables_to_feed is None: 303 variable_nodes_to_feed = list(all_variables.values()) 304 else: 305 not_in_graph = set(variables_to_feed).difference(list(all_variables)) 306 if not_in_graph: 307 raise ValueError( 308 'Asked to feed variables that were not found in graph: {}. ' 309 'Variables contained in the graph: {}'.format( 310 not_in_graph, list(all_variables))) 311 variable_nodes_to_feed = [ 312 all_variables[name] for name in variables_to_feed 313 ] 314 315 if logging.get_verbosity() >= logging.INFO: 316 prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') 317 with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: 318 graph_writer.write(graph_def.SerializeToString()) 319 320 # Load the Variables so that we can freeze the graph. 321 with session.Session(graph=ops_lib.Graph()) as sess: 322 restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) 323 if restorer is not None: 324 restorer.restore(sess, checkpoint_path) 325 graph_def.CopyFrom( 326 graph_util.convert_variables_to_constants( 327 sess, 328 graph_def, 329 output_node_names=[ 330 _parse_tensor_name(n.name)[0] 331 for n in signature_def.outputs.values() 332 ], 333 variable_names_blacklist=[ 334 n.name for n, _ in variable_nodes_to_feed 335 ], 336 )) 337 338 signature_def = _prune_removed_feed_nodes(signature_def, graph_def) 339 340 frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') 341 config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') 342 logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) 343 with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: 344 graph_writer.write(graph_def.SerializeToString()) 345 config = _signature_to_tf2xla_config( 346 signature_def, variable_nodes_to_feed=variable_nodes_to_feed) 347 logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) 348 with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: 349 config_writer.write(str(config)) 350 351 output_dir = os.path.dirname(output_prefix) 352 file_io.recursive_create_dir(output_dir) 353 354 entry_point = re.sub( 355 '[^0-9a-zA-Z]+', '_', 356 '__xla_' + output_prefix + '__' + cpp_class) 357 358 logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) 359 360 makefile_inc_location = '{}_makefile.inc'.format(output_prefix) 361 with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: 362 makefile_writer.write(_xla_makefile_string(output_prefix)) 363 364 output_prefix = _shlex_quote(output_prefix) 365 366 _pywrap_tfcompile.Compile( 367 graph=frozen_graph_def_location, 368 config=config_pbtxt_location, 369 cpp_class=cpp_class, 370 target_triple=target_triple, 371 target_cpu=target_cpu, 372 entry_point=entry_point, 373 out_function_object='{}.o'.format(output_prefix), 374 out_header='{}.h'.format(output_prefix), 375 out_metadata_object='{}_metadata.o'.format(output_prefix), 376 gen_name_to_index=True, 377 # ProgramShape isn't uniquefied by entry_point. 378 gen_program_shape=False) 379 380 381def _optimize_graph(meta_graph_def, signature_def): 382 """Optimize `meta_graph_def` using grappler. Returns a `GraphDef`.""" 383 # We need to add a collection called 'train_op' so that grappler 384 # knows what the outputs are. 385 new_meta_graph_def = copy.deepcopy(meta_graph_def) 386 fetch_collection = meta_graph_pb2.CollectionDef() 387 for tensor_info in ( 388 list(signature_def.inputs.values()) + 389 list(signature_def.outputs.values())): 390 fetch_collection.node_list.value.append(tensor_info.name) 391 392 new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection) 393 394 config = config_pb2.ConfigProto() 395 rewrite_options = config.graph_options.rewrite_options 396 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 397 return tf_optimizer.OptimizeGraph(config, new_meta_graph_def) 398 399 400def _replace_input_placeholders_with_default_values(graph_def, signature_def): 401 """Replace graphdef's `tf.placeholder` input ops with all-zero constants.""" 402 name_to_node_map = dict((n.name, n) for n in graph_def.node) 403 processed_nodes = set([]) 404 for name, input_ in signature_def.inputs.items(): 405 tensor_name, _ = _parse_tensor_name(input_.name) 406 if tensor_name in processed_nodes: 407 continue 408 processed_nodes.add(tensor_name) 409 if tensor_name not in name_to_node_map: 410 raise RuntimeError( 411 'Unable to find input signature tensor \'{}\' in optimized GraphDef. ' 412 'Graph nodes are: {}'.format(tensor_name, 413 list(name_to_node_map.keys()))) 414 node = name_to_node_map[tensor_name] 415 if node.op not in ('Placeholder', 'PlaceholderV2'): 416 logging.info( 417 'Tried to convert SavedModel input node \'{}\' from a placeholder, ' 418 'but it doesn\'t look like a placeholder: {}'.format(tensor_name, 419 node)) 420 continue 421 shape = tensor_shape.TensorShape(input_.tensor_shape) 422 if not shape.is_fully_defined(): 423 raise ValueError( 424 'Expected fully defined input shape for signature_def \'{}\', ' 425 'tensor name: \'{}\'; but shape is: {}.' 426 .format(name, tensor_name, shape)) 427 temp_graph = ops_lib.Graph() 428 with temp_graph.as_default(): 429 const = array_ops.zeros( 430 shape, dtype=input_.dtype, name=tensor_name) 431 node.CopyFrom(const.op.node_def) 432 # Sometimes zeros() also creates additional nodes 433 for op in temp_graph.get_operations(): 434 if op.name == const.op.name: # We just inserted this one. 435 continue 436 graph_def.node.append(op.node_def) 437 name_to_node_map[op.node_def.name] = op.node_def 438 439 440def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed): 441 """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. 442 443 Args: 444 signature_def: Instance of `SignatureDef`. 445 variable_nodes_to_feed: List of tuples of form `(node_def, modified)` 446 corresponding to VarHandleOp, and a boolean `modified` that describes 447 whether the variable was modified during execution. 448 449 Returns: 450 An instance of `tf2xla.Config` proto. 451 452 Raises: 453 RuntimeError: If TensorFlow was not compiled with XLA. 454 """ 455 from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top 456 457 config = tf2xla_pb2.Config() 458 tensor_id = tf2xla_pb2.TensorId 459 460 for name, input_ in signature_def.inputs.items(): 461 name = name.replace('/', '_') 462 name = 'feed_{}'.format(name) 463 (node_name, output_index) = _parse_tensor_name(input_.name) 464 output_index = int(output_index) 465 config.feed.append( 466 tf2xla_pb2.Feed( 467 id=tensor_id(node_name=node_name, output_index=output_index), 468 name=name, 469 type=input_.dtype, 470 shape=input_.tensor_shape)) 471 for name, output_ in signature_def.outputs.items(): 472 name = name.replace('/', '_') 473 name = 'fetch_{}'.format(name) 474 (node_name, output_index) = _parse_tensor_name(output_.name) 475 output_index = int(output_index) 476 config.fetch.append( 477 tf2xla_pb2.Fetch( 478 id=tensor_id(node_name=node_name, output_index=output_index), 479 name=name, 480 type=output_.dtype, 481 shape=output_.tensor_shape)) 482 for (node, modified) in variable_nodes_to_feed: 483 name = node.name.replace('/', '_') 484 name = 'param_{}'.format(name) 485 config.variable.append( 486 tf2xla_pb2.Variable( 487 node_name=node.name, 488 name=name, 489 type=node.attr['dtype'].type, 490 shape=node.attr['shape'].shape, 491 readonly=not modified)) 492 493 return config 494