1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper utilities for AOT compilation.""" 16 17import collections 18import copy 19import os 20import re 21import shlex 22from typing import List, Tuple 23 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.core.protobuf import meta_graph_pb2 26from tensorflow.python.client import session 27from tensorflow.python.framework import graph_util 28from tensorflow.python.framework import ops as ops_lib 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import versions 31from tensorflow.python.grappler import tf_optimizer 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import array_ops 34from tensorflow.python.platform import sysconfig as sysconfig_lib 35from tensorflow.python.platform import test 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.training import saver as saver_lib 38 39try: 40 from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top 41except ImportError as e: 42 _pywrap_tfcompile_import_error = ImportError( 43 'Unable to import _pywrap_tfcompile; you must build TensorFlow ' 44 'with XLA. You may need to build tensorflow with flag ' 45 '--define=with_xla_support=true. Original error: {}'.format(str(e))) 46else: 47 _pywrap_tfcompile_import_error = None 48 49 50_READ_ONLY_VARIABLE_OPS = ( 51 'ReadVariableOp', 52 'IsVariableInitializedOp', 53 'ResourceGather', 54 'ResourceGatherNd', 55 'VariableShape', 56) 57 58_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN') 59 60 61def _shlex_quote(s): 62 return shlex.quote(s) 63 64 65def _sysconfig_module(): 66 """Load tf.sysconfig if available and working (i.e., inside a pip package).""" 67 try: 68 _ = sysconfig_lib.get_include() 69 except (ImportError, ValueError): 70 # ValueError may come from saved_model_cli_test trying to enable 71 # eager mode twice. 72 return None 73 return sysconfig_lib 74 75 76def _parse_tensor_name(name): 77 """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0).""" 78 if ':' in name and not name.endswith(':'): 79 node_name = name[:name.rfind(':')] 80 output_slot = int(name[name.rfind(':') + 1:]) 81 return node_name, output_slot 82 else: 83 return name, None 84 85 86_XLA_MAKEFILE_TEMPLATE = """ 87INC = -I{tensorflow_includes} 88LIB = -L{compiled_dir} 89CXXFLAGS = {cxx_flags} 90""" 91 92 93def _xla_makefile_string(output_prefix): 94 """Returns a Makefile string with variables for using XLA binary object files. 95 96 Attempts to identify the right include header paths when run from either 97 an installed TensorFlow pip package, or from bazel run. 98 99 Args: 100 output_prefix: A string containing the output prefix for the XLA AOT 101 compiled header + object files. 102 103 Returns: 104 A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`. 105 """ 106 sysconfig = _sysconfig_module() 107 output_dir, _ = os.path.split(output_prefix) 108 if sysconfig: 109 tensorflow_includes = _shlex_quote(sysconfig.get_include()) 110 else: 111 # Try hard to find the real source directory if this is a local bazel run. 112 if os.path.islink(__file__): 113 this_file = __file__ 114 while os.path.islink(this_file): 115 this_file = os.readlink(this_file) 116 base = os.path.realpath( 117 os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) 118 else: 119 try: 120 base = test.test_src_dir_path('') 121 except KeyError: # Can't find TEST_SRCDIR in environment path. 122 base = os.path.realpath( 123 os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3))) 124 expected_header = os.path.join( 125 base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') 126 if not os.path.exists(expected_header): 127 logging.error( 128 'Could not find includes path. Missing file: {}' 129 .format(expected_header)) 130 tensorflow_includes = base 131 132 return _XLA_MAKEFILE_TEMPLATE.format( 133 tensorflow_includes=tensorflow_includes, 134 compiled_dir=_shlex_quote(output_dir), 135 cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format( 136 versions.CXX11_ABI_FLAG)) 137 138 139def _get_variable_nodes_from_graph_def(graph_def): 140 """Get the list of Variable nodes from `graph_def`. 141 142 Args: 143 graph_def: An instance of `GraphDef`. This GraphDef *must* 144 have already been optimized by Grappler. In particular, function 145 inlining must have already happened. 146 147 Returns: 148 A dict mapping string names of variables to tuples `(node_def, modified)`, 149 where `node_def` is the `NodeDef` corresponding to variable, and `modified` 150 is a python bool describing whether the variable is modified during runtime. 151 """ 152 variables = [n for n in graph_def.node if n.op == 'VarHandleOp'] 153 variable_name_map = dict((n.name, n) for n in variables) 154 child_map = collections.defaultdict(lambda: []) 155 for n in graph_def.node: 156 for inp in n.input: 157 if not inp.startswith('^'): 158 child_map[inp].append(n) 159 variables = {} 160 for (v_name, v_node) in variable_name_map.items(): 161 queue = list(child_map[v_name]) 162 processed = set([]) 163 while queue: 164 n_current = queue.pop() 165 if n_current.name in processed: 166 continue 167 processed.add(n_current.name) 168 if n_current.op in _PASS_THROUGH_VARIABLE_OPS: 169 children = child_map.get(n_current.name, []) 170 queue.extend(children) 171 elif n_current.op not in _READ_ONLY_VARIABLE_OPS: 172 variables[v_name] = (v_node, True) 173 queue = [] 174 if v_name not in variables: 175 variables[v_name] = (v_node, False) 176 177 return variables 178 179 180def _prune_removed_feed_nodes(signature_def, graph_def): 181 """Identify the inputs in the signature no longer in graph_def, prune them. 182 183 Args: 184 signature_def: A `SignatureDef` instance. 185 graph_def: A `GraphDef` instance. 186 187 Returns: 188 A new pruned `SignatureDef`. 189 """ 190 node_names = set([n.name for n in graph_def.node]) 191 new_signature_def = meta_graph_pb2.SignatureDef() 192 new_signature_def.CopyFrom(signature_def) 193 for (k, v) in signature_def.inputs.items(): 194 tensor_name, _ = _parse_tensor_name(v.name) 195 if tensor_name not in node_names: 196 logging.warn( 197 'Signature input key \'{}\', tensor name \'{}\', has been pruned ' 198 'while freezing the graph. Removing it from the compiled signatures.' 199 .format(k, tensor_name)) 200 del new_signature_def.inputs[k] 201 return new_signature_def 202 203 204def freeze_model(checkpoint_path: str, 205 meta_graph_def: meta_graph_pb2.MetaGraphDef, 206 output_prefix: str, signature_def_key: str, 207 variables_to_feed: List[str]) -> Tuple[str, str]: 208 """Freeze a `MetaGraphDef` in preparation for tfcompile`. 209 210 The graph is always optimized with grappler, and optionally (by default) 211 variables are frozen as constants, before compilation happens. 212 213 Args: 214 checkpoint_path: Python string. Path to checkpoints/variables. 215 meta_graph_def: Instance of `MetaGraphDef`. 216 output_prefix: Python string. Path prefix for outputs. 217 signature_def_key: String, the signature_def to use in the SavedModel. 218 variables_to_feed: A list of strings, the variables that will be fed by the 219 user; these won't be frozen. If `None`, then we will extract all the 220 variables in the graph and mark them as to-feed. The default behavior is 221 an empty tuple: all variables must be frozen. 222 Returns: 223 a pair containing the path to the frozen model and the path to the config. 224 Raises: 225 RuntimeError: If tensorflow was not built with XLA. 226 ImportError: If tensorflow was built with XLA but there was another 227 issue importing the tfcompile python wrapper. 228 ValueError: If `meta_graph_def.signature_def[signature_def_key]` is 229 missing or has empty outputs. 230 """ 231 if _pywrap_tfcompile_import_error: 232 raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type 233 234 signature_def_map = meta_graph_def.signature_def 235 if signature_def_key not in signature_def_map: 236 raise ValueError( 237 f"Unable to find signature_def_key '{signature_def_key}' in signature " 238 'def map of `meta_graph_def`. Available keys: ' 239 f'{list(signature_def_map.keys())}') 240 signature_def = signature_def_map[signature_def_key] 241 if not signature_def.outputs: 242 raise ValueError( 243 f'Signature key {signature_def_key} must have outputs, but saw none:\n' 244 f'{str(signature_def)}') 245 246 file_io.recursive_create_dir(output_prefix) 247 if logging.get_verbosity() >= logging.INFO: 248 original_graph_def_location = os.path.join(output_prefix, 249 'original_graph.pb') 250 with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: 251 graph_writer.write(meta_graph_def.graph_def.SerializeToString()) 252 253 # This updates graph_def in place. 254 _replace_input_placeholders_with_default_values( 255 meta_graph_def.graph_def, signature_def) 256 257 graph_def = _optimize_graph(meta_graph_def, signature_def) 258 259 all_variables = _get_variable_nodes_from_graph_def(graph_def) 260 if variables_to_feed is None: 261 variable_nodes_to_feed = list(all_variables.values()) 262 else: 263 not_in_graph = set(variables_to_feed).difference(list(all_variables)) 264 if not_in_graph: 265 raise ValueError('Asked to feed variables that were not found in graph: ' 266 f'{not_in_graph}. Variables contained in the graph: ' 267 f'{list(all_variables)}') 268 variable_nodes_to_feed = [ 269 all_variables[name] for name in variables_to_feed 270 ] 271 272 if logging.get_verbosity() >= logging.INFO: 273 prefrozen_graph_def_location = os.path.join(output_prefix, 274 'prefrozen_graph.pb') 275 with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: 276 graph_writer.write(graph_def.SerializeToString()) 277 278 # Load the Variables so that we can freeze the graph. 279 with session.Session(graph=ops_lib.Graph()) as sess: 280 restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) 281 if restorer is not None: 282 restorer.restore(sess, checkpoint_path) 283 graph_def.CopyFrom( 284 graph_util.convert_variables_to_constants( 285 sess, 286 graph_def, 287 output_node_names=[ 288 _parse_tensor_name(n.name)[0] 289 for n in signature_def.outputs.values() 290 ], 291 variable_names_blacklist=[ 292 n.name for n, _ in variable_nodes_to_feed 293 ], 294 )) 295 296 signature_def = _prune_removed_feed_nodes(signature_def, graph_def) 297 298 frozen_graph_def_location = os.path.join(output_prefix, 'frozen_graph.pb') 299 config_pbtxt_location = os.path.join(output_prefix, 'config.pbtxt') 300 logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) 301 with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: 302 graph_writer.write(graph_def.SerializeToString()) 303 config = _signature_to_tf2xla_config( 304 signature_def, variable_nodes_to_feed=variable_nodes_to_feed) 305 logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) 306 with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: 307 config_writer.write(str(config)) 308 return frozen_graph_def_location, config_pbtxt_location 309 310 311def aot_compile_cpu_meta_graph_def(checkpoint_path, 312 meta_graph_def, 313 output_prefix, 314 signature_def_key, 315 cpp_class, 316 target_triple, 317 target_cpu, 318 variables_to_feed=(), 319 multithreading=False): 320 """Compile a `MetaGraphDef` to header+object files in `output_prefix`. 321 322 Use XLA AOT (`tfcompile`) to convert the given meta graph and 323 signature into a header + object files. Also create an include makefile 324 that helps identify the appropriate necessary include and library paths 325 to incorporate these files into your C++ program. 326 327 Freezing a graph entails restoring the checkpoint and replacing any inputs and 328 variables with constants. If values are feed, those are used, else inputs are 329 replaced with default all-zero constants. Finally, the graph is pruned and 330 then optimized with grappler. 331 332 If the `freeze_graph` is `True`, all variables are embedded as constants 333 into the graph and binary objects. If it is `False`, then the variable 334 values become inputs and outputs of the compiled class and the C++ 335 caller must set these values manually. 336 337 Args: 338 checkpoint_path: Python string. Path to checkpoints/variables. 339 meta_graph_def: Instance of `MetaGraphDef`. 340 output_prefix: Python string. Path prefix for outputs. 341 signature_def_key: String, the signature_def to use in the SavedModel. 342 cpp_class: String, Name of output C++ class. 343 target_triple: String, LLVM target triple. 344 target_cpu: String, LLVM target cpu name. 345 variables_to_feed: A list of strings, the variables that will be fed by the 346 user; these won't be frozen. If `None`, then we will extract all the 347 variables in the graph and mark them as to-feed. The default behavior is 348 an empty tuple: all variables must be frozen. 349 multithreading: Whether to enable multithreading in the compiled 350 computation. Note that if using this option, the resulting object files 351 may have external dependencies on multithreading libraries like nsync. 352 353 Raises: 354 RuntimeError: If tensorflow was not built with XLA. 355 ImportError: If tensorflow was built with XLA but there was another 356 issue importing the tfcompile python wrapper. 357 ValueError: If `meta_graph_def.signature_def[signature_def_key]` is 358 missing or has empty outputs. 359 """ 360 if _pywrap_tfcompile_import_error: 361 raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type 362 363 else: 364 # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap 365 # so that we can set these directly instead of relying on env vars. 366 xla_flags = os.environ.get('XLA_FLAGS') 367 if not xla_flags: 368 xla_flags = '--xla_cpu_multi_thread_eigen={}'.format( 369 'true' if multithreading else 'false') 370 else: 371 xla_flags += ' --xla_cpu_multi_thread_eigen={}'.format( 372 'true' if multithreading else 'false') 373 os.environ['XLA_FLAGS'] = xla_flags 374 375 temp_dir = test.get_temp_dir() 376 file_io.recursive_create_dir(temp_dir) 377 frozen_graph_def_location, config_pbtxt_location = freeze_model( 378 checkpoint_path=checkpoint_path, 379 meta_graph_def=meta_graph_def, 380 output_prefix=temp_dir, 381 signature_def_key=signature_def_key, 382 variables_to_feed=variables_to_feed) 383 output_dir = os.path.dirname(output_prefix) 384 file_io.recursive_create_dir(output_dir) 385 386 entry_point = re.sub( 387 '[^0-9a-zA-Z]+', '_', 388 '__xla_' + output_prefix + '__' + cpp_class) 389 390 logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) 391 392 makefile_inc_location = '{}_makefile.inc'.format(output_prefix) 393 with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: 394 makefile_writer.write(_xla_makefile_string(output_prefix)) 395 396 output_prefix = _shlex_quote(output_prefix) 397 398 _pywrap_tfcompile.Compile( 399 graph=frozen_graph_def_location, 400 config=config_pbtxt_location, 401 cpp_class=cpp_class, 402 target_triple=target_triple, 403 target_cpu=target_cpu, 404 entry_point=entry_point, 405 out_function_object='{}.o'.format(output_prefix), 406 out_header='{}.h'.format(output_prefix), 407 out_metadata_object='{}_metadata.o'.format(output_prefix), 408 gen_name_to_index=True, 409 # ProgramShape isn't uniquefied by entry_point. 410 gen_program_shape=False) 411 412 413def _optimize_graph(meta_graph_def, signature_def): 414 """Optimize `meta_graph_def` using grappler. Returns a `GraphDef`.""" 415 # We need to add a collection called 'train_op' so that grappler 416 # knows what the outputs are. 417 new_meta_graph_def = copy.deepcopy(meta_graph_def) 418 fetch_collection = meta_graph_pb2.CollectionDef() 419 for tensor_info in ( 420 list(signature_def.inputs.values()) + 421 list(signature_def.outputs.values())): 422 fetch_collection.node_list.value.append(tensor_info.name) 423 424 new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection) 425 426 config = config_pb2.ConfigProto() 427 rewrite_options = config.graph_options.rewrite_options 428 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 429 return tf_optimizer.OptimizeGraph(config, new_meta_graph_def) 430 431 432def _replace_input_placeholders_with_default_values(graph_def, signature_def): 433 """Replace graphdef's `tf.placeholder` input ops with all-zero constants.""" 434 name_to_node_map = dict((n.name, n) for n in graph_def.node) 435 processed_nodes = set([]) 436 for name, input_ in signature_def.inputs.items(): 437 tensor_name, _ = _parse_tensor_name(input_.name) 438 if tensor_name in processed_nodes: 439 continue 440 processed_nodes.add(tensor_name) 441 if tensor_name not in name_to_node_map: 442 raise RuntimeError( 443 f"Unable to find input signature tensor '{tensor_name}' in optimized " 444 f'GraphDef. Graph nodes are: {list(name_to_node_map.keys())}') 445 node = name_to_node_map[tensor_name] 446 if node.op not in ('Placeholder', 'PlaceholderV2'): 447 logging.info( 448 'Tried to convert SavedModel input node \'{}\' from a placeholder, ' 449 'but it doesn\'t look like a placeholder: {}'.format(tensor_name, 450 node)) 451 continue 452 shape = tensor_shape.TensorShape(input_.tensor_shape) 453 if not shape.is_fully_defined(): 454 raise ValueError( 455 f"Expected fully defined input shape for signature_def '{name}', " 456 f"tensor name: '{tensor_name}'; but shape is: {shape}.") 457 temp_graph = ops_lib.Graph() 458 with temp_graph.as_default(): 459 const = array_ops.zeros( 460 shape, dtype=input_.dtype, name=tensor_name) 461 node.CopyFrom(const.op.node_def) 462 # Sometimes zeros() also creates additional nodes 463 for op in temp_graph.get_operations(): 464 if op.name == const.op.name: # We just inserted this one. 465 continue 466 graph_def.node.append(op.node_def) 467 name_to_node_map[op.node_def.name] = op.node_def 468 469 470def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed): 471 """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. 472 473 Args: 474 signature_def: Instance of `SignatureDef`. 475 variable_nodes_to_feed: List of tuples of form `(node_def, modified)` 476 corresponding to VarHandleOp, and a boolean `modified` that describes 477 whether the variable was modified during execution. 478 479 Returns: 480 An instance of `tf2xla.Config` proto. 481 482 Raises: 483 RuntimeError: If TensorFlow was not compiled with XLA. 484 """ 485 from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top 486 487 config = tf2xla_pb2.Config() 488 tensor_id = tf2xla_pb2.TensorId 489 490 for name, input_ in signature_def.inputs.items(): 491 name = name.replace('/', '_') 492 name = 'feed_{}'.format(name) 493 (node_name, output_index) = _parse_tensor_name(input_.name) 494 output_index = int(output_index) 495 config.feed.append( 496 tf2xla_pb2.Feed( 497 id=tensor_id(node_name=node_name, output_index=output_index), 498 name=name, 499 type=input_.dtype, 500 shape=input_.tensor_shape)) 501 for name, output_ in signature_def.outputs.items(): 502 name = name.replace('/', '_') 503 name = 'fetch_{}'.format(name) 504 (node_name, output_index) = _parse_tensor_name(output_.name) 505 output_index = int(output_index) 506 config.fetch.append( 507 tf2xla_pb2.Fetch( 508 id=tensor_id(node_name=node_name, output_index=output_index), 509 name=name, 510 type=output_.dtype, 511 shape=output_.tensor_shape)) 512 for (node, modified) in variable_nodes_to_feed: 513 name = node.name.replace('/', '_') 514 name = 'param_{}'.format(name) 515 config.variable.append( 516 tf2xla_pb2.Variable( 517 node_name=node.name, 518 name=name, 519 type=node.attr['dtype'].type, 520 shape=node.attr['shape'].shape, 521 readonly=not modified)) 522 523 return config 524