• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import errno
23import gc
24import itertools
25import os
26import re
27import shutil
28import tempfile
29import warnings
30
31import numpy as np
32import six
33
34from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
35from tensorflow.core.framework import graph_pb2
36from tensorflow.core.protobuf import config_pb2
37from tensorflow.core.protobuf import rewriter_config_pb2
38from tensorflow.python.compiler.tensorrt import trt_convert
39from tensorflow.python.compiler.tensorrt import utils as trt_utils
40from tensorflow.python.eager import def_function
41from tensorflow.python.framework import graph_io
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import tensor_spec
44from tensorflow.python.framework import test_util
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.profiler import trace
49from tensorflow.python.saved_model import builder
50from tensorflow.python.saved_model import load
51from tensorflow.python.saved_model import loader
52from tensorflow.python.saved_model import save
53from tensorflow.python.saved_model import signature_constants
54from tensorflow.python.saved_model import signature_def_utils
55from tensorflow.python.saved_model import tag_constants
56from tensorflow.python.saved_model import utils
57from tensorflow.python.tools import saved_model_utils
58from tensorflow.python.training.tracking import tracking
59from tensorflow.python.util import nest
60
61TfTrtIntegrationTestParams = collections.namedtuple(
62    "TfTrtIntegrationTestParams",
63    [
64        # A function that creates the TF graph for testing.
65        "graph_fn",
66        # A list of specifications for input tensors.
67        "input_specs",
68        # A list of specifications for output tensors.
69        "output_specs",
70        # A list of list of input shapes. Each shape must match the
71        # corresponding element in `input_specs`.
72        "input_dims",
73        # A list of list of expected output shapes. Each shape must match the
74        # corresponding element in `output_specs`.
75        "expected_output_dims"
76    ])
77
78RunParams = collections.namedtuple(
79    "RunParams",
80    [
81        # Whether to run the conversion online with RewriterConfig, or offline
82        # with TrtGraphConverter.
83        "convert_online",
84        "precision_mode",
85        "dynamic_engine",
86        "use_calibration",
87        "test_name",
88        # Is this test for TF 2.0?
89        "is_v2",
90    ])
91
92FP32 = "FP32"
93FP16 = "FP16"
94INT8 = "INT8"
95PRECISION_MODES = [FP32, FP16, INT8]
96
97
98def IsQuantizationMode(mode):
99  return mode == "INT8"
100
101
102def IsQuantizationWithCalibration(params):
103  return IsQuantizationMode(params.precision_mode) and params.use_calibration
104
105
106def IsQuantizationWithoutCalibration(params):
107  return IsQuantizationMode(
108      params.precision_mode) and not params.use_calibration
109
110
111class GraphState(object):
112  ORIGINAL = 0
113  CALIBRATE = 1
114  INFERENCE = 2
115
116
117class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
118  """Class to test Tensorflow-TensorRT integration."""
119
120  @property
121  def trt_incompatible_op(self):
122    return math_ops.erfc
123
124  @property
125  def precision_modes(self):
126    return ["FP32", "FP16", "INT8"]
127
128  # str is bytes in py2, but unicode in py3.
129  def _ToUnicode(self, s):
130    if six.PY2:
131      if isinstance(s, unicode):
132        return s
133      return s.decode("utf-8")
134    else:
135      if isinstance(s, str):
136        return s
137      return s.decode("utf-8")
138
139  def _ToBytes(self, s):
140    if six.PY2:
141      if isinstance(s, unicode):
142        return s.encode("utf-8")
143      return s
144    else:
145      if isinstance(s, str):
146        return s.encode("utf-8")
147      return s
148
149  def _ToString(self, s):
150    if six.PY2:
151      if isinstance(s, unicode):
152        return s.encode("utf-8")
153      return s
154    else:
155      if isinstance(s, str):
156        return s
157      return s.decode("utf-8")
158
159  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
160    super(TfTrtIntegrationTestBase, self).__init__(methodName)
161    self._trt_test_params = None
162    self._disable_non_trt_optimizers = False
163    self._use_implicit_batch = True
164    self._profile_strategy = "Unknown"
165
166  def setUp(self):
167    """Setup method."""
168    super(TfTrtIntegrationTestBase, self).setUp()
169    warnings.simplefilter("always")
170
171    if not is_tensorrt_enabled():
172      self.skipTest("Test requires TensorRT")
173
174  def _GetTensorSpec(self, shape, mask, dtype, name):
175    # Set dimension i to None if mask[i] == False
176    assert len(shape) == len(mask)
177    new_shape = [s if m else None for s, m in zip(shape, mask)]
178    return tensor_spec.TensorSpec(new_shape, dtype, name)
179
180  def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes):
181    """Build test parameters.
182
183    The input_shapes and output_shapes arguments are known (static) shapes that
184    can be used to generate test data. To define the model, we also specify
185    corresponding input/output TensoSpecs. These are defined using the shape
186    arguments. For each input tensor we define:
187
188    input_spec = [None] + input_shape[1:]
189
190    and similarly for output shapes. This means that we leave the first (batch)
191    dimension unknown, the rest is just copied from the shapes arg.
192
193    Args:
194      graph_fn: The function to build the graph.
195      dtype: The element type.
196      input_shapes: The input shapes.
197      output_shapes: The output shapes.
198
199    Returns:
200      The test parameters.
201    """
202
203    input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes]
204    output_mask = [
205        [False] + [True] * (len(shape) - 1) for shape in output_shapes
206    ]
207
208    return self.BuildParamsWithMask(graph_fn, dtype, input_shapes,
209                                    output_shapes, input_mask, output_mask, [],
210                                    [])
211
212  def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes,
213                          input_mask, output_mask, extra_inputs, extra_outputs):
214    """Build test parameters with static or dynamic input shapes.
215
216    To define dynamic shapes give a boolean mask that describes which
217    dimensions to treat as known. The values in input_mask are interpreted the
218    following way:
219    - True: known dim (use the corresponding value from input_shapes)
220    - False: unknown dim (replace the corresponding value from input_shapes
221             with None)
222    For example, to define the first two dimension with unknown size use
223    input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]].
224
225    Args:
226      graph_fn: The function to build the graph.
227      dtype: The element type.
228      input_shapes: The input shapes.
229      output_shapes: The output shapes.
230      input_mask: The input shape masks.
231      output_mask: the output shape masks.
232      extra_inputs: list of additional input shapes
233      extra_outputs: list of additional outputs shapes
234
235    Returns:
236      The test parameters.
237    """
238
239    def _ValidateShapes(shapes):
240      # Make sure all the shapes are fully specified.
241      for shape in shapes:
242        assert all(shape)
243
244    _ValidateShapes(input_shapes)
245    _ValidateShapes(output_shapes)
246
247    assert len(input_mask) == len(input_shapes)
248    assert len(output_mask) == len(output_shapes)
249    for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs):
250      assert len(input_shapes) == len(extra_in_shape)
251      assert len(output_shapes) == len(extra_out_shape)
252
253    return TfTrtIntegrationTestParams(
254        graph_fn=graph_fn,
255        input_specs=[
256            self._GetTensorSpec(shape, mask, dtype, "input_%d" % i)
257            for i, (shape, mask) in enumerate(zip(input_shapes, input_mask))
258        ],
259        output_specs=[
260            self._GetTensorSpec(shape, mask, dtype, "output_%d" % i)
261            for i, (shape, mask) in enumerate(zip(output_shapes, output_mask))
262        ],
263        input_dims=[input_shapes] + extra_inputs,
264        expected_output_dims=[output_shapes] + extra_outputs)
265
266  def DisableNonTrtOptimizers(self):
267    self._disable_non_trt_optimizers = True
268
269  def SetDynamicShapeModeAndProfileStrategy(self, profile_strategy="Range"):
270    self._use_implicit_batch = False
271    self._profile_strategy = profile_strategy
272
273  def GetParams(self):
274    """Returns a TfTrtIntegrationTestParams for the test."""
275    raise NotImplementedError()
276
277  def GetConversionParams(self, run_params):
278    """Returns a TrtConversionParams for test."""
279    conversion_params = trt_convert.TrtConversionParams(
280        # We use the minimum of all the batch sizes, so when multiple different
281        # input shapes are provided it'll always create new engines in the
282        # cache, and we can therefore test the cache behavior.
283        max_workspace_size_bytes=1 << 25,
284        precision_mode=run_params.precision_mode,
285        minimum_segment_size=2,
286        maximum_cached_engines=1,
287        use_calibration=run_params.use_calibration)
288    return conversion_params
289
290  def GetMaxBatchSize(self, run_params):
291    """Returns the max_batch_size that the converter should use for tests."""
292    if run_params.dynamic_engine:
293      return None
294    batch_list = []
295    for dims_list in self._GetParamsCached().input_dims:
296      assert dims_list
297      # Each list of shapes should have same batch size.
298      input_batches = [dims[0] for dims in dims_list]
299      assert max(input_batches) == min(input_batches)
300      batch_list.append(input_batches[0])
301    return max(batch_list)
302
303  def ShouldRunTest(self, run_params):
304    """Whether to run the test."""
305    # Ensure use_calibration=True in case of INT8 precision
306    return (run_params.use_calibration or not IsQuantizationMode(
307        run_params.precision_mode)), "test either calibration or non-INT8"
308
309  def ExpectedEnginesToBuild(self, run_params):
310    """Returns the expected engines to build, implemented by subclass."""
311    raise NotImplementedError()
312
313  def ExpectedMaxBatchSizes(self, run_params):
314    """Returns the expected maximum batch sizes of the build engines."""
315    return None
316
317  def ExpectedAbsoluteTolerance(self, run_params):
318    """The absolute tolerance to compare floating point results."""
319    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
320
321  def ExpectedRelativeTolerance(self, run_params):
322    """The relative tolerance to compare floating point results."""
323    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
324
325  def _GetParamsCached(self):
326    if self._trt_test_params is None:
327      self._trt_test_params = self.GetParams()
328    return self._trt_test_params
329
330  def _GetGPUOptions(self):
331    gpu_options = config_pb2.GPUOptions()
332    gpu_options.allow_growth = True
333    return gpu_options
334
335  def _GetConfigProto(self, run_params, graph_state):
336    """Get config proto based on specific settings."""
337    conversion_params = self.GetConversionParams(run_params)
338    max_batch_size = self.GetMaxBatchSize(run_params)
339
340    if graph_state == GraphState.INFERENCE and run_params.convert_online:
341      rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
342          conversion_params,
343          is_dynamic_op=run_params.dynamic_engine,
344          max_batch_size=max_batch_size,
345          disable_non_trt_optimizers=self._disable_non_trt_optimizers)
346    else:
347      rewriter_cfg = rewriter_config_pb2.RewriterConfig()
348      if self._disable_non_trt_optimizers:
349        trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
350
351    config = config_pb2.ConfigProto(
352        gpu_options=self._GetGPUOptions(),
353        graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
354    return config
355
356  def _GetFeedNames(self):
357    params = self._GetParamsCached()
358    # Construct the feeds tensor names by appending :0 to the node names.
359    return [spec.name + ":0" for spec in params.input_specs]
360
361  def _GetFetchNames(self):
362    params = self._GetParamsCached()
363    # Construct the fetches tensor names by appending :0 to the node names.
364    return [spec.name + ":0" for spec in params.output_specs]
365
366  def _GetFeedDict(self, inputs_data):
367    return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)}
368
369  def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2):
370    """Run given graphdef multiple times using TF 1.x runtime."""
371    params = self._GetParamsCached()
372    fetches = self._GetFetchNames()
373    g = ops.Graph()
374    with g.as_default():
375      with self.session(graph=g, config=config, use_gpu=True) as sess:
376        loader.load(sess, [tag_constants.SERVING], saved_model_dir)
377        vals = []
378        # Run for each input(s) shape
379        for expected_shapes, current_input_data in zip(
380            params.expected_output_dims, inputs_data):
381          val = None
382          for _ in range(num_runs):
383            new_val = sess.run(fetches, self._GetFeedDict(current_input_data))
384            self.assertEqual(len(expected_shapes), len(new_val))
385            for expected_shape, actual_val in zip(expected_shapes, new_val):
386              self.assertEqual(list(expected_shape), list(actual_val.shape))
387            if val is not None:
388              # Some ops may have nondeterministic output. E.g. Conv2D may use
389              # winograd algorithm. So we set atol/rtol be larger than 1.e-06.
390              self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05)
391            val = new_val
392          vals.append(val)
393        return vals
394
395  def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2):
396    """Run given graphdef multiple times using TF 2.0 runtime."""
397    params = self._GetParamsCached()
398    root = load.load(saved_model_dir)
399    func = root.signatures[
400        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
401    results = []
402    for expected_shapes, current_input_data in zip(params.expected_output_dims,
403                                                   inputs_data):
404      val = None
405      for _ in range(num_runs):
406        feed_dict = {
407            params.input_specs[i].name: current_input_data[i]
408            for i in range(len(params.input_specs))
409        }
410        new_val = func(**feed_dict)
411        assert isinstance(new_val, dict)
412        # The key of the output map is always like output_i.
413        new_val = [new_val[key] for key in sorted(new_val)]
414        # Each element is an eager Tensor, and accessing individual elements is
415        # very expensive, so we convert them to a numpy array first.
416        new_val = [v.numpy() for v in new_val]
417        self.assertEqual(len(expected_shapes), len(new_val))
418        for expected_shape, actual_val in zip(expected_shapes, new_val):
419          self.assertEqual(list(expected_shape), list(actual_val.shape))
420        if val is not None:
421          # Some ops may have nondeterministic output. E.g. Conv2D may use
422          # winograd algorithm. So we set atol/rtol be larger than 1.e-06.
423          self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05)
424        val = new_val
425      results.append(val)
426
427    return results
428
429  def _RunGraph(self,
430                run_params,
431                saved_model_dir,
432                inputs_data,
433                graph_state,
434                num_runs=2):
435    params = self._GetParamsCached()
436    for data in inputs_data:
437      assert len(params.input_specs) == len(data)
438
439    if run_params.is_v2:
440      results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state,
441                                 num_runs)
442      gc.collect()  # Force GC to destroy the TRT engine cache.
443      return results
444
445    # The default config for tf.session is None. Create a config with
446    # TensorRTOptimizer enabled to support convert_online for inference.
447    config = None
448    # TODO(b/170220818): use the default session config to run inferenence
449    #   graphs for the offline conversion case after fixing the bug.
450    if graph_state == GraphState.INFERENCE:
451      config = self._GetConfigProto(run_params, GraphState.INFERENCE)
452    return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs)
453
454  def _CreateConverter(self, run_params, saved_model_dir, conversion_params):
455    """Returns a TrtGraphConverter."""
456    if run_params.is_v2:
457      converter_v2 = trt_convert.TrtGraphConverterV2(
458          input_saved_model_dir=saved_model_dir,
459          conversion_params=conversion_params,
460          use_dynamic_shape=not self._use_implicit_batch,
461          dynamic_shape_profile_strategy=self._profile_strategy)
462      if self._disable_non_trt_optimizers:
463        converter_v2._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
464      return converter_v2
465
466    converter_v1 = trt_convert.TrtGraphConverter(
467        input_saved_model_dir=saved_model_dir,
468        max_batch_size=self.GetMaxBatchSize(run_params),
469        max_workspace_size_bytes=conversion_params.max_workspace_size_bytes,
470        precision_mode=conversion_params.precision_mode,
471        minimum_segment_size=conversion_params.minimum_segment_size,
472        is_dynamic_op=run_params.dynamic_engine,
473        maximum_cached_engines=conversion_params.maximum_cached_engines,
474        use_calibration=conversion_params.use_calibration)
475    if self._disable_non_trt_optimizers:
476      converter_v1._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
477    return converter_v1
478
479  def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
480    """Return trt converted graphdef in INT8 mode."""
481    conversion_params = self.GetConversionParams(run_params)
482    logging.info(conversion_params)
483    assert conversion_params.precision_mode == "INT8"
484    assert run_params.dynamic_engine
485    assert conversion_params.maximum_cached_engines == 1
486    assert conversion_params.use_calibration
487
488    # We only support calibrating single engine.
489    # TODO(aaroey): fix this.
490    assert len(inputs_data) == 1
491
492    converter = self._CreateConverter(run_params, saved_model_dir,
493                                      conversion_params)
494    int8_gdef = converter.convert()
495    self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef,
496                         GraphState.CALIBRATE)
497
498    converter.calibrate(
499        fetch_names=self._GetFetchNames(),
500        num_runs=5,
501        feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0]))
502    trt_saved_model_dir = self._GetSavedModelDir(run_params,
503                                                 GraphState.CALIBRATE)
504    converter.save(trt_saved_model_dir)
505    return trt_saved_model_dir
506
507  def _ShouldConverterBuild(self, run_params):
508    return True
509
510  def _GetInferGraph(self, run_params, saved_model_dir):
511    """Return trt converted graphdef."""
512    conversion_params = self.GetConversionParams(run_params)
513    logging.info(conversion_params)
514
515    converter = self._CreateConverter(run_params, saved_model_dir,
516                                      conversion_params)
517    converter.convert()
518
519    if not self._use_implicit_batch and self._ShouldConverterBuild(run_params):
520      logging.info("Using build mode")
521
522      def _BuildInputFn():
523        for shapes in self._GetParamsCached().input_dims:
524          yield [np.zeros(x).astype(np.float32) for x in shapes]
525
526      converter.build(input_fn=_BuildInputFn)
527
528    trt_saved_model_dir = self._GetSavedModelDir(run_params,
529                                                 GraphState.INFERENCE)
530    converter.save(trt_saved_model_dir)
531    return trt_saved_model_dir
532
533  def _GetGraphStateLabel(self, graph_state):
534    if graph_state == GraphState.ORIGINAL:
535      return "Original"
536    elif graph_state == GraphState.CALIBRATE:
537      return "CalibEngine"
538    elif graph_state == GraphState.INFERENCE:
539      return "InferEngine"
540    else:
541      return "UnknownState"
542
543  def _WriteGraph(self, run_params, gdef, graph_state):
544    temp_dir = os.getenv("TRT_TEST_TMPDIR")
545    if not temp_dir:
546      return
547
548    graph_name = (
549        self.__class__.__name__ + "_" + run_params.test_name + "_" +
550        self._GetGraphStateLabel(graph_state) + ".pbtxt")
551    logging.info("Writing graph to %s/%s", temp_dir, graph_name)
552    graph_io.write_graph(gdef, temp_dir, graph_name)
553
554  # Removes the prefix(s) of function name(s).
555  # The input value can be a string or a sequence of string.
556  def _Canonicalize(self, value):
557    if isinstance(value, str):
558      return self._ToString(value.split("/")[-1])
559    elif isinstance(value, collections.abc.Iterable):
560      return set(self._Canonicalize(nm) for nm in value)
561    else:
562      raise TypeError(
563          "'_Canonicalize' can only be used on strings or sequence of strings!")
564
565  # Removes the graph sequence number prefix from the name(s) only if the
566  # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts
567  # such a prefix exists.
568  # The input value can be a string or a sequence of string.
569  def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix):
570    if isinstance(value, str):
571      match = re.search(r"TRTEngineOp_\d+_", value)
572      has_prefix = match and value.startswith(match.group(0))
573      assert (not expecting_prefix) or has_prefix
574      if has_prefix:
575        parts = value.split("_", maxsplit=2)
576        assert len(parts) == 3
577        return parts[0] + "_" + parts[2]
578      return value
579    elif isinstance(value, collections.abc.Iterable):
580      return set(
581          self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix)
582          for nm in value)
583    else:
584      raise TypeError(
585          "'_RemoveGraphSequenceNumberImpl' can only be used on strings "
586          "or sequence of strings!")
587
588  def _RemoveGraphSequenceNumber(self, name):
589    return self._RemoveGraphSequenceNumberImpl(name, True)
590
591  def _MayRemoveGraphSequenceNumber(self, name):
592    return self._RemoveGraphSequenceNumberImpl(name, False)
593
594  def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef):
595    old_to_new_node_map = {
596        self._ToString(node.name): self._ToString(node.name)
597        for node in original_gdef.node
598    }
599    for engine_name, node_names in expected_engines.items():
600      for node_name in node_names:
601        old_to_new_node_map[node_name] = engine_name
602    name_to_node_map = {
603        self._ToString(node.name): node for node in original_gdef.node
604    }
605
606    def _InputName(inp):
607      inp = self._ToString(inp)
608      prefix = ""
609      if inp[0] == "^":
610        prefix = "^"
611        inp = inp[1:]
612      parts = inp.split(":")
613      if len(parts) > 1 and parts[-1].isdigit():
614        inp = inp[:-len(parts[-1]) - 1]
615      return (prefix, inp)
616
617    # Compute the expected mapping from each node to its input nodes.
618    expected_input_map = {}
619    removed_const_nodes = set([
620        self._ToString(node.name)
621        for node in original_gdef.node
622        if node.op == "Const"
623    ])
624    for node in original_gdef.node:
625      name_str = self._ToString(node.name)
626      target_node_name = old_to_new_node_map[name_str]
627      is_engine_op = (target_node_name != name_str)
628      if target_node_name not in expected_input_map:
629        expected_input_map[target_node_name] = set()
630      input_set = expected_input_map[target_node_name]
631      for inp in node.input:
632        (prefix, inp_name) = _InputName(inp)
633        mapped_input = old_to_new_node_map[inp_name]
634        # Add the input only if it's outside the segment (note that it could be
635        # in a different engine).
636        if not is_engine_op or (mapped_input != target_node_name and
637                                name_to_node_map[inp_name].op != "Const"):
638          input_set.add(prefix + mapped_input)
639          if mapped_input in removed_const_nodes:
640            removed_const_nodes.remove(mapped_input)
641    # Remove const nodes that have no outputs.
642    expected_input_map = {
643        k: v
644        for k, v in expected_input_map.items()
645        if k not in removed_const_nodes
646    }
647
648    # Compute the actual mapping from each node to its input nodes. If a cast
649    # op doesn't exist in the original graph, we replace the use of the cast op
650    # with the input of the op. This allows the verification to handle the case
651    # where the TF-TRT bridge splits a cast op into a chain of two cast ops.
652    new_cast_op_name_to_node_map = {
653        node.name: node
654        for node in converted_gdef.node
655        if (node.name not in old_to_new_node_map and node.op == "Cast")
656    }
657    actual_input_map = {}
658    for node in converted_gdef.node:
659      name_str = node.name
660      # Only nodes from the original graph or TRTEngineOp nodes are added as
661      # keys to the map.
662      if node.op == "TRTEngineOp":
663        name_str = self._RemoveGraphSequenceNumber(name_str)
664      elif name_str not in old_to_new_node_map:
665        continue
666      actual_input_map[name_str] = set()
667      input_set = actual_input_map[name_str]
668      for inp in node.input:
669        (prefix, node_name) = _InputName(inp)
670        node_name = self._MayRemoveGraphSequenceNumber(node_name)
671        if node_name in new_cast_op_name_to_node_map:
672          (prefix, node_name) = _InputName(
673              new_cast_op_name_to_node_map[node_name].input[0])
674        input_set.add(prefix + node_name)
675
676    self.assertEqual(
677        expected_input_map,
678        actual_input_map,
679        msg="\nexpected:\n%s\nvs actual:\n%s" %
680        (sorted(expected_input_map.items()), sorted(actual_input_map.items())))
681
682  def _VerifyMaxBatchSizeAnnotations(
683      self,
684      expected_engines,
685      original_gdef,
686      converted_gdef,
687      default_max_batch_size,
688      expected_max_batch_sizes=None,
689  ):
690    """Verifies the max batch size annotations in the original and converted GraphDef.
691
692    Args:
693      expected_engines: A sequence of engines names.
694      original_gdef: GraphDef. The graph def before TensorRT conversion.
695      converted_gdef: GraphDef. The graph def after TensorRT conversion.
696      default_max_batch_size: The default maximum batch size to use if no node
697        inside a segment is annoted with a customized max batch size. This value
698        is None when the graph is converted to TF-TRT with dynamic engines.
699      expected_max_batch_sizes: Optional. A sequence of max batch sizes for all
700        the engines. `None` if does not check enforce max batch sizes.
701    """
702    if isinstance(expected_max_batch_sizes, collections.abc.Collection):
703      self.assertEqual(len(expected_max_batch_sizes), len(expected_engines))
704    else:
705      self.assertIsNone(
706          expected_max_batch_sizes,
707          "'expected_max_batch_sizes' shall only be a sequence "
708          "of integers or `None`.")
709
710    def _ChainAllNodes(graph_def):
711      return itertools.chain(
712          graph_def.node,
713          itertools.chain(
714              *[func.node_def for func in graph_def.library.function]))
715
716    old_name_to_node_map = {
717        self._ToString(node.name): node
718        for node in _ChainAllNodes(original_gdef)
719    }
720    new_name_to_func_map = {
721        self._ToString(func.signature.name): func
722        for func in converted_gdef.library.function
723    }
724
725    def _DetectStaticBatchSize(node_def):
726      """Returns the static batch size of an operation or None.
727
728      It is incorrect to use the output shapes to find the batch size of an
729      operation, as the segmenter actually uses the input shapes. However, it is
730      a simplication and works for most of the cases for the test purposes.
731
732      Args:
733        node_def: `tf.NodeDef`. The target node for analysis.
734
735      Returns:
736        If all the outputs of the node have the same static batch size, returns
737        the int value for the batch size. Otherwise returns None.
738      """
739      shapes = node_def.attr["_output_shapes"].list.shape
740      batch_size = set(
741          list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes)
742      if len(batch_size) == 1 and list(batch_size)[0] >= 1:
743        return list(batch_size)[0]
744      return None
745
746    name_to_engines_map = {}
747    actual_max_batch_sizes = []
748    for node in _ChainAllNodes(converted_gdef):
749      if node.op == "TRTEngineOp":
750        engine = node
751        engine_name = self._RemoveGraphSequenceNumber(
752            self._Canonicalize(self._ToString(engine.name)))
753        self.assertIn(engine_name, expected_engines)
754        name_to_engines_map[engine_name] = engine
755        # The input nodes shall not have the conflicting annotation (no
756        # annotation or the same annotation) with the maximum batch size
757        # annotation. If the engine has maximum batch size annotation as the
758        # non-default maximum batch size, then at least one input node shall
759        # have the same annotation to be the source.
760        self.assertIn("max_batch_size", node.attr)
761        engine_max_batch_size = node.attr["max_batch_size"].i
762        self.assertIsInstance(engine_max_batch_size, int)
763        actual_max_batch_sizes.append(engine_max_batch_size)
764        seg_func = node.attr["segment_func"].func
765        self.assertIsNotNone(seg_func)
766        self.assertIn(seg_func.name, new_name_to_func_map)
767        seg_func_def = new_name_to_func_map[seg_func.name]
768        logging.info("Segment function name: %s. Including %d nodes.",
769                     seg_func.name, len(seg_func_def.node_def))
770        node_max_batch_size_all_none = True
771        # Use the native segment to search for replaced nodes
772        for alternative_node in seg_func_def.node_def:
773          node_name = self._Canonicalize(self._ToString(alternative_node.name))
774          if node_name not in old_name_to_node_map:
775            continue
776          original_node = old_name_to_node_map[node_name]
777          node_max_batch_size = None
778          if "_tftrt_op_max_batch_size" in original_node.attr:
779            node_max_batch_size = original_node.attr[
780                "_tftrt_op_max_batch_size"].i
781          elif (original_node.op != "Const" and
782                alternative_node.op != "Const" and
783                "_output_shapes" in original_node.attr):
784            node_max_batch_size = _DetectStaticBatchSize(original_node)
785          logging.info(
786              "'{%s}(%s)'s max batch size annotation is %s. "
787              "'{%s}'s max batch size is %s.", node_name, original_node.op,
788              str(node_max_batch_size), engine_name, str(engine_max_batch_size))
789          node_max_batch_size_all_none &= node_max_batch_size is None
790          self.assertTrue(engine_max_batch_size == node_max_batch_size or
791                          node_max_batch_size is None)
792        logging.info("'{%s}'s max batch size is %d.", engine_name,
793                     engine_max_batch_size)
794        self.assertTrue(default_max_batch_size is None or
795                        engine_max_batch_size == default_max_batch_size or
796                        not node_max_batch_size_all_none)
797
798    self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys()))
799    if expected_max_batch_sizes is not None:
800      self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes)
801
802  def _GetGraphDef(self, run_params, gdef_or_saved_model_dir):
803    if isinstance(gdef_or_saved_model_dir, str):
804      if run_params.is_v2:
805        root = load.load(gdef_or_saved_model_dir)
806        func = root.signatures[
807            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
808        gdef = func.graph.as_graph_def()
809        # Manually unref the loaded saved model and force GC to destroy the TRT
810        # engine cache after load(). There is currently a reference cycle in 2.0
811        # which prevents auto deletion of the resource.
812        # TODO(laigd): fix this.
813        del func
814        del root
815        gc.collect()
816        return gdef
817      return saved_model_utils.get_meta_graph_def(
818          gdef_or_saved_model_dir, tag_constants.SERVING).graph_def
819    assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef)
820    return gdef_or_saved_model_dir
821
822  def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify,
823                        graph_state):
824    expected_engines = self.ExpectedEnginesToBuild(run_params)
825    num_engines = 0
826    functions = [f.signature.name for f in gdef_to_verify.library.function]
827    for node in gdef_to_verify.node:
828      if node.op == "TRTEngineOp":
829        logging.info("Found TRTEngineOp: " + node.name)
830        num_engines += 1
831        segment_funcdef_name = node.attr["segment_func"].func.name
832        function_name = node.name + "_native_segment"
833        is_dynamic_engine = not node.attr["static_engine"].b
834        self.assertNotEmpty(segment_funcdef_name, node.name)
835        self.assertIn(function_name, functions)
836        if (not IsQuantizationWithCalibration(run_params) and
837            not is_dynamic_engine):
838          self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
839        self.assertIn(
840            self._RemoveGraphSequenceNumber(node.name), expected_engines)
841        if IsQuantizationWithoutCalibration(run_params):
842          # TODO(bixia): Refine this check by inspecting nodes in the engine.
843          if self._ToBytes("INT8") != node.attr["precision_mode"].s:
844            self.assertEqual(
845                self._ToBytes("FP16"), node.attr["precision_mode"].s, node.name)
846        else:
847          self.assertEqual(
848              self._ToBytes(run_params.precision_mode),
849              node.attr["precision_mode"].s, node.name)
850
851        self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
852                         node.name)
853        self.assertEqual(node.attr["use_calibration"].b,
854                         run_params.use_calibration, node.name)
855
856        has_calibration_data = len(node.attr["calibration_data"].s)
857        if (IsQuantizationWithCalibration(run_params) and
858            graph_state == GraphState.INFERENCE):
859          self.assertTrue(has_calibration_data, node.name)
860        else:
861          self.assertFalse(has_calibration_data, node.name)
862    if graph_state == GraphState.ORIGINAL:
863      self.assertEqual(0, num_engines)
864    else:
865      self.assertEqual(num_engines, len(expected_engines))
866      if isinstance(expected_engines, dict):
867        self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify)
868      self._VerifyMaxBatchSizeAnnotations(
869          expected_engines=expected_engines,
870          original_gdef=original_gdef,
871          converted_gdef=gdef_to_verify,
872          expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
873          default_max_batch_size=self.GetMaxBatchSize(run_params))
874
875  def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify,
876                        graph_state):
877    if graph_state == GraphState.ORIGINAL:
878      return
879    expected_engines = self.ExpectedEnginesToBuild(run_params)
880    all_op_names = [node.name for node in gdef_to_verify.node]
881    trt_op_names = []
882    for func in gdef_to_verify.library.function:
883      if not re.search(r"TRTEngineOp_\d+_\d+_native_segment",
884                       func.signature.name):
885        for node in func.node_def:
886          all_op_names.append(node.name)
887          if node.op == "TRTEngineOp":
888            trt_op_names.append(node.name)
889            if not self._use_implicit_batch:
890              self.assertEqual(
891                  self._ToString(node.attr["profile_strategy"].s).lower(),
892                  self._profile_strategy.lower())
893
894    all_op_names = self._Canonicalize(all_op_names)
895    trt_op_names = self._RemoveGraphSequenceNumber(
896        self._Canonicalize(trt_op_names))
897
898    if isinstance(expected_engines, dict):
899      # For simplicity we don't verify the connections inside the engine in
900      # 2.0, but we still make sure that the converted ops are gone from the
901      # graph.
902      unexpected_names = set(nest.flatten(expected_engines.values()))
903      self.assertEmpty(
904          [name for name in unexpected_names if name in all_op_names])
905      expected_engines = set(expected_engines.keys())
906
907    self.assertEqual(set(expected_engines), trt_op_names)
908
909  def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir,
910                      gdef_or_saved_model_dir_to_verify, graph_state):
911    original_gdef = self._GetGraphDef(run_params,
912                                      original_gdef_or_saved_model_dir)
913    gdef_to_verify = self._GetGraphDef(run_params,
914                                       gdef_or_saved_model_dir_to_verify)
915    self._WriteGraph(run_params, gdef_to_verify, graph_state)
916    if run_params.is_v2:
917      self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify,
918                             graph_state)
919    else:
920      self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify,
921                             graph_state)
922
923  def _GetSavedModelDir(self, run_params, graph_state):
924    test_tmpdir = os.getenv("TRT_TEST_TMPDIR")
925    if test_tmpdir:
926      saved_model_dir = os.path.join(
927          test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name +
928          "_" + self._GetGraphStateLabel(graph_state))
929      try:
930        # For TF 1.x we need to make sure the output directory doesn't exist
931        # before exporting the saved model.
932        shutil.rmtree(saved_model_dir)
933      except OSError as e:
934        if e.errno != errno.ENOENT:
935          raise
936      return saved_model_dir
937    return tempfile.mkdtemp(dir=self.get_temp_dir())
938
939  def _MakeSavedModelV1(self, run_params):
940    """Write the saved model as an input for testing."""
941    params = self._GetParamsCached()
942    g = ops.Graph()
943    with g.as_default():
944      inputs = []
945      for spec in params.input_specs:
946        inp = array_ops.placeholder(
947            dtype=spec.dtype, shape=spec.shape, name=spec.name)
948        inputs.append(inp)
949      outputs = params.graph_fn(*inputs)
950      if not isinstance(outputs, list) and not isinstance(outputs, tuple):
951        outputs = [outputs]
952
953    signature_def = signature_def_utils.build_signature_def(
954        inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs},
955        outputs={out.op.name: utils.build_tensor_info(out) for out in outputs},
956        method_name=signature_constants.PREDICT_METHOD_NAME)
957
958    saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL)
959    saved_model_builder = builder.SavedModelBuilder(saved_model_dir)
960    with self.session(
961        graph=g, config=self._GetConfigProto(run_params,
962                                             GraphState.ORIGINAL)) as sess:
963      saved_model_builder.add_meta_graph_and_variables(
964          sess, [tag_constants.SERVING],
965          signature_def_map={
966              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
967                  signature_def
968          })
969    saved_model_builder.save()
970    return saved_model_dir
971
972  def _MakeSavedModelV2(self, run_params):
973    params = self._GetParamsCached()
974    root = tracking.AutoTrackable()
975    root.run = def_function.function(
976        params.graph_fn, input_signature=params.input_specs)
977    saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL)
978    logging.info("Saving input SavedModel to %s", saved_model_dir)
979    save.save(root, saved_model_dir,
980              {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})
981    return saved_model_dir
982
983  def _MakeSavedModel(self, run_params):
984    if run_params.is_v2:
985      return self._MakeSavedModelV2(run_params)
986    return self._MakeSavedModelV1(run_params)
987
988  def RunTest(self, run_params):
989    with trace.Trace(run_params.test_name):
990      should_run, reason_for_skipping = self.ShouldRunTest(run_params)
991      if not should_run:
992        return self.skipTest(reason_for_skipping)
993
994      saved_model_dir = self._MakeSavedModel(run_params)
995
996      np.random.seed(12345)  # Fix the seed so the test is deterministic.
997      inputs_data = []
998      input_specs = self._GetParamsCached().input_specs
999      for dim_list in self._GetParamsCached().input_dims:
1000        assert len(input_specs) == len(dim_list)
1001        current_input_data = []
1002        for spec, np_shape in zip(input_specs, dim_list):
1003          np_dtype = spec.dtype.as_numpy_dtype()
1004          # Multiply the input by some constant to avoid all zeros input for
1005          # integer types.
1006          scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0
1007          # TODO(laigd): add debug options. E.g. we can set the input data to be
1008          # continuous natural numbers:
1009          # seq = np.arange(np.prod(np_shape))
1010          # seq.resize(np_shape)
1011          # current_inputs_data.append(scale * seq.astype(np_dtype))
1012          data = (scale * np.random.random_sample(np_shape)).astype(np_dtype)
1013          if run_params.is_v2:
1014            with ops.device("/GPU:0"):
1015              data = ops.convert_to_tensor(data)
1016          current_input_data.append(data)
1017        inputs_data.append(current_input_data)
1018
1019      # Verify the original graph.
1020      self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir,
1021                           GraphState.ORIGINAL)
1022
1023      # Run the original graph without TensorRT to get the reference result.
1024      logging.info("Running original graph w/o TensorRT\n")
1025      ref_result = self._RunGraph(
1026          run_params,
1027          saved_model_dir,
1028          inputs_data,
1029          GraphState.ORIGINAL,
1030          num_runs=1)
1031
1032      # Run calibration if necessary.
1033      if IsQuantizationWithCalibration(run_params):
1034        infer_saved_model_dir = self._GetCalibratedInferGraph(
1035            run_params, saved_model_dir, inputs_data)
1036        self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir,
1037                             GraphState.INFERENCE)
1038      elif not run_params.convert_online:
1039        infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir)
1040        self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir,
1041                             GraphState.INFERENCE)
1042      else:
1043        infer_saved_model_dir = saved_model_dir
1044
1045      # Run the inference graph, either using the converted graph or the
1046      # original graph with convert_online == True.
1047      logging.info("Running final inference graph\n")
1048      result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data,
1049                              GraphState.INFERENCE)
1050      self.assertAllClose(
1051          ref_result,
1052          result,
1053          atol=self.ExpectedAbsoluteTolerance(run_params),
1054          rtol=self.ExpectedRelativeTolerance(run_params))
1055
1056  def testIdempotence(self):
1057    # Test that applying tensorrt optimizer or offline conversion tools multiple
1058    # times to the same graph will result in same graph.
1059    #
1060    # TODO(aaroey): implement this.
1061    pass
1062
1063
1064def _GetTestConfigsV1():
1065  """Returns the config combinations to run the test."""
1066  convert_online, convert_offline = True, False
1067  dynamic_engine, static_engine = True, False
1068  use_calibration, no_calibration = True, False
1069
1070  # Add all possible test cases and let the derived test class to decide
1071  # whether to run specific ones with ShouldRunTest().
1072  #
1073  # Note: INT8 without calibration behaves like FP32/FP16.
1074  opts = list(
1075      itertools.product([FP32, FP16, INT8], [convert_online, convert_offline],
1076                        [dynamic_engine, static_engine], [no_calibration]))
1077  # We always run calibration with offline tool.
1078  # TODO(aaroey): static calibration engine is not supported yet.
1079  opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
1080  return opts
1081
1082
1083def _GetTestConfigsV2():
1084  """Returns the config combinations to run the test."""
1085  convert_offline = False
1086  # TODO(laigd): add support for static_engine.
1087  dynamic_engine = True
1088  # TODO(laigd): add support for calibration.
1089  no_calibration = False
1090
1091  # Add all possible test cases and let the derived test class to decide
1092  # whether to run specific ones with ShouldRunTest().
1093  #
1094  # Note:
1095  # - In TF2.0 the conversion always produce dynamic engine, and we don't test
1096  #   the offline mode here.
1097  # - For simplicity we don't test online conversion which requires setting the
1098  #   Grappler config in default eager context.
1099  # - INT8 without calibration behaves like FP32/FP16.
1100  opts = list(
1101      itertools.product([FP32, FP16, INT8], [convert_offline], [dynamic_engine],
1102                        [no_calibration]))
1103  # We always run calibration with offline tool.
1104  # TODO(aaroey): INT8+calibration is not supported yet in V2.
1105  # opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
1106  return opts
1107
1108
1109def _GetTest(run_params):
1110  """Gets a single test method based on the parameters."""
1111
1112  def _Test(self):
1113    logging.info(
1114        "Running test %s with parameters: convert_online=%s, "
1115        "precision_mode=%s, dynamic_engine=%s", run_params.test_name,
1116        run_params.convert_online, run_params.precision_mode,
1117        run_params.dynamic_engine)
1118    self.RunTest(run_params)
1119
1120  return _Test
1121
1122
1123def _AddTestsFor(test_class, is_v2):
1124  """Adds test methods to TfTrtIntegrationTestBase for specific TF version."""
1125  opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1()
1126  for (precision_mode, convert_online, dynamic_engine, use_calibration) in opts:
1127    conversion = "OnlineConversion" if convert_online else "OfflineConversion"
1128    engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine"
1129    calibration_type = "UseCalibration" if use_calibration else "NoCalibration"
1130    test_name = "%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt",
1131                                    conversion, engine_type, precision_mode,
1132                                    calibration_type)
1133    run_params = RunParams(
1134        convert_online=convert_online,
1135        precision_mode=precision_mode,
1136        dynamic_engine=dynamic_engine,
1137        test_name=test_name,
1138        use_calibration=use_calibration,
1139        is_v2=is_v2)
1140    if is_v2:
1141      setattr(test_class, test_name,
1142              test_util.run_v2_only(_GetTest(run_params)))
1143    else:
1144      setattr(test_class, test_name,
1145              test_util.run_v1_only("", _GetTest(run_params)))
1146
1147
1148def _AddTests(test_class):
1149  """Adds test methods to TfTrtIntegrationTestBase."""
1150  _AddTestsFor(test_class, is_v2=False)
1151  _AddTestsFor(test_class, is_v2=True)
1152
1153
1154if is_tensorrt_enabled():
1155  os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
1156  _AddTests(TfTrtIntegrationTestBase)
1157