• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper conversion to trt_graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22from functools import partial  # pylint: disable=g-importing-member
23import os
24import platform
25import tempfile
26
27import six as _six
28
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.core.protobuf import meta_graph_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python.client import session
33from tensorflow.python.compiler.tensorrt import utils as trt_utils
34from tensorflow.python.eager import context
35from tensorflow.python.eager import wrap_function
36from tensorflow.python.framework import convert_to_constants
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import graph_util
40from tensorflow.python.framework import importer
41from tensorflow.python.framework import ops
42from tensorflow.python.grappler import tf_optimizer
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import gen_resource_variable_ops
45from tensorflow.python.platform import tf_logging
46from tensorflow.python.saved_model import builder
47from tensorflow.python.saved_model import load
48from tensorflow.python.saved_model import loader
49from tensorflow.python.saved_model import save
50from tensorflow.python.saved_model import signature_constants
51from tensorflow.python.saved_model import tag_constants
52from tensorflow.python.training import saver
53from tensorflow.python.training.tracking import tracking
54from tensorflow.python.util import deprecation
55from tensorflow.python.util import nest
56from tensorflow.python.util.lazy_loader import LazyLoader
57from tensorflow.python.util.tf_export import tf_export
58
59if platform.system() == "Windows":
60  raise RuntimeError("Windows platform is not supported")
61
62# Lazily load the op, since it's not available in cpu-only builds. Importing
63# this at top will cause tests that imports TF-TRT fail when they're built
64# and run without CUDA/GPU.
65gen_trt_ops = LazyLoader(
66    "gen_trt_ops", globals(),
67    "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
68
69_pywrap_py_utils = LazyLoader(
70    "_pywrap_py_utils", globals(),
71    "tensorflow.compiler.tf2tensorrt._pywrap_py_utils")
72
73# Register TRT ops in python, so that when users import this module they can
74# execute a TRT-converted graph without calling any of the methods in this
75# module.
76#
77# This will call register_op_list() in
78# tensorflow/python/framework/op_def_registry.py, but it doesn't register
79# the op or the op kernel in C++ runtime.
80try:
81  gen_trt_ops.trt_engine_op  # pylint: disable=pointless-statement
82except AttributeError:
83  pass
84
85
86def _to_bytes(s):
87  """Encode s if it is a sequence of chars."""
88  if isinstance(s, _six.text_type):
89    return s.encode("utf-8", errors="surrogateescape")
90  return s
91
92
93def _to_string(s):
94  """Decode s if it is a sequence of bytes."""
95  if isinstance(s, _six.binary_type):
96    return s.decode("utf-8")
97  return s
98
99
100class TrtPrecisionMode(object):
101  FP32 = "FP32"
102  FP16 = "FP16"
103  INT8 = "INT8"
104
105  @staticmethod
106  def supported_precision_modes():
107    precisions = [
108        TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8
109    ]
110    return precisions + [p.lower() for p in precisions]
111
112
113# Use a large enough number as the default max_workspace_size for TRT engines,
114# so it can produce reasonable performance results with the default.
115DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
116
117
118@tf_export("experimental.tensorrt.ConversionParams", v1=[])
119class TrtConversionParams(
120    collections.namedtuple("TrtConversionParams", [
121        "max_workspace_size_bytes", "precision_mode", "minimum_segment_size",
122        "maximum_cached_engines", "use_calibration", "allow_build_at_runtime"
123    ])):
124  """Parameters that are used for TF-TRT conversion.
125
126  Fields:
127    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
128      engine can use at execution time. This corresponds to the
129      'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
130    precision_mode: one the strings in
131      TrtPrecisionMode.supported_precision_modes().
132    minimum_segment_size: the minimum number of nodes required for a subgraph
133      to be replaced by TRTEngineOp.
134    maximum_cached_engines: max number of cached TRT engines for dynamic TRT
135      ops. Created TRT engines for a dynamic dimension are cached. This is the
136      maximum number of engines that can be cached. If the number of cached
137      engines is already at max but none of them supports the input shapes,
138      the TRTEngineOp will fall back to run the original TF subgraph that
139      corresponds to the TRTEngineOp.
140    use_calibration: this argument is ignored if precision_mode is not INT8.
141      If set to True, a calibration graph will be created to calibrate the
142      missing ranges. The calibration graph must be converted to an inference
143      graph by running calibration with calibrate(). If set to False,
144      quantization nodes will be expected for every tensor in the graph
145      (excluding those which will be fused). If a range is missing, an error
146      will occur. Please note that accuracy may be negatively affected if
147      there is a mismatch between which tensors TRT quantizes and which
148      tensors were trained with fake quantization.
149    allow_build_at_runtime: whether to build TensorRT engines during runtime.
150      If no TensorRT engine can be found in cache that can handle the given
151      inputs during runtime, then a new TensorRT engine is built at runtime if
152      allow_build_at_runtime=True, and otherwise native TF is used.
153  """
154
155  def __new__(cls,
156              max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
157              precision_mode=TrtPrecisionMode.FP32,
158              minimum_segment_size=3,
159              maximum_cached_engines=1,
160              use_calibration=True,
161              allow_build_at_runtime=True):
162    return super(TrtConversionParams,
163                 cls).__new__(cls, max_workspace_size_bytes, precision_mode,
164                              minimum_segment_size, maximum_cached_engines,
165                              use_calibration, allow_build_at_runtime)
166
167
168DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams()
169
170_TRT_ENGINE_OP_NAME = "TRTEngineOp"
171
172
173def _check_conversion_params(conversion_params, is_v2=False):
174  """Validate the provided TrtConversionParams.
175
176  Args:
177    conversion_params: a TrtConversionParams instance.
178    is_v2: whether we're getting a RewriterConfig for TF 2.0.
179
180  Raises:
181    TypeError: if any of the parameters are of unexpected type.
182    ValueError: if any of the parameters are of unexpected value.
183  """
184  supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
185  if conversion_params.precision_mode not in supported_precision_modes:
186    raise ValueError(
187        ("precision mode '{}' is not supported."
188         "It should be one of {}").format(conversion_params.precision_mode,
189                                          supported_precision_modes))
190
191
192def _check_trt_version_compatibility():
193  """Check compatibility of TensorRT version.
194
195  Raises:
196    RuntimeError: if the TensorRT library version is incompatible.
197  """
198  linked_version = _pywrap_py_utils.get_linked_tensorrt_version()
199  loaded_version = _pywrap_py_utils.get_loaded_tensorrt_version()
200  assert isinstance(linked_version, tuple)
201  assert isinstance(loaded_version, tuple)
202  assert len(linked_version) == 3
203  assert len(loaded_version) == 3
204  tf_logging.info("Linked TensorRT version: %s" % str(linked_version))
205  tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
206  if loaded_version < linked_version:
207    tf_logging.error(
208        "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " %
209        (".".join(str(x) for x in loaded_version), ".".join(
210            str(x) for x in linked_version)) +
211        "TensorRT does not support forward compatibility. " +
212        "It is also required to use the same major version of TensorRT " +
213        "during compilation and runtime.")
214    raise RuntimeError("Incompatible TensorRT versions")
215  if loaded_version[0] > linked_version[0]:
216    tf_logging.error(
217        "Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " %
218        (".".join(str(x) for x in loaded_version), ".".join(
219            str(x) for x in linked_version)) +
220        "It is required to use the same major version " +
221        "of TensorRT during compilation and runtime.")
222    raise RuntimeError("Incompatible TensorRT major version")
223  if loaded_version != linked_version:
224    tf_logging.info(
225        "Loaded TensorRT %s and linked TensorFlow against TensorRT %s. " %
226        (".".join(str(x) for x in loaded_version), ".".join(
227            str(x) for x in linked_version)) +
228        "This is supported because TensorRT " +
229        " minor/patch upgrades are backward compatible")
230
231
232def _get_tensorrt_rewriter_config(conversion_params,
233                                  is_dynamic_op=None,
234                                  max_batch_size=None,
235                                  is_v2=False,
236                                  disable_non_trt_optimizers=False,
237                                  use_implicit_batch=True):
238  """Returns a RewriterConfig proto for TRT transformation.
239
240  Args:
241    conversion_params: a TrtConversionParams instance.
242    is_dynamic_op: whether to use dynamic engines.
243    max_batch_size: maximum batch size for static engines.
244    is_v2: whether we're getting a RewriterConfig for TF 2.0.
245    disable_non_trt_optimizers: Turn off all default Grappler optimizers.
246    use_implicit_batch: Whether to use implicit batch or explicit batch.
247
248  Returns:
249    A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
250
251  Raises:
252    TypeError: if any of the parameters are of unexpected type.
253    ValueError: if any of the parameters are of unexpected value.
254  """
255  _check_conversion_params(conversion_params, is_v2=is_v2)
256  if is_v2 and is_dynamic_op is not None and not is_dynamic_op:
257    raise ValueError("is_dynamic_op is either None or True for TF2")
258  if not is_v2 and is_dynamic_op is None:
259    raise ValueError("is_dynamic_op can't be None for TF1")
260
261  if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None:
262    raise ValueError("max_batch_size has to be None for TF2"
263                     " or when is_dynamic_op == True in TF1")
264  if is_dynamic_op is not None and not is_dynamic_op and not isinstance(
265      max_batch_size, int):
266    raise ValueError(
267        "max_batch_size has to be an integer for is_dynamic_op==False in TF1")
268  rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
269  # Disable Grappler Remapper to avoid that fused OPs that may not be
270  # beneficial to TF-TRT and are not supported by TF-TRT.
271  rewriter_config_with_trt.remapping = False
272
273  if not disable_non_trt_optimizers:
274    # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
275    # need to run constant folding again.
276    rewriter_config_with_trt.optimizers.extend(
277        ["constfold", "layout", "constfold"])
278
279  rewriter_config_with_trt.meta_optimizer_iterations = (
280      rewriter_config_pb2.RewriterConfig.ONE)
281  optimizer = rewriter_config_with_trt.custom_optimizers.add()
282
283  if not disable_non_trt_optimizers:
284    # Add a constfold optimizer to cleanup the unused Const nodes.
285    rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
286
287  optimizer.name = "TensorRTOptimizer"
288  optimizer.parameter_map[
289      "minimum_segment_size"].i = conversion_params.minimum_segment_size
290  optimizer.parameter_map["max_workspace_size_bytes"].i = (
291      conversion_params.max_workspace_size_bytes)
292  optimizer.parameter_map["precision_mode"].s = _to_bytes(
293      conversion_params.precision_mode)
294  optimizer.parameter_map[
295      "maximum_cached_engines"].i = conversion_params.maximum_cached_engines
296  optimizer.parameter_map[
297      "use_calibration"].b = conversion_params.use_calibration
298  optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
299  optimizer.parameter_map[
300      "allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
301  if max_batch_size is not None:
302    optimizer.parameter_map["max_batch_size"].i = max_batch_size
303  optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
304
305  # Disabling optimizers should happen after defining the TF-TRT grappler pass
306  # otherwise the template can overwrite the disablement.
307  if disable_non_trt_optimizers:
308    trt_utils.disable_non_trt_optimizers_in_rewriter_config(
309        rewriter_config_with_trt)
310
311  return rewriter_config_with_trt
312
313
314@deprecation.deprecated(
315    None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.")
316def get_tensorrt_rewriter_config(conversion_params,
317                                 is_dynamic_op=None,
318                                 max_batch_size=None,
319                                 is_v2=False,
320                                 disable_non_trt_optimizers=False):
321  return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op,
322                                       max_batch_size, is_v2,
323                                       disable_non_trt_optimizers)
324
325
326# Remove all scope prefixes in the node name. In TF 2.0, the same concrete
327# function can be initialized multiple times with different prefixes, and
328# this will result in the same TRTEngineOp being initialized multiple times
329# with different cache and duplicate TRT engines.
330# TODO(laigd): this may be caused by the fact that TRTEngineOp is not
331# stateful, need to investigate.
332# TODO(laigd): we rely on the fact that all functions are fully inlined
333# before TF-TRT optimizer is called, as otherwise it may generate the same
334# name when optimizing a different function graph. Fix this.
335def _get_canonical_engine_name(name):
336  return name.split("/")[-1]
337
338
339class TrtGraphConverter(object):
340  """A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels.
341
342  To run the conversion without quantization calibration (e.g. for FP32/FP16
343  precision modes):
344
345  ```python
346  converter = TrtGraphConverter(
347      input_saved_model_dir="my_dir",
348      precision_mode=TrtPrecisionMode.FP16)
349  converted_graph_def = converter.convert()
350  converter.save(output_saved_model_dir)
351  ```
352
353  To run the conversion with quantization calibration:
354
355  ```python
356  converter = TrtGraphConverter(
357      input_saved_model_dir="my_dir",
358      precision_mode=TrtPrecisionMode.INT8)
359  converter.convert()
360
361  # Run calibration 10 times.
362  converted_graph_def = converter.calibrate(
363      fetch_names=['output:0'],
364      num_runs=10,
365      feed_dict_fn=lambda: {'input:0': my_next_data()})
366
367  converter.save(output_saved_model_dir)
368  ```
369  """
370
371  def __init__(self,
372               input_saved_model_dir=None,
373               input_saved_model_tags=None,
374               input_saved_model_signature_key=None,
375               input_graph_def=None,
376               nodes_denylist=None,
377               max_batch_size=1,
378               max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
379               precision_mode=TrtPrecisionMode.FP32,
380               minimum_segment_size=3,
381               is_dynamic_op=False,
382               maximum_cached_engines=1,
383               use_calibration=True):
384    """Initializes the converter.
385
386    Args:
387      input_saved_model_dir: the directory to load the SavedModel which contains
388        the input graph to transforms. Used only when input_graph_def is None.
389      input_saved_model_tags: list of tags to load the SavedModel.
390      input_saved_model_signature_key: the key of the signature to optimize the
391        graph for.
392      input_graph_def: a GraphDef object containing a model to be transformed.
393        If set to None, the graph will be read from the SavedModel loaded from
394        input_saved_model_dir.
395      nodes_denylist: list of node names to prevent the converter from touching.
396      max_batch_size: max size for the input batch.
397      max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
398        engine can use at execution time. This corresponds to the
399        'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
400      precision_mode: one of TrtPrecisionMode.supported_precision_modes().
401      minimum_segment_size: the minimum number of nodes required for a subgraph
402        to be replaced by TRTEngineOp.
403      is_dynamic_op: whether to generate dynamic TRT ops which will build the
404        TRT network and engine at run time.
405      maximum_cached_engines: max number of cached TRT engines in dynamic TRT
406        ops. If the number of cached engines is already at max but none of them
407        can serve the input, the TRTEngineOp will fall back to run the TF
408        function based on which the TRTEngineOp is created.
409      use_calibration: this argument is ignored if precision_mode is not INT8.
410        If set to True, a calibration graph will be created to calibrate the
411        missing ranges. The calibration graph must be converted to an inference
412        graph by running calibration with calibrate(). If set to False,
413        quantization nodes will be expected for every tensor in the graph
414        (excluding those which will be fused). If a range is missing, an error
415        will occur. Please note that accuracy may be negatively affected if
416        there is a mismatch between which tensors TRT quantizes and which
417        tensors were trained with fake quantization.
418
419    Raises:
420      ValueError: if the combination of the parameters is invalid.
421      RuntimeError: if this class is used in TF 2.0.
422    """
423    if context.executing_eagerly():
424      raise RuntimeError(
425          "Please use tf.experimental.tensorrt.Converter in TF 2.0.")
426
427    if input_graph_def and input_saved_model_dir:
428      raise ValueError(
429          "Can only specify one of input_graph_def and input_saved_model_dir")
430    if not input_graph_def and not input_saved_model_dir:
431      raise ValueError("Must specify one of input_graph_def and "
432                       "input_saved_model_dir")
433    _check_trt_version_compatibility()
434
435    self._input_graph_def = input_graph_def
436    self._nodes_denylist = nodes_denylist
437
438    self._input_saved_model_dir = input_saved_model_dir
439    self._converted = False
440    self._grappler_meta_graph_def = None
441
442    self._input_saved_model_tags = (
443        input_saved_model_tags or [tag_constants.SERVING])
444    self._input_saved_model_signature_key = (
445        input_saved_model_signature_key or
446        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
447
448    # For calibration usage.
449    self._calibration_graph = None
450    self._calibration_data_collected = False
451    self._need_calibration = (
452        precision_mode == TrtPrecisionMode.INT8 and use_calibration)
453    if self._need_calibration and not is_dynamic_op:
454      tf_logging.warn(
455          "INT8 precision mode with calibration is supported with "
456          "dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
457      is_dynamic_op = True
458
459    self._is_dynamic_op = is_dynamic_op
460    if is_dynamic_op:
461      self._max_batch_size = None
462      if max_batch_size is not None:
463        tf_logging.warn("When is_dynamic_op==True max_batch_size should be "
464                        "None")
465    else:
466      if not isinstance(max_batch_size, int):
467        raise ValueError("When is_dynamic_op==False max_batch_size should be "
468                         "an integer")
469      self._max_batch_size = max_batch_size
470
471    self._conversion_params = TrtConversionParams(
472        max_workspace_size_bytes=max_workspace_size_bytes,
473        precision_mode=precision_mode,
474        minimum_segment_size=minimum_segment_size,
475        maximum_cached_engines=maximum_cached_engines,
476        use_calibration=use_calibration,
477        allow_build_at_runtime=True)
478    _check_conversion_params(self._conversion_params)
479
480    self._test_only_disable_non_trt_optimizers = False
481
482  def _run_conversion(self):
483    """Run Grappler's OptimizeGraph() tool to convert the graph."""
484    # Create custom ConfigProto for Grappler.
485    grappler_session_config = config_pb2.ConfigProto()
486    custom_rewriter_config = _get_tensorrt_rewriter_config(
487        conversion_params=self._conversion_params,
488        is_dynamic_op=self._is_dynamic_op,
489        max_batch_size=self._max_batch_size,
490        disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
491        use_implicit_batch=True)
492    grappler_session_config.graph_options.rewrite_options.CopyFrom(
493        custom_rewriter_config)
494
495    # Run Grappler.
496    self._converted_graph_def = tf_optimizer.OptimizeGraph(
497        grappler_session_config,
498        self._grappler_meta_graph_def,
499        graph_id=b"tf_graph")
500    self._converted = True
501
502  def _add_nodes_denylist(self):
503    if self._nodes_denylist:
504      collection_def = self._grappler_meta_graph_def.collection_def["train_op"]
505      denylist = collection_def.node_list.value
506      for i in self._nodes_denylist:
507        if isinstance(i, ops.Tensor):
508          denylist.append(_to_bytes(i.name))
509        else:
510          denylist.append(_to_bytes(i))
511
512  def _convert_graph_def(self):
513    """Convert the input GraphDef."""
514    graph = ops.Graph()
515    with graph.as_default():
516      importer.import_graph_def(self._input_graph_def, name="")
517    self._grappler_meta_graph_def = saver.export_meta_graph(
518        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
519    self._add_nodes_denylist()
520
521    self._run_conversion()
522
523  def _collections_to_keep(self, collection_keys):
524    # TODO(laigd): currently we use the collection key to filter out
525    # collections that depend on variable ops, but this may miss some
526    # other user-defined collections. A better way would be to use
527    # CollectionDef::NodeList for the filtering.
528    collections_to_remove = (
529        ops.GraphKeys._VARIABLE_COLLECTIONS + [
530            ops.GraphKeys.TRAIN_OP, ops.GraphKeys.WHILE_CONTEXT,
531            ops.GraphKeys.COND_CONTEXT
532        ])
533    return [key for key in collection_keys if key not in collections_to_remove]
534
535  def _convert_saved_model(self):
536    """Convert the input SavedModel."""
537    graph = ops.Graph()
538    with session.Session(graph=graph) as sess:
539      input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
540                                         self._input_saved_model_dir)
541      input_signature_def = input_meta_graph_def.signature_def[
542          self._input_saved_model_signature_key]
543
544      def _gather_names(tensor_info):
545        """Get the node names from a TensorInfo."""
546        return {tensor_info[key].name.split(":")[0] for key in tensor_info}
547
548      # Get input and outputs from all SignatureDef.
549      output_node_names = _gather_names(input_signature_def.inputs).union(
550          _gather_names(input_signature_def.outputs))
551
552      # Preserve nodes in collection
553      for collection_key in self._collections_to_keep(
554          input_meta_graph_def.collection_def):
555        for op in sess.graph.get_collection(collection_key):
556          if isinstance(op, ops.Operation):
557            output_node_names.add(op.name.split(":")[0])
558
559      # Freeze the variables in the SavedModel graph and copy the frozen
560      # graph over.
561      frozen_graph_def = graph_util.convert_variables_to_constants(
562          sess, sess.graph.as_graph_def(add_shapes=True),
563          list(output_node_names))
564      self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
565      self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
566
567      # Copy the collections that are not variables.
568      for collection_key in self._collections_to_keep(
569          input_meta_graph_def.collection_def):
570        self._grappler_meta_graph_def.collection_def[collection_key].CopyFrom(
571            input_meta_graph_def.collection_def[collection_key])
572
573      self._add_nodes_denylist()
574
575      # Copy other information.
576      self._grappler_meta_graph_def.meta_info_def.CopyFrom(
577          input_meta_graph_def.meta_info_def)
578      self._grappler_meta_graph_def.signature_def[
579          self._input_saved_model_signature_key].CopyFrom(input_signature_def)
580      # TODO(laigd): maybe add back AssetFileDef.
581
582    self._run_conversion()
583
584  def convert(self):
585    """Run the TF-TRT conversion.
586
587    Returns:
588      The converted GraphDef for TF 1.x.
589    """
590    assert not self._converted
591    if self._input_graph_def:
592      self._convert_graph_def()
593    else:
594      self._convert_saved_model()
595    return self._converted_graph_def
596
597  def calibrate(self,
598                fetch_names,
599                num_runs,
600                feed_dict_fn=None,
601                input_map_fn=None):
602    """Run the calibration and return the calibrated GraphDef.
603
604    Args:
605      fetch_names: a list of output tensor name to fetch during calibration.
606      num_runs: number of runs of the graph during calibration.
607      feed_dict_fn: a function that returns a dictionary mapping input names (as
608        strings) in the GraphDef to be calibrated to values (e.g. Python list,
609        numpy arrays, etc). One and only one of `feed_dict_fn` and
610        `input_map_fn` should be specified.
611      input_map_fn: a function that returns a dictionary mapping input names (as
612        strings) in the GraphDef to be calibrated to Tensor objects. The values
613        of the named input tensors in the GraphDef to be calibrated will be
614        re-mapped to the respective `Tensor` values during calibration. One and
615        only one of `feed_dict_fn` and `input_map_fn` should be specified.
616
617    Raises:
618      ValueError: if the input combination is invalid.
619      RuntimeError: if this method is called in eager mode.
620
621    Returns:
622      The GraphDef after the calibration.
623    """
624    assert self._converted
625    assert self._need_calibration
626    assert not self._calibration_data_collected
627
628    if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
629                                           not input_map_fn):
630      raise ValueError(
631          "Should specify one and only one of feed_dict_fn and input_map_fn.")
632
633    if input_map_fn:
634      for k, v in input_map_fn().items():
635        if not isinstance(k, str):
636          raise ValueError("Keys of input_map_fn must be of type str")
637        if not isinstance(v, ops.Tensor):
638          raise ValueError("Values of input_map_fn must be of type tf.Tensor")
639
640    self._calibration_graph = ops.Graph()
641    with self._calibration_graph.as_default():
642      fetches = importer.import_graph_def(
643          self._converted_graph_def,
644          input_map=input_map_fn() if input_map_fn else None,
645          return_elements=fetch_names,
646          name="")
647
648    calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
649    if self._test_only_disable_non_trt_optimizers:
650      trt_utils.disable_non_trt_optimizers_in_rewriter_config(
651          calibrate_rewriter_cfg)
652
653    # Set allow_soft_placement=True to run the graph for calibration so that
654    # OPs supported by TensorRT but don't have a GPU implementation are allowed
655    # to execute on CPU.
656    calibrate_config = config_pb2.ConfigProto(
657        allow_soft_placement=True,
658        graph_options=config_pb2.GraphOptions(
659            rewrite_options=calibrate_rewriter_cfg))
660
661    with session.Session(
662        graph=self._calibration_graph,
663        config=calibrate_config) as calibration_sess:
664      for _ in range(num_runs):
665        calibration_sess.run(
666            fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
667
668      # Maps device name to the corresponding get_calibration_data.
669      #
670      # TODO(laigd): a better way would be to use calibration_sess to list
671      # all the devices, add one get_calibration_data for each device, and
672      # fetch each such op for every resource until its found. This can work
673      # even when the device of the TRTEngineOp is empty or not fully specified.
674      device_to_get_resource_op_map = {}
675
676      with self._calibration_graph.as_default():
677        resource_name_input = array_ops.placeholder(dtypes.string)
678
679        for node in self._converted_graph_def.node:
680          if node.op == _TRT_ENGINE_OP_NAME:
681            # Adds the get_calibration_data op for the device if not done
682            # before. We only add one such op for each device.
683            # TODO(laigd): What if the device is empty?????
684            if node.device not in device_to_get_resource_op_map:
685              with self._calibration_graph.device(node.device):
686                serialized_resources_output = (
687                    gen_trt_ops.get_calibration_data_op(resource_name_input))
688              device_to_get_resource_op_map[node.device] = (
689                  serialized_resources_output)
690
691            # Get the calibration resource.
692            calibration_result = calibration_sess.run(
693                device_to_get_resource_op_map[node.device],
694                feed_dict={
695                    resource_name_input: _get_canonical_engine_name(node.name)
696                })
697            node.attr["calibration_data"].s = calibration_result
698
699      self._calibration_data_collected = True
700
701    return self._converted_graph_def
702
703  def save(self, output_saved_model_dir):
704    """Save the converted graph as a SavedModel.
705
706    Args:
707      output_saved_model_dir: construct a SavedModel using the converted
708        GraphDef and save it to the specified directory. This option only works
709        when the input graph is loaded from a SavedModel, i.e. when
710        input_saved_model_dir is specified and input_graph_def is None in
711        __init__().
712
713    Raises:
714      ValueError: if the input to the converter is a GraphDef instead of a
715      SavedModel.
716    """
717    assert self._converted
718    if self._need_calibration:
719      assert self._calibration_data_collected
720    if self._input_graph_def:
721      raise ValueError(
722          "Not able to save to a SavedModel since input is a GraphDef")
723
724    def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
725      """Restores collections that we need to keep."""
726      scope = ""
727      for key in collection_keys:
728        collection_def = src_meta_graph_def.collection_def[key]
729        kind = collection_def.WhichOneof("kind")
730        if kind is None:
731          tf_logging.error(
732              "Cannot identify data type for collection %s. Skipping.", key)
733          continue
734        from_proto = ops.get_from_proto_function(key)
735        if from_proto and kind == "bytes_list":
736          proto_type = ops.get_collection_proto_type(key)
737          # It is assumed that there are no Variables Keys in collections
738          for value in collection_def.bytes_list.value:
739            proto = proto_type()
740            proto.ParseFromString(value)
741            try:
742              new_value = from_proto(proto, import_scope=scope)
743            except:
744              continue
745            dest_graph.add_to_collection(key, new_value)
746        else:
747          field = getattr(collection_def, kind)
748          if kind == "node_list":
749            for value in field.value:
750              name = ops.prepend_name_scope(value, scope)
751              # Since the graph has been optimized, the node may no longer
752              # exists
753              try:
754                col_op = dest_graph.as_graph_element(name)
755              except (TypeError, ValueError, KeyError):
756                continue
757              dest_graph.add_to_collection(key, col_op)
758          elif kind == "int64_list":
759            # NOTE(opensource): This force conversion is to work around the
760            # fact that Python2 distinguishes between int and long, while
761            # Python3 has only int.
762            for value in field.value:
763              dest_graph.add_to_collection(key, int(value))
764          else:
765            for value in field.value:
766              dest_graph.add_to_collection(key,
767                                           ops.prepend_name_scope(value, scope))
768
769    # Write the transformed graphdef as SavedModel.
770    saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
771    with ops.Graph().as_default():
772      importer.import_graph_def(self._converted_graph_def, name="")
773      _restore_collections(
774          ops.get_default_graph(), self._grappler_meta_graph_def,
775          self._collections_to_keep(
776              self._grappler_meta_graph_def.collection_def))
777      # We don't use any specific converter here.
778      with session.Session() as sess:
779        saved_model_builder.add_meta_graph_and_variables(
780            sess,
781            self._input_saved_model_tags,
782            signature_def_map=self._grappler_meta_graph_def.signature_def)
783    # Ignore other meta graphs from the input SavedModel.
784    saved_model_builder.save()
785
786
787def _get_resource_handle(name, device):
788  with ops.device(device):
789    return gen_trt_ops.create_trt_resource_handle(resource_name=name)
790
791
792class _TRTEngineResourceDeleter(tracking.CapturableResourceDeleter):
793  """Resource deleter for destroying TRT engine cache resource."""
794
795  def __init__(self, resource_name, device):
796    super(_TRTEngineResourceDeleter, self).__init__()
797    self._resource_name = resource_name
798    self._device = device
799
800  def destroy_resource(self):
801    handle = _get_resource_handle(self._resource_name, self._device)
802    with ops.device(self._device):
803      gen_resource_variable_ops.destroy_resource_op(
804          handle, ignore_lookup_error=True)
805
806
807class _TRTEngineResource(tracking.TrackableResource):
808  """Class to track the serialized engines resource."""
809
810  def __init__(self,
811               resource_name,
812               filename,
813               maximum_cached_engines,
814               device="GPU"):
815    super(_TRTEngineResource, self).__init__(
816        device=device, deleter=_TRTEngineResourceDeleter(resource_name, device))
817    self._resource_name = resource_name
818    # Track the serialized engine file in the SavedModel.
819    self._filename = self._track_trackable(
820        tracking.Asset(filename), "_serialized_trt_resource_filename")
821    self._maximum_cached_engines = maximum_cached_engines
822
823  def _create_resource(self):
824    return _get_resource_handle(self._resource_name, self._resource_device)
825
826  def _initialize(self):
827    gen_trt_ops.initialize_trt_resource(
828        self.resource_handle,
829        self._filename,
830        max_cached_engines_count=self._maximum_cached_engines)
831
832
833@tf_export("experimental.tensorrt.Converter", v1=[])
834class TrtGraphConverterV2(object):
835  """An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
836
837  Currently this is not available on Windows platform.
838
839  There are several ways to run the conversion:
840
841  1. FP32/FP16 precision
842
843     ```python
844     params = tf.experimental.tensorrt.ConversionParams(
845         precision_mode='FP16')
846     converter = tf.experimental.tensorrt.Converter(
847         input_saved_model_dir="my_dir", conversion_params=params)
848     converter.convert()
849     converter.save(output_saved_model_dir)
850     ```
851
852     In this case, no TRT engines will be built or saved in the converted
853     SavedModel. But if input data is available during conversion, we can still
854     build and save the TRT engines to reduce the cost during inference (see
855     option 2 below).
856
857  2. FP32/FP16 precision with pre-built engines
858
859     ```python
860     params = tf.experimental.tensorrt.ConversionParams(
861         precision_mode='FP16',
862         # Set this to a large enough number so it can cache all the engines.
863         maximum_cached_engines=16)
864     converter = tf.experimental.tensorrt.Converter(
865         input_saved_model_dir="my_dir", conversion_params=params)
866     converter.convert()
867
868     # Define a generator function that yields input data, and use it to execute
869     # the graph to build TRT engines.
870     # With TensorRT 5.1, different engines will be built (and saved later) for
871     # different input shapes to the TRTEngineOp.
872     def my_input_fn():
873       for _ in range(num_runs):
874         inp1, inp2 = ...
875         yield inp1, inp2
876
877     converter.build(input_fn=my_input_fn)  # Generate corresponding TRT engines
878     converter.save(output_saved_model_dir)  # Generated engines will be saved.
879     ```
880
881     In this way, one engine will be built/saved for each unique input shapes of
882     the TRTEngineOp. This is good for applications that cannot afford building
883     engines during inference but have access to input data that is similar to
884     the one used in production (for example, that has the same input shapes).
885     Also, the generated TRT engines is platform dependent, so we need to run
886     `build()` in an environment that is similar to production (e.g. with
887     same type of GPU).
888
889  3. INT8 precision and calibration with pre-built engines
890
891     ```python
892     params = tf.experimental.tensorrt.ConversionParams(
893         precision_mode='INT8',
894         # Currently only one INT8 engine is supported in this mode.
895         maximum_cached_engines=1,
896         use_calibration=True)
897     converter = tf.experimental.tensorrt.Converter(
898         input_saved_model_dir="my_dir", conversion_params=params)
899
900     # Define a generator function that yields input data, and run INT8
901     # calibration with the data. All input data should have the same shape.
902     # At the end of convert(), the calibration stats (e.g. range information)
903     # will be saved and can be used to generate more TRT engines with different
904     # shapes. Also, one TRT engine will be generated (with the same shape as
905     # the calibration data) for save later.
906     def my_calibration_input_fn():
907       for _ in range(num_runs):
908         inp1, inp2 = ...
909         yield inp1, inp2
910
911     converter.convert(calibration_input_fn=my_calibration_input_fn)
912
913     # (Optional) Generate more TRT engines offline (same as the previous
914     # option), to avoid the cost of generating them during inference.
915     def my_input_fn():
916       for _ in range(num_runs):
917         inp1, inp2 = ...
918         yield inp1, inp2
919     converter.build(input_fn=my_input_fn)
920
921     # Save the TRT engine and the engines.
922     converter.save(output_saved_model_dir)
923     ```
924  """
925
926  def __init__(self,
927               input_saved_model_dir=None,
928               input_saved_model_tags=None,
929               input_saved_model_signature_key=None,
930               conversion_params=None):
931    """Initialize the converter.
932
933    Args:
934      input_saved_model_dir: the directory to load the SavedModel which contains
935        the input graph to transforms. Used only when input_graph_def is None.
936      input_saved_model_tags: list of tags to load the SavedModel.
937      input_saved_model_signature_key: the key of the signature to optimize the
938        graph for.
939      conversion_params: a TrtConversionParams instance.
940
941    Raises:
942      ValueError: if the combination of the parameters is invalid.
943    """
944    assert context.executing_eagerly()
945    if conversion_params is None:
946      conversion_params = TrtConversionParams()
947
948    _check_trt_version_compatibility()
949    _check_conversion_params(conversion_params, is_v2=True)
950
951    self._conversion_params = conversion_params
952    self._input_saved_model_dir = input_saved_model_dir
953    self._input_saved_model_tags = (
954        input_saved_model_tags or [tag_constants.SERVING])
955    self._input_saved_model_signature_key = (
956        input_saved_model_signature_key or
957        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
958
959    self._need_calibration = (
960        conversion_params.precision_mode == TrtPrecisionMode.INT8 and
961        conversion_params.use_calibration)
962
963    self._converted = False
964    self._build_called_once = False
965
966    # Fields to support TF-TRT testing and shouldn't be used for other purpose.
967    self._test_only_disable_non_trt_optimizers = False
968    self._test_only_use_implicit_batch = True
969
970  def _need_trt_profiles(self):
971    return not self._test_only_use_implicit_batch
972
973  def _run_conversion(self, meta_graph_def):
974    """Run Grappler's OptimizeGraph() tool to convert the graph.
975
976    Args:
977      meta_graph_def: the MetaGraphDef instance to run the optimizations on.
978
979    Returns:
980      The optimized GraphDef.
981    """
982    grappler_session_config = config_pb2.ConfigProto()
983    custom_rewriter_config = _get_tensorrt_rewriter_config(
984        conversion_params=self._conversion_params,
985        is_dynamic_op=True,
986        max_batch_size=None,
987        disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
988        use_implicit_batch=self._test_only_use_implicit_batch)
989    grappler_session_config.graph_options.rewrite_options.CopyFrom(
990        custom_rewriter_config)
991    return tf_optimizer.OptimizeGraph(
992        grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
993
994  def _for_each_trt_node(self, graph_def, fn):
995    """Helper method to manipulate all TRTEngineOps in a GraphDef."""
996    for node in graph_def.node:
997      if node.op == _TRT_ENGINE_OP_NAME:
998        fn(node)
999    for func in graph_def.library.function:
1000      for node in func.node_def:
1001        if node.op == _TRT_ENGINE_OP_NAME:
1002          fn(node)
1003
1004  def _rebuild_func(self, func):
1005    """Rebuild function from graph_def."""
1006    rebuilt_func = wrap_function.function_from_graph_def(
1007        self._converted_graph_def, [tensor.name for tensor in func.inputs],
1008        [tensor.name for tensor in func.outputs])
1009    rebuilt_func.graph.structured_outputs = nest.pack_sequence_as(
1010        func.graph.structured_outputs, rebuilt_func.graph.structured_outputs)
1011    # Copy structured input signature from original function (used during
1012    # serialization)
1013    rebuilt_func.graph.structured_input_signature = (
1014        func.structured_input_signature)
1015    return rebuilt_func
1016
1017  # TODO(laigd): provide a utility function to optimize a ConcreteFunction and
1018  # use it here (b/124792963).
1019  def convert(self, calibration_input_fn=None):
1020    """Convert the input SavedModel in 2.0 format.
1021
1022    Args:
1023      calibration_input_fn: a generator function that yields input data as a
1024        list or tuple, which will be used to execute the converted signature for
1025        calibration. All the returned input data should have the same shape.
1026        Example: `def input_fn(): yield input1, input2, input3`
1027
1028    Raises:
1029      ValueError: if the input combination is invalid.
1030
1031    Returns:
1032      The TF-TRT converted Function.
1033    """
1034    assert not self._converted
1035
1036    if (self._need_calibration and not calibration_input_fn):
1037      raise ValueError("Should specify calibration_input_fn because INT8 "
1038                       "calibration is needed")
1039    if (not self._need_calibration and calibration_input_fn):
1040      raise ValueError("Should not specify calibration_input_fn because INT8 "
1041                       "calibration is not needed")
1042
1043    self._saved_model = load.load(self._input_saved_model_dir,
1044                                  self._input_saved_model_tags)
1045    func = self._saved_model.signatures[self._input_saved_model_signature_key]
1046    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
1047    grappler_meta_graph_def = saver.export_meta_graph(
1048        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
1049
1050    # Add a collection 'train_op' so that Grappler knows the outputs.
1051    fetch_collection = meta_graph_pb2.CollectionDef()
1052    for array in frozen_func.inputs + frozen_func.outputs:
1053      fetch_collection.node_list.value.append(array.name)
1054    grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
1055        fetch_collection)
1056
1057    # Run TRT optimizer in Grappler to convert the graph.
1058    self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
1059    self._converted_func = wrap_function.function_from_graph_def(
1060        self._converted_graph_def,
1061        [tensor.name for tensor in frozen_func.inputs],
1062        [tensor.name for tensor in frozen_func.outputs])
1063    # Reconstruct the output signatures using the ones from original model.
1064    self._converted_func.graph.structured_outputs = nest.pack_sequence_as(
1065        func.graph.structured_outputs,
1066        self._converted_func.graph.structured_outputs)
1067    # Copy structured input signature from original function (used during
1068    # serialization)
1069    self._converted_func.graph.structured_input_signature = (
1070        func.structured_input_signature)
1071
1072    if self._need_calibration:
1073      for inp in calibration_input_fn():
1074        self._converted_func(*map(ops.convert_to_tensor, inp))
1075
1076      def _save_calibration_table(node):
1077        calibration_table = gen_trt_ops.get_calibration_data_op(
1078            _get_canonical_engine_name(node.name))
1079        node.attr["calibration_data"].s = calibration_table.numpy()
1080
1081      self._for_each_trt_node(self._converted_graph_def,
1082                              _save_calibration_table)
1083
1084      # Rebuild the function since calibration has changed the graph.
1085      self._converted_func = self._rebuild_func(self._converted_func)
1086
1087    self._converted = True
1088    return self._converted_func
1089
1090  def build(self, input_fn):
1091    """Run inference with converted graph in order to build TensorRT engines.
1092
1093    Args:
1094      input_fn: a generator function that yields input data as a list or tuple,
1095        which will be used to execute the converted signature to generate TRT
1096        engines. Example:
1097        `def input_fn():
1098             # Let's assume a network with 2 input tensors. We generate 3 sets
1099             # of dummy input data:
1100             input_shapes = [[(1, 16), (2, 16)], # 1st input list
1101                             [(2, 32), (4, 32)], # 2nd list of two tensors
1102                             [(4, 32), (8, 32)]] # 3rd input list
1103             for shapes in input_shapes:
1104                 # return a list of input tensors
1105                 yield [np.zeros(x).astype(np.float32) for x in shapes]`
1106    Raises:
1107      NotImplementedError: build() is already called.
1108      RuntimeError: the input_fx is None.
1109    """
1110    if self._build_called_once:
1111      raise NotImplementedError("build() is already called. It is not "
1112                                "supported to call build() more than once.")
1113    if not input_fn:
1114      raise RuntimeError("input_fn is None. Method build() needs input_fn "
1115                         "to be specified in order to build TensorRT engines")
1116
1117    def _set_profile_generation_mode(value, node):
1118      node.attr["_profile_generation_mode"].b = value
1119
1120    if self._need_trt_profiles():
1121      # Enable profile generation.
1122      self._for_each_trt_node(self._converted_graph_def,
1123                              partial(_set_profile_generation_mode, True))
1124      # Profile generation is enabled using the _profile_generation_mode
1125      # attribute of the TRTEngineOps. We need to rebuild the function to
1126      # change this attribute.
1127      func = self._rebuild_func(self._converted_func)
1128    else:
1129      func = self._converted_func
1130
1131    first_input = None
1132    # Run inference:
1133    #   Builds TRT engines if self._need_trt_profiles is False.
1134    #   Builds TRT optimization profiles if self._need_trt_profiles is True.
1135    for inp in input_fn():
1136      if not first_input:
1137        first_input = inp
1138      func(*map(ops.convert_to_tensor, inp))
1139
1140    if self._need_trt_profiles():
1141      # Disable profile generation.
1142      self._for_each_trt_node(self._converted_graph_def,
1143                              partial(_set_profile_generation_mode, False))
1144      # Use the first input in explicit batch mode to build TensorRT engines
1145      # after generating all the profiles. The first input is used but any of
1146      # the inputs can be used because the shape of this input does not
1147      # determine the engine and instead the shapes collected in profiles
1148      # determine the engine.
1149      self._converted_func(*map(ops.convert_to_tensor, first_input))
1150
1151    self._build_called_once = True
1152
1153  def save(self, output_saved_model_dir):
1154    """Save the converted SavedModel.
1155
1156    Args:
1157      output_saved_model_dir: directory to saved the converted SavedModel.
1158    """
1159    assert self._converted
1160
1161    if self._need_trt_profiles() and not self._build_called_once:
1162      raise NotImplementedError(
1163          "build() is not called . Explicit batch mode "
1164          "(use_implicit_batch=False) requires generating TensorRT optimization"
1165          " profiles which is done by calling build().")
1166
1167    # Serialize the TRT engines in the cache if any, and create trackable
1168    # resource to track them.
1169    engine_asset_dir = tempfile.mkdtemp()
1170    resource_map = {}
1171
1172    def _serialize_and_track_engine(node):
1173      """Serialize TRT engines in the cache and track them."""
1174      # Don't dump the same cache twice.
1175      canonical_engine_name = _get_canonical_engine_name(node.name)
1176      if canonical_engine_name in resource_map:
1177        return
1178
1179      filename = os.path.join(engine_asset_dir,
1180                              "trt-serialized-engine." + canonical_engine_name)
1181
1182      try:
1183        gen_trt_ops.serialize_trt_resource(
1184            resource_name=canonical_engine_name,
1185            filename=filename,
1186            delete_resource=True)
1187      except errors.NotFoundError:
1188        tf_logging.info("Could not find %s in TF-TRT cache. "
1189                        "This can happen if build() is not called, "
1190                        "which means TensorRT engines will be built "
1191                        "and cached at runtime." % canonical_engine_name)
1192        return
1193
1194      # TODO(laigd): add an option for the user to choose the device.
1195      resource_map[canonical_engine_name] = _TRTEngineResource(
1196          canonical_engine_name, filename,
1197          self._conversion_params.maximum_cached_engines)
1198
1199    self._for_each_trt_node(self._converted_graph_def,
1200                            _serialize_and_track_engine)
1201    self._saved_model.trt_engine_resources = resource_map
1202
1203    # Rewrite the signature map using the optimized ConcreteFunction.
1204    signatures = {
1205        key: value for key, value in self._saved_model.signatures.items()
1206    }
1207
1208    # Set allow_build_at_runtime=False if asked by user.
1209    #
1210    # This attribute is set here because build() needs it to be True in order to
1211    # build engines.
1212    if not self._conversion_params.allow_build_at_runtime:
1213
1214      def _reset_allow_build_at_runtime(node):
1215        node.attr["allow_build_at_runtime"].b = False
1216
1217      self._for_each_trt_node(self._converted_graph_def,
1218                              _reset_allow_build_at_runtime)
1219      # Rebuild the function since a node attribute changed above
1220      reset_converted_func = wrap_function.function_from_graph_def(
1221          self._converted_graph_def,
1222          [tensor.name for tensor in self._converted_func.inputs],
1223          [tensor.name for tensor in self._converted_func.outputs])
1224      reset_converted_func.graph.structured_outputs = nest.pack_sequence_as(
1225          self._converted_func.graph.structured_outputs,
1226          reset_converted_func.graph.structured_outputs)
1227      reset_converted_func.graph.strucutred_input_signature = (
1228          self._converted_func.structured_input_signature)
1229      self._converted_func = reset_converted_func
1230
1231    signatures[self._input_saved_model_signature_key] = self._converted_func
1232    save.save(self._saved_model, output_saved_model_dir, signatures)
1233
1234
1235# TODO(laigd): use TrtConversionParams here.
1236def create_inference_graph(
1237    input_graph_def,
1238    outputs,
1239    max_batch_size=1,
1240    max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
1241    precision_mode=TrtPrecisionMode.FP32,
1242    minimum_segment_size=3,
1243    is_dynamic_op=False,
1244    maximum_cached_engines=1,
1245    input_saved_model_dir=None,
1246    input_saved_model_tags=None,
1247    input_saved_model_signature_key=None,
1248    output_saved_model_dir=None):
1249  """Python wrapper for the TRT transformation.
1250
1251  Args:
1252    input_graph_def: a GraphDef object containing a model to be transformed. If
1253      set to None, the graph will be read from the SavedModel loaded from
1254      input_saved_model_dir.
1255    outputs: list of tensors or node names for the model outputs. Only used when
1256      input_graph_def is not None.
1257    max_batch_size: max size for the input batch.
1258    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
1259      engine can use at execution time. This corresponds to the 'workspaceSize'
1260      parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
1261    precision_mode: one of TrtPrecisionMode.supported_precision_modes().
1262    minimum_segment_size: the minimum number of nodes required for a subgraph to
1263      be replaced by TRTEngineOp.
1264    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
1265      network and engine at run time.
1266    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
1267      If the number of cached engines is already at max but none of them can
1268      serve the input, the TRTEngineOp will fall back to run the TF function
1269      based on which the TRTEngineOp is created.
1270    input_saved_model_dir: the directory to load the SavedModel which contains
1271      the input graph to transforms. Used only when input_graph_def is None.
1272    input_saved_model_tags: list of tags to load the SavedModel.
1273    input_saved_model_signature_key: the key of the signature to optimize the
1274      graph for.
1275    output_saved_model_dir: if not None, construct a SavedModel using the
1276      returned GraphDef and save it to the specified directory. This option only
1277      works when the input graph is loaded from a SavedModel, i.e. when
1278      input_saved_model_dir is specified and input_graph_def is None.
1279
1280  Returns:
1281    A GraphDef transformed from input_graph_def (or the SavedModel graph def
1282    loaded from input_saved_model_dir, if input_graph_def is not present), where
1283    all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
1284    function is added for each of the subgraphs.
1285
1286    If is_dynamic_op is True, each TRTEngineOp will contain a serialized
1287    subgraph GraphDef, which will be converted to a TRT engine at execution time
1288    and the TRT engine will be cached for future usage. A new TRT engine will be
1289    created each time when none of the cached engines match the input shapes. If
1290    it fails to execute the TRT engine or the number of cached engines reaches
1291    maximum_cached_engines, the op will fall back to call the corresponding TF
1292    function.
1293
1294    If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
1295    engine created from the corresponding subgraph. No more engines will be
1296    created on the fly, and the op will fall back to call the corresponding TF
1297    function when it fails to execute the engine.
1298
1299  Raises:
1300    ValueError: if the combination of the parameters is invalid.
1301  """
1302  trt_converter = TrtGraphConverter(
1303      input_saved_model_dir=input_saved_model_dir,
1304      input_saved_model_tags=input_saved_model_tags,
1305      input_saved_model_signature_key=input_saved_model_signature_key,
1306      input_graph_def=input_graph_def,
1307      nodes_denylist=outputs,
1308      max_batch_size=max_batch_size,
1309      max_workspace_size_bytes=max_workspace_size_bytes,
1310      precision_mode=precision_mode,
1311      minimum_segment_size=minimum_segment_size,
1312      is_dynamic_op=is_dynamic_op,
1313      maximum_cached_engines=maximum_cached_engines,
1314      use_calibration=False)
1315  converted_graph_def = trt_converter.convert()
1316  if output_saved_model_dir:
1317    trt_converter.save(output_saved_model_dir)
1318  return converted_graph_def
1319