1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper utilities for AOT compilation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import os 24import pipes 25import re 26import shlex 27 28import six 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import meta_graph_pb2 32from tensorflow.python.client import session 33from tensorflow.python.framework import graph_util 34from tensorflow.python.framework import ops as ops_lib 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import versions 37from tensorflow.python.grappler import tf_optimizer 38from tensorflow.python.lib.io import file_io 39from tensorflow.python.ops import array_ops 40from tensorflow.python.platform import sysconfig as sysconfig_lib 41from tensorflow.python.platform import test 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training import saver as saver_lib 44 45try: 46 from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top 47except ImportError as e: 48 _pywrap_tfcompile_import_error = ImportError( 49 'Unable to import _pywrap_tfcompile; you must build TensorFlow ' 50 'with XLA. You may need to build tensorflow with flag ' 51 '--define=with_xla_support=true. Original error: {}'.format(str(e))) 52else: 53 _pywrap_tfcompile_import_error = None 54 55 56_READ_ONLY_VARIABLE_OPS = ( 57 'ReadVariableOp', 58 'IsVariableInitializedOp', 59 'ResourceGather', 60 'ResourceGatherNd', 61 'VariableShape', 62) 63 64_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN') 65 66 67def _shlex_quote(s): 68 if six.PY2: 69 return pipes.quote(s) 70 else: 71 return shlex.quote(s) 72 73 74def _sysconfig_module(): 75 """Load tf.sysconfig if available and working (i.e., inside a pip package).""" 76 try: 77 _ = sysconfig_lib.get_include() 78 except ImportError: 79 return None 80 return sysconfig_lib 81 82 83def _parse_tensor_name(name): 84 """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0).""" 85 if ':' in name and not name.endswith(':'): 86 node_name = name[:name.rfind(':')] 87 output_slot = int(name[name.rfind(':') + 1:]) 88 return node_name, output_slot 89 else: 90 return name, None 91 92 93_XLA_MAKEFILE_TEMPLATE = """ 94INC = -I{tensorflow_includes} 95LIB = -L{compiled_dir} 96CXXFLAGS = {cxx_flags} 97""" 98 99 100def _xla_makefile_string(output_prefix): 101 """Returns a Makefile string with variables for using XLA binary object files. 102 103 Attempts to identify the right include header paths when run from either 104 an installed TensorFlow pip package, or from bazel run. 105 106 Args: 107 output_prefix: A string containing the output prefix for the XLA AOT 108 compiled header + object files. 109 110 Returns: 111 A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`. 112 """ 113 sysconfig = _sysconfig_module() 114 output_dir, _ = os.path.split(output_prefix) 115 if sysconfig: 116 tensorflow_includes = _shlex_quote(sysconfig.get_include()) 117 else: 118 # Try hard to find the real source directory if this is a local bazel run. 119 if os.path.islink(__file__): 120 this_file = __file__ 121 while os.path.islink(this_file): 122 this_file = os.readlink(this_file) 123 base = os.path.realpath( 124 os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) 125 else: 126 try: 127 base = test.test_src_dir_path('') 128 except KeyError: # Can't find TEST_SRCDIR in environment path. 129 base = os.path.realpath( 130 os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3))) 131 expected_header = os.path.join( 132 base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') 133 if not os.path.exists(expected_header): 134 logging.error( 135 'Could not find includes path. Missing file: {}' 136 .format(expected_header)) 137 tensorflow_includes = base 138 139 return _XLA_MAKEFILE_TEMPLATE.format( 140 tensorflow_includes=tensorflow_includes, 141 compiled_dir=_shlex_quote(output_dir), 142 cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format( 143 versions.CXX11_ABI_FLAG)) 144 145 146def _get_variable_nodes_from_graph_def(graph_def): 147 """Get the list of Variable nodes from `graph_def`. 148 149 Args: 150 graph_def: An instance of `GraphDef`. This GraphDef *must* 151 have already been optimized by Grappler. In particular, function 152 inlining must have already happened. 153 154 Returns: 155 A dict mapping string names of variables to tuples `(node_def, modified)`, 156 where `node_def` is the `NodeDef` corresponding to variable, and `modified` 157 is a python bool describing whether the variable is modified during runtime. 158 """ 159 variables = [n for n in graph_def.node if n.op == 'VarHandleOp'] 160 variable_name_map = dict((n.name, n) for n in variables) 161 child_map = collections.defaultdict(lambda: []) 162 for n in graph_def.node: 163 for inp in n.input: 164 if not inp.startswith('^'): 165 child_map[inp].append(n) 166 variables = {} 167 for (v_name, v_node) in variable_name_map.items(): 168 queue = list(child_map[v_name]) 169 processed = set([]) 170 while queue: 171 n_current = queue.pop() 172 if n_current.name in processed: 173 continue 174 processed.add(n_current.name) 175 if n_current.op in _PASS_THROUGH_VARIABLE_OPS: 176 children = child_map.get(n_current.name, []) 177 queue.extend(children) 178 elif n_current.op not in _READ_ONLY_VARIABLE_OPS: 179 variables[v_name] = (v_node, True) 180 queue = [] 181 if v_name not in variables: 182 variables[v_name] = (v_node, False) 183 184 return variables 185 186 187def _prune_removed_feed_nodes(signature_def, graph_def): 188 """Identify the inputs in the signature no longer in graph_def, prune them. 189 190 Args: 191 signature_def: A `SignatureDef` instance. 192 graph_def: A `GraphDef` instance. 193 194 Returns: 195 A new pruned `SignatureDef`. 196 """ 197 node_names = set([n.name for n in graph_def.node]) 198 new_signature_def = meta_graph_pb2.SignatureDef() 199 new_signature_def.CopyFrom(signature_def) 200 for (k, v) in signature_def.inputs.items(): 201 tensor_name, _ = _parse_tensor_name(v.name) 202 if tensor_name not in node_names: 203 logging.warn( 204 'Signature input key \'{}\', tensor name \'{}\', has been pruned ' 205 'while freezing the graph. Removing it from the compiled signatures.' 206 .format(k, tensor_name)) 207 del new_signature_def.inputs[k] 208 return new_signature_def 209 210 211def aot_compile_cpu_meta_graph_def(checkpoint_path, 212 meta_graph_def, 213 output_prefix, 214 signature_def_key, 215 cpp_class, 216 target_triple, 217 target_cpu, 218 variables_to_feed=(), 219 multithreading=False): 220 """Compile a `MetaGraphDef` to header+object files in `output_prefix`. 221 222 Use XLA AOT (`tfcompile`) to convert the given meta graph and 223 signature into a header + object files. Also create an include makefile 224 that helps identify the appropriate necessary include and library paths 225 to incorporate these files into your C++ program. 226 227 The graph is always optimized with grappler, and optionally (by default) 228 variables are frozen as constants, before compilation happens. 229 230 If the `freeze_graph` is `True`, all variables are embedded as constants 231 into the graph and binary objects. If it is `False`, then the variable 232 values become inputs and outputs of the compiled class and the C++ 233 caller must set these values manually. 234 235 Args: 236 checkpoint_path: Python string. Path to checkpoints/variables. 237 meta_graph_def: Instance of `MetaGraphDef`. 238 output_prefix: Python string. Path prefix for outputs. 239 signature_def_key: String, the signature_def to use in the SavedModel. 240 cpp_class: String, Name of output C++ class. 241 target_triple: String, LLVM target triple. 242 target_cpu: String, LLVM target cpu name. 243 variables_to_feed: A list of strings, the variables that will be fed by the 244 user; these won't be frozen. If `None`, then we will extract all the 245 variables in the graph and mark them as to-feed. The default behavior is 246 an empty tuple: all variables must be frozen. 247 multithreading: Whether to enable multithreading in the compiled 248 computation. Note that if using this option, the resulting object files 249 may have external dependencies on multithreading libraries like nsync. 250 251 Raises: 252 RuntimeError: If tensorflow was not built with XLA. 253 ImportError: If tensorflow was built with XLA but there was another 254 issue importing the tfcompile python wrapper. 255 ValueError: If `meta_graph_def.signature_def[signature_def_key]` is 256 missing or has empty outputs. 257 """ 258 if _pywrap_tfcompile_import_error: 259 raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type 260 261 else: 262 # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap 263 # so that we can set these directly instead of relying on env vars. 264 xla_flags = os.environ.get('XLA_FLAGS') 265 if not xla_flags: 266 xla_flags = '--xla_cpu_multi_thread_eigen={}'.format( 267 'true' if multithreading else 'false') 268 else: 269 xla_flags += ',--xla_cpu_multi_thread_eigen={}'.format( 270 'true' if multithreading else 'false') 271 os.environ['XLA_FLAGS'] = xla_flags 272 273 signature_def_map = meta_graph_def.signature_def 274 if signature_def_key not in signature_def_map: 275 raise ValueError( 276 'Unable to find signature_def key \'{}\' in signature def map. ' 277 'Available keys: {}'.format( 278 signature_def_key, 279 list(signature_def_map.keys()))) 280 signature_def = signature_def_map[signature_def_key] 281 if not signature_def.outputs: 282 raise ValueError( 283 'Signature key {} must have outputs, but saw none:\n{}'.format( 284 signature_def_key, str(signature_def))) 285 286 temp_dir = test.get_temp_dir() 287 file_io.recursive_create_dir(temp_dir) 288 if logging.get_verbosity() >= logging.INFO: 289 original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') 290 with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: 291 graph_writer.write(meta_graph_def.graph_def.SerializeToString()) 292 293 # This updates graph_def in place. 294 _replace_input_placeholders_with_default_values( 295 meta_graph_def.graph_def, signature_def) 296 297 graph_def = _optimize_graph(meta_graph_def, signature_def) 298 299 all_variables = _get_variable_nodes_from_graph_def(graph_def) 300 if variables_to_feed is None: 301 variable_nodes_to_feed = list(all_variables.values()) 302 else: 303 not_in_graph = set(variables_to_feed).difference(list(all_variables)) 304 if not_in_graph: 305 raise ValueError( 306 'Asked to feed variables that were not found in graph: {}. ' 307 'Variables contained in the graph: {}'.format( 308 not_in_graph, list(all_variables))) 309 variable_nodes_to_feed = [ 310 all_variables[name] for name in variables_to_feed 311 ] 312 313 if logging.get_verbosity() >= logging.INFO: 314 prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') 315 with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: 316 graph_writer.write(graph_def.SerializeToString()) 317 318 # Load the Variables so that we can freeze the graph. 319 with session.Session(graph=ops_lib.Graph()) as sess: 320 restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) 321 if restorer is not None: 322 restorer.restore(sess, checkpoint_path) 323 graph_def.CopyFrom( 324 graph_util.convert_variables_to_constants( 325 sess, 326 graph_def, 327 output_node_names=[ 328 _parse_tensor_name(n.name)[0] 329 for n in signature_def.outputs.values() 330 ], 331 variable_names_blacklist=[ 332 n.name for n, _ in variable_nodes_to_feed 333 ], 334 )) 335 336 signature_def = _prune_removed_feed_nodes(signature_def, graph_def) 337 338 frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') 339 config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') 340 logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) 341 with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: 342 graph_writer.write(graph_def.SerializeToString()) 343 config = _signature_to_tf2xla_config( 344 signature_def, variable_nodes_to_feed=variable_nodes_to_feed) 345 logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) 346 with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: 347 config_writer.write(str(config)) 348 349 output_dir = os.path.dirname(output_prefix) 350 file_io.recursive_create_dir(output_dir) 351 352 entry_point = re.sub( 353 '[^0-9a-zA-Z]+', '_', 354 '__xla_' + output_prefix + '__' + cpp_class) 355 356 logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) 357 358 makefile_inc_location = '{}_makefile.inc'.format(output_prefix) 359 with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: 360 makefile_writer.write(_xla_makefile_string(output_prefix)) 361 362 output_prefix = _shlex_quote(output_prefix) 363 364 _pywrap_tfcompile.Compile( 365 graph=frozen_graph_def_location, 366 config=config_pbtxt_location, 367 cpp_class=cpp_class, 368 target_triple=target_triple, 369 target_cpu=target_cpu, 370 entry_point=entry_point, 371 out_function_object='{}.o'.format(output_prefix), 372 out_header='{}.h'.format(output_prefix), 373 out_metadata_object='{}_metadata.o'.format(output_prefix), 374 gen_name_to_index=True, 375 # ProgramShape isn't uniquefied by entry_point. 376 gen_program_shape=False) 377 378 379def _optimize_graph(meta_graph_def, signature_def): 380 """Optimize `meta_graph_def` using grappler. Returns a `GraphDef`.""" 381 # We need to add a collection called 'train_op' so that grappler 382 # knows what the outputs are. 383 new_meta_graph_def = copy.deepcopy(meta_graph_def) 384 fetch_collection = meta_graph_pb2.CollectionDef() 385 for tensor_info in ( 386 list(signature_def.inputs.values()) + 387 list(signature_def.outputs.values())): 388 fetch_collection.node_list.value.append(tensor_info.name) 389 390 new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection) 391 392 config = config_pb2.ConfigProto() 393 rewrite_options = config.graph_options.rewrite_options 394 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 395 return tf_optimizer.OptimizeGraph(config, new_meta_graph_def) 396 397 398def _replace_input_placeholders_with_default_values(graph_def, signature_def): 399 """Replace graphdef's `tf.placeholder` input ops with all-zero constants.""" 400 name_to_node_map = dict((n.name, n) for n in graph_def.node) 401 processed_nodes = set([]) 402 for name, input_ in signature_def.inputs.items(): 403 tensor_name, _ = _parse_tensor_name(input_.name) 404 if tensor_name in processed_nodes: 405 continue 406 processed_nodes.add(tensor_name) 407 if tensor_name not in name_to_node_map: 408 raise RuntimeError( 409 'Unable to find input signature tensor \'{}\' in optimized GraphDef. ' 410 'Graph nodes are: {}'.format(tensor_name, 411 list(name_to_node_map.keys()))) 412 node = name_to_node_map[tensor_name] 413 if node.op not in ('Placeholder', 'PlaceholderV2'): 414 logging.info( 415 'Tried to convert SavedModel input node \'{}\' from a placeholder, ' 416 'but it doesn\'t look like a placeholder: {}'.format(tensor_name, 417 node)) 418 continue 419 shape = tensor_shape.TensorShape(input_.tensor_shape) 420 if not shape.is_fully_defined(): 421 raise ValueError( 422 'Expected fully defined input shape for signature_def \'{}\', ' 423 'tensor name: \'{}\'; but shape is: {}.' 424 .format(name, tensor_name, shape)) 425 temp_graph = ops_lib.Graph() 426 with temp_graph.as_default(): 427 const = array_ops.zeros( 428 shape, dtype=input_.dtype, name=tensor_name) 429 node.CopyFrom(const.op.node_def) 430 # Sometimes zeros() also creates additional nodes 431 for op in temp_graph.get_operations(): 432 if op.name == const.op.name: # We just inserted this one. 433 continue 434 graph_def.node.append(op.node_def) 435 name_to_node_map[op.node_def.name] = op.node_def 436 437 438def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed): 439 """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. 440 441 Args: 442 signature_def: Instance of `SignatureDef`. 443 variable_nodes_to_feed: List of tuples of form `(node_def, modified)` 444 corresponding to VarHandleOp, and a boolean `modified` that describes 445 whether the variable was modified during execution. 446 447 Returns: 448 An instance of `tf2xla.Config` proto. 449 450 Raises: 451 RuntimeError: If TensorFlow was not compiled with XLA. 452 """ 453 from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top 454 455 config = tf2xla_pb2.Config() 456 tensor_id = tf2xla_pb2.TensorId 457 458 for name, input_ in signature_def.inputs.items(): 459 name = name.replace('/', '_') 460 name = 'feed_{}'.format(name) 461 (node_name, output_index) = _parse_tensor_name(input_.name) 462 output_index = int(output_index) 463 config.feed.append( 464 tf2xla_pb2.Feed( 465 id=tensor_id(node_name=node_name, output_index=output_index), 466 name=name, 467 type=input_.dtype, 468 shape=input_.tensor_shape)) 469 for name, output_ in signature_def.outputs.items(): 470 name = name.replace('/', '_') 471 name = 'fetch_{}'.format(name) 472 (node_name, output_index) = _parse_tensor_name(output_.name) 473 output_index = int(output_index) 474 config.fetch.append( 475 tf2xla_pb2.Fetch( 476 id=tensor_id(node_name=node_name, output_index=output_index), 477 name=name, 478 type=output_.dtype, 479 shape=output_.tensor_shape)) 480 for (node, modified) in variable_nodes_to_feed: 481 name = node.name.replace('/', '_') 482 name = 'param_{}'.format(name) 483 config.variable.append( 484 tf2xla_pb2.Variable( 485 node_name=node.name, 486 name=name, 487 type=node.attr['dtype'].type, 488 shape=node.attr['shape'].shape, 489 readonly=not modified)) 490 491 return config 492