• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17import collections
18import errno
19import gc
20import itertools
21import os
22import re
23import shutil
24import tempfile
25import warnings
26
27import numpy as np
28
29from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
30from tensorflow.core.framework import graph_pb2
31from tensorflow.core.protobuf import config_pb2
32from tensorflow.core.protobuf import rewriter_config_pb2
33from tensorflow.python.compiler.tensorrt import trt_convert
34from tensorflow.python.compiler.tensorrt import utils as trt_utils
35from tensorflow.python.eager import def_function
36from tensorflow.python.framework import graph_io
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import test_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.profiler import trace
44from tensorflow.python.saved_model import builder
45from tensorflow.python.saved_model import load
46from tensorflow.python.saved_model import loader
47from tensorflow.python.saved_model import save
48from tensorflow.python.saved_model import signature_constants
49from tensorflow.python.saved_model import signature_def_utils
50from tensorflow.python.saved_model import tag_constants
51from tensorflow.python.saved_model import utils
52from tensorflow.python.tools import saved_model_utils
53from tensorflow.python.trackable import autotrackable
54from tensorflow.python.util import nest
55
56logging.get_logger().propagate = False
57
58TfTrtIntegrationTestParams = collections.namedtuple(
59    "TfTrtIntegrationTestParams",
60    [
61        # A function that creates the TF graph for testing.
62        "graph_fn",
63        # A list of specifications for input tensors.
64        "input_specs",
65        # A list of specifications for output tensors.
66        "output_specs",
67        # A list of list of input shapes. Each shape must match the
68        # corresponding element in `input_specs`.
69        "input_dims",
70        # A list of list of expected output shapes. Each shape must match the
71        # corresponding element in `output_specs`.
72        "expected_output_dims"
73    ])
74
75RunParams = collections.namedtuple(
76    "RunParams",
77    [
78        # Whether to run the conversion online with RewriterConfig, or offline
79        # with TrtGraphConverter.
80        "convert_online",
81        "precision_mode",
82        "dynamic_engine",
83        "use_calibration",
84        "test_name",
85        # Is this test for TF 2.0?
86        "is_v2",
87        "dynamic_shape",
88    ])
89
90FP32 = "FP32"
91FP16 = "FP16"
92INT8 = "INT8"
93PRECISION_MODES = [FP32, FP16, INT8]
94
95
96def IsQuantizationMode(mode):
97  return mode == "INT8"
98
99
100def IsQuantizationWithCalibration(params):
101  return IsQuantizationMode(params.precision_mode) and params.use_calibration
102
103
104def IsQuantizationWithoutCalibration(params):
105  return IsQuantizationMode(
106      params.precision_mode) and not params.use_calibration
107
108
109class GraphState(object):
110  ORIGINAL = 0
111  CALIBRATE = 1
112  INFERENCE = 2
113
114
115class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
116  """Class to test Tensorflow-TensorRT integration."""
117
118  @property
119  def trt_incompatible_op(self):
120    return math_ops.erfc
121
122  @property
123  def trt_incompatible_binary_op(self):
124    return math_ops.igamma
125
126  @property
127  def precision_modes(self):
128    return ["FP32", "FP16", "INT8"]
129
130  # str is bytes in py2, but unicode in py3.
131  def _ToUnicode(self, s):
132    if isinstance(s, str):
133      return s
134    return s.decode("utf-8")
135
136  def _ToBytes(self, s):
137    if isinstance(s, str):
138      return s.encode("utf-8")
139    return s
140
141  def _ToString(self, s):
142    if isinstance(s, str):
143      return s
144    return s.decode("utf-8")
145
146  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
147    super(TfTrtIntegrationTestBase, self).__init__(methodName)
148    self._trt_test_params = None
149    self._disable_non_trt_optimizers = False
150    self._profile_strategy = "ImplicitBatchModeCompatible"
151
152  def setUp(self):
153    """Setup method."""
154    super().setUp()
155    warnings.simplefilter("always")
156
157    if not is_tensorrt_enabled():
158      self.skipTest("Test requires TensorRT")
159
160  def tearDown(self):
161    """Making sure to clean artifact."""
162    idx = 0
163    while gc.garbage:
164      gc.collect()  # Force GC to destroy the TRT engine cache.
165      idx += 1
166      if idx >= 10:  # After 10 iterations, break to avoid infinite collect.
167        break
168
169  def _GetTensorSpec(self, shape, mask, dtype, name):
170    # Set dimension i to None if mask[i] == False
171    assert len(shape) == len(mask), (
172        f"len(shape): {len(shape)} == len(mask): {len(mask)}")
173
174    new_shape = [s if m else None for s, m in zip(shape, mask)]
175    return tensor_spec.TensorSpec(new_shape, dtype, name)
176
177  def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes):
178    """Build test parameters.
179
180    The input_shapes and output_shapes arguments are known (static) shapes that
181    can be used to generate test data. To define the model, we also specify
182    corresponding input/output TensorSpecs. These are defined using the shape
183    arguments. For each input tensor we define:
184
185    input_spec = [None] + input_shape[1:]
186
187    and similarly for output shapes. This means that we leave the first (batch)
188    dimension unknown, the rest is just copied from the shapes arg.
189
190    Args:
191      graph_fn: The function to build the graph.
192      dtype: The element type.
193      input_shapes: The input shapes.
194      output_shapes: The output shapes.
195
196    Returns:
197      The test parameters.
198    """
199
200    input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes]
201    output_mask = [[False] + [True] * (len(shape) - 1) if shape else []
202                   for shape in output_shapes]
203
204    return self.BuildParamsWithMask(graph_fn, dtype, input_shapes,
205                                    output_shapes, input_mask, output_mask, [],
206                                    [])
207
208  def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes,
209                          input_mask, output_mask, extra_inputs, extra_outputs):
210    """Build test parameters with static or dynamic input shapes.
211
212    To define dynamic shapes give a boolean mask that describes which
213    dimensions to treat as known. The values in input_mask are interpreted the
214    following way:
215    - True: known dim (use the corresponding value from input_shapes)
216    - False: unknown dim (replace the corresponding value from input_shapes
217             with None)
218    For example, to define the first two dimension with unknown size use
219    input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]].
220
221    Args:
222      graph_fn: The function to build the graph.
223      dtype: The element type.
224      input_shapes: The input shapes.
225      output_shapes: The output shapes.
226      input_mask: The input shape masks.
227      output_mask: the output shape masks.
228      extra_inputs: list of additional input shapes
229      extra_outputs: list of additional outputs shapes
230
231    Returns:
232      The test parameters.
233    """
234
235    def _ValidateShapes(shapes):
236      # Make sure all the shapes are fully specified.
237      for shape in shapes:
238        assert all(shape), f"Shape unspecified: {shape}"
239
240    _ValidateShapes(input_shapes)
241    _ValidateShapes(output_shapes)
242
243    assert len(input_mask) == len(input_shapes), (
244        f"Inconsistent input_mask and input_shapes: len({input_mask}) != "
245        f"len({input_shapes}).")
246    assert len(output_mask) == len(output_shapes), (
247        f"Inconsistent output_mask and output_shapes: len({output_mask}) != "
248        f"len({output_shapes}).")
249    for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs):
250      assert len(input_shapes) == len(extra_in_shape), (
251          f"Inconsistent input_shapes and extra_in_shape: len({input_shapes}) "
252          f"!= len({extra_in_shape}).")
253      assert len(output_shapes) == len(extra_out_shape), (
254          f"Inconsistent output_shapes and extra_out_shape: "
255          f"len({output_shapes}) != len({extra_out_shape}).")
256
257    return TfTrtIntegrationTestParams(
258        graph_fn=graph_fn,
259        input_specs=[
260            self._GetTensorSpec(shape, mask, dtype, "input_%d" % i)
261            for i, (shape, mask) in enumerate(zip(input_shapes, input_mask))
262        ],
263        output_specs=[
264            self._GetTensorSpec(shape, mask, dtype, "output_%d" % i)
265            for i, (shape, mask) in enumerate(zip(output_shapes, output_mask))
266        ],
267        input_dims=[input_shapes] + extra_inputs,
268        expected_output_dims=[output_shapes] + extra_outputs)
269
270  def DisableNonTrtOptimizers(self):
271    self._disable_non_trt_optimizers = True
272
273  def GetParams(self):
274    """Returns a TfTrtIntegrationTestParams for the test."""
275    raise NotImplementedError()
276
277  def GetConversionParams(self, run_params):
278    """Returns a TrtConversionParams for test."""
279    conversion_params = trt_convert.TrtConversionParams(
280        # We use the minimum of all the batch sizes, so when multiple different
281        # input shapes are provided it'll always create new engines in the
282        # cache, and we can therefore test the cache behavior.
283        max_workspace_size_bytes=(
284            trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES),
285        precision_mode=run_params.precision_mode,
286        minimum_segment_size=2,
287        maximum_cached_engines=1,
288        use_calibration=run_params.use_calibration)
289    return conversion_params
290
291  def GetMaxBatchSize(self, run_params):
292    """Returns the max_batch_size that the converter should use for tests."""
293    if run_params.dynamic_engine:
294      return None
295    batch_list = []
296    for dims_list in self._GetParamsCached().input_dims:
297      assert dims_list, f"Expect non-empty `dim_list` but got: {dims_list}"
298      # Each list of shapes should have same batch size.
299      input_batches = [dims[0] for dims in dims_list]
300      assert max(input_batches) == min(input_batches), (
301          f"Inconsistent batch_size: max({input_batches}) != "
302          f"min({input_batches}).")
303      batch_list.append(input_batches[0])
304    return max(batch_list)
305
306  def ShouldRunTest(self, run_params):
307    """Whether to run the test."""
308    # Ensure use_calibration=True in case of INT8 precision
309    return (run_params.use_calibration or not IsQuantizationMode(
310        run_params.precision_mode)), "test either calibration or non-INT8"
311
312  def ExpectedEnginesToBuild(self, run_params):
313    """Returns the expected engines to build, implemented by subclass."""
314    raise NotImplementedError()
315
316  def ExpectedConnections(self, run_params):
317    """Returns the expected edges or an empty dict to skip the check."""
318    return {}
319
320  def ExpectedMaxBatchSizes(self, run_params):
321    """Returns the expected maximum batch sizes of the build engines."""
322    return None
323
324  def ExpectedAbsoluteTolerance(self, run_params):
325    """The absolute tolerance to compare floating point results."""
326    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
327
328  def ExpectedRelativeTolerance(self, run_params):
329    """The relative tolerance to compare floating point results."""
330    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
331
332  def _GetParamsCached(self):
333    if self._trt_test_params is None:
334      self._trt_test_params = self.GetParams()
335    return self._trt_test_params
336
337  def _GetGPUOptions(self):
338    gpu_options = config_pb2.GPUOptions()
339    gpu_options.allow_growth = True
340    return gpu_options
341
342  def _GetConfigProto(self, run_params, graph_state):
343    """Get config proto based on specific settings."""
344    conversion_params = self.GetConversionParams(run_params)
345    max_batch_size = self.GetMaxBatchSize(run_params)
346
347    if graph_state == GraphState.INFERENCE and run_params.convert_online:
348      rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
349          conversion_params,
350          is_dynamic_op=run_params.dynamic_engine,
351          max_batch_size=max_batch_size,
352          disable_non_trt_optimizers=self._disable_non_trt_optimizers)
353    else:
354      rewriter_cfg = rewriter_config_pb2.RewriterConfig()
355      if self._disable_non_trt_optimizers:
356        trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
357
358    config = config_pb2.ConfigProto(
359        gpu_options=self._GetGPUOptions(),
360        graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
361    return config
362
363  def _GetFeedNames(self):
364    params = self._GetParamsCached()
365    # Construct the feeds tensor names by appending :0 to the node names.
366    return [spec.name + ":0" for spec in params.input_specs]
367
368  def _GetFetchNames(self):
369    params = self._GetParamsCached()
370    # Construct the fetches tensor names by appending :0 to the node names.
371    return [spec.name + ":0" for spec in params.output_specs]
372
373  def _GetFeedDict(self, inputs_data):
374    return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)}
375
376  def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2):
377    """Run given graphdef multiple times using TF 1.x runtime."""
378    params = self._GetParamsCached()
379    fetches = self._GetFetchNames()
380    g = ops.Graph()
381    with g.as_default():
382      with self.session(graph=g, config=config, use_gpu=True) as sess:
383        loader.load(sess, [tag_constants.SERVING], saved_model_dir)
384        vals = []
385        # Run for each input(s) shape
386        for expected_shapes, current_input_data in zip(
387            params.expected_output_dims, inputs_data):
388          val = None
389          for _ in range(num_runs):
390            new_val = sess.run(fetches, self._GetFeedDict(current_input_data))
391            self.assertEqual(len(expected_shapes), len(new_val))
392            for expected_shape, actual_val in zip(expected_shapes, new_val):
393              self.assertEqual(list(expected_shape), list(actual_val.shape))
394            if val is not None:
395              # Some ops may have nondeterministic output. E.g. Conv2D may use
396              # winograd algorithm. So we set atol/rtol be larger than 1.e-06.
397              self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05)
398            val = new_val
399          vals.append(val)
400        return vals
401
402  def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2):
403    """Run given graphdef multiple times using TF 2.0 runtime."""
404    params = self._GetParamsCached()
405    root = load.load(saved_model_dir)
406    func = root.signatures[
407        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
408    results = []
409    for expected_shapes, current_input_data in zip(params.expected_output_dims,
410                                                   inputs_data):
411      val = None
412      for _ in range(num_runs):
413        feed_dict = {
414            params.input_specs[i].name: current_input_data[i]
415            for i in range(len(params.input_specs))
416        }
417        new_val = func(**feed_dict)
418        assert isinstance(new_val, dict), (
419            f"Invalid type for `new_val`, expected `dict`. Got: {type(new_val)}."
420        )
421        # The key of the output map is always like output_i.
422        new_val = [new_val[key] for key in sorted(new_val)]
423        # Each element is an eager Tensor, and accessing individual elements is
424        # very expensive, so we convert them to a numpy array first.
425        new_val = [v.numpy() for v in new_val]
426        self.assertEqual(len(expected_shapes), len(new_val))
427        for expected_shape, actual_val in zip(expected_shapes, new_val):
428          self.assertEqual(list(expected_shape), list(actual_val.shape))
429        if val is not None:
430          # Some ops may have nondeterministic output. E.g. Conv2D may use
431          # winograd algorithm. So we set atol/rtol be larger than 1.e-06.
432          self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05)
433        val = new_val
434      results.append(val)
435
436    return results
437
438  def _RunGraph(self,
439                run_params,
440                saved_model_dir,
441                inputs_data,
442                graph_state,
443                num_runs=2):
444    params = self._GetParamsCached()
445    for data in inputs_data:
446      assert len(params.input_specs) == len(data), (
447          f"Inconsistent params.input_specs and data: "
448          f"len({params.input_specs}) != len({data}).")
449
450    if run_params.is_v2:
451      results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state,
452                                 num_runs)
453      gc.collect()  # Force GC to destroy the TRT engine cache.
454      return results
455
456    # The default config for tf.session is None. Create a config with
457    # TensorRTOptimizer enabled to support convert_online for inference.
458    config = None
459    # TODO(b/170220818): use the default session config to run inferenence
460    #   graphs for the offline conversion case after fixing the bug.
461    if graph_state == GraphState.INFERENCE:
462      config = self._GetConfigProto(run_params, GraphState.INFERENCE)
463    return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs)
464
465  def _CreateConverter(self, run_params, saved_model_dir, conversion_params):
466    """Returns a TrtGraphConverter."""
467    if run_params.is_v2:
468      converter_v2 = trt_convert.TrtGraphConverterV2(
469          input_saved_model_dir=saved_model_dir,
470          use_dynamic_shape=run_params.dynamic_shape,
471          dynamic_shape_profile_strategy=self._profile_strategy,
472          **conversion_params._asdict())
473      if self._disable_non_trt_optimizers:
474        converter_v2._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
475      return converter_v2
476
477    converter_v1 = trt_convert.TrtGraphConverter(
478        input_saved_model_dir=saved_model_dir,
479        max_batch_size=self.GetMaxBatchSize(run_params),
480        max_workspace_size_bytes=conversion_params.max_workspace_size_bytes,
481        precision_mode=conversion_params.precision_mode,
482        minimum_segment_size=conversion_params.minimum_segment_size,
483        is_dynamic_op=run_params.dynamic_engine,
484        maximum_cached_engines=conversion_params.maximum_cached_engines,
485        use_calibration=conversion_params.use_calibration)
486    if self._disable_non_trt_optimizers:
487      converter_v1._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
488    return converter_v1
489
490  def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
491    """Return trt converted graphdef in INT8 mode."""
492    conversion_params = self.GetConversionParams(run_params)
493    logging.info(conversion_params)
494    assert conversion_params.precision_mode == "INT8", (
495        f"Incorrect precision mode, expected INT8 but got: "
496        f"{conversion_params.precision_mode}.")
497    assert run_params.dynamic_engine, "dynamic_engine parameter must be True."
498    assert conversion_params.maximum_cached_engines == 1, (
499        f"maximum_cached_engines: {conversion_params.maximum_cached_engines} == 1"
500    )
501    assert conversion_params.use_calibration, "use_calibration must be True."
502
503    # We only support calibrating single engine.
504    # TODO(aaroey): fix this.
505    assert len(inputs_data) == 1, (f"len(inputs_data): {len(inputs_data)} == 1")
506
507    converter = self._CreateConverter(run_params, saved_model_dir,
508                                      conversion_params)
509    if run_params.is_v2:
510
511      def CalibrationInputFn():
512        for data_tensors in inputs_data:
513          yield data_tensors
514
515      converter.convert(calibration_input_fn=CalibrationInputFn)
516    else:
517      int8_gdef = converter.convert()
518      self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef,
519                           GraphState.CALIBRATE)
520
521      converter.calibrate(
522          fetch_names=self._GetFetchNames(),
523          num_runs=5,
524          feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0]))
525
526    if run_params.dynamic_shape and self._ShouldConverterBuild(run_params):
527      logging.info("Using build mode")
528
529      def _BuildInputFn():
530        for shapes in self._GetParamsCached().input_dims:
531          yield [
532              array_ops.zeros(x, dtype=spec.dtype)
533              for (x, spec) in zip(shapes,
534                                   self._GetParamsCached().input_specs)
535          ]
536
537      converter.build(input_fn=_BuildInputFn)
538
539    trt_saved_model_dir = self._GetSavedModelDir(run_params,
540                                                 GraphState.CALIBRATE)
541    converter.save(trt_saved_model_dir)
542    return trt_saved_model_dir
543
544  def _ShouldConverterBuild(self, run_params):
545    return True
546
547  def _GetInferGraph(self, run_params, saved_model_dir):
548    """Return trt converted graphdef."""
549    conversion_params = self.GetConversionParams(run_params)
550    logging.info(conversion_params)
551
552    converter = self._CreateConverter(run_params, saved_model_dir,
553                                      conversion_params)
554    converter.convert()
555
556    if run_params.is_v2:
557      try:
558        line_length = max(160, os.get_terminal_size().columns)
559      except OSError:
560        line_length = 160
561      converter.summary(line_length=line_length, detailed=True)
562
563    if run_params.dynamic_shape and self._ShouldConverterBuild(run_params):
564      logging.info("Using build mode")
565
566      def _BuildInputFn():
567        for shapes in self._GetParamsCached().input_dims:
568          yield [
569              array_ops.zeros(x, dtype=spec.dtype)
570              for (x, spec) in zip(shapes,
571                                   self._GetParamsCached().input_specs)
572          ]
573
574      converter.build(input_fn=_BuildInputFn)
575
576    trt_saved_model_dir = self._GetSavedModelDir(run_params,
577                                                 GraphState.INFERENCE)
578    converter.save(trt_saved_model_dir)
579    return trt_saved_model_dir
580
581  def _GetGraphStateLabel(self, graph_state):
582    if graph_state == GraphState.ORIGINAL:
583      return "Original"
584    elif graph_state == GraphState.CALIBRATE:
585      return "CalibEngine"
586    elif graph_state == GraphState.INFERENCE:
587      return "InferEngine"
588    else:
589      return "UnknownState"
590
591  def _WriteGraph(self, run_params, gdef, graph_state):
592    temp_dir = os.getenv("TRT_TEST_TMPDIR")
593    if not temp_dir:
594      return
595
596    graph_name = (
597        self.__class__.__name__ + "_" + run_params.test_name + "_" +
598        self._GetGraphStateLabel(graph_state) + ".pbtxt")
599    logging.info("Writing graph to %s/%s", temp_dir, graph_name)
600    graph_io.write_graph(gdef, temp_dir, graph_name)
601
602  # Removes the prefix(s) of function name(s).
603  # The input value can be a string or a sequence of string.
604  def _Canonicalize(self, value):
605    if isinstance(value, str):
606      return self._ToString(value.split("/")[-1])
607    elif isinstance(value, collections.abc.Iterable):
608      return set(self._Canonicalize(nm) for nm in value)
609    else:
610      raise TypeError(
611          "'_Canonicalize' can only be used on strings or sequence of strings!")
612
613  # Removes the graph sequence number prefix from the name(s) only if the
614  # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts
615  # such a prefix exists.
616  # The input value can be a string or a sequence of string.
617  def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix):
618    if isinstance(value, str):
619      match = re.search(r"TRTEngineOp_\d{3,}_", value)
620      has_prefix = match and value.startswith(match.group(0))
621      assert (not expecting_prefix) or has_prefix, (
622          f"Expect (not expecting_prefix) or has_prefix but got: "
623          f"- expecting_prefix = {expecting_prefix}\n- has_prefix = {has_prefix}"
624      )
625      if has_prefix:
626        parts = value.split("_", maxsplit=2)
627        assert len(parts) == 3, (
628            f"Incorrect `parts` of length == 3, but got: len({parts}).")
629        return parts[0] + "_" + parts[2]
630      return value
631    elif isinstance(value, collections.abc.Iterable):
632      return set(
633          self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix)
634          for nm in value)
635    else:
636      raise TypeError(
637          "'_RemoveGraphSequenceNumberImpl' can only be used on strings "
638          "or sequence of strings!")
639
640  def _RemoveGraphSequenceNumber(self, name):
641    return self._RemoveGraphSequenceNumberImpl(name, True)
642
643  def _MayRemoveGraphSequenceNumber(self, name):
644    return self._RemoveGraphSequenceNumberImpl(name, False)
645
646  def _VerifyConnections(self, expected_engines, expected_input_map,
647                         original_gdef, converted_gdef):
648    """Checks that the converted graph contains the expected connections."""
649    old_to_new_node_map = {
650        self._ToString(node.name): self._ToString(node.name)
651        for node in original_gdef.node
652    }
653    for engine_name, node_names in expected_engines.items():
654      for node_name in node_names:
655        old_to_new_node_map[node_name] = engine_name
656
657    def _InputName(inp):
658      inp = self._ToString(inp)
659      prefix = ""
660      if inp[0] == "^":
661        prefix = "^"
662        inp = inp[1:]
663      parts = inp.split(":")
664      if len(parts) > 1 and parts[-1].isdigit():
665        inp = inp[:-len(parts[-1]) - 1]
666      return (prefix, inp)
667
668    # Compute the actual mapping from each node to its input nodes. If a cast
669    # op doesn't exist in the original graph, we replace the use of the cast op
670    # with the input of the op. This allows the verification to handle the case
671    # where the TF-TRT bridge splits a cast op into a chain of two cast ops.
672    new_cast_op_name_to_node_map = {
673        node.name: node
674        for node in converted_gdef.node
675        if (node.name not in old_to_new_node_map and node.op == "Cast")
676    }
677    actual_input_map = {}
678    for node in converted_gdef.node:
679      name_str = node.name
680      # Only nodes from the original graph or TRTEngineOp nodes are added as
681      # keys to the map.
682      if node.op == "TRTEngineOp":
683        name_str = self._RemoveGraphSequenceNumber(name_str)
684      elif name_str not in old_to_new_node_map:
685        continue
686      actual_input_map[name_str] = set()
687      input_set = actual_input_map[name_str]
688      for inp in node.input:
689        (prefix, node_name) = _InputName(inp)
690        node_name = self._MayRemoveGraphSequenceNumber(node_name)
691        if node_name in new_cast_op_name_to_node_map:
692          (prefix, node_name) = _InputName(
693              new_cast_op_name_to_node_map[node_name].input[0])
694        input_set.add(prefix + node_name)
695
696    self.assertEqual(
697        expected_input_map,
698        actual_input_map,
699        msg="\nexpected:\n%s\nvs actual:\n%s" %
700        (sorted(expected_input_map.items()), sorted(actual_input_map.items())))
701
702  def _VerifyMaxBatchSizeAnnotations(
703      self,
704      expected_engines,
705      original_gdef,
706      converted_gdef,
707      default_max_batch_size,
708      expected_max_batch_sizes=None,
709  ):
710    """Verifies the max batch size annotations in the original and converted GraphDef.
711
712    Args:
713      expected_engines: A sequence of engines names.
714      original_gdef: GraphDef. The graph def before TensorRT conversion.
715      converted_gdef: GraphDef. The graph def after TensorRT conversion.
716      default_max_batch_size: The default maximum batch size to use if no node
717        inside a segment is annoted with a customized max batch size. This value
718        is None when the graph is converted to TF-TRT with dynamic engines.
719      expected_max_batch_sizes: Optional. A sequence of max batch sizes for all
720        the engines. `None` if does not check enforce max batch sizes.
721    """
722    if isinstance(expected_max_batch_sizes, collections.abc.Collection):
723      self.assertEqual(len(expected_max_batch_sizes), len(expected_engines))
724    else:
725      self.assertIsNone(
726          expected_max_batch_sizes,
727          "'expected_max_batch_sizes' shall only be a sequence "
728          "of integers or `None`.")
729
730    def _ChainAllNodes(graph_def):
731      return itertools.chain(
732          graph_def.node,
733          itertools.chain(
734              *[func.node_def for func in graph_def.library.function]))
735
736    old_name_to_node_map = {
737        self._ToString(node.name): node
738        for node in _ChainAllNodes(original_gdef)
739    }
740    new_name_to_func_map = {
741        self._ToString(func.signature.name): func
742        for func in converted_gdef.library.function
743    }
744
745    def _DetectStaticBatchSize(node_def):
746      """Returns the static batch size of an operation or None.
747
748      It is incorrect to use the output shapes to find the batch size of an
749      operation, as the segmenter actually uses the input shapes. However, it is
750      a simplication and works for most of the cases for the test purposes.
751
752      Args:
753        node_def: `tf.NodeDef`. The target node for analysis.
754
755      Returns:
756        If all the outputs of the node have the same static batch size, returns
757        the int value for the batch size. Otherwise returns None.
758      """
759      shapes = node_def.attr["_output_shapes"].list.shape
760      batch_size = set(
761          list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes)
762      if len(batch_size) == 1 and list(batch_size)[0] >= 1:
763        return list(batch_size)[0]
764      return None
765
766    name_to_engines_map = {}
767    actual_max_batch_sizes = []
768    for node in _ChainAllNodes(converted_gdef):
769      if node.op == "TRTEngineOp":
770        engine = node
771        engine_name = self._RemoveGraphSequenceNumber(
772            self._Canonicalize(self._ToString(engine.name)))
773        self.assertIn(engine_name, expected_engines)
774        name_to_engines_map[engine_name] = engine
775        # The input nodes shall not have the conflicting annotation (no
776        # annotation or the same annotation) with the maximum batch size
777        # annotation. If the engine has maximum batch size annotation as the
778        # non-default maximum batch size, then at least one input node shall
779        # have the same annotation to be the source.
780        self.assertIn("max_batch_size", node.attr)
781        engine_max_batch_size = node.attr["max_batch_size"].i
782        self.assertIsInstance(engine_max_batch_size, int)
783        actual_max_batch_sizes.append(engine_max_batch_size)
784        seg_func = node.attr["segment_func"].func
785        self.assertIsNotNone(seg_func)
786        self.assertIn(seg_func.name, new_name_to_func_map)
787        seg_func_def = new_name_to_func_map[seg_func.name]
788        logging.info("Segment function name: %s. Including %d nodes.",
789                     seg_func.name, len(seg_func_def.node_def))
790        node_max_batch_size_all_none = True
791        # Use the native segment to search for replaced nodes
792        for alternative_node in seg_func_def.node_def:
793          node_name = self._Canonicalize(self._ToString(alternative_node.name))
794          if node_name not in old_name_to_node_map:
795            continue
796          original_node = old_name_to_node_map[node_name]
797          node_max_batch_size = None
798          if "_tftrt_op_max_batch_size" in original_node.attr:
799            node_max_batch_size = original_node.attr[
800                "_tftrt_op_max_batch_size"].i
801          elif (original_node.op != "Const" and
802                alternative_node.op != "Const" and
803                "_output_shapes" in original_node.attr):
804            node_max_batch_size = _DetectStaticBatchSize(original_node)
805          logging.info(
806              "'{%s}(%s)'s max batch size annotation is %s. "
807              "'{%s}'s max batch size is %s.", node_name, original_node.op,
808              str(node_max_batch_size), engine_name, str(engine_max_batch_size))
809          node_max_batch_size_all_none &= node_max_batch_size is None
810          self.assertTrue(engine_max_batch_size == node_max_batch_size or
811                          node_max_batch_size is None)
812        logging.info("'{%s}'s max batch size is %d.", engine_name,
813                     engine_max_batch_size)
814        self.assertTrue(default_max_batch_size is None or
815                        engine_max_batch_size == default_max_batch_size or
816                        not node_max_batch_size_all_none)
817
818    self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys()))
819    if expected_max_batch_sizes is not None:
820      self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes)
821
822  def _GetGraphDef(self, run_params, gdef_or_saved_model_dir):
823    if isinstance(gdef_or_saved_model_dir, str):
824      if run_params.is_v2:
825        root = load.load(gdef_or_saved_model_dir)
826        func = root.signatures[
827            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
828        gdef = func.graph.as_graph_def()
829        # Manually unref the loaded saved model and force GC to destroy the TRT
830        # engine cache after load(). There is currently a reference cycle in 2.0
831        # which prevents auto deletion of the resource.
832        # TODO(laigd): fix this.
833        del func
834        del root
835        gc.collect()
836        return gdef
837      return saved_model_utils.get_meta_graph_def(
838          gdef_or_saved_model_dir, tag_constants.SERVING).graph_def
839    assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef), (
840        f"Incorrect `gdef_or_saved_model_dir` type, expected "
841        f"`graph_pb2.GraphDef`, but got: {type(gdef_or_saved_model_dir)}.")
842    return gdef_or_saved_model_dir
843
844  def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify,
845                        graph_state):
846    expected_engines = self.ExpectedEnginesToBuild(run_params)
847    num_engines = 0
848    functions = [f.signature.name for f in gdef_to_verify.library.function]
849    for node in gdef_to_verify.node:
850      if node.op == "TRTEngineOp":
851        logging.info("Found TRTEngineOp: " + node.name)
852        num_engines += 1
853        segment_funcdef_name = node.attr["segment_func"].func.name
854        function_name = node.name + "_native_segment"
855        is_dynamic_engine = not node.attr["static_engine"].b
856        self.assertNotEmpty(segment_funcdef_name, node.name)
857        self.assertIn(function_name, functions)
858        if (not IsQuantizationWithCalibration(run_params) and
859            not is_dynamic_engine):
860          self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
861        self.assertIn(
862            self._RemoveGraphSequenceNumber(node.name), expected_engines)
863        if IsQuantizationWithoutCalibration(run_params):
864          # TODO(bixia): Refine this check by inspecting nodes in the engine.
865          if self._ToBytes("INT8") != node.attr["precision_mode"].s:
866            self.assertEqual(
867                self._ToBytes("FP16"), node.attr["precision_mode"].s, node.name)
868        else:
869          self.assertEqual(
870              self._ToBytes(run_params.precision_mode),
871              node.attr["precision_mode"].s, node.name)
872
873        self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
874                         node.name)
875        self.assertEqual(node.attr["use_calibration"].b,
876                         run_params.use_calibration, node.name)
877
878        has_calibration_data = len(node.attr["calibration_data"].s)
879        if (IsQuantizationWithCalibration(run_params) and
880            graph_state == GraphState.INFERENCE):
881          self.assertTrue(has_calibration_data, node.name)
882        else:
883          self.assertFalse(has_calibration_data, node.name)
884    if graph_state == GraphState.ORIGINAL:
885      self.assertEqual(0, num_engines)
886    else:
887      self.assertEqual(num_engines, len(expected_engines))
888      expected_connections = self.ExpectedConnections(run_params)
889      if expected_connections:
890        self._VerifyConnections(expected_engines, expected_connections,
891                                original_gdef, gdef_to_verify)
892      self._VerifyMaxBatchSizeAnnotations(
893          expected_engines=expected_engines,
894          original_gdef=original_gdef,
895          converted_gdef=gdef_to_verify,
896          expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
897          default_max_batch_size=self.GetMaxBatchSize(run_params))
898
899  def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify,
900                        graph_state):
901    if graph_state == GraphState.ORIGINAL:
902      return
903    expected_engines = self.ExpectedEnginesToBuild(run_params)
904    all_op_names = [node.name for node in gdef_to_verify.node]
905    trt_op_names = []
906    for func in gdef_to_verify.library.function:
907      if not re.search(r"TRTEngineOp_\d{3,}_\d{3,}_native_segment",
908                       func.signature.name):
909        for node in func.node_def:
910          all_op_names.append(node.name)
911          if node.op == "TRTEngineOp":
912            trt_op_names.append(node.name)
913            if run_params.dynamic_shape:
914              self.assertEqual(
915                  self._ToString(node.attr["profile_strategy"].s).lower(),
916                  self._profile_strategy.lower())
917
918    all_op_names = self._Canonicalize(all_op_names)
919    trt_op_names = self._RemoveGraphSequenceNumber(
920        self._Canonicalize(trt_op_names))
921
922    if isinstance(expected_engines, dict):
923      # For simplicity we don't verify the connections inside the engine in
924      # 2.0, but we still make sure that the converted ops are gone from the
925      # graph.
926      unexpected_names = set(nest.flatten(expected_engines.values()))
927      self.assertEmpty(
928          [name for name in unexpected_names if name in all_op_names])
929      expected_engines = set(expected_engines.keys())
930
931    self.assertEqual(set(expected_engines), trt_op_names)
932
933  def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir,
934                      gdef_or_saved_model_dir_to_verify, graph_state):
935    original_gdef = self._GetGraphDef(run_params,
936                                      original_gdef_or_saved_model_dir)
937    gdef_to_verify = self._GetGraphDef(run_params,
938                                       gdef_or_saved_model_dir_to_verify)
939    self._WriteGraph(run_params, gdef_to_verify, graph_state)
940    if run_params.is_v2:
941      self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify,
942                             graph_state)
943    else:
944      self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify,
945                             graph_state)
946
947  def _GetSavedModelDir(self, run_params, graph_state):
948    test_tmpdir = os.getenv("TRT_TEST_TMPDIR")
949    if test_tmpdir:
950      saved_model_dir = os.path.join(
951          test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name +
952          "_" + self._GetGraphStateLabel(graph_state))
953      try:
954        # For TF 1.x we need to make sure the output directory doesn't exist
955        # before exporting the saved model.
956        shutil.rmtree(saved_model_dir)
957      except OSError as e:
958        if e.errno != errno.ENOENT:
959          raise
960      return saved_model_dir
961    return tempfile.mkdtemp(dir=self.get_temp_dir())
962
963  def _MakeSavedModelV1(self, run_params):
964    """Write the saved model as an input for testing."""
965    params = self._GetParamsCached()
966    g = ops.Graph()
967    with g.as_default():
968      inputs = []
969      for spec in params.input_specs:
970        inp = array_ops.placeholder(
971            dtype=spec.dtype, shape=spec.shape, name=spec.name)
972        inputs.append(inp)
973      outputs = params.graph_fn(*inputs)
974      if not isinstance(outputs, list) and not isinstance(outputs, tuple):
975        outputs = [outputs]
976
977    signature_def = signature_def_utils.build_signature_def(
978        inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs},
979        outputs={out.op.name: utils.build_tensor_info(out) for out in outputs},
980        method_name=signature_constants.PREDICT_METHOD_NAME)
981
982    saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL)
983    saved_model_builder = builder.SavedModelBuilder(saved_model_dir)
984    with self.session(
985        graph=g, config=self._GetConfigProto(run_params,
986                                             GraphState.ORIGINAL)) as sess:
987      saved_model_builder.add_meta_graph_and_variables(
988          sess, [tag_constants.SERVING],
989          signature_def_map={
990              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
991                  signature_def
992          })
993    saved_model_builder.save()
994    return saved_model_dir
995
996  def _MakeSavedModelV2(self, run_params):
997    params = self._GetParamsCached()
998    root = autotrackable.AutoTrackable()
999    root.run = def_function.function(
1000        params.graph_fn, input_signature=params.input_specs)
1001    saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL)
1002    logging.info("Saving input SavedModel to %s", saved_model_dir)
1003    save.save(root, saved_model_dir,
1004              {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})
1005    return saved_model_dir
1006
1007  def _MakeSavedModel(self, run_params):
1008    if run_params.is_v2:
1009      return self._MakeSavedModelV2(run_params)
1010    return self._MakeSavedModelV1(run_params)
1011
1012  def RunTest(self, run_params):
1013    with trace.Trace(run_params.test_name):
1014      should_run, reason_for_skipping = self.ShouldRunTest(run_params)
1015      if not should_run:
1016        return self.skipTest(reason_for_skipping)
1017
1018      saved_model_dir = self._MakeSavedModel(run_params)
1019
1020      np.random.seed(12345)  # Fix the seed so the test is deterministic.
1021      inputs_data = []
1022      input_specs = self._GetParamsCached().input_specs
1023      for dim_list in self._GetParamsCached().input_dims:
1024        assert len(input_specs) == len(dim_list), (
1025            f"Inconsistent input_specs and dim_list: len({input_specs}) != "
1026            f"len({dim_list}).")
1027        current_input_data = []
1028        for spec, np_shape in zip(input_specs, dim_list):
1029          np_dtype = spec.dtype.as_numpy_dtype()
1030          # Multiply the input by some constant to avoid all zeros input for
1031          # integer types.
1032          scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0
1033          # TODO(laigd): add debug options. E.g. we can set the input data to be
1034          # continuous natural numbers:
1035          # seq = np.arange(np.prod(np_shape))
1036          # seq.resize(np_shape)
1037          # current_inputs_data.append(scale * seq.astype(np_dtype))
1038          data = (scale * np.random.random_sample(np_shape)).astype(np_dtype)
1039          if run_params.is_v2:
1040            with ops.device("/GPU:0"):
1041              data = ops.convert_to_tensor(data)
1042          current_input_data.append(data)
1043        inputs_data.append(current_input_data)
1044
1045      # Verify the original graph.
1046      self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir,
1047                           GraphState.ORIGINAL)
1048
1049      # Run the original graph without TensorRT to get the reference result.
1050      logging.info("Running original graph w/o TensorRT\n")
1051      ref_result = self._RunGraph(
1052          run_params,
1053          saved_model_dir,
1054          inputs_data,
1055          GraphState.ORIGINAL,
1056          num_runs=1)
1057
1058      # Run calibration if necessary.
1059      if IsQuantizationWithCalibration(run_params):
1060        infer_saved_model_dir = self._GetCalibratedInferGraph(
1061            run_params, saved_model_dir, inputs_data)
1062        self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir,
1063                             GraphState.INFERENCE)
1064      elif not run_params.convert_online:
1065        infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir)
1066        self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir,
1067                             GraphState.INFERENCE)
1068      else:
1069        infer_saved_model_dir = saved_model_dir
1070
1071      # Run the inference graph, either using the converted graph or the
1072      # original graph with convert_online == True.
1073      logging.info("Running final inference graph\n")
1074      result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data,
1075                              GraphState.INFERENCE)
1076      self.assertAllClose(
1077          ref_result,
1078          result,
1079          atol=self.ExpectedAbsoluteTolerance(run_params),
1080          rtol=self.ExpectedRelativeTolerance(run_params))
1081
1082  def testIdempotence(self):
1083    # Test that applying tensorrt optimizer or offline conversion tools multiple
1084    # times to the same graph will result in same graph.
1085    #
1086    # TODO(aaroey): implement this.
1087    pass
1088
1089
1090def _GetTestConfigsV1():
1091  """Returns the config combinations to run the test."""
1092  convert_online, convert_offline = True, False
1093  dynamic_engine, static_engine = True, False
1094  use_calibration, no_calibration = True, False
1095  implicit_batch = False
1096
1097  # Add all possible test cases and let the derived test class to decide
1098  # whether to run specific ones with ShouldRunTest().
1099  #
1100  # Note: INT8 without calibration behaves like FP32/FP16.
1101  opts = list(
1102      itertools.product([FP32, FP16, INT8], [convert_online, convert_offline],
1103                        [dynamic_engine, static_engine], [no_calibration],
1104                        [implicit_batch]))
1105  # We always run calibration with offline tool.
1106  # TODO(aaroey): static calibration engine is not supported yet.
1107  opts.append(
1108      (INT8, convert_offline, dynamic_engine, use_calibration, implicit_batch))
1109  return opts
1110
1111
1112def _GetTestConfigsV2():
1113  """Returns the config combinations to run the test."""
1114  convert_offline = False
1115  # TODO(laigd): add support for static_engine.
1116  dynamic_engine = True
1117  # TODO(laigd): add support for calibration.
1118  no_calibration = False
1119  use_calibration = True
1120
1121  # Add all possible test cases and let the derived test class to decide
1122  # whether to run specific ones with ShouldRunTest().
1123  #
1124  # Note:
1125  # - In TF2.0 the conversion always produce dynamic engine, and we don't test
1126  #   the offline mode here.
1127  # - For simplicity we don't test online conversion which requires setting the
1128  #   Grappler config in default eager context.
1129  # - INT8 without calibration behaves like FP32/FP16.
1130  opts = list(
1131      itertools.product([FP32, FP16], [convert_offline], [dynamic_engine],
1132                        [no_calibration], [False, True]))
1133  # We always run calibration with offline tool.
1134  opts.append((INT8, convert_offline, dynamic_engine, use_calibration, False))
1135  opts.append((INT8, convert_offline, dynamic_engine, use_calibration, True))
1136  return opts
1137
1138
1139def _GetTest(run_params):
1140  """Gets a single test method based on the parameters."""
1141
1142  def _Test(self):
1143    logging.info(f"Running test `{run_params.test_name}` with parameters: "
1144                 f"convert_online={run_params.convert_online}, "
1145                 f"precision_mode={run_params.precision_mode}, "
1146                 f"dynamic_engine={run_params.dynamic_engine}, "
1147                 f"dynamic_shape={run_params.dynamic_shape}")
1148    self.RunTest(run_params)
1149
1150  return _Test
1151
1152
1153def _AddTestsFor(test_class, is_v2):
1154  """Adds test methods to TfTrtIntegrationTestBase for specific TF version."""
1155  opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1()
1156  for (precision_mode, convert_online, dynamic_engine, use_calibration,
1157       dynamic_shape) in opts:
1158    conversion = "OnlineConversion" if convert_online else "OfflineConversion"
1159    engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine"
1160    calibration_type = "UseCalibration" if use_calibration else "NoCalibration"
1161    dynamic_shape_type = "DynamicShape" if dynamic_shape else "ImplicitBatch"
1162    test_name = "%s_%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt",
1163                                       conversion, engine_type, precision_mode,
1164                                       calibration_type, dynamic_shape_type)
1165    run_params = RunParams(
1166        convert_online=convert_online,
1167        precision_mode=precision_mode,
1168        dynamic_engine=dynamic_engine,
1169        test_name=test_name,
1170        use_calibration=use_calibration,
1171        is_v2=is_v2,
1172        dynamic_shape=dynamic_shape)
1173    if is_v2:
1174      setattr(test_class, test_name,
1175              test_util.run_v2_only(_GetTest(run_params)))
1176    else:
1177      setattr(test_class, test_name,
1178              test_util.run_v1_only("", _GetTest(run_params)))
1179
1180
1181def _AddTests(test_class):
1182  """Adds test methods to TfTrtIntegrationTestBase."""
1183  _AddTestsFor(test_class, is_v2=False)
1184  _AddTestsFor(test_class, is_v2=True)
1185
1186
1187if is_tensorrt_enabled():
1188  os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
1189  _AddTests(TfTrtIntegrationTestBase)
1190