1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Exposes the Python wrapper conversion to trt_graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22from functools import partial # pylint: disable=g-importing-member 23import os 24import platform 25import tempfile 26 27import six as _six 28 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.core.protobuf import meta_graph_pb2 31from tensorflow.core.protobuf import rewriter_config_pb2 32from tensorflow.python.client import session 33from tensorflow.python.compiler.tensorrt import utils as trt_utils 34from tensorflow.python.eager import context 35from tensorflow.python.eager import wrap_function 36from tensorflow.python.framework import convert_to_constants 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import errors 39from tensorflow.python.framework import graph_util 40from tensorflow.python.framework import importer 41from tensorflow.python.framework import ops 42from tensorflow.python.grappler import tf_optimizer 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import gen_resource_variable_ops 45from tensorflow.python.platform import tf_logging 46from tensorflow.python.saved_model import builder 47from tensorflow.python.saved_model import load 48from tensorflow.python.saved_model import loader 49from tensorflow.python.saved_model import save 50from tensorflow.python.saved_model import signature_constants 51from tensorflow.python.saved_model import tag_constants 52from tensorflow.python.training import saver 53from tensorflow.python.training.tracking import tracking 54from tensorflow.python.util import deprecation 55from tensorflow.python.util import nest 56from tensorflow.python.util.lazy_loader import LazyLoader 57from tensorflow.python.util.tf_export import tf_export 58 59if platform.system() == "Windows": 60 raise RuntimeError("Windows platform is not supported") 61 62# Lazily load the op, since it's not available in cpu-only builds. Importing 63# this at top will cause tests that imports TF-TRT fail when they're built 64# and run without CUDA/GPU. 65gen_trt_ops = LazyLoader( 66 "gen_trt_ops", globals(), 67 "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") 68 69_pywrap_py_utils = LazyLoader( 70 "_pywrap_py_utils", globals(), 71 "tensorflow.compiler.tf2tensorrt._pywrap_py_utils") 72 73# Register TRT ops in python, so that when users import this module they can 74# execute a TRT-converted graph without calling any of the methods in this 75# module. 76# 77# This will call register_op_list() in 78# tensorflow/python/framework/op_def_registry.py, but it doesn't register 79# the op or the op kernel in C++ runtime. 80try: 81 gen_trt_ops.trt_engine_op # pylint: disable=pointless-statement 82except AttributeError: 83 pass 84 85 86def _to_bytes(s): 87 """Encode s if it is a sequence of chars.""" 88 if isinstance(s, _six.text_type): 89 return s.encode("utf-8", errors="surrogateescape") 90 return s 91 92 93def _to_string(s): 94 """Decode s if it is a sequence of bytes.""" 95 if isinstance(s, _six.binary_type): 96 return s.decode("utf-8") 97 return s 98 99 100class TrtPrecisionMode(object): 101 FP32 = "FP32" 102 FP16 = "FP16" 103 INT8 = "INT8" 104 105 @staticmethod 106 def supported_precision_modes(): 107 precisions = [ 108 TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8 109 ] 110 return precisions + [p.lower() for p in precisions] 111 112 113# Use a large enough number as the default max_workspace_size for TRT engines, 114# so it can produce reasonable performance results with the default. 115DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 116 117 118@tf_export("experimental.tensorrt.ConversionParams", v1=[]) 119class TrtConversionParams( 120 collections.namedtuple("TrtConversionParams", [ 121 "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", 122 "maximum_cached_engines", "use_calibration", "allow_build_at_runtime" 123 ])): 124 """Parameters that are used for TF-TRT conversion. 125 126 Fields: 127 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 128 engine can use at execution time. This corresponds to the 129 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 130 precision_mode: one the strings in 131 TrtPrecisionMode.supported_precision_modes(). 132 minimum_segment_size: the minimum number of nodes required for a subgraph 133 to be replaced by TRTEngineOp. 134 maximum_cached_engines: max number of cached TRT engines for dynamic TRT 135 ops. Created TRT engines for a dynamic dimension are cached. This is the 136 maximum number of engines that can be cached. If the number of cached 137 engines is already at max but none of them supports the input shapes, 138 the TRTEngineOp will fall back to run the original TF subgraph that 139 corresponds to the TRTEngineOp. 140 use_calibration: this argument is ignored if precision_mode is not INT8. 141 If set to True, a calibration graph will be created to calibrate the 142 missing ranges. The calibration graph must be converted to an inference 143 graph by running calibration with calibrate(). If set to False, 144 quantization nodes will be expected for every tensor in the graph 145 (excluding those which will be fused). If a range is missing, an error 146 will occur. Please note that accuracy may be negatively affected if 147 there is a mismatch between which tensors TRT quantizes and which 148 tensors were trained with fake quantization. 149 allow_build_at_runtime: whether to build TensorRT engines during runtime. 150 If no TensorRT engine can be found in cache that can handle the given 151 inputs during runtime, then a new TensorRT engine is built at runtime if 152 allow_build_at_runtime=True, and otherwise native TF is used. 153 """ 154 155 def __new__(cls, 156 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 157 precision_mode=TrtPrecisionMode.FP32, 158 minimum_segment_size=3, 159 maximum_cached_engines=1, 160 use_calibration=True, 161 allow_build_at_runtime=True): 162 return super(TrtConversionParams, 163 cls).__new__(cls, max_workspace_size_bytes, precision_mode, 164 minimum_segment_size, maximum_cached_engines, 165 use_calibration, allow_build_at_runtime) 166 167 168DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams() 169 170_TRT_ENGINE_OP_NAME = "TRTEngineOp" 171 172 173def _check_conversion_params(conversion_params, is_v2=False): 174 """Validate the provided TrtConversionParams. 175 176 Args: 177 conversion_params: a TrtConversionParams instance. 178 is_v2: whether we're getting a RewriterConfig for TF 2.0. 179 180 Raises: 181 TypeError: if any of the parameters are of unexpected type. 182 ValueError: if any of the parameters are of unexpected value. 183 """ 184 supported_precision_modes = TrtPrecisionMode.supported_precision_modes() 185 if conversion_params.precision_mode not in supported_precision_modes: 186 raise ValueError( 187 ("precision mode '{}' is not supported." 188 "It should be one of {}").format(conversion_params.precision_mode, 189 supported_precision_modes)) 190 191 192def _check_trt_version_compatibility(): 193 """Check compatibility of TensorRT version. 194 195 Raises: 196 RuntimeError: if the TensorRT library version is incompatible. 197 """ 198 linked_version = _pywrap_py_utils.get_linked_tensorrt_version() 199 loaded_version = _pywrap_py_utils.get_loaded_tensorrt_version() 200 assert isinstance(linked_version, tuple) 201 assert isinstance(loaded_version, tuple) 202 assert len(linked_version) == 3 203 assert len(loaded_version) == 3 204 tf_logging.info("Linked TensorRT version: %s" % str(linked_version)) 205 tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version)) 206 if loaded_version < linked_version: 207 tf_logging.error( 208 "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " % 209 (".".join(str(x) for x in loaded_version), ".".join( 210 str(x) for x in linked_version)) + 211 "TensorRT does not support forward compatibility. " + 212 "It is also required to use the same major version of TensorRT " + 213 "during compilation and runtime.") 214 raise RuntimeError("Incompatible TensorRT versions") 215 if loaded_version[0] > linked_version[0]: 216 tf_logging.error( 217 "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " % 218 (".".join(str(x) for x in loaded_version), ".".join( 219 str(x) for x in linked_version)) + 220 "It is required to use the same major version " + 221 "of TensorRT during compilation and runtime.") 222 raise RuntimeError("Incompatible TensorRT major version") 223 if loaded_version != linked_version: 224 tf_logging.info( 225 "Loaded TensorRT %s and linked TensorFlow against TensorRT %s. " % 226 (".".join(str(x) for x in loaded_version), ".".join( 227 str(x) for x in linked_version)) + 228 "This is supported because TensorRT " + 229 " minor/patch upgrades are backward compatible") 230 231 232def _get_tensorrt_rewriter_config(conversion_params, 233 is_dynamic_op=None, 234 max_batch_size=None, 235 is_v2=False, 236 disable_non_trt_optimizers=False, 237 use_implicit_batch=True): 238 """Returns a RewriterConfig proto for TRT transformation. 239 240 Args: 241 conversion_params: a TrtConversionParams instance. 242 is_dynamic_op: whether to use dynamic engines. 243 max_batch_size: maximum batch size for static engines. 244 is_v2: whether we're getting a RewriterConfig for TF 2.0. 245 disable_non_trt_optimizers: Turn off all default Grappler optimizers. 246 use_implicit_batch: Whether to use implicit batch or explicit batch. 247 248 Returns: 249 A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. 250 251 Raises: 252 TypeError: if any of the parameters are of unexpected type. 253 ValueError: if any of the parameters are of unexpected value. 254 """ 255 _check_conversion_params(conversion_params, is_v2=is_v2) 256 if is_v2 and is_dynamic_op is not None and not is_dynamic_op: 257 raise ValueError("is_dynamic_op is either None or True for TF2") 258 if not is_v2 and is_dynamic_op is None: 259 raise ValueError("is_dynamic_op can't be None for TF1") 260 261 if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None: 262 raise ValueError("max_batch_size has to be None for TF2" 263 " or when is_dynamic_op == True in TF1") 264 if is_dynamic_op is not None and not is_dynamic_op and not isinstance( 265 max_batch_size, int): 266 raise ValueError( 267 "max_batch_size has to be an integer for is_dynamic_op==False in TF1") 268 rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() 269 # Disable Grappler Remapper to avoid that fused OPs that may not be 270 # beneficial to TF-TRT and are not supported by TF-TRT. 271 rewriter_config_with_trt.remapping = False 272 273 if not disable_non_trt_optimizers: 274 # Layout optimizer may add Const nodes followed by Reshape nodes, thus we 275 # need to run constant folding again. 276 rewriter_config_with_trt.optimizers.extend( 277 ["constfold", "layout", "constfold"]) 278 279 rewriter_config_with_trt.meta_optimizer_iterations = ( 280 rewriter_config_pb2.RewriterConfig.ONE) 281 optimizer = rewriter_config_with_trt.custom_optimizers.add() 282 283 if not disable_non_trt_optimizers: 284 # Add a constfold optimizer to cleanup the unused Const nodes. 285 rewriter_config_with_trt.custom_optimizers.add().name = "constfold" 286 287 optimizer.name = "TensorRTOptimizer" 288 optimizer.parameter_map[ 289 "minimum_segment_size"].i = conversion_params.minimum_segment_size 290 optimizer.parameter_map["max_workspace_size_bytes"].i = ( 291 conversion_params.max_workspace_size_bytes) 292 optimizer.parameter_map["precision_mode"].s = _to_bytes( 293 conversion_params.precision_mode) 294 optimizer.parameter_map[ 295 "maximum_cached_engines"].i = conversion_params.maximum_cached_engines 296 optimizer.parameter_map[ 297 "use_calibration"].b = conversion_params.use_calibration 298 optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op 299 optimizer.parameter_map[ 300 "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime 301 if max_batch_size is not None: 302 optimizer.parameter_map["max_batch_size"].i = max_batch_size 303 optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch 304 305 # Disabling optimizers should happen after defining the TF-TRT grappler pass 306 # otherwise the template can overwrite the disablement. 307 if disable_non_trt_optimizers: 308 trt_utils.disable_non_trt_optimizers_in_rewriter_config( 309 rewriter_config_with_trt) 310 311 return rewriter_config_with_trt 312 313 314@deprecation.deprecated( 315 None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.") 316def get_tensorrt_rewriter_config(conversion_params, 317 is_dynamic_op=None, 318 max_batch_size=None, 319 is_v2=False, 320 disable_non_trt_optimizers=False): 321 return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op, 322 max_batch_size, is_v2, 323 disable_non_trt_optimizers) 324 325 326# Remove all scope prefixes in the node name. In TF 2.0, the same concrete 327# function can be initialized multiple times with different prefixes, and 328# this will result in the same TRTEngineOp being initialized multiple times 329# with different cache and duplicate TRT engines. 330# TODO(laigd): this may be caused by the fact that TRTEngineOp is not 331# stateful, need to investigate. 332# TODO(laigd): we rely on the fact that all functions are fully inlined 333# before TF-TRT optimizer is called, as otherwise it may generate the same 334# name when optimizing a different function graph. Fix this. 335def _get_canonical_engine_name(name): 336 return name.split("/")[-1] 337 338 339class TrtGraphConverter(object): 340 """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels. 341 342 To run the conversion without quantization calibration (e.g. for FP32/FP16 343 precision modes): 344 345 ```python 346 converter = TrtGraphConverter( 347 input_saved_model_dir="my_dir", 348 precision_mode=TrtPrecisionMode.FP16) 349 converted_graph_def = converter.convert() 350 converter.save(output_saved_model_dir) 351 ``` 352 353 To run the conversion with quantization calibration: 354 355 ```python 356 converter = TrtGraphConverter( 357 input_saved_model_dir="my_dir", 358 precision_mode=TrtPrecisionMode.INT8) 359 converter.convert() 360 361 # Run calibration 10 times. 362 converted_graph_def = converter.calibrate( 363 fetch_names=['output:0'], 364 num_runs=10, 365 feed_dict_fn=lambda: {'input:0': my_next_data()}) 366 367 converter.save(output_saved_model_dir) 368 ``` 369 """ 370 371 def __init__(self, 372 input_saved_model_dir=None, 373 input_saved_model_tags=None, 374 input_saved_model_signature_key=None, 375 input_graph_def=None, 376 nodes_denylist=None, 377 max_batch_size=1, 378 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 379 precision_mode=TrtPrecisionMode.FP32, 380 minimum_segment_size=3, 381 is_dynamic_op=False, 382 maximum_cached_engines=1, 383 use_calibration=True): 384 """Initializes the converter. 385 386 Args: 387 input_saved_model_dir: the directory to load the SavedModel which contains 388 the input graph to transforms. Used only when input_graph_def is None. 389 input_saved_model_tags: list of tags to load the SavedModel. 390 input_saved_model_signature_key: the key of the signature to optimize the 391 graph for. 392 input_graph_def: a GraphDef object containing a model to be transformed. 393 If set to None, the graph will be read from the SavedModel loaded from 394 input_saved_model_dir. 395 nodes_denylist: list of node names to prevent the converter from touching. 396 max_batch_size: max size for the input batch. 397 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 398 engine can use at execution time. This corresponds to the 399 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 400 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 401 minimum_segment_size: the minimum number of nodes required for a subgraph 402 to be replaced by TRTEngineOp. 403 is_dynamic_op: whether to generate dynamic TRT ops which will build the 404 TRT network and engine at run time. 405 maximum_cached_engines: max number of cached TRT engines in dynamic TRT 406 ops. If the number of cached engines is already at max but none of them 407 can serve the input, the TRTEngineOp will fall back to run the TF 408 function based on which the TRTEngineOp is created. 409 use_calibration: this argument is ignored if precision_mode is not INT8. 410 If set to True, a calibration graph will be created to calibrate the 411 missing ranges. The calibration graph must be converted to an inference 412 graph by running calibration with calibrate(). If set to False, 413 quantization nodes will be expected for every tensor in the graph 414 (excluding those which will be fused). If a range is missing, an error 415 will occur. Please note that accuracy may be negatively affected if 416 there is a mismatch between which tensors TRT quantizes and which 417 tensors were trained with fake quantization. 418 419 Raises: 420 ValueError: if the combination of the parameters is invalid. 421 RuntimeError: if this class is used in TF 2.0. 422 """ 423 if context.executing_eagerly(): 424 raise RuntimeError( 425 "Please use tf.experimental.tensorrt.Converter in TF 2.0.") 426 427 if input_graph_def and input_saved_model_dir: 428 raise ValueError( 429 "Can only specify one of input_graph_def and input_saved_model_dir") 430 if not input_graph_def and not input_saved_model_dir: 431 raise ValueError("Must specify one of input_graph_def and " 432 "input_saved_model_dir") 433 _check_trt_version_compatibility() 434 435 self._input_graph_def = input_graph_def 436 self._nodes_denylist = nodes_denylist 437 438 self._input_saved_model_dir = input_saved_model_dir 439 self._converted = False 440 self._grappler_meta_graph_def = None 441 442 self._input_saved_model_tags = ( 443 input_saved_model_tags or [tag_constants.SERVING]) 444 self._input_saved_model_signature_key = ( 445 input_saved_model_signature_key or 446 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 447 448 # For calibration usage. 449 self._calibration_graph = None 450 self._calibration_data_collected = False 451 self._need_calibration = ( 452 precision_mode == TrtPrecisionMode.INT8 and use_calibration) 453 if self._need_calibration and not is_dynamic_op: 454 tf_logging.warn( 455 "INT8 precision mode with calibration is supported with " 456 "dynamic TRT ops only. Disregarding is_dynamic_op parameter.") 457 is_dynamic_op = True 458 459 self._is_dynamic_op = is_dynamic_op 460 if is_dynamic_op: 461 self._max_batch_size = None 462 if max_batch_size is not None: 463 tf_logging.warn("When is_dynamic_op==True max_batch_size should be " 464 "None") 465 else: 466 if not isinstance(max_batch_size, int): 467 raise ValueError("When is_dynamic_op==False max_batch_size should be " 468 "an integer") 469 self._max_batch_size = max_batch_size 470 471 self._conversion_params = TrtConversionParams( 472 max_workspace_size_bytes=max_workspace_size_bytes, 473 precision_mode=precision_mode, 474 minimum_segment_size=minimum_segment_size, 475 maximum_cached_engines=maximum_cached_engines, 476 use_calibration=use_calibration, 477 allow_build_at_runtime=True) 478 _check_conversion_params(self._conversion_params) 479 480 self._test_only_disable_non_trt_optimizers = False 481 482 def _run_conversion(self): 483 """Run Grappler's OptimizeGraph() tool to convert the graph.""" 484 # Create custom ConfigProto for Grappler. 485 grappler_session_config = config_pb2.ConfigProto() 486 custom_rewriter_config = _get_tensorrt_rewriter_config( 487 conversion_params=self._conversion_params, 488 is_dynamic_op=self._is_dynamic_op, 489 max_batch_size=self._max_batch_size, 490 disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, 491 use_implicit_batch=True) 492 grappler_session_config.graph_options.rewrite_options.CopyFrom( 493 custom_rewriter_config) 494 495 # Run Grappler. 496 self._converted_graph_def = tf_optimizer.OptimizeGraph( 497 grappler_session_config, 498 self._grappler_meta_graph_def, 499 graph_id=b"tf_graph") 500 self._converted = True 501 502 def _add_nodes_denylist(self): 503 if self._nodes_denylist: 504 collection_def = self._grappler_meta_graph_def.collection_def["train_op"] 505 denylist = collection_def.node_list.value 506 for i in self._nodes_denylist: 507 if isinstance(i, ops.Tensor): 508 denylist.append(_to_bytes(i.name)) 509 else: 510 denylist.append(_to_bytes(i)) 511 512 def _convert_graph_def(self): 513 """Convert the input GraphDef.""" 514 graph = ops.Graph() 515 with graph.as_default(): 516 importer.import_graph_def(self._input_graph_def, name="") 517 self._grappler_meta_graph_def = saver.export_meta_graph( 518 graph_def=graph.as_graph_def(add_shapes=True), graph=graph) 519 self._add_nodes_denylist() 520 521 self._run_conversion() 522 523 def _collections_to_keep(self, collection_keys): 524 # TODO(laigd): currently we use the collection key to filter out 525 # collections that depend on variable ops, but this may miss some 526 # other user-defined collections. A better way would be to use 527 # CollectionDef::NodeList for the filtering. 528 collections_to_remove = ( 529 ops.GraphKeys._VARIABLE_COLLECTIONS + [ 530 ops.GraphKeys.TRAIN_OP, ops.GraphKeys.WHILE_CONTEXT, 531 ops.GraphKeys.COND_CONTEXT 532 ]) 533 return [key for key in collection_keys if key not in collections_to_remove] 534 535 def _convert_saved_model(self): 536 """Convert the input SavedModel.""" 537 graph = ops.Graph() 538 with session.Session(graph=graph) as sess: 539 input_meta_graph_def = loader.load(sess, self._input_saved_model_tags, 540 self._input_saved_model_dir) 541 input_signature_def = input_meta_graph_def.signature_def[ 542 self._input_saved_model_signature_key] 543 544 def _gather_names(tensor_info): 545 """Get the node names from a TensorInfo.""" 546 return {tensor_info[key].name.split(":")[0] for key in tensor_info} 547 548 # Get input and outputs from all SignatureDef. 549 output_node_names = _gather_names(input_signature_def.inputs).union( 550 _gather_names(input_signature_def.outputs)) 551 552 # Preserve nodes in collection 553 for collection_key in self._collections_to_keep( 554 input_meta_graph_def.collection_def): 555 for op in sess.graph.get_collection(collection_key): 556 if isinstance(op, ops.Operation): 557 output_node_names.add(op.name.split(":")[0]) 558 559 # Freeze the variables in the SavedModel graph and copy the frozen 560 # graph over. 561 frozen_graph_def = graph_util.convert_variables_to_constants( 562 sess, sess.graph.as_graph_def(add_shapes=True), 563 list(output_node_names)) 564 self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() 565 self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) 566 567 # Copy the collections that are not variables. 568 for collection_key in self._collections_to_keep( 569 input_meta_graph_def.collection_def): 570 self._grappler_meta_graph_def.collection_def[collection_key].CopyFrom( 571 input_meta_graph_def.collection_def[collection_key]) 572 573 self._add_nodes_denylist() 574 575 # Copy other information. 576 self._grappler_meta_graph_def.meta_info_def.CopyFrom( 577 input_meta_graph_def.meta_info_def) 578 self._grappler_meta_graph_def.signature_def[ 579 self._input_saved_model_signature_key].CopyFrom(input_signature_def) 580 # TODO(laigd): maybe add back AssetFileDef. 581 582 self._run_conversion() 583 584 def convert(self): 585 """Run the TF-TRT conversion. 586 587 Returns: 588 The converted GraphDef for TF 1.x. 589 """ 590 assert not self._converted 591 if self._input_graph_def: 592 self._convert_graph_def() 593 else: 594 self._convert_saved_model() 595 return self._converted_graph_def 596 597 def calibrate(self, 598 fetch_names, 599 num_runs, 600 feed_dict_fn=None, 601 input_map_fn=None): 602 """Run the calibration and return the calibrated GraphDef. 603 604 Args: 605 fetch_names: a list of output tensor name to fetch during calibration. 606 num_runs: number of runs of the graph during calibration. 607 feed_dict_fn: a function that returns a dictionary mapping input names (as 608 strings) in the GraphDef to be calibrated to values (e.g. Python list, 609 numpy arrays, etc). One and only one of `feed_dict_fn` and 610 `input_map_fn` should be specified. 611 input_map_fn: a function that returns a dictionary mapping input names (as 612 strings) in the GraphDef to be calibrated to Tensor objects. The values 613 of the named input tensors in the GraphDef to be calibrated will be 614 re-mapped to the respective `Tensor` values during calibration. One and 615 only one of `feed_dict_fn` and `input_map_fn` should be specified. 616 617 Raises: 618 ValueError: if the input combination is invalid. 619 RuntimeError: if this method is called in eager mode. 620 621 Returns: 622 The GraphDef after the calibration. 623 """ 624 assert self._converted 625 assert self._need_calibration 626 assert not self._calibration_data_collected 627 628 if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and 629 not input_map_fn): 630 raise ValueError( 631 "Should specify one and only one of feed_dict_fn and input_map_fn.") 632 633 if input_map_fn: 634 for k, v in input_map_fn().items(): 635 if not isinstance(k, str): 636 raise ValueError("Keys of input_map_fn must be of type str") 637 if not isinstance(v, ops.Tensor): 638 raise ValueError("Values of input_map_fn must be of type tf.Tensor") 639 640 self._calibration_graph = ops.Graph() 641 with self._calibration_graph.as_default(): 642 fetches = importer.import_graph_def( 643 self._converted_graph_def, 644 input_map=input_map_fn() if input_map_fn else None, 645 return_elements=fetch_names, 646 name="") 647 648 calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig() 649 if self._test_only_disable_non_trt_optimizers: 650 trt_utils.disable_non_trt_optimizers_in_rewriter_config( 651 calibrate_rewriter_cfg) 652 653 # Set allow_soft_placement=True to run the graph for calibration so that 654 # OPs supported by TensorRT but don't have a GPU implementation are allowed 655 # to execute on CPU. 656 calibrate_config = config_pb2.ConfigProto( 657 allow_soft_placement=True, 658 graph_options=config_pb2.GraphOptions( 659 rewrite_options=calibrate_rewriter_cfg)) 660 661 with session.Session( 662 graph=self._calibration_graph, 663 config=calibrate_config) as calibration_sess: 664 for _ in range(num_runs): 665 calibration_sess.run( 666 fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None) 667 668 # Maps device name to the corresponding get_calibration_data. 669 # 670 # TODO(laigd): a better way would be to use calibration_sess to list 671 # all the devices, add one get_calibration_data for each device, and 672 # fetch each such op for every resource until its found. This can work 673 # even when the device of the TRTEngineOp is empty or not fully specified. 674 device_to_get_resource_op_map = {} 675 676 with self._calibration_graph.as_default(): 677 resource_name_input = array_ops.placeholder(dtypes.string) 678 679 for node in self._converted_graph_def.node: 680 if node.op == _TRT_ENGINE_OP_NAME: 681 # Adds the get_calibration_data op for the device if not done 682 # before. We only add one such op for each device. 683 # TODO(laigd): What if the device is empty????? 684 if node.device not in device_to_get_resource_op_map: 685 with self._calibration_graph.device(node.device): 686 serialized_resources_output = ( 687 gen_trt_ops.get_calibration_data_op(resource_name_input)) 688 device_to_get_resource_op_map[node.device] = ( 689 serialized_resources_output) 690 691 # Get the calibration resource. 692 calibration_result = calibration_sess.run( 693 device_to_get_resource_op_map[node.device], 694 feed_dict={ 695 resource_name_input: _get_canonical_engine_name(node.name) 696 }) 697 node.attr["calibration_data"].s = calibration_result 698 699 self._calibration_data_collected = True 700 701 return self._converted_graph_def 702 703 def save(self, output_saved_model_dir): 704 """Save the converted graph as a SavedModel. 705 706 Args: 707 output_saved_model_dir: construct a SavedModel using the converted 708 GraphDef and save it to the specified directory. This option only works 709 when the input graph is loaded from a SavedModel, i.e. when 710 input_saved_model_dir is specified and input_graph_def is None in 711 __init__(). 712 713 Raises: 714 ValueError: if the input to the converter is a GraphDef instead of a 715 SavedModel. 716 """ 717 assert self._converted 718 if self._need_calibration: 719 assert self._calibration_data_collected 720 if self._input_graph_def: 721 raise ValueError( 722 "Not able to save to a SavedModel since input is a GraphDef") 723 724 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): 725 """Restores collections that we need to keep.""" 726 scope = "" 727 for key in collection_keys: 728 collection_def = src_meta_graph_def.collection_def[key] 729 kind = collection_def.WhichOneof("kind") 730 if kind is None: 731 tf_logging.error( 732 "Cannot identify data type for collection %s. Skipping.", key) 733 continue 734 from_proto = ops.get_from_proto_function(key) 735 if from_proto and kind == "bytes_list": 736 proto_type = ops.get_collection_proto_type(key) 737 # It is assumed that there are no Variables Keys in collections 738 for value in collection_def.bytes_list.value: 739 proto = proto_type() 740 proto.ParseFromString(value) 741 try: 742 new_value = from_proto(proto, import_scope=scope) 743 except: 744 continue 745 dest_graph.add_to_collection(key, new_value) 746 else: 747 field = getattr(collection_def, kind) 748 if kind == "node_list": 749 for value in field.value: 750 name = ops.prepend_name_scope(value, scope) 751 # Since the graph has been optimized, the node may no longer 752 # exists 753 try: 754 col_op = dest_graph.as_graph_element(name) 755 except (TypeError, ValueError, KeyError): 756 continue 757 dest_graph.add_to_collection(key, col_op) 758 elif kind == "int64_list": 759 # NOTE(opensource): This force conversion is to work around the 760 # fact that Python2 distinguishes between int and long, while 761 # Python3 has only int. 762 for value in field.value: 763 dest_graph.add_to_collection(key, int(value)) 764 else: 765 for value in field.value: 766 dest_graph.add_to_collection(key, 767 ops.prepend_name_scope(value, scope)) 768 769 # Write the transformed graphdef as SavedModel. 770 saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) 771 with ops.Graph().as_default(): 772 importer.import_graph_def(self._converted_graph_def, name="") 773 _restore_collections( 774 ops.get_default_graph(), self._grappler_meta_graph_def, 775 self._collections_to_keep( 776 self._grappler_meta_graph_def.collection_def)) 777 # We don't use any specific converter here. 778 with session.Session() as sess: 779 saved_model_builder.add_meta_graph_and_variables( 780 sess, 781 self._input_saved_model_tags, 782 signature_def_map=self._grappler_meta_graph_def.signature_def) 783 # Ignore other meta graphs from the input SavedModel. 784 saved_model_builder.save() 785 786 787def _get_resource_handle(name, device): 788 with ops.device(device): 789 return gen_trt_ops.create_trt_resource_handle(resource_name=name) 790 791 792class _TRTEngineResourceDeleter(tracking.CapturableResourceDeleter): 793 """Resource deleter for destroying TRT engine cache resource.""" 794 795 def __init__(self, resource_name, device): 796 super(_TRTEngineResourceDeleter, self).__init__() 797 self._resource_name = resource_name 798 self._device = device 799 800 def destroy_resource(self): 801 handle = _get_resource_handle(self._resource_name, self._device) 802 with ops.device(self._device): 803 gen_resource_variable_ops.destroy_resource_op( 804 handle, ignore_lookup_error=True) 805 806 807class _TRTEngineResource(tracking.TrackableResource): 808 """Class to track the serialized engines resource.""" 809 810 def __init__(self, 811 resource_name, 812 filename, 813 maximum_cached_engines, 814 device="GPU"): 815 super(_TRTEngineResource, self).__init__( 816 device=device, deleter=_TRTEngineResourceDeleter(resource_name, device)) 817 self._resource_name = resource_name 818 # Track the serialized engine file in the SavedModel. 819 self._filename = self._track_trackable( 820 tracking.Asset(filename), "_serialized_trt_resource_filename") 821 self._maximum_cached_engines = maximum_cached_engines 822 823 def _create_resource(self): 824 return _get_resource_handle(self._resource_name, self._resource_device) 825 826 def _initialize(self): 827 gen_trt_ops.initialize_trt_resource( 828 self.resource_handle, 829 self._filename, 830 max_cached_engines_count=self._maximum_cached_engines) 831 832 833@tf_export("experimental.tensorrt.Converter", v1=[]) 834class TrtGraphConverterV2(object): 835 """An offline converter for TF-TRT transformation for TF 2.0 SavedModels. 836 837 Currently this is not available on Windows platform. 838 839 There are several ways to run the conversion: 840 841 1. FP32/FP16 precision 842 843 ```python 844 params = tf.experimental.tensorrt.ConversionParams( 845 precision_mode='FP16') 846 converter = tf.experimental.tensorrt.Converter( 847 input_saved_model_dir="my_dir", conversion_params=params) 848 converter.convert() 849 converter.save(output_saved_model_dir) 850 ``` 851 852 In this case, no TRT engines will be built or saved in the converted 853 SavedModel. But if input data is available during conversion, we can still 854 build and save the TRT engines to reduce the cost during inference (see 855 option 2 below). 856 857 2. FP32/FP16 precision with pre-built engines 858 859 ```python 860 params = tf.experimental.tensorrt.ConversionParams( 861 precision_mode='FP16', 862 # Set this to a large enough number so it can cache all the engines. 863 maximum_cached_engines=16) 864 converter = tf.experimental.tensorrt.Converter( 865 input_saved_model_dir="my_dir", conversion_params=params) 866 converter.convert() 867 868 # Define a generator function that yields input data, and use it to execute 869 # the graph to build TRT engines. 870 # With TensorRT 5.1, different engines will be built (and saved later) for 871 # different input shapes to the TRTEngineOp. 872 def my_input_fn(): 873 for _ in range(num_runs): 874 inp1, inp2 = ... 875 yield inp1, inp2 876 877 converter.build(input_fn=my_input_fn) # Generate corresponding TRT engines 878 converter.save(output_saved_model_dir) # Generated engines will be saved. 879 ``` 880 881 In this way, one engine will be built/saved for each unique input shapes of 882 the TRTEngineOp. This is good for applications that cannot afford building 883 engines during inference but have access to input data that is similar to 884 the one used in production (for example, that has the same input shapes). 885 Also, the generated TRT engines is platform dependent, so we need to run 886 `build()` in an environment that is similar to production (e.g. with 887 same type of GPU). 888 889 3. INT8 precision and calibration with pre-built engines 890 891 ```python 892 params = tf.experimental.tensorrt.ConversionParams( 893 precision_mode='INT8', 894 # Currently only one INT8 engine is supported in this mode. 895 maximum_cached_engines=1, 896 use_calibration=True) 897 converter = tf.experimental.tensorrt.Converter( 898 input_saved_model_dir="my_dir", conversion_params=params) 899 900 # Define a generator function that yields input data, and run INT8 901 # calibration with the data. All input data should have the same shape. 902 # At the end of convert(), the calibration stats (e.g. range information) 903 # will be saved and can be used to generate more TRT engines with different 904 # shapes. Also, one TRT engine will be generated (with the same shape as 905 # the calibration data) for save later. 906 def my_calibration_input_fn(): 907 for _ in range(num_runs): 908 inp1, inp2 = ... 909 yield inp1, inp2 910 911 converter.convert(calibration_input_fn=my_calibration_input_fn) 912 913 # (Optional) Generate more TRT engines offline (same as the previous 914 # option), to avoid the cost of generating them during inference. 915 def my_input_fn(): 916 for _ in range(num_runs): 917 inp1, inp2 = ... 918 yield inp1, inp2 919 converter.build(input_fn=my_input_fn) 920 921 # Save the TRT engine and the engines. 922 converter.save(output_saved_model_dir) 923 ``` 924 """ 925 926 def __init__(self, 927 input_saved_model_dir=None, 928 input_saved_model_tags=None, 929 input_saved_model_signature_key=None, 930 conversion_params=None): 931 """Initialize the converter. 932 933 Args: 934 input_saved_model_dir: the directory to load the SavedModel which contains 935 the input graph to transforms. Used only when input_graph_def is None. 936 input_saved_model_tags: list of tags to load the SavedModel. 937 input_saved_model_signature_key: the key of the signature to optimize the 938 graph for. 939 conversion_params: a TrtConversionParams instance. 940 941 Raises: 942 ValueError: if the combination of the parameters is invalid. 943 """ 944 assert context.executing_eagerly() 945 if conversion_params is None: 946 conversion_params = TrtConversionParams() 947 948 _check_trt_version_compatibility() 949 _check_conversion_params(conversion_params, is_v2=True) 950 951 self._conversion_params = conversion_params 952 self._input_saved_model_dir = input_saved_model_dir 953 self._input_saved_model_tags = ( 954 input_saved_model_tags or [tag_constants.SERVING]) 955 self._input_saved_model_signature_key = ( 956 input_saved_model_signature_key or 957 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) 958 959 self._need_calibration = ( 960 conversion_params.precision_mode == TrtPrecisionMode.INT8 and 961 conversion_params.use_calibration) 962 963 self._converted = False 964 self._build_called_once = False 965 966 # Fields to support TF-TRT testing and shouldn't be used for other purpose. 967 self._test_only_disable_non_trt_optimizers = False 968 self._test_only_use_implicit_batch = True 969 970 def _need_trt_profiles(self): 971 return not self._test_only_use_implicit_batch 972 973 def _run_conversion(self, meta_graph_def): 974 """Run Grappler's OptimizeGraph() tool to convert the graph. 975 976 Args: 977 meta_graph_def: the MetaGraphDef instance to run the optimizations on. 978 979 Returns: 980 The optimized GraphDef. 981 """ 982 grappler_session_config = config_pb2.ConfigProto() 983 custom_rewriter_config = _get_tensorrt_rewriter_config( 984 conversion_params=self._conversion_params, 985 is_dynamic_op=True, 986 max_batch_size=None, 987 disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, 988 use_implicit_batch=self._test_only_use_implicit_batch) 989 grappler_session_config.graph_options.rewrite_options.CopyFrom( 990 custom_rewriter_config) 991 return tf_optimizer.OptimizeGraph( 992 grappler_session_config, meta_graph_def, graph_id=b"tf_graph") 993 994 def _for_each_trt_node(self, graph_def, fn): 995 """Helper method to manipulate all TRTEngineOps in a GraphDef.""" 996 for node in graph_def.node: 997 if node.op == _TRT_ENGINE_OP_NAME: 998 fn(node) 999 for func in graph_def.library.function: 1000 for node in func.node_def: 1001 if node.op == _TRT_ENGINE_OP_NAME: 1002 fn(node) 1003 1004 def _rebuild_func(self, func): 1005 """Rebuild function from graph_def.""" 1006 rebuilt_func = wrap_function.function_from_graph_def( 1007 self._converted_graph_def, [tensor.name for tensor in func.inputs], 1008 [tensor.name for tensor in func.outputs]) 1009 rebuilt_func.graph.structured_outputs = nest.pack_sequence_as( 1010 func.graph.structured_outputs, rebuilt_func.graph.structured_outputs) 1011 # Copy structured input signature from original function (used during 1012 # serialization) 1013 rebuilt_func.graph.structured_input_signature = ( 1014 func.structured_input_signature) 1015 return rebuilt_func 1016 1017 # TODO(laigd): provide a utility function to optimize a ConcreteFunction and 1018 # use it here (b/124792963). 1019 def convert(self, calibration_input_fn=None): 1020 """Convert the input SavedModel in 2.0 format. 1021 1022 Args: 1023 calibration_input_fn: a generator function that yields input data as a 1024 list or tuple, which will be used to execute the converted signature for 1025 calibration. All the returned input data should have the same shape. 1026 Example: `def input_fn(): yield input1, input2, input3` 1027 1028 Raises: 1029 ValueError: if the input combination is invalid. 1030 1031 Returns: 1032 The TF-TRT converted Function. 1033 """ 1034 assert not self._converted 1035 1036 if (self._need_calibration and not calibration_input_fn): 1037 raise ValueError("Should specify calibration_input_fn because INT8 " 1038 "calibration is needed") 1039 if (not self._need_calibration and calibration_input_fn): 1040 raise ValueError("Should not specify calibration_input_fn because INT8 " 1041 "calibration is not needed") 1042 1043 self._saved_model = load.load(self._input_saved_model_dir, 1044 self._input_saved_model_tags) 1045 func = self._saved_model.signatures[self._input_saved_model_signature_key] 1046 frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) 1047 grappler_meta_graph_def = saver.export_meta_graph( 1048 graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) 1049 1050 # Add a collection 'train_op' so that Grappler knows the outputs. 1051 fetch_collection = meta_graph_pb2.CollectionDef() 1052 for array in frozen_func.inputs + frozen_func.outputs: 1053 fetch_collection.node_list.value.append(array.name) 1054 grappler_meta_graph_def.collection_def["train_op"].CopyFrom( 1055 fetch_collection) 1056 1057 # Run TRT optimizer in Grappler to convert the graph. 1058 self._converted_graph_def = self._run_conversion(grappler_meta_graph_def) 1059 self._converted_func = wrap_function.function_from_graph_def( 1060 self._converted_graph_def, 1061 [tensor.name for tensor in frozen_func.inputs], 1062 [tensor.name for tensor in frozen_func.outputs]) 1063 # Reconstruct the output signatures using the ones from original model. 1064 self._converted_func.graph.structured_outputs = nest.pack_sequence_as( 1065 func.graph.structured_outputs, 1066 self._converted_func.graph.structured_outputs) 1067 # Copy structured input signature from original function (used during 1068 # serialization) 1069 self._converted_func.graph.structured_input_signature = ( 1070 func.structured_input_signature) 1071 1072 if self._need_calibration: 1073 for inp in calibration_input_fn(): 1074 self._converted_func(*map(ops.convert_to_tensor, inp)) 1075 1076 def _save_calibration_table(node): 1077 calibration_table = gen_trt_ops.get_calibration_data_op( 1078 _get_canonical_engine_name(node.name)) 1079 node.attr["calibration_data"].s = calibration_table.numpy() 1080 1081 self._for_each_trt_node(self._converted_graph_def, 1082 _save_calibration_table) 1083 1084 # Rebuild the function since calibration has changed the graph. 1085 self._converted_func = self._rebuild_func(self._converted_func) 1086 1087 self._converted = True 1088 return self._converted_func 1089 1090 def build(self, input_fn): 1091 """Run inference with converted graph in order to build TensorRT engines. 1092 1093 Args: 1094 input_fn: a generator function that yields input data as a list or tuple, 1095 which will be used to execute the converted signature to generate TRT 1096 engines. Example: 1097 `def input_fn(): 1098 # Let's assume a network with 2 input tensors. We generate 3 sets 1099 # of dummy input data: 1100 input_shapes = [[(1, 16), (2, 16)], # 1st input list 1101 [(2, 32), (4, 32)], # 2nd list of two tensors 1102 [(4, 32), (8, 32)]] # 3rd input list 1103 for shapes in input_shapes: 1104 # return a list of input tensors 1105 yield [np.zeros(x).astype(np.float32) for x in shapes]` 1106 Raises: 1107 NotImplementedError: build() is already called. 1108 RuntimeError: the input_fx is None. 1109 """ 1110 if self._build_called_once: 1111 raise NotImplementedError("build() is already called. It is not " 1112 "supported to call build() more than once.") 1113 if not input_fn: 1114 raise RuntimeError("input_fn is None. Method build() needs input_fn " 1115 "to be specified in order to build TensorRT engines") 1116 1117 def _set_profile_generation_mode(value, node): 1118 node.attr["_profile_generation_mode"].b = value 1119 1120 if self._need_trt_profiles(): 1121 # Enable profile generation. 1122 self._for_each_trt_node(self._converted_graph_def, 1123 partial(_set_profile_generation_mode, True)) 1124 # Profile generation is enabled using the _profile_generation_mode 1125 # attribute of the TRTEngineOps. We need to rebuild the function to 1126 # change this attribute. 1127 func = self._rebuild_func(self._converted_func) 1128 else: 1129 func = self._converted_func 1130 1131 first_input = None 1132 # Run inference: 1133 # Builds TRT engines if self._need_trt_profiles is False. 1134 # Builds TRT optimization profiles if self._need_trt_profiles is True. 1135 for inp in input_fn(): 1136 if not first_input: 1137 first_input = inp 1138 func(*map(ops.convert_to_tensor, inp)) 1139 1140 if self._need_trt_profiles(): 1141 # Disable profile generation. 1142 self._for_each_trt_node(self._converted_graph_def, 1143 partial(_set_profile_generation_mode, False)) 1144 # Use the first input in explicit batch mode to build TensorRT engines 1145 # after generating all the profiles. The first input is used but any of 1146 # the inputs can be used because the shape of this input does not 1147 # determine the engine and instead the shapes collected in profiles 1148 # determine the engine. 1149 self._converted_func(*map(ops.convert_to_tensor, first_input)) 1150 1151 self._build_called_once = True 1152 1153 def save(self, output_saved_model_dir): 1154 """Save the converted SavedModel. 1155 1156 Args: 1157 output_saved_model_dir: directory to saved the converted SavedModel. 1158 """ 1159 assert self._converted 1160 1161 if self._need_trt_profiles() and not self._build_called_once: 1162 raise NotImplementedError( 1163 "build() is not called . Explicit batch mode " 1164 "(use_implicit_batch=False) requires generating TensorRT optimization" 1165 " profiles which is done by calling build().") 1166 1167 # Serialize the TRT engines in the cache if any, and create trackable 1168 # resource to track them. 1169 engine_asset_dir = tempfile.mkdtemp() 1170 resource_map = {} 1171 1172 def _serialize_and_track_engine(node): 1173 """Serialize TRT engines in the cache and track them.""" 1174 # Don't dump the same cache twice. 1175 canonical_engine_name = _get_canonical_engine_name(node.name) 1176 if canonical_engine_name in resource_map: 1177 return 1178 1179 filename = os.path.join(engine_asset_dir, 1180 "trt-serialized-engine." + canonical_engine_name) 1181 1182 try: 1183 gen_trt_ops.serialize_trt_resource( 1184 resource_name=canonical_engine_name, 1185 filename=filename, 1186 delete_resource=True) 1187 except errors.NotFoundError: 1188 tf_logging.info("Could not find %s in TF-TRT cache. " 1189 "This can happen if build() is not called, " 1190 "which means TensorRT engines will be built " 1191 "and cached at runtime." % canonical_engine_name) 1192 return 1193 1194 # TODO(laigd): add an option for the user to choose the device. 1195 resource_map[canonical_engine_name] = _TRTEngineResource( 1196 canonical_engine_name, filename, 1197 self._conversion_params.maximum_cached_engines) 1198 1199 self._for_each_trt_node(self._converted_graph_def, 1200 _serialize_and_track_engine) 1201 self._saved_model.trt_engine_resources = resource_map 1202 1203 # Rewrite the signature map using the optimized ConcreteFunction. 1204 signatures = { 1205 key: value for key, value in self._saved_model.signatures.items() 1206 } 1207 1208 # Set allow_build_at_runtime=False if asked by user. 1209 # 1210 # This attribute is set here because build() needs it to be True in order to 1211 # build engines. 1212 if not self._conversion_params.allow_build_at_runtime: 1213 1214 def _reset_allow_build_at_runtime(node): 1215 node.attr["allow_build_at_runtime"].b = False 1216 1217 self._for_each_trt_node(self._converted_graph_def, 1218 _reset_allow_build_at_runtime) 1219 # Rebuild the function since a node attribute changed above 1220 reset_converted_func = wrap_function.function_from_graph_def( 1221 self._converted_graph_def, 1222 [tensor.name for tensor in self._converted_func.inputs], 1223 [tensor.name for tensor in self._converted_func.outputs]) 1224 reset_converted_func.graph.structured_outputs = nest.pack_sequence_as( 1225 self._converted_func.graph.structured_outputs, 1226 reset_converted_func.graph.structured_outputs) 1227 reset_converted_func.graph.strucutred_input_signature = ( 1228 self._converted_func.structured_input_signature) 1229 self._converted_func = reset_converted_func 1230 1231 signatures[self._input_saved_model_signature_key] = self._converted_func 1232 save.save(self._saved_model, output_saved_model_dir, signatures) 1233 1234 1235# TODO(laigd): use TrtConversionParams here. 1236def create_inference_graph( 1237 input_graph_def, 1238 outputs, 1239 max_batch_size=1, 1240 max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, 1241 precision_mode=TrtPrecisionMode.FP32, 1242 minimum_segment_size=3, 1243 is_dynamic_op=False, 1244 maximum_cached_engines=1, 1245 input_saved_model_dir=None, 1246 input_saved_model_tags=None, 1247 input_saved_model_signature_key=None, 1248 output_saved_model_dir=None): 1249 """Python wrapper for the TRT transformation. 1250 1251 Args: 1252 input_graph_def: a GraphDef object containing a model to be transformed. If 1253 set to None, the graph will be read from the SavedModel loaded from 1254 input_saved_model_dir. 1255 outputs: list of tensors or node names for the model outputs. Only used when 1256 input_graph_def is not None. 1257 max_batch_size: max size for the input batch. 1258 max_workspace_size_bytes: the maximum GPU temporary memory which the TRT 1259 engine can use at execution time. This corresponds to the 'workspaceSize' 1260 parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). 1261 precision_mode: one of TrtPrecisionMode.supported_precision_modes(). 1262 minimum_segment_size: the minimum number of nodes required for a subgraph to 1263 be replaced by TRTEngineOp. 1264 is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT 1265 network and engine at run time. 1266 maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. 1267 If the number of cached engines is already at max but none of them can 1268 serve the input, the TRTEngineOp will fall back to run the TF function 1269 based on which the TRTEngineOp is created. 1270 input_saved_model_dir: the directory to load the SavedModel which contains 1271 the input graph to transforms. Used only when input_graph_def is None. 1272 input_saved_model_tags: list of tags to load the SavedModel. 1273 input_saved_model_signature_key: the key of the signature to optimize the 1274 graph for. 1275 output_saved_model_dir: if not None, construct a SavedModel using the 1276 returned GraphDef and save it to the specified directory. This option only 1277 works when the input graph is loaded from a SavedModel, i.e. when 1278 input_saved_model_dir is specified and input_graph_def is None. 1279 1280 Returns: 1281 A GraphDef transformed from input_graph_def (or the SavedModel graph def 1282 loaded from input_saved_model_dir, if input_graph_def is not present), where 1283 all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF 1284 function is added for each of the subgraphs. 1285 1286 If is_dynamic_op is True, each TRTEngineOp will contain a serialized 1287 subgraph GraphDef, which will be converted to a TRT engine at execution time 1288 and the TRT engine will be cached for future usage. A new TRT engine will be 1289 created each time when none of the cached engines match the input shapes. If 1290 it fails to execute the TRT engine or the number of cached engines reaches 1291 maximum_cached_engines, the op will fall back to call the corresponding TF 1292 function. 1293 1294 If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT 1295 engine created from the corresponding subgraph. No more engines will be 1296 created on the fly, and the op will fall back to call the corresponding TF 1297 function when it fails to execute the engine. 1298 1299 Raises: 1300 ValueError: if the combination of the parameters is invalid. 1301 """ 1302 trt_converter = TrtGraphConverter( 1303 input_saved_model_dir=input_saved_model_dir, 1304 input_saved_model_tags=input_saved_model_tags, 1305 input_saved_model_signature_key=input_saved_model_signature_key, 1306 input_graph_def=input_graph_def, 1307 nodes_denylist=outputs, 1308 max_batch_size=max_batch_size, 1309 max_workspace_size_bytes=max_workspace_size_bytes, 1310 precision_mode=precision_mode, 1311 minimum_segment_size=minimum_segment_size, 1312 is_dynamic_op=is_dynamic_op, 1313 maximum_cached_engines=maximum_cached_engines, 1314 use_calibration=False) 1315 converted_graph_def = trt_converter.convert() 1316 if output_saved_model_dir: 1317 trt_converter.save(output_saved_model_dir) 1318 return converted_graph_def 1319