• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper conversion to trt_graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six as _six
22from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.core.protobuf import meta_graph_pb2
25from tensorflow.core.protobuf import rewriter_config_pb2
26from tensorflow.python.client import session
27from tensorflow.python.eager import context
28from tensorflow.python.eager import function
29from tensorflow.python.framework import convert_to_constants
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import func_graph
32from tensorflow.python.framework import graph_util
33from tensorflow.python.framework import importer
34from tensorflow.python.framework import ops
35from tensorflow.python.grappler import tf_optimizer
36from tensorflow.python.ops import array_ops
37from tensorflow.python.platform import tf_logging
38from tensorflow.python.saved_model import builder
39from tensorflow.python.saved_model import load
40from tensorflow.python.saved_model import loader
41from tensorflow.python.saved_model import save
42from tensorflow.python.saved_model import signature_constants
43from tensorflow.python.saved_model import tag_constants
44from tensorflow.python.training import saver
45
46
47def _to_bytes(s):
48  """Encode s if it is a sequence of chars."""
49  if isinstance(s, _six.text_type):
50    return s.encode("utf-8", errors="surrogateescape")
51  return s
52
53
54def _to_string(s):
55  """Decode s if it is a sequence of bytes."""
56  if isinstance(s, _six.binary_type):
57    return s.decode("utf-8")
58  return s
59
60
61class GraphConverter(object):
62  """Base class for offline converters to optimize SavedModels/GraphDefs.
63
64  A `GraphConverter` object encapsulates the environment to convert (optimize) a
65  TensorFlow SavedModel or GraphDef.
66
67  To create a custom GraphConverter:
68
69  ```python
70  class MyGraphConverter(GraphConverter):
71    ...
72
73    def get_rewriter_config(self, rewriter_config_template=None):
74      my_rewriter_config = ...
75      return my_rewriter_config
76  ```
77
78  Then to run the conversion without quantization calibration:
79
80  ```python
81  my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
82  converted_graph_def = my_converter.convert()
83  my_converter.save(output_saved_model_dir)  # Optional
84  ```
85
86  To run the conversion with quantization calibration:
87
88  ```python
89  my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
90  my_converter.convert()
91
92  # Run calibration 10 times.
93  converted_graph_def = my_converter.calibrate(
94      fetch_names=['output:0'],
95      num_runs=10,
96      feed_dict_fn=lambda: {'input:0': my_next_data()})
97
98  my_converter.save(output_saved_model_dir)  # Optional
99  ```
100  """
101
102  # TODO(laigd): clean up the parameters.
103  def __init__(self,
104               input_saved_model_dir=None,
105               input_saved_model_tags=None,
106               input_saved_model_signature_key=None,
107               input_graph_def=None,
108               nodes_blacklist=None,
109               session_config=None):
110    """Initialize the converter.
111
112    Args:
113      input_saved_model_dir: the directory to load the SavedModel which contains
114        the input graph to transforms. Used only when input_graph_def is None.
115      input_saved_model_tags: list of tags to load the SavedModel.
116      input_saved_model_signature_key: the key of the signature to optimize the
117        graph for.
118      input_graph_def: a GraphDef object containing a model to be transformed.
119        If set to None, the graph will be read from the SavedModel loaded from
120        input_saved_model_dir.
121      nodes_blacklist: list of node names to prevent the converter from
122        touching. Only used when input_graph_def is not None.
123      session_config: the ConfigProto used to create a Session. It's also used
124        as a template to create a RewriterConfig for conversion. If not
125        specified, a default ConfigProto will be used.
126
127    Raises:
128      ValueError: if the combination of the parameters is invalid.
129    """
130    if context.executing_eagerly():
131      if input_graph_def or not input_saved_model_dir:
132        raise ValueError(
133            "TF 2.0 only supports conversion of SavedModel, please specify "
134            "input_saved_model_dir as input.")
135    else:
136      if input_graph_def and input_saved_model_dir:
137        raise ValueError(
138            "Can only specify one of input_graph_def and input_saved_model_dir")
139      if not input_graph_def and not input_saved_model_dir:
140        raise ValueError("Must specify one of input_graph_def and "
141                         "input_saved_model_dir")
142
143      self._input_graph_def = input_graph_def
144      self._nodes_blacklist = nodes_blacklist
145
146    self._input_saved_model_dir = input_saved_model_dir
147    self._converted = False
148    self._grappler_meta_graph_def = None
149
150    self._input_saved_model_tags = (
151        input_saved_model_tags or [tag_constants.SERVING])
152    self._input_saved_model_signature_key = (
153        input_saved_model_signature_key or
154        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
155    self._session_config = session_config or config_pb2.ConfigProto()
156
157    # For calibration usage.
158    self._calibration_graph = None
159    self._calibration_sess = None
160    self._calibration_data_collected = False
161
162  def get_rewriter_config(self, rewriter_config_template=None):
163    """Returns a RewriterConfig proto for TRT transformation.
164
165    Args:
166      rewriter_config_template: a template RewriterConfig proto used to create a
167        RewriterConfig for the conversion. The implementation should not modify
168        the template. If None, it will use a default one.
169
170    Returns:
171      A RewriterConfig proto which will be used to run the conversion using
172      Grappler.
173    """
174    raise NotImplementedError("get_rewriter_config")
175
176  def _run_conversion(self):
177    """Run Grappler's OptimizeGraph() tool to convert the graph."""
178    # Create custom ConfigProto for Grappler.
179    grappler_session_config = config_pb2.ConfigProto()
180    grappler_session_config.CopyFrom(self._session_config)
181    rewriter_config = None
182    if (grappler_session_config.HasField("graph_options") and
183        grappler_session_config.graph_options.HasField("rewrite_options")):
184      rewriter_config = grappler_session_config.graph_options.rewrite_options
185    custom_rewriter_config = self.get_rewriter_config(rewriter_config)
186    grappler_session_config.graph_options.rewrite_options.CopyFrom(
187        custom_rewriter_config)
188
189    # Run Grappler.
190    self._converted_graph_def = tf_optimizer.OptimizeGraph(
191        grappler_session_config,
192        self._grappler_meta_graph_def,
193        graph_id=b"tf_graph")
194    self._converted = True
195
196  def _add_nodes_blacklist(self):
197    if self._nodes_blacklist:
198      collection_def = self._grappler_meta_graph_def.collection_def["train_op"]
199      blacklist = collection_def.node_list.value
200      for i in self._nodes_blacklist:
201        if isinstance(i, ops.Tensor):
202          blacklist.append(_to_bytes(i.name))
203        else:
204          blacklist.append(_to_bytes(i))
205
206  def _convert_graph_def(self):
207    """Convert the input GraphDef."""
208    graph = ops.Graph()
209    with graph.as_default():
210      importer.import_graph_def(self._input_graph_def, name="")
211    self._grappler_meta_graph_def = saver.export_meta_graph(
212        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
213    self._add_nodes_blacklist()
214
215    self._run_conversion()
216
217  def _convert_saved_model(self):
218    """Convert the input SavedModel."""
219    graph = ops.Graph()
220    with session.Session(graph=graph, config=self._session_config) as sess:
221      input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
222                                         self._input_saved_model_dir)
223      input_signature_def = input_meta_graph_def.signature_def[
224          self._input_saved_model_signature_key]
225
226      def _gather_names(tensor_info):
227        """Get the node names from a TensorInfo."""
228        return set([tensor_info[key].name.split(":")[0] for key in tensor_info])
229
230      # Get input and outputs from all SignatureDef.
231      output_node_names = _gather_names(input_signature_def.inputs).union(
232          _gather_names(input_signature_def.outputs))
233
234      # Freeze the variables in the SavedModel graph and copy the frozen
235      # graph over.
236      frozen_graph_def = graph_util.convert_variables_to_constants(
237          sess, sess.graph.as_graph_def(add_shapes=True),
238          list(output_node_names))
239      self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
240      self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
241
242      # Copy the collections that are not variables.
243      for key in input_meta_graph_def.collection_def:
244        # TODO(laigd): currently we use the collection key to filter out
245        # collections that depend on variable ops, but this may miss some
246        # other user-defined collections. A better way would be to use
247        # CollectionDef::NodeList for the filtering.
248        if key not in [
249            "variables", "local_variables", "model_variables",
250            "trainable_variables", "train_op", "table_initializer"
251        ]:
252          self._grappler_meta_graph_def.collection_def[key].CopyFrom(
253              input_meta_graph_def.collection_def[key])
254
255      self._add_nodes_blacklist()
256
257      # Copy other information.
258      self._grappler_meta_graph_def.meta_info_def.CopyFrom(
259          input_meta_graph_def.meta_info_def)
260      self._grappler_meta_graph_def.signature_def[
261          self._input_saved_model_signature_key].CopyFrom(input_signature_def)
262      # TODO(laigd): maybe add back AssetFileDef.
263
264    self._run_conversion()
265
266  # TODO(laigd): provide a utility function to optimize a ConcreteFunction and
267  # use it here (b/124792963).
268  def _convert_saved_model_v2(self):
269    """Convert the input SavedModel in 2.0 format."""
270    self._saved_model = load.load(self._input_saved_model_dir,
271                                  self._input_saved_model_tags)
272    func = self._saved_model.signatures[self._input_saved_model_signature_key]
273    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
274    self._grappler_meta_graph_def = saver.export_meta_graph(
275        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
276
277    # Add a collection 'train_op' so that Grappler knows the outputs.
278    fetch_collection = meta_graph_pb2.CollectionDef()
279    for array in func.inputs + func.outputs:
280      fetch_collection.node_list.value.append(array.name)
281    self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
282        fetch_collection)
283
284    # Run TRT optimizer in Grappler to convert the graph.
285    self._run_conversion()
286
287    def _get_tensor(graph, tensors):
288      new_tensors = []
289      for tensor in tensors:
290        new_tensor = graph.get_tensor_by_name(tensor.name)
291        new_tensor.set_shape(tensor.shape)
292        new_tensors.append(new_tensor)
293      return new_tensors
294
295    # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
296    converted_graph = func_graph.FuncGraph(func.graph.name)
297    with converted_graph.as_default():
298      importer.import_graph_def(self._converted_graph_def, name="")
299
300    converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs)
301    converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs)
302    converted_graph.structured_outputs = func.graph.structured_outputs
303    converted_graph.structured_input_signature = (
304        func.graph.structured_input_signature)
305
306    # pylint: disable=protected-access
307    # TODO(laigd): should we set up the signature as well?
308    self._converted_func = function.ConcreteFunction(
309        converted_graph, attrs=None, signature=None)
310    self._converted_func.add_to_graph()
311    self._converted_func._arg_keywords = func._arg_keywords
312    self._converted_func._num_positional_args = func._num_positional_args
313    self._converted_func._captured_inputs = func._captured_inputs
314    self._converted_func.graph.variables = func.graph.variables
315    # pylint: enable=protected-access
316
317  def convert(self):
318    """Run the conversion.
319
320    Returns:
321      The converted GraphDef for TF 1.x, or the converted ConcreteFunction in TF
322      2.0+.
323    """
324    assert not self._converted
325
326    if context.executing_eagerly():
327      self._convert_saved_model_v2()
328      return self._converted_func
329    else:
330      if self._input_graph_def:
331        self._convert_graph_def()
332      else:
333        self._convert_saved_model()
334      return self._converted_graph_def
335
336  def calibrate(self,
337                fetch_names,
338                num_runs,
339                feed_dict_fn=None,
340                input_map_fn=None):
341    """Run the calibration and return the calibrated GraphDef.
342
343    Args:
344      fetch_names: a list of output tensor name to fetch during calibration.
345      num_runs: number of runs of the graph during calibration.
346      feed_dict_fn: a function that returns a dictionary mapping input names (as
347        strings) in the GraphDef to be calibrated to values (e.g. Python list,
348        numpy arrays, etc). One and only one of `feed_dict_fn` and
349        `input_map_fn` should be specified.
350      input_map_fn: a function that returns a dictionary mapping input names (as
351        strings) in the GraphDef to be calibrated to Tensor objects. The values
352        of the named input tensors in the GraphDef to be calibrated will be
353        re-mapped to the respective `Tensor` values during calibration. One and
354        only one of `feed_dict_fn` and `input_map_fn` should be specified.
355
356    Raises:
357      ValueError: if the input combination is invalid.
358      RuntimeError: if this method is called in eager mode.
359
360    Returns:
361      The GraphDef after the calibration.
362    """
363    assert self._converted
364    assert not self._calibration_sess
365
366    if context.executing_eagerly():
367      raise RuntimeError("Calibration for TF 2.0 is not supported yet.")
368
369    if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
370                                           not input_map_fn):
371      raise ValueError(
372          "Should specify one and only one of feed_dict_fn and input_map_fn.")
373
374    self._calibration_graph = ops.Graph()
375    with self._calibration_graph.as_default():
376      fetches = importer.import_graph_def(
377          self._converted_graph_def,
378          input_map=input_map_fn() if input_map_fn else None,
379          return_elements=fetch_names,
380          name="")
381    self._calibration_sess = session.Session(
382        graph=self._calibration_graph, config=self._session_config)
383
384    for _ in range(num_runs):
385      self._calibration_sess.run(
386          fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
387
388    self.finalize_calibration()
389    return self._converted_graph_def
390
391  def finalize_calibration(self):
392    """Clean up calibration resources and finalize the calibration.
393
394    Implementations need to close self._calibration_sess before returning.
395    """
396    raise NotImplementedError("finalize_calibration")
397
398  def save(self, output_saved_model_dir):
399    """Save the converted graph as a SavedModel.
400
401    Args:
402      output_saved_model_dir: construct a SavedModel using the converted
403        GraphDef and save it to the specified directory. This option only works
404        when the input graph is loaded from a SavedModel, i.e. when
405        input_saved_model_dir is specified and input_graph_def is None in
406        __init__().
407
408    Raises:
409      ValueError: if the input to the converter is a GraphDef instead of a
410      SavedModel.
411    """
412    assert self._converted
413
414    if context.executing_eagerly():
415      # Rewrite the signature map using the optimized ConcreteFunction.
416      signatures = {
417          key: value for key, value in self._saved_model.signatures.items()
418      }
419      signatures[self._input_saved_model_signature_key] = self._converted_func
420      save.save(self._saved_model, output_saved_model_dir, signatures)
421    else:
422      if self._input_graph_def:
423        raise ValueError(
424            "Not able to save to a SavedModel since input is a GraphDef")
425
426      # Write the transformed graphdef as SavedModel.
427      saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
428      with ops.Graph().as_default():
429        importer.import_graph_def(self._converted_graph_def, name="")
430        # We don't use any specific converter here.
431        with session.Session(config=self._session_config) as sess:
432          saved_model_builder.add_meta_graph_and_variables(
433              sess,
434              self._input_saved_model_tags,
435              signature_def_map=self._grappler_meta_graph_def.signature_def)
436      # Ignore other meta graphs from the input SavedModel.
437      saved_model_builder.save()
438
439
440class TrtPrecisionMode(object):
441  FP32 = "FP32"
442  FP16 = "FP16"
443  INT8 = "INT8"
444
445  @staticmethod
446  def supported_precision_modes():
447    return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
448
449
450# Use a large enough number as the default max_workspace_size for TRT engines,
451# so it can produce reasonable performance results with the default.
452DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
453
454
455class TrtGraphConverter(GraphConverter):
456  """A GraphConverter for TRT transformation."""
457
458  _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF_TRT_Calibration"
459
460  @classmethod
461  def get_tensorrt_rewriter_config(
462      cls,
463      rewriter_config_template=None,
464      max_batch_size=1,
465      max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
466      precision_mode=TrtPrecisionMode.FP32,
467      minimum_segment_size=3,
468      is_dynamic_op=False,
469      maximum_cached_engines=1,
470      cached_engine_batches=None,
471      use_calibration=True,
472      use_function_backup=True):
473    """Returns a RewriterConfig proto for TRT transformation.
474
475    Args:
476      rewriter_config_template: a template RewriterConfig proto used to create a
477        TRT-enabled RewriterConfig. If None, it will use a default one.
478      max_batch_size: max size for the input batch
479      max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
480        engine can use at execution time. This corresponds to the
481        'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
482      precision_mode: one of TrtPrecisionMode.supported_precision_modes().
483      minimum_segment_size: the minimum number of nodes required for a subgraph
484        to be replaced by TRTEngineOp.
485      is_dynamic_op: whether to generate dynamic TRT ops which will build the
486        TRT network and engine at run time.
487      maximum_cached_engines: max number of cached TRT engines in dynamic TRT
488        ops. If the number of cached engines is already at max but none of them
489        can serve the input, the TRTEngineOp will fall back to run the TF
490        function based on which the TRTEngineOp is created.
491      cached_engine_batches: a list of batch sizes used to create cached
492        engines, only used when is_dynamic_op is True. The length of the list
493        should be <= maximum_cached_engines, and the dynamic TRT op will use
494        this list to determine the batch sizes of the cached engines, instead of
495        making the decision on the fly. This is useful when we know the most
496        common batch size(s) the application is going to generate.
497      use_calibration: this argument is ignored if precision_mode is not INT8.
498        If set to True, a calibration graph will be created to calibrate the
499        missing ranges. The calibration graph must be converted to an inference
500        graph by running calibration with calibrate(). If set to False,
501        quantization nodes will be expected for every tensor in the graph
502        (exlcuding those which will be fused). If a range is missing, an error
503        will occur. Please note that accuracy may be negatively affected if
504        there is a mismatch between which tensors TRT quantizes and which
505        tensors were trained with fake quantization.
506      use_function_backup: if set to True, it will create a FunctionDef for each
507        subgraph that is converted to TRT op, and if TRT ops fail to execute at
508        runtime, it'll invoke that function as a fallback.
509
510    Returns:
511      A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
512
513    Raises:
514      TypeError: if any of the parameters are of unexpected type.
515      ValueError: if any of the parameters are of unexpected value.
516    """
517    # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
518    # even if it cannot find TensorRT library.
519    trt_ops.load_trt_ops()
520    # pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable
521    # Import a random symbol to trigger loading of TRT library.
522    from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
523    # pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable
524
525    if rewriter_config_template is not None and not isinstance(
526        rewriter_config_template, rewriter_config_pb2.RewriterConfig):
527      raise TypeError(
528          "rewriter_config_template should be a RewriterConfig proto.")
529
530    rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
531    if rewriter_config_template is None:
532      # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
533      # need to run constant folding again.
534      rewriter_config_with_trt.optimizers.extend(
535          ["constfold", "layout", "constfold"])
536      rewriter_config_with_trt.meta_optimizer_iterations = (
537          rewriter_config_pb2.RewriterConfig.ONE)
538    else:
539      rewriter_config_with_trt.CopyFrom(rewriter_config_template)
540
541    optimizer = rewriter_config_with_trt.custom_optimizers.add()
542    optimizer.name = "TensorRTOptimizer"
543    optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
544    optimizer.parameter_map["max_batch_size"].i = max_batch_size
545    optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
546    optimizer.parameter_map[
547        "max_workspace_size_bytes"].i = max_workspace_size_bytes
548    optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
549    optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
550    if cached_engine_batches:
551      optimizer.parameter_map["cached_engine_batches"].list.i.extend(
552          cached_engine_batches)
553    optimizer.parameter_map["use_calibration"].b = use_calibration
554    optimizer.parameter_map["use_function_backup"].b = use_function_backup
555    return rewriter_config_with_trt
556
557  def __init__(self,
558               input_saved_model_dir=None,
559               input_saved_model_tags=None,
560               input_saved_model_signature_key=None,
561               input_graph_def=None,
562               nodes_blacklist=None,
563               session_config=None,
564               max_batch_size=1,
565               max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
566               precision_mode=TrtPrecisionMode.FP32,
567               minimum_segment_size=3,
568               is_dynamic_op=False,
569               maximum_cached_engines=1,
570               cached_engine_batches=None,
571               use_calibration=True,
572               use_function_backup=True):
573    """Initialize the converter.
574
575    Args:
576      input_saved_model_dir: the directory to load the SavedModel which contains
577        the input graph to transforms. Used only when input_graph_def is None.
578      input_saved_model_tags: list of tags to load the SavedModel.
579      input_saved_model_signature_key: the key of the signature to optimize the
580        graph for.
581      input_graph_def: a GraphDef object containing a model to be transformed.
582        If set to None, the graph will be read from the SavedModel loaded from
583        input_saved_model_dir.
584      nodes_blacklist: list of node names to prevent the converter from
585        touching. Only used when input_graph_def is not None.
586      session_config: the ConfigProto used to create a Session. It's also used
587        as a template to create a TRT-enabled ConfigProto for conversion. If not
588        specified, a default ConfigProto will be used.
589      max_batch_size: max size for the input batch.
590      max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
591        engine can use at execution time. This corresponds to the
592        'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
593      precision_mode: one of TrtPrecisionMode.supported_precision_modes().
594      minimum_segment_size: the minimum number of nodes required for a subgraph
595        to be replaced by TRTEngineOp.
596      is_dynamic_op: whether to generate dynamic TRT ops which will build the
597        TRT network and engine at run time.
598      maximum_cached_engines: max number of cached TRT engines in dynamic TRT
599        ops. If the number of cached engines is already at max but none of them
600        can serve the input, the TRTEngineOp will fall back to run the TF
601        function based on which the TRTEngineOp is created.
602      cached_engine_batches: a list of batch sizes used to create cached
603        engines, only used when is_dynamic_op is True. The length of the list
604        should be <= maximum_cached_engines, and the dynamic TRT op will use
605        this list to determine the batch sizes of the cached engines, instead of
606        making the decision on the fly. This is useful when we know the most
607        common batch size(s) the application is going to generate.
608      use_calibration: this argument is ignored if precision_mode is not INT8.
609        If set to True, a calibration graph will be created to calibrate the
610        missing ranges. The calibration graph must be converted to an inference
611        graph by running calibration with calibrate(). If set to False,
612        quantization nodes will be expected for every tensor in the graph
613        (exlcuding those which will be fused). If a range is missing, an error
614        will occur. Please note that accuracy may be negatively affected if
615        there is a mismatch between which tensors TRT quantizes and which
616        tensors were trained with fake quantization.
617      use_function_backup: if set to True, it will create a FunctionDef for each
618        subgraph that is converted to TRT op, and if TRT ops fail to execute at
619        runtime, it'll invoke that function as a fallback.
620
621    Raises:
622      ValueError: if the combination of the parameters is invalid.
623      RuntimeError: if the TensorRT library version is incompatible.
624    """
625    super(TrtGraphConverter, self).__init__(
626        input_saved_model_dir=input_saved_model_dir,
627        input_saved_model_tags=input_saved_model_tags,
628        input_saved_model_signature_key=input_saved_model_signature_key,
629        input_graph_def=input_graph_def,
630        nodes_blacklist=nodes_blacklist,
631        session_config=session_config)
632
633    # TODO(laigd): move all the validations below to
634    # get_tensorrt_rewriter_config().
635
636    # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
637    # even if it cannot find TensorRT library.
638    trt_ops.load_trt_ops()
639    # pylint: disable=g-import-not-at-top,line-too-long
640    from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
641    from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
642    # pylint: enable=g-import-not-at-top,line-too-long
643
644    # Check compatibility of TensorRT version.
645    compiled_version = get_linked_tensorrt_version()
646    loaded_version = get_loaded_tensorrt_version()
647    tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
648    tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
649    version_mismatch = False
650    if loaded_version[0] < compiled_version[0]:
651      tf_logging.error(
652          "TensorRT version mismatch. Tensorflow was compiled against " +
653          "TensorRT %s but library loaded from environment is TensorRT %s" %
654          (".".join([str(x) for x in compiled_version]),
655           ".".join([str(x) for x in loaded_version])) +
656          ". Please make sure that correct version of TensorRT " +
657          "is available in the system and added to ldconfig or LD_LIBRARY_PATH")
658      raise RuntimeError("Incompatible TensorRT library version")
659    for i in zip(loaded_version, compiled_version):
660      if i[0] != i[1]:
661        tf_logging.warn("TensorRT mismatch. Compiled against version " +
662                        "%s, but loaded %s. Things may not work" %
663                        (".".join([str(x) for x in compiled_version]),
664                         ".".join([str(x) for x in loaded_version])))
665        version_mismatch = True
666        break
667    if not version_mismatch:
668      tf_logging.info("Running against TensorRT version %s" %
669                      ".".join([str(x) for x in loaded_version]))
670
671    # Check input arguments.
672    supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
673    if precision_mode not in supported_precision_modes:
674      raise ValueError(("precision mode '{}' is not supported."
675                        "It should be one of {}").format(
676                            precision_mode, supported_precision_modes))
677
678    if cached_engine_batches:
679      if not isinstance(cached_engine_batches, list):
680        raise TypeError("cached_engine_batches should be a list.")
681      if len(cached_engine_batches) > maximum_cached_engines:
682        raise ValueError("cached_engine_batches should not contain more than "
683                         "maximum_cached_engines items.")
684
685    self._need_calibration = (
686        precision_mode == TrtPrecisionMode.INT8 and use_calibration)
687    self._use_function_backup = use_function_backup
688
689    # TODO(laigd): consider provide a mechanism to remove the fallback path
690    # after calibration is done.
691    if self._need_calibration and not use_function_backup:
692      raise ValueError(
693          "Calibration requires enabling fallback to TF function execution.")
694
695    # TODO(laigd):
696    # - Get rid of is_dynamic_op option, it should always be True, and it should
697    #   accept N shapes as input.
698    # - Verify in int8 mode that maximum_cached_engines and
699    #   cached_engine_batches are set appropriately.
700    # - If it fails to build the int8 engine it should return error.
701    self._max_batch_size = max_batch_size
702    self._max_workspace_size_bytes = max_workspace_size_bytes
703    self._precision_mode = precision_mode
704    self._minimum_segment_size = minimum_segment_size
705    self._is_dynamic_op = is_dynamic_op
706    self._maximum_cached_engines = maximum_cached_engines
707    self._cached_engine_batches = cached_engine_batches
708
709  def get_rewriter_config(self, rewriter_config_template=None):
710    return TrtGraphConverter.get_tensorrt_rewriter_config(
711        rewriter_config_template,
712        max_batch_size=self._max_batch_size,
713        max_workspace_size_bytes=self._max_workspace_size_bytes,
714        precision_mode=self._precision_mode,
715        minimum_segment_size=self._minimum_segment_size,
716        is_dynamic_op=self._is_dynamic_op,
717        maximum_cached_engines=self._maximum_cached_engines,
718        cached_engine_batches=self._cached_engine_batches,
719        use_calibration=self._need_calibration,
720        use_function_backup=self._use_function_backup)
721
722  def finalize_calibration(self):
723    assert self._need_calibration
724    assert self._converted
725    assert not self._calibration_data_collected
726
727    # Lazily load the op, since it's not available in cpu-only builds. Importing
728    # this at top will cause tests that imports TF-TRT fail when they're built
729    # and run without CUDA/GPU.
730    # pylint: disable=g-import-not-at-top,line-too-long
731    from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op
732    # pylint: enable=g-import-not-at-top,line-too-long
733
734    # TODO(laigd): a better way would be to use self._calibration_sess to list
735    # all the devices, add one get_serialized_resource_op for each device, and
736    # fetch each such op for every resource until its found. This can work
737    # even when the device of the TRTEngineOp is empty or not fully specified.
738
739    # Maps device name to the corresponding get_serialized_resource_op.
740    device_to_get_resource_op_map = {}
741
742    with self._calibration_graph.as_default():
743      container_input = array_ops.placeholder(dtypes.string)
744      resource_name_input = array_ops.placeholder(dtypes.string)
745
746      for node in self._converted_graph_def.node:
747        if node.op == "TRTEngineOp":
748          # Adds the get_serialized_resource_op for the device if not done
749          # before. We only add one such op for each device.
750          # TODO(laigd): What if the device is empty?????
751          if node.device not in device_to_get_resource_op_map:
752            with self._calibration_graph.device(node.device):
753              serialized_resources_output = (
754                  get_serialized_resource_op(container_input,
755                                             resource_name_input))
756            device_to_get_resource_op_map[node.device] = (
757                serialized_resources_output)
758
759          # Get the calibration resource.
760          calibration_result = self._calibration_sess.run(
761              device_to_get_resource_op_map[node.device],
762              feed_dict={
763                  container_input:
764                      TrtGraphConverter
765                      ._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
766                  resource_name_input:
767                      node.name
768              })
769          node.attr["calibration_data"].s = calibration_result
770
771    self._calibration_data_collected = True
772    self._calibration_sess.close()
773
774  def save(self, output_saved_model_dir):
775    """Save the converted graph as a SavedModel."""
776    if self._need_calibration:
777      assert self._calibration_data_collected
778    super(TrtGraphConverter, self).save(output_saved_model_dir)
779
780
781def create_inference_graph(
782    input_graph_def,
783    outputs,
784    max_batch_size=1,
785    max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
786    precision_mode=TrtPrecisionMode.FP32,
787    minimum_segment_size=3,
788    is_dynamic_op=False,
789    maximum_cached_engines=1,
790    cached_engine_batches=None,
791    input_saved_model_dir=None,
792    input_saved_model_tags=None,
793    input_saved_model_signature_key=None,
794    output_saved_model_dir=None,
795    session_config=None):
796  """Python wrapper for the TRT transformation.
797
798  Args:
799    input_graph_def: a GraphDef object containing a model to be transformed. If
800      set to None, the graph will be read from the SavedModel loaded from
801      input_saved_model_dir.
802    outputs: list of tensors or node names for the model outputs. Only used when
803      input_graph_def is not None.
804    max_batch_size: max size for the input batch.
805    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
806      engine can use at execution time. This corresponds to the 'workspaceSize'
807      parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
808    precision_mode: one of TrtPrecisionMode.supported_precision_modes().
809    minimum_segment_size: the minimum number of nodes required for a subgraph to
810      be replaced by TRTEngineOp.
811    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
812      network and engine at run time.
813    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
814      If the number of cached engines is already at max but none of them can
815      serve the input, the TRTEngineOp will fall back to run the TF function
816      based on which the TRTEngineOp is created.
817    cached_engine_batches: a list of batch sizes used to create cached engines,
818      only used when is_dynamic_op is True. The length of the list should be <=
819      maximum_cached_engines, and the dynamic TRT op will use this list to
820      determine the batch sizes of the cached engines, instead of making the
821      decision on the fly. This is useful when we know the most common batch
822      size(s) the application is going to generate.
823    input_saved_model_dir: the directory to load the SavedModel which contains
824      the input graph to transforms. Used only when input_graph_def is None.
825    input_saved_model_tags: list of tags to load the SavedModel.
826    input_saved_model_signature_key: the key of the signature to optimize the
827      graph for.
828    output_saved_model_dir: if not None, construct a SavedModel using the
829      returned GraphDef and save it to the specified directory. This option only
830      works when the input graph is loaded from a SavedModel, i.e. when
831      input_saved_model_dir is specified and input_graph_def is None.
832    session_config: the ConfigProto used to create a Session. It's also used as
833      a template to create a TRT-enabled ConfigProto for conversion. If not
834      specified, a default ConfigProto will be used.
835
836  Returns:
837    A GraphDef transformed from input_graph_def (or the SavedModel graph def
838    loaded from input_saved_model_dir, if input_graph_def is not present), where
839    all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
840    function is added for each of the subgraphs.
841
842    If is_dynamic_op is True, each TRTEngineOp will contain a serialized
843    subgraph GraphDef, which will be converted to a TRT engine at execution time
844    and the TRT engine will be cached for future usage. A new TRT engine will be
845    created each time when none of the cached engines match the input shapes. If
846    it fails to execute the TRT engine or the number of cached engines reaches
847    maximum_cached_engines, the op will fall back to call the corresponding TF
848    function.
849
850    If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
851    engine created from the corresponding subgraph. No more engines will be
852    created on the fly, and the op will fall back to call the corresponding TF
853    function when it fails to execute the engine.
854
855  Raises:
856    ValueError: if the combination of the parameters is invalid.
857  """
858  trt_converter = TrtGraphConverter(
859      input_saved_model_dir=input_saved_model_dir,
860      input_saved_model_tags=input_saved_model_tags,
861      input_saved_model_signature_key=input_saved_model_signature_key,
862      input_graph_def=input_graph_def,
863      nodes_blacklist=outputs,
864      session_config=session_config,
865      max_batch_size=max_batch_size,
866      max_workspace_size_bytes=max_workspace_size_bytes,
867      precision_mode=precision_mode,
868      minimum_segment_size=minimum_segment_size,
869      is_dynamic_op=is_dynamic_op,
870      maximum_cached_engines=maximum_cached_engines,
871      cached_engine_batches=cached_engine_batches,
872      use_calibration=False)
873  converted_graph_def = trt_converter.convert()
874  if output_saved_model_dir:
875    trt_converter.save(output_saved_model_dir)
876  return converted_graph_def
877