1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17import collections 18import errno 19import gc 20import itertools 21import os 22import re 23import shutil 24import tempfile 25import warnings 26 27import numpy as np 28 29from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled 30from tensorflow.core.framework import graph_pb2 31from tensorflow.core.protobuf import config_pb2 32from tensorflow.core.protobuf import rewriter_config_pb2 33from tensorflow.python.compiler.tensorrt import trt_convert 34from tensorflow.python.compiler.tensorrt import utils as trt_utils 35from tensorflow.python.eager import def_function 36from tensorflow.python.framework import graph_io 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import test_util 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.profiler import trace 44from tensorflow.python.saved_model import builder 45from tensorflow.python.saved_model import load 46from tensorflow.python.saved_model import loader 47from tensorflow.python.saved_model import save 48from tensorflow.python.saved_model import signature_constants 49from tensorflow.python.saved_model import signature_def_utils 50from tensorflow.python.saved_model import tag_constants 51from tensorflow.python.saved_model import utils 52from tensorflow.python.tools import saved_model_utils 53from tensorflow.python.trackable import autotrackable 54from tensorflow.python.util import nest 55 56logging.get_logger().propagate = False 57 58TfTrtIntegrationTestParams = collections.namedtuple( 59 "TfTrtIntegrationTestParams", 60 [ 61 # A function that creates the TF graph for testing. 62 "graph_fn", 63 # A list of specifications for input tensors. 64 "input_specs", 65 # A list of specifications for output tensors. 66 "output_specs", 67 # A list of list of input shapes. Each shape must match the 68 # corresponding element in `input_specs`. 69 "input_dims", 70 # A list of list of expected output shapes. Each shape must match the 71 # corresponding element in `output_specs`. 72 "expected_output_dims" 73 ]) 74 75RunParams = collections.namedtuple( 76 "RunParams", 77 [ 78 # Whether to run the conversion online with RewriterConfig, or offline 79 # with TrtGraphConverter. 80 "convert_online", 81 "precision_mode", 82 "dynamic_engine", 83 "use_calibration", 84 "test_name", 85 # Is this test for TF 2.0? 86 "is_v2", 87 "dynamic_shape", 88 ]) 89 90FP32 = "FP32" 91FP16 = "FP16" 92INT8 = "INT8" 93PRECISION_MODES = [FP32, FP16, INT8] 94 95 96def IsQuantizationMode(mode): 97 return mode == "INT8" 98 99 100def IsQuantizationWithCalibration(params): 101 return IsQuantizationMode(params.precision_mode) and params.use_calibration 102 103 104def IsQuantizationWithoutCalibration(params): 105 return IsQuantizationMode( 106 params.precision_mode) and not params.use_calibration 107 108 109class GraphState(object): 110 ORIGINAL = 0 111 CALIBRATE = 1 112 INFERENCE = 2 113 114 115class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): 116 """Class to test Tensorflow-TensorRT integration.""" 117 118 @property 119 def trt_incompatible_op(self): 120 return math_ops.erfc 121 122 @property 123 def trt_incompatible_binary_op(self): 124 return math_ops.igamma 125 126 @property 127 def precision_modes(self): 128 return ["FP32", "FP16", "INT8"] 129 130 # str is bytes in py2, but unicode in py3. 131 def _ToUnicode(self, s): 132 if isinstance(s, str): 133 return s 134 return s.decode("utf-8") 135 136 def _ToBytes(self, s): 137 if isinstance(s, str): 138 return s.encode("utf-8") 139 return s 140 141 def _ToString(self, s): 142 if isinstance(s, str): 143 return s 144 return s.decode("utf-8") 145 146 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 147 super(TfTrtIntegrationTestBase, self).__init__(methodName) 148 self._trt_test_params = None 149 self._disable_non_trt_optimizers = False 150 self._profile_strategy = "ImplicitBatchModeCompatible" 151 152 def setUp(self): 153 """Setup method.""" 154 super().setUp() 155 warnings.simplefilter("always") 156 157 if not is_tensorrt_enabled(): 158 self.skipTest("Test requires TensorRT") 159 160 def tearDown(self): 161 """Making sure to clean artifact.""" 162 idx = 0 163 while gc.garbage: 164 gc.collect() # Force GC to destroy the TRT engine cache. 165 idx += 1 166 if idx >= 10: # After 10 iterations, break to avoid infinite collect. 167 break 168 169 def _GetTensorSpec(self, shape, mask, dtype, name): 170 # Set dimension i to None if mask[i] == False 171 assert len(shape) == len(mask), ( 172 f"len(shape): {len(shape)} == len(mask): {len(mask)}") 173 174 new_shape = [s if m else None for s, m in zip(shape, mask)] 175 return tensor_spec.TensorSpec(new_shape, dtype, name) 176 177 def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes): 178 """Build test parameters. 179 180 The input_shapes and output_shapes arguments are known (static) shapes that 181 can be used to generate test data. To define the model, we also specify 182 corresponding input/output TensorSpecs. These are defined using the shape 183 arguments. For each input tensor we define: 184 185 input_spec = [None] + input_shape[1:] 186 187 and similarly for output shapes. This means that we leave the first (batch) 188 dimension unknown, the rest is just copied from the shapes arg. 189 190 Args: 191 graph_fn: The function to build the graph. 192 dtype: The element type. 193 input_shapes: The input shapes. 194 output_shapes: The output shapes. 195 196 Returns: 197 The test parameters. 198 """ 199 200 input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes] 201 output_mask = [[False] + [True] * (len(shape) - 1) if shape else [] 202 for shape in output_shapes] 203 204 return self.BuildParamsWithMask(graph_fn, dtype, input_shapes, 205 output_shapes, input_mask, output_mask, [], 206 []) 207 208 def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes, 209 input_mask, output_mask, extra_inputs, extra_outputs): 210 """Build test parameters with static or dynamic input shapes. 211 212 To define dynamic shapes give a boolean mask that describes which 213 dimensions to treat as known. The values in input_mask are interpreted the 214 following way: 215 - True: known dim (use the corresponding value from input_shapes) 216 - False: unknown dim (replace the corresponding value from input_shapes 217 with None) 218 For example, to define the first two dimension with unknown size use 219 input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]]. 220 221 Args: 222 graph_fn: The function to build the graph. 223 dtype: The element type. 224 input_shapes: The input shapes. 225 output_shapes: The output shapes. 226 input_mask: The input shape masks. 227 output_mask: the output shape masks. 228 extra_inputs: list of additional input shapes 229 extra_outputs: list of additional outputs shapes 230 231 Returns: 232 The test parameters. 233 """ 234 235 def _ValidateShapes(shapes): 236 # Make sure all the shapes are fully specified. 237 for shape in shapes: 238 assert all(shape), f"Shape unspecified: {shape}" 239 240 _ValidateShapes(input_shapes) 241 _ValidateShapes(output_shapes) 242 243 assert len(input_mask) == len(input_shapes), ( 244 f"Inconsistent input_mask and input_shapes: len({input_mask}) != " 245 f"len({input_shapes}).") 246 assert len(output_mask) == len(output_shapes), ( 247 f"Inconsistent output_mask and output_shapes: len({output_mask}) != " 248 f"len({output_shapes}).") 249 for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs): 250 assert len(input_shapes) == len(extra_in_shape), ( 251 f"Inconsistent input_shapes and extra_in_shape: len({input_shapes}) " 252 f"!= len({extra_in_shape}).") 253 assert len(output_shapes) == len(extra_out_shape), ( 254 f"Inconsistent output_shapes and extra_out_shape: " 255 f"len({output_shapes}) != len({extra_out_shape}).") 256 257 return TfTrtIntegrationTestParams( 258 graph_fn=graph_fn, 259 input_specs=[ 260 self._GetTensorSpec(shape, mask, dtype, "input_%d" % i) 261 for i, (shape, mask) in enumerate(zip(input_shapes, input_mask)) 262 ], 263 output_specs=[ 264 self._GetTensorSpec(shape, mask, dtype, "output_%d" % i) 265 for i, (shape, mask) in enumerate(zip(output_shapes, output_mask)) 266 ], 267 input_dims=[input_shapes] + extra_inputs, 268 expected_output_dims=[output_shapes] + extra_outputs) 269 270 def DisableNonTrtOptimizers(self): 271 self._disable_non_trt_optimizers = True 272 273 def GetParams(self): 274 """Returns a TfTrtIntegrationTestParams for the test.""" 275 raise NotImplementedError() 276 277 def GetConversionParams(self, run_params): 278 """Returns a TrtConversionParams for test.""" 279 conversion_params = trt_convert.TrtConversionParams( 280 # We use the minimum of all the batch sizes, so when multiple different 281 # input shapes are provided it'll always create new engines in the 282 # cache, and we can therefore test the cache behavior. 283 max_workspace_size_bytes=( 284 trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES), 285 precision_mode=run_params.precision_mode, 286 minimum_segment_size=2, 287 maximum_cached_engines=1, 288 use_calibration=run_params.use_calibration) 289 return conversion_params 290 291 def GetMaxBatchSize(self, run_params): 292 """Returns the max_batch_size that the converter should use for tests.""" 293 if run_params.dynamic_engine: 294 return None 295 batch_list = [] 296 for dims_list in self._GetParamsCached().input_dims: 297 assert dims_list, f"Expect non-empty `dim_list` but got: {dims_list}" 298 # Each list of shapes should have same batch size. 299 input_batches = [dims[0] for dims in dims_list] 300 assert max(input_batches) == min(input_batches), ( 301 f"Inconsistent batch_size: max({input_batches}) != " 302 f"min({input_batches}).") 303 batch_list.append(input_batches[0]) 304 return max(batch_list) 305 306 def ShouldRunTest(self, run_params): 307 """Whether to run the test.""" 308 # Ensure use_calibration=True in case of INT8 precision 309 return (run_params.use_calibration or not IsQuantizationMode( 310 run_params.precision_mode)), "test either calibration or non-INT8" 311 312 def ExpectedEnginesToBuild(self, run_params): 313 """Returns the expected engines to build, implemented by subclass.""" 314 raise NotImplementedError() 315 316 def ExpectedConnections(self, run_params): 317 """Returns the expected edges or an empty dict to skip the check.""" 318 return {} 319 320 def ExpectedMaxBatchSizes(self, run_params): 321 """Returns the expected maximum batch sizes of the build engines.""" 322 return None 323 324 def ExpectedAbsoluteTolerance(self, run_params): 325 """The absolute tolerance to compare floating point results.""" 326 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 327 328 def ExpectedRelativeTolerance(self, run_params): 329 """The relative tolerance to compare floating point results.""" 330 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 331 332 def _GetParamsCached(self): 333 if self._trt_test_params is None: 334 self._trt_test_params = self.GetParams() 335 return self._trt_test_params 336 337 def _GetGPUOptions(self): 338 gpu_options = config_pb2.GPUOptions() 339 gpu_options.allow_growth = True 340 return gpu_options 341 342 def _GetConfigProto(self, run_params, graph_state): 343 """Get config proto based on specific settings.""" 344 conversion_params = self.GetConversionParams(run_params) 345 max_batch_size = self.GetMaxBatchSize(run_params) 346 347 if graph_state == GraphState.INFERENCE and run_params.convert_online: 348 rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( 349 conversion_params, 350 is_dynamic_op=run_params.dynamic_engine, 351 max_batch_size=max_batch_size, 352 disable_non_trt_optimizers=self._disable_non_trt_optimizers) 353 else: 354 rewriter_cfg = rewriter_config_pb2.RewriterConfig() 355 if self._disable_non_trt_optimizers: 356 trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg) 357 358 config = config_pb2.ConfigProto( 359 gpu_options=self._GetGPUOptions(), 360 graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg)) 361 return config 362 363 def _GetFeedNames(self): 364 params = self._GetParamsCached() 365 # Construct the feeds tensor names by appending :0 to the node names. 366 return [spec.name + ":0" for spec in params.input_specs] 367 368 def _GetFetchNames(self): 369 params = self._GetParamsCached() 370 # Construct the fetches tensor names by appending :0 to the node names. 371 return [spec.name + ":0" for spec in params.output_specs] 372 373 def _GetFeedDict(self, inputs_data): 374 return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)} 375 376 def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2): 377 """Run given graphdef multiple times using TF 1.x runtime.""" 378 params = self._GetParamsCached() 379 fetches = self._GetFetchNames() 380 g = ops.Graph() 381 with g.as_default(): 382 with self.session(graph=g, config=config, use_gpu=True) as sess: 383 loader.load(sess, [tag_constants.SERVING], saved_model_dir) 384 vals = [] 385 # Run for each input(s) shape 386 for expected_shapes, current_input_data in zip( 387 params.expected_output_dims, inputs_data): 388 val = None 389 for _ in range(num_runs): 390 new_val = sess.run(fetches, self._GetFeedDict(current_input_data)) 391 self.assertEqual(len(expected_shapes), len(new_val)) 392 for expected_shape, actual_val in zip(expected_shapes, new_val): 393 self.assertEqual(list(expected_shape), list(actual_val.shape)) 394 if val is not None: 395 # Some ops may have nondeterministic output. E.g. Conv2D may use 396 # winograd algorithm. So we set atol/rtol be larger than 1.e-06. 397 self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) 398 val = new_val 399 vals.append(val) 400 return vals 401 402 def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2): 403 """Run given graphdef multiple times using TF 2.0 runtime.""" 404 params = self._GetParamsCached() 405 root = load.load(saved_model_dir) 406 func = root.signatures[ 407 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 408 results = [] 409 for expected_shapes, current_input_data in zip(params.expected_output_dims, 410 inputs_data): 411 val = None 412 for _ in range(num_runs): 413 feed_dict = { 414 params.input_specs[i].name: current_input_data[i] 415 for i in range(len(params.input_specs)) 416 } 417 new_val = func(**feed_dict) 418 assert isinstance(new_val, dict), ( 419 f"Invalid type for `new_val`, expected `dict`. Got: {type(new_val)}." 420 ) 421 # The key of the output map is always like output_i. 422 new_val = [new_val[key] for key in sorted(new_val)] 423 # Each element is an eager Tensor, and accessing individual elements is 424 # very expensive, so we convert them to a numpy array first. 425 new_val = [v.numpy() for v in new_val] 426 self.assertEqual(len(expected_shapes), len(new_val)) 427 for expected_shape, actual_val in zip(expected_shapes, new_val): 428 self.assertEqual(list(expected_shape), list(actual_val.shape)) 429 if val is not None: 430 # Some ops may have nondeterministic output. E.g. Conv2D may use 431 # winograd algorithm. So we set atol/rtol be larger than 1.e-06. 432 self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) 433 val = new_val 434 results.append(val) 435 436 return results 437 438 def _RunGraph(self, 439 run_params, 440 saved_model_dir, 441 inputs_data, 442 graph_state, 443 num_runs=2): 444 params = self._GetParamsCached() 445 for data in inputs_data: 446 assert len(params.input_specs) == len(data), ( 447 f"Inconsistent params.input_specs and data: " 448 f"len({params.input_specs}) != len({data}).") 449 450 if run_params.is_v2: 451 results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state, 452 num_runs) 453 gc.collect() # Force GC to destroy the TRT engine cache. 454 return results 455 456 # The default config for tf.session is None. Create a config with 457 # TensorRTOptimizer enabled to support convert_online for inference. 458 config = None 459 # TODO(b/170220818): use the default session config to run inferenence 460 # graphs for the offline conversion case after fixing the bug. 461 if graph_state == GraphState.INFERENCE: 462 config = self._GetConfigProto(run_params, GraphState.INFERENCE) 463 return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs) 464 465 def _CreateConverter(self, run_params, saved_model_dir, conversion_params): 466 """Returns a TrtGraphConverter.""" 467 if run_params.is_v2: 468 converter_v2 = trt_convert.TrtGraphConverterV2( 469 input_saved_model_dir=saved_model_dir, 470 use_dynamic_shape=run_params.dynamic_shape, 471 dynamic_shape_profile_strategy=self._profile_strategy, 472 **conversion_params._asdict()) 473 if self._disable_non_trt_optimizers: 474 converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access 475 return converter_v2 476 477 converter_v1 = trt_convert.TrtGraphConverter( 478 input_saved_model_dir=saved_model_dir, 479 max_batch_size=self.GetMaxBatchSize(run_params), 480 max_workspace_size_bytes=conversion_params.max_workspace_size_bytes, 481 precision_mode=conversion_params.precision_mode, 482 minimum_segment_size=conversion_params.minimum_segment_size, 483 is_dynamic_op=run_params.dynamic_engine, 484 maximum_cached_engines=conversion_params.maximum_cached_engines, 485 use_calibration=conversion_params.use_calibration) 486 if self._disable_non_trt_optimizers: 487 converter_v1._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access 488 return converter_v1 489 490 def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data): 491 """Return trt converted graphdef in INT8 mode.""" 492 conversion_params = self.GetConversionParams(run_params) 493 logging.info(conversion_params) 494 assert conversion_params.precision_mode == "INT8", ( 495 f"Incorrect precision mode, expected INT8 but got: " 496 f"{conversion_params.precision_mode}.") 497 assert run_params.dynamic_engine, "dynamic_engine parameter must be True." 498 assert conversion_params.maximum_cached_engines == 1, ( 499 f"maximum_cached_engines: {conversion_params.maximum_cached_engines} == 1" 500 ) 501 assert conversion_params.use_calibration, "use_calibration must be True." 502 503 # We only support calibrating single engine. 504 # TODO(aaroey): fix this. 505 assert len(inputs_data) == 1, (f"len(inputs_data): {len(inputs_data)} == 1") 506 507 converter = self._CreateConverter(run_params, saved_model_dir, 508 conversion_params) 509 if run_params.is_v2: 510 511 def CalibrationInputFn(): 512 for data_tensors in inputs_data: 513 yield data_tensors 514 515 converter.convert(calibration_input_fn=CalibrationInputFn) 516 else: 517 int8_gdef = converter.convert() 518 self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef, 519 GraphState.CALIBRATE) 520 521 converter.calibrate( 522 fetch_names=self._GetFetchNames(), 523 num_runs=5, 524 feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0])) 525 526 if run_params.dynamic_shape and self._ShouldConverterBuild(run_params): 527 logging.info("Using build mode") 528 529 def _BuildInputFn(): 530 for shapes in self._GetParamsCached().input_dims: 531 yield [ 532 array_ops.zeros(x, dtype=spec.dtype) 533 for (x, spec) in zip(shapes, 534 self._GetParamsCached().input_specs) 535 ] 536 537 converter.build(input_fn=_BuildInputFn) 538 539 trt_saved_model_dir = self._GetSavedModelDir(run_params, 540 GraphState.CALIBRATE) 541 converter.save(trt_saved_model_dir) 542 return trt_saved_model_dir 543 544 def _ShouldConverterBuild(self, run_params): 545 return True 546 547 def _GetInferGraph(self, run_params, saved_model_dir): 548 """Return trt converted graphdef.""" 549 conversion_params = self.GetConversionParams(run_params) 550 logging.info(conversion_params) 551 552 converter = self._CreateConverter(run_params, saved_model_dir, 553 conversion_params) 554 converter.convert() 555 556 if run_params.is_v2: 557 try: 558 line_length = max(160, os.get_terminal_size().columns) 559 except OSError: 560 line_length = 160 561 converter.summary(line_length=line_length, detailed=True) 562 563 if run_params.dynamic_shape and self._ShouldConverterBuild(run_params): 564 logging.info("Using build mode") 565 566 def _BuildInputFn(): 567 for shapes in self._GetParamsCached().input_dims: 568 yield [ 569 array_ops.zeros(x, dtype=spec.dtype) 570 for (x, spec) in zip(shapes, 571 self._GetParamsCached().input_specs) 572 ] 573 574 converter.build(input_fn=_BuildInputFn) 575 576 trt_saved_model_dir = self._GetSavedModelDir(run_params, 577 GraphState.INFERENCE) 578 converter.save(trt_saved_model_dir) 579 return trt_saved_model_dir 580 581 def _GetGraphStateLabel(self, graph_state): 582 if graph_state == GraphState.ORIGINAL: 583 return "Original" 584 elif graph_state == GraphState.CALIBRATE: 585 return "CalibEngine" 586 elif graph_state == GraphState.INFERENCE: 587 return "InferEngine" 588 else: 589 return "UnknownState" 590 591 def _WriteGraph(self, run_params, gdef, graph_state): 592 temp_dir = os.getenv("TRT_TEST_TMPDIR") 593 if not temp_dir: 594 return 595 596 graph_name = ( 597 self.__class__.__name__ + "_" + run_params.test_name + "_" + 598 self._GetGraphStateLabel(graph_state) + ".pbtxt") 599 logging.info("Writing graph to %s/%s", temp_dir, graph_name) 600 graph_io.write_graph(gdef, temp_dir, graph_name) 601 602 # Removes the prefix(s) of function name(s). 603 # The input value can be a string or a sequence of string. 604 def _Canonicalize(self, value): 605 if isinstance(value, str): 606 return self._ToString(value.split("/")[-1]) 607 elif isinstance(value, collections.abc.Iterable): 608 return set(self._Canonicalize(nm) for nm in value) 609 else: 610 raise TypeError( 611 "'_Canonicalize' can only be used on strings or sequence of strings!") 612 613 # Removes the graph sequence number prefix from the name(s) only if the 614 # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts 615 # such a prefix exists. 616 # The input value can be a string or a sequence of string. 617 def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix): 618 if isinstance(value, str): 619 match = re.search(r"TRTEngineOp_\d{3,}_", value) 620 has_prefix = match and value.startswith(match.group(0)) 621 assert (not expecting_prefix) or has_prefix, ( 622 f"Expect (not expecting_prefix) or has_prefix but got: " 623 f"- expecting_prefix = {expecting_prefix}\n- has_prefix = {has_prefix}" 624 ) 625 if has_prefix: 626 parts = value.split("_", maxsplit=2) 627 assert len(parts) == 3, ( 628 f"Incorrect `parts` of length == 3, but got: len({parts}).") 629 return parts[0] + "_" + parts[2] 630 return value 631 elif isinstance(value, collections.abc.Iterable): 632 return set( 633 self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix) 634 for nm in value) 635 else: 636 raise TypeError( 637 "'_RemoveGraphSequenceNumberImpl' can only be used on strings " 638 "or sequence of strings!") 639 640 def _RemoveGraphSequenceNumber(self, name): 641 return self._RemoveGraphSequenceNumberImpl(name, True) 642 643 def _MayRemoveGraphSequenceNumber(self, name): 644 return self._RemoveGraphSequenceNumberImpl(name, False) 645 646 def _VerifyConnections(self, expected_engines, expected_input_map, 647 original_gdef, converted_gdef): 648 """Checks that the converted graph contains the expected connections.""" 649 old_to_new_node_map = { 650 self._ToString(node.name): self._ToString(node.name) 651 for node in original_gdef.node 652 } 653 for engine_name, node_names in expected_engines.items(): 654 for node_name in node_names: 655 old_to_new_node_map[node_name] = engine_name 656 657 def _InputName(inp): 658 inp = self._ToString(inp) 659 prefix = "" 660 if inp[0] == "^": 661 prefix = "^" 662 inp = inp[1:] 663 parts = inp.split(":") 664 if len(parts) > 1 and parts[-1].isdigit(): 665 inp = inp[:-len(parts[-1]) - 1] 666 return (prefix, inp) 667 668 # Compute the actual mapping from each node to its input nodes. If a cast 669 # op doesn't exist in the original graph, we replace the use of the cast op 670 # with the input of the op. This allows the verification to handle the case 671 # where the TF-TRT bridge splits a cast op into a chain of two cast ops. 672 new_cast_op_name_to_node_map = { 673 node.name: node 674 for node in converted_gdef.node 675 if (node.name not in old_to_new_node_map and node.op == "Cast") 676 } 677 actual_input_map = {} 678 for node in converted_gdef.node: 679 name_str = node.name 680 # Only nodes from the original graph or TRTEngineOp nodes are added as 681 # keys to the map. 682 if node.op == "TRTEngineOp": 683 name_str = self._RemoveGraphSequenceNumber(name_str) 684 elif name_str not in old_to_new_node_map: 685 continue 686 actual_input_map[name_str] = set() 687 input_set = actual_input_map[name_str] 688 for inp in node.input: 689 (prefix, node_name) = _InputName(inp) 690 node_name = self._MayRemoveGraphSequenceNumber(node_name) 691 if node_name in new_cast_op_name_to_node_map: 692 (prefix, node_name) = _InputName( 693 new_cast_op_name_to_node_map[node_name].input[0]) 694 input_set.add(prefix + node_name) 695 696 self.assertEqual( 697 expected_input_map, 698 actual_input_map, 699 msg="\nexpected:\n%s\nvs actual:\n%s" % 700 (sorted(expected_input_map.items()), sorted(actual_input_map.items()))) 701 702 def _VerifyMaxBatchSizeAnnotations( 703 self, 704 expected_engines, 705 original_gdef, 706 converted_gdef, 707 default_max_batch_size, 708 expected_max_batch_sizes=None, 709 ): 710 """Verifies the max batch size annotations in the original and converted GraphDef. 711 712 Args: 713 expected_engines: A sequence of engines names. 714 original_gdef: GraphDef. The graph def before TensorRT conversion. 715 converted_gdef: GraphDef. The graph def after TensorRT conversion. 716 default_max_batch_size: The default maximum batch size to use if no node 717 inside a segment is annoted with a customized max batch size. This value 718 is None when the graph is converted to TF-TRT with dynamic engines. 719 expected_max_batch_sizes: Optional. A sequence of max batch sizes for all 720 the engines. `None` if does not check enforce max batch sizes. 721 """ 722 if isinstance(expected_max_batch_sizes, collections.abc.Collection): 723 self.assertEqual(len(expected_max_batch_sizes), len(expected_engines)) 724 else: 725 self.assertIsNone( 726 expected_max_batch_sizes, 727 "'expected_max_batch_sizes' shall only be a sequence " 728 "of integers or `None`.") 729 730 def _ChainAllNodes(graph_def): 731 return itertools.chain( 732 graph_def.node, 733 itertools.chain( 734 *[func.node_def for func in graph_def.library.function])) 735 736 old_name_to_node_map = { 737 self._ToString(node.name): node 738 for node in _ChainAllNodes(original_gdef) 739 } 740 new_name_to_func_map = { 741 self._ToString(func.signature.name): func 742 for func in converted_gdef.library.function 743 } 744 745 def _DetectStaticBatchSize(node_def): 746 """Returns the static batch size of an operation or None. 747 748 It is incorrect to use the output shapes to find the batch size of an 749 operation, as the segmenter actually uses the input shapes. However, it is 750 a simplication and works for most of the cases for the test purposes. 751 752 Args: 753 node_def: `tf.NodeDef`. The target node for analysis. 754 755 Returns: 756 If all the outputs of the node have the same static batch size, returns 757 the int value for the batch size. Otherwise returns None. 758 """ 759 shapes = node_def.attr["_output_shapes"].list.shape 760 batch_size = set( 761 list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes) 762 if len(batch_size) == 1 and list(batch_size)[0] >= 1: 763 return list(batch_size)[0] 764 return None 765 766 name_to_engines_map = {} 767 actual_max_batch_sizes = [] 768 for node in _ChainAllNodes(converted_gdef): 769 if node.op == "TRTEngineOp": 770 engine = node 771 engine_name = self._RemoveGraphSequenceNumber( 772 self._Canonicalize(self._ToString(engine.name))) 773 self.assertIn(engine_name, expected_engines) 774 name_to_engines_map[engine_name] = engine 775 # The input nodes shall not have the conflicting annotation (no 776 # annotation or the same annotation) with the maximum batch size 777 # annotation. If the engine has maximum batch size annotation as the 778 # non-default maximum batch size, then at least one input node shall 779 # have the same annotation to be the source. 780 self.assertIn("max_batch_size", node.attr) 781 engine_max_batch_size = node.attr["max_batch_size"].i 782 self.assertIsInstance(engine_max_batch_size, int) 783 actual_max_batch_sizes.append(engine_max_batch_size) 784 seg_func = node.attr["segment_func"].func 785 self.assertIsNotNone(seg_func) 786 self.assertIn(seg_func.name, new_name_to_func_map) 787 seg_func_def = new_name_to_func_map[seg_func.name] 788 logging.info("Segment function name: %s. Including %d nodes.", 789 seg_func.name, len(seg_func_def.node_def)) 790 node_max_batch_size_all_none = True 791 # Use the native segment to search for replaced nodes 792 for alternative_node in seg_func_def.node_def: 793 node_name = self._Canonicalize(self._ToString(alternative_node.name)) 794 if node_name not in old_name_to_node_map: 795 continue 796 original_node = old_name_to_node_map[node_name] 797 node_max_batch_size = None 798 if "_tftrt_op_max_batch_size" in original_node.attr: 799 node_max_batch_size = original_node.attr[ 800 "_tftrt_op_max_batch_size"].i 801 elif (original_node.op != "Const" and 802 alternative_node.op != "Const" and 803 "_output_shapes" in original_node.attr): 804 node_max_batch_size = _DetectStaticBatchSize(original_node) 805 logging.info( 806 "'{%s}(%s)'s max batch size annotation is %s. " 807 "'{%s}'s max batch size is %s.", node_name, original_node.op, 808 str(node_max_batch_size), engine_name, str(engine_max_batch_size)) 809 node_max_batch_size_all_none &= node_max_batch_size is None 810 self.assertTrue(engine_max_batch_size == node_max_batch_size or 811 node_max_batch_size is None) 812 logging.info("'{%s}'s max batch size is %d.", engine_name, 813 engine_max_batch_size) 814 self.assertTrue(default_max_batch_size is None or 815 engine_max_batch_size == default_max_batch_size or 816 not node_max_batch_size_all_none) 817 818 self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys())) 819 if expected_max_batch_sizes is not None: 820 self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes) 821 822 def _GetGraphDef(self, run_params, gdef_or_saved_model_dir): 823 if isinstance(gdef_or_saved_model_dir, str): 824 if run_params.is_v2: 825 root = load.load(gdef_or_saved_model_dir) 826 func = root.signatures[ 827 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 828 gdef = func.graph.as_graph_def() 829 # Manually unref the loaded saved model and force GC to destroy the TRT 830 # engine cache after load(). There is currently a reference cycle in 2.0 831 # which prevents auto deletion of the resource. 832 # TODO(laigd): fix this. 833 del func 834 del root 835 gc.collect() 836 return gdef 837 return saved_model_utils.get_meta_graph_def( 838 gdef_or_saved_model_dir, tag_constants.SERVING).graph_def 839 assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef), ( 840 f"Incorrect `gdef_or_saved_model_dir` type, expected " 841 f"`graph_pb2.GraphDef`, but got: {type(gdef_or_saved_model_dir)}.") 842 return gdef_or_saved_model_dir 843 844 def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify, 845 graph_state): 846 expected_engines = self.ExpectedEnginesToBuild(run_params) 847 num_engines = 0 848 functions = [f.signature.name for f in gdef_to_verify.library.function] 849 for node in gdef_to_verify.node: 850 if node.op == "TRTEngineOp": 851 logging.info("Found TRTEngineOp: " + node.name) 852 num_engines += 1 853 segment_funcdef_name = node.attr["segment_func"].func.name 854 function_name = node.name + "_native_segment" 855 is_dynamic_engine = not node.attr["static_engine"].b 856 self.assertNotEmpty(segment_funcdef_name, node.name) 857 self.assertIn(function_name, functions) 858 if (not IsQuantizationWithCalibration(run_params) and 859 not is_dynamic_engine): 860 self.assertTrue(len(node.attr["serialized_segment"].s), node.name) 861 self.assertIn( 862 self._RemoveGraphSequenceNumber(node.name), expected_engines) 863 if IsQuantizationWithoutCalibration(run_params): 864 # TODO(bixia): Refine this check by inspecting nodes in the engine. 865 if self._ToBytes("INT8") != node.attr["precision_mode"].s: 866 self.assertEqual( 867 self._ToBytes("FP16"), node.attr["precision_mode"].s, node.name) 868 else: 869 self.assertEqual( 870 self._ToBytes(run_params.precision_mode), 871 node.attr["precision_mode"].s, node.name) 872 873 self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, 874 node.name) 875 self.assertEqual(node.attr["use_calibration"].b, 876 run_params.use_calibration, node.name) 877 878 has_calibration_data = len(node.attr["calibration_data"].s) 879 if (IsQuantizationWithCalibration(run_params) and 880 graph_state == GraphState.INFERENCE): 881 self.assertTrue(has_calibration_data, node.name) 882 else: 883 self.assertFalse(has_calibration_data, node.name) 884 if graph_state == GraphState.ORIGINAL: 885 self.assertEqual(0, num_engines) 886 else: 887 self.assertEqual(num_engines, len(expected_engines)) 888 expected_connections = self.ExpectedConnections(run_params) 889 if expected_connections: 890 self._VerifyConnections(expected_engines, expected_connections, 891 original_gdef, gdef_to_verify) 892 self._VerifyMaxBatchSizeAnnotations( 893 expected_engines=expected_engines, 894 original_gdef=original_gdef, 895 converted_gdef=gdef_to_verify, 896 expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), 897 default_max_batch_size=self.GetMaxBatchSize(run_params)) 898 899 def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify, 900 graph_state): 901 if graph_state == GraphState.ORIGINAL: 902 return 903 expected_engines = self.ExpectedEnginesToBuild(run_params) 904 all_op_names = [node.name for node in gdef_to_verify.node] 905 trt_op_names = [] 906 for func in gdef_to_verify.library.function: 907 if not re.search(r"TRTEngineOp_\d{3,}_\d{3,}_native_segment", 908 func.signature.name): 909 for node in func.node_def: 910 all_op_names.append(node.name) 911 if node.op == "TRTEngineOp": 912 trt_op_names.append(node.name) 913 if run_params.dynamic_shape: 914 self.assertEqual( 915 self._ToString(node.attr["profile_strategy"].s).lower(), 916 self._profile_strategy.lower()) 917 918 all_op_names = self._Canonicalize(all_op_names) 919 trt_op_names = self._RemoveGraphSequenceNumber( 920 self._Canonicalize(trt_op_names)) 921 922 if isinstance(expected_engines, dict): 923 # For simplicity we don't verify the connections inside the engine in 924 # 2.0, but we still make sure that the converted ops are gone from the 925 # graph. 926 unexpected_names = set(nest.flatten(expected_engines.values())) 927 self.assertEmpty( 928 [name for name in unexpected_names if name in all_op_names]) 929 expected_engines = set(expected_engines.keys()) 930 931 self.assertEqual(set(expected_engines), trt_op_names) 932 933 def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir, 934 gdef_or_saved_model_dir_to_verify, graph_state): 935 original_gdef = self._GetGraphDef(run_params, 936 original_gdef_or_saved_model_dir) 937 gdef_to_verify = self._GetGraphDef(run_params, 938 gdef_or_saved_model_dir_to_verify) 939 self._WriteGraph(run_params, gdef_to_verify, graph_state) 940 if run_params.is_v2: 941 self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify, 942 graph_state) 943 else: 944 self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify, 945 graph_state) 946 947 def _GetSavedModelDir(self, run_params, graph_state): 948 test_tmpdir = os.getenv("TRT_TEST_TMPDIR") 949 if test_tmpdir: 950 saved_model_dir = os.path.join( 951 test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name + 952 "_" + self._GetGraphStateLabel(graph_state)) 953 try: 954 # For TF 1.x we need to make sure the output directory doesn't exist 955 # before exporting the saved model. 956 shutil.rmtree(saved_model_dir) 957 except OSError as e: 958 if e.errno != errno.ENOENT: 959 raise 960 return saved_model_dir 961 return tempfile.mkdtemp(dir=self.get_temp_dir()) 962 963 def _MakeSavedModelV1(self, run_params): 964 """Write the saved model as an input for testing.""" 965 params = self._GetParamsCached() 966 g = ops.Graph() 967 with g.as_default(): 968 inputs = [] 969 for spec in params.input_specs: 970 inp = array_ops.placeholder( 971 dtype=spec.dtype, shape=spec.shape, name=spec.name) 972 inputs.append(inp) 973 outputs = params.graph_fn(*inputs) 974 if not isinstance(outputs, list) and not isinstance(outputs, tuple): 975 outputs = [outputs] 976 977 signature_def = signature_def_utils.build_signature_def( 978 inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs}, 979 outputs={out.op.name: utils.build_tensor_info(out) for out in outputs}, 980 method_name=signature_constants.PREDICT_METHOD_NAME) 981 982 saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) 983 saved_model_builder = builder.SavedModelBuilder(saved_model_dir) 984 with self.session( 985 graph=g, config=self._GetConfigProto(run_params, 986 GraphState.ORIGINAL)) as sess: 987 saved_model_builder.add_meta_graph_and_variables( 988 sess, [tag_constants.SERVING], 989 signature_def_map={ 990 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 991 signature_def 992 }) 993 saved_model_builder.save() 994 return saved_model_dir 995 996 def _MakeSavedModelV2(self, run_params): 997 params = self._GetParamsCached() 998 root = autotrackable.AutoTrackable() 999 root.run = def_function.function( 1000 params.graph_fn, input_signature=params.input_specs) 1001 saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) 1002 logging.info("Saving input SavedModel to %s", saved_model_dir) 1003 save.save(root, saved_model_dir, 1004 {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run}) 1005 return saved_model_dir 1006 1007 def _MakeSavedModel(self, run_params): 1008 if run_params.is_v2: 1009 return self._MakeSavedModelV2(run_params) 1010 return self._MakeSavedModelV1(run_params) 1011 1012 def RunTest(self, run_params): 1013 with trace.Trace(run_params.test_name): 1014 should_run, reason_for_skipping = self.ShouldRunTest(run_params) 1015 if not should_run: 1016 return self.skipTest(reason_for_skipping) 1017 1018 saved_model_dir = self._MakeSavedModel(run_params) 1019 1020 np.random.seed(12345) # Fix the seed so the test is deterministic. 1021 inputs_data = [] 1022 input_specs = self._GetParamsCached().input_specs 1023 for dim_list in self._GetParamsCached().input_dims: 1024 assert len(input_specs) == len(dim_list), ( 1025 f"Inconsistent input_specs and dim_list: len({input_specs}) != " 1026 f"len({dim_list}).") 1027 current_input_data = [] 1028 for spec, np_shape in zip(input_specs, dim_list): 1029 np_dtype = spec.dtype.as_numpy_dtype() 1030 # Multiply the input by some constant to avoid all zeros input for 1031 # integer types. 1032 scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0 1033 # TODO(laigd): add debug options. E.g. we can set the input data to be 1034 # continuous natural numbers: 1035 # seq = np.arange(np.prod(np_shape)) 1036 # seq.resize(np_shape) 1037 # current_inputs_data.append(scale * seq.astype(np_dtype)) 1038 data = (scale * np.random.random_sample(np_shape)).astype(np_dtype) 1039 if run_params.is_v2: 1040 with ops.device("/GPU:0"): 1041 data = ops.convert_to_tensor(data) 1042 current_input_data.append(data) 1043 inputs_data.append(current_input_data) 1044 1045 # Verify the original graph. 1046 self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir, 1047 GraphState.ORIGINAL) 1048 1049 # Run the original graph without TensorRT to get the reference result. 1050 logging.info("Running original graph w/o TensorRT\n") 1051 ref_result = self._RunGraph( 1052 run_params, 1053 saved_model_dir, 1054 inputs_data, 1055 GraphState.ORIGINAL, 1056 num_runs=1) 1057 1058 # Run calibration if necessary. 1059 if IsQuantizationWithCalibration(run_params): 1060 infer_saved_model_dir = self._GetCalibratedInferGraph( 1061 run_params, saved_model_dir, inputs_data) 1062 self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, 1063 GraphState.INFERENCE) 1064 elif not run_params.convert_online: 1065 infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir) 1066 self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, 1067 GraphState.INFERENCE) 1068 else: 1069 infer_saved_model_dir = saved_model_dir 1070 1071 # Run the inference graph, either using the converted graph or the 1072 # original graph with convert_online == True. 1073 logging.info("Running final inference graph\n") 1074 result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data, 1075 GraphState.INFERENCE) 1076 self.assertAllClose( 1077 ref_result, 1078 result, 1079 atol=self.ExpectedAbsoluteTolerance(run_params), 1080 rtol=self.ExpectedRelativeTolerance(run_params)) 1081 1082 def testIdempotence(self): 1083 # Test that applying tensorrt optimizer or offline conversion tools multiple 1084 # times to the same graph will result in same graph. 1085 # 1086 # TODO(aaroey): implement this. 1087 pass 1088 1089 1090def _GetTestConfigsV1(): 1091 """Returns the config combinations to run the test.""" 1092 convert_online, convert_offline = True, False 1093 dynamic_engine, static_engine = True, False 1094 use_calibration, no_calibration = True, False 1095 implicit_batch = False 1096 1097 # Add all possible test cases and let the derived test class to decide 1098 # whether to run specific ones with ShouldRunTest(). 1099 # 1100 # Note: INT8 without calibration behaves like FP32/FP16. 1101 opts = list( 1102 itertools.product([FP32, FP16, INT8], [convert_online, convert_offline], 1103 [dynamic_engine, static_engine], [no_calibration], 1104 [implicit_batch])) 1105 # We always run calibration with offline tool. 1106 # TODO(aaroey): static calibration engine is not supported yet. 1107 opts.append( 1108 (INT8, convert_offline, dynamic_engine, use_calibration, implicit_batch)) 1109 return opts 1110 1111 1112def _GetTestConfigsV2(): 1113 """Returns the config combinations to run the test.""" 1114 convert_offline = False 1115 # TODO(laigd): add support for static_engine. 1116 dynamic_engine = True 1117 # TODO(laigd): add support for calibration. 1118 no_calibration = False 1119 use_calibration = True 1120 1121 # Add all possible test cases and let the derived test class to decide 1122 # whether to run specific ones with ShouldRunTest(). 1123 # 1124 # Note: 1125 # - In TF2.0 the conversion always produce dynamic engine, and we don't test 1126 # the offline mode here. 1127 # - For simplicity we don't test online conversion which requires setting the 1128 # Grappler config in default eager context. 1129 # - INT8 without calibration behaves like FP32/FP16. 1130 opts = list( 1131 itertools.product([FP32, FP16], [convert_offline], [dynamic_engine], 1132 [no_calibration], [False, True])) 1133 # We always run calibration with offline tool. 1134 opts.append((INT8, convert_offline, dynamic_engine, use_calibration, False)) 1135 opts.append((INT8, convert_offline, dynamic_engine, use_calibration, True)) 1136 return opts 1137 1138 1139def _GetTest(run_params): 1140 """Gets a single test method based on the parameters.""" 1141 1142 def _Test(self): 1143 logging.info(f"Running test `{run_params.test_name}` with parameters: " 1144 f"convert_online={run_params.convert_online}, " 1145 f"precision_mode={run_params.precision_mode}, " 1146 f"dynamic_engine={run_params.dynamic_engine}, " 1147 f"dynamic_shape={run_params.dynamic_shape}") 1148 self.RunTest(run_params) 1149 1150 return _Test 1151 1152 1153def _AddTestsFor(test_class, is_v2): 1154 """Adds test methods to TfTrtIntegrationTestBase for specific TF version.""" 1155 opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1() 1156 for (precision_mode, convert_online, dynamic_engine, use_calibration, 1157 dynamic_shape) in opts: 1158 conversion = "OnlineConversion" if convert_online else "OfflineConversion" 1159 engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" 1160 calibration_type = "UseCalibration" if use_calibration else "NoCalibration" 1161 dynamic_shape_type = "DynamicShape" if dynamic_shape else "ImplicitBatch" 1162 test_name = "%s_%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt", 1163 conversion, engine_type, precision_mode, 1164 calibration_type, dynamic_shape_type) 1165 run_params = RunParams( 1166 convert_online=convert_online, 1167 precision_mode=precision_mode, 1168 dynamic_engine=dynamic_engine, 1169 test_name=test_name, 1170 use_calibration=use_calibration, 1171 is_v2=is_v2, 1172 dynamic_shape=dynamic_shape) 1173 if is_v2: 1174 setattr(test_class, test_name, 1175 test_util.run_v2_only(_GetTest(run_params))) 1176 else: 1177 setattr(test_class, test_name, 1178 test_util.run_v1_only("", _GetTest(run_params))) 1179 1180 1181def _AddTests(test_class): 1182 """Adds test methods to TfTrtIntegrationTestBase.""" 1183 _AddTestsFor(test_class, is_v2=False) 1184 _AddTestsFor(test_class, is_v2=True) 1185 1186 1187if is_tensorrt_enabled(): 1188 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" 1189 _AddTests(TfTrtIntegrationTestBase) 1190