1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import errno 23import gc 24import itertools 25import os 26import re 27import shutil 28import tempfile 29import warnings 30 31import numpy as np 32import six 33 34from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled 35from tensorflow.core.framework import graph_pb2 36from tensorflow.core.protobuf import config_pb2 37from tensorflow.core.protobuf import rewriter_config_pb2 38from tensorflow.python.compiler.tensorrt import trt_convert 39from tensorflow.python.compiler.tensorrt import utils as trt_utils 40from tensorflow.python.eager import def_function 41from tensorflow.python.framework import graph_io 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import tensor_spec 44from tensorflow.python.framework import test_util 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.profiler import trace 49from tensorflow.python.saved_model import builder 50from tensorflow.python.saved_model import load 51from tensorflow.python.saved_model import loader 52from tensorflow.python.saved_model import save 53from tensorflow.python.saved_model import signature_constants 54from tensorflow.python.saved_model import signature_def_utils 55from tensorflow.python.saved_model import tag_constants 56from tensorflow.python.saved_model import utils 57from tensorflow.python.tools import saved_model_utils 58from tensorflow.python.training.tracking import tracking 59from tensorflow.python.util import nest 60 61TfTrtIntegrationTestParams = collections.namedtuple( 62 "TfTrtIntegrationTestParams", 63 [ 64 # A function that creates the TF graph for testing. 65 "graph_fn", 66 # A list of specifications for input tensors. 67 "input_specs", 68 # A list of specifications for output tensors. 69 "output_specs", 70 # A list of list of input shapes. Each shape must match the 71 # corresponding element in `input_specs`. 72 "input_dims", 73 # A list of list of expected output shapes. Each shape must match the 74 # corresponding element in `output_specs`. 75 "expected_output_dims" 76 ]) 77 78RunParams = collections.namedtuple( 79 "RunParams", 80 [ 81 # Whether to run the conversion online with RewriterConfig, or offline 82 # with TrtGraphConverter. 83 "convert_online", 84 "precision_mode", 85 "dynamic_engine", 86 "use_calibration", 87 "test_name", 88 # Is this test for TF 2.0? 89 "is_v2", 90 ]) 91 92FP32 = "FP32" 93FP16 = "FP16" 94INT8 = "INT8" 95PRECISION_MODES = [FP32, FP16, INT8] 96 97 98def IsQuantizationMode(mode): 99 return mode == "INT8" 100 101 102def IsQuantizationWithCalibration(params): 103 return IsQuantizationMode(params.precision_mode) and params.use_calibration 104 105 106def IsQuantizationWithoutCalibration(params): 107 return IsQuantizationMode( 108 params.precision_mode) and not params.use_calibration 109 110 111class GraphState(object): 112 ORIGINAL = 0 113 CALIBRATE = 1 114 INFERENCE = 2 115 116 117class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): 118 """Class to test Tensorflow-TensorRT integration.""" 119 120 @property 121 def trt_incompatible_op(self): 122 return math_ops.erfc 123 124 @property 125 def precision_modes(self): 126 return ["FP32", "FP16", "INT8"] 127 128 # str is bytes in py2, but unicode in py3. 129 def _ToUnicode(self, s): 130 if six.PY2: 131 if isinstance(s, unicode): 132 return s 133 return s.decode("utf-8") 134 else: 135 if isinstance(s, str): 136 return s 137 return s.decode("utf-8") 138 139 def _ToBytes(self, s): 140 if six.PY2: 141 if isinstance(s, unicode): 142 return s.encode("utf-8") 143 return s 144 else: 145 if isinstance(s, str): 146 return s.encode("utf-8") 147 return s 148 149 def _ToString(self, s): 150 if six.PY2: 151 if isinstance(s, unicode): 152 return s.encode("utf-8") 153 return s 154 else: 155 if isinstance(s, str): 156 return s 157 return s.decode("utf-8") 158 159 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 160 super(TfTrtIntegrationTestBase, self).__init__(methodName) 161 self._trt_test_params = None 162 self._disable_non_trt_optimizers = False 163 self._use_implicit_batch = True 164 self._profile_strategy = "Unknown" 165 166 def setUp(self): 167 """Setup method.""" 168 super(TfTrtIntegrationTestBase, self).setUp() 169 warnings.simplefilter("always") 170 171 if not is_tensorrt_enabled(): 172 self.skipTest("Test requires TensorRT") 173 174 def _GetTensorSpec(self, shape, mask, dtype, name): 175 # Set dimension i to None if mask[i] == False 176 assert len(shape) == len(mask) 177 new_shape = [s if m else None for s, m in zip(shape, mask)] 178 return tensor_spec.TensorSpec(new_shape, dtype, name) 179 180 def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes): 181 """Build test parameters. 182 183 The input_shapes and output_shapes arguments are known (static) shapes that 184 can be used to generate test data. To define the model, we also specify 185 corresponding input/output TensoSpecs. These are defined using the shape 186 arguments. For each input tensor we define: 187 188 input_spec = [None] + input_shape[1:] 189 190 and similarly for output shapes. This means that we leave the first (batch) 191 dimension unknown, the rest is just copied from the shapes arg. 192 193 Args: 194 graph_fn: The function to build the graph. 195 dtype: The element type. 196 input_shapes: The input shapes. 197 output_shapes: The output shapes. 198 199 Returns: 200 The test parameters. 201 """ 202 203 input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes] 204 output_mask = [ 205 [False] + [True] * (len(shape) - 1) for shape in output_shapes 206 ] 207 208 return self.BuildParamsWithMask(graph_fn, dtype, input_shapes, 209 output_shapes, input_mask, output_mask, [], 210 []) 211 212 def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes, 213 input_mask, output_mask, extra_inputs, extra_outputs): 214 """Build test parameters with static or dynamic input shapes. 215 216 To define dynamic shapes give a boolean mask that describes which 217 dimensions to treat as known. The values in input_mask are interpreted the 218 following way: 219 - True: known dim (use the corresponding value from input_shapes) 220 - False: unknown dim (replace the corresponding value from input_shapes 221 with None) 222 For example, to define the first two dimension with unknown size use 223 input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]]. 224 225 Args: 226 graph_fn: The function to build the graph. 227 dtype: The element type. 228 input_shapes: The input shapes. 229 output_shapes: The output shapes. 230 input_mask: The input shape masks. 231 output_mask: the output shape masks. 232 extra_inputs: list of additional input shapes 233 extra_outputs: list of additional outputs shapes 234 235 Returns: 236 The test parameters. 237 """ 238 239 def _ValidateShapes(shapes): 240 # Make sure all the shapes are fully specified. 241 for shape in shapes: 242 assert all(shape) 243 244 _ValidateShapes(input_shapes) 245 _ValidateShapes(output_shapes) 246 247 assert len(input_mask) == len(input_shapes) 248 assert len(output_mask) == len(output_shapes) 249 for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs): 250 assert len(input_shapes) == len(extra_in_shape) 251 assert len(output_shapes) == len(extra_out_shape) 252 253 return TfTrtIntegrationTestParams( 254 graph_fn=graph_fn, 255 input_specs=[ 256 self._GetTensorSpec(shape, mask, dtype, "input_%d" % i) 257 for i, (shape, mask) in enumerate(zip(input_shapes, input_mask)) 258 ], 259 output_specs=[ 260 self._GetTensorSpec(shape, mask, dtype, "output_%d" % i) 261 for i, (shape, mask) in enumerate(zip(output_shapes, output_mask)) 262 ], 263 input_dims=[input_shapes] + extra_inputs, 264 expected_output_dims=[output_shapes] + extra_outputs) 265 266 def DisableNonTrtOptimizers(self): 267 self._disable_non_trt_optimizers = True 268 269 def SetDynamicShapeModeAndProfileStrategy(self, profile_strategy="Range"): 270 self._use_implicit_batch = False 271 self._profile_strategy = profile_strategy 272 273 def GetParams(self): 274 """Returns a TfTrtIntegrationTestParams for the test.""" 275 raise NotImplementedError() 276 277 def GetConversionParams(self, run_params): 278 """Returns a TrtConversionParams for test.""" 279 conversion_params = trt_convert.TrtConversionParams( 280 # We use the minimum of all the batch sizes, so when multiple different 281 # input shapes are provided it'll always create new engines in the 282 # cache, and we can therefore test the cache behavior. 283 max_workspace_size_bytes=1 << 25, 284 precision_mode=run_params.precision_mode, 285 minimum_segment_size=2, 286 maximum_cached_engines=1, 287 use_calibration=run_params.use_calibration) 288 return conversion_params 289 290 def GetMaxBatchSize(self, run_params): 291 """Returns the max_batch_size that the converter should use for tests.""" 292 if run_params.dynamic_engine: 293 return None 294 batch_list = [] 295 for dims_list in self._GetParamsCached().input_dims: 296 assert dims_list 297 # Each list of shapes should have same batch size. 298 input_batches = [dims[0] for dims in dims_list] 299 assert max(input_batches) == min(input_batches) 300 batch_list.append(input_batches[0]) 301 return max(batch_list) 302 303 def ShouldRunTest(self, run_params): 304 """Whether to run the test.""" 305 # Ensure use_calibration=True in case of INT8 precision 306 return (run_params.use_calibration or not IsQuantizationMode( 307 run_params.precision_mode)), "test either calibration or non-INT8" 308 309 def ExpectedEnginesToBuild(self, run_params): 310 """Returns the expected engines to build, implemented by subclass.""" 311 raise NotImplementedError() 312 313 def ExpectedMaxBatchSizes(self, run_params): 314 """Returns the expected maximum batch sizes of the build engines.""" 315 return None 316 317 def ExpectedAbsoluteTolerance(self, run_params): 318 """The absolute tolerance to compare floating point results.""" 319 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 320 321 def ExpectedRelativeTolerance(self, run_params): 322 """The relative tolerance to compare floating point results.""" 323 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 324 325 def _GetParamsCached(self): 326 if self._trt_test_params is None: 327 self._trt_test_params = self.GetParams() 328 return self._trt_test_params 329 330 def _GetGPUOptions(self): 331 gpu_options = config_pb2.GPUOptions() 332 gpu_options.allow_growth = True 333 return gpu_options 334 335 def _GetConfigProto(self, run_params, graph_state): 336 """Get config proto based on specific settings.""" 337 conversion_params = self.GetConversionParams(run_params) 338 max_batch_size = self.GetMaxBatchSize(run_params) 339 340 if graph_state == GraphState.INFERENCE and run_params.convert_online: 341 rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( 342 conversion_params, 343 is_dynamic_op=run_params.dynamic_engine, 344 max_batch_size=max_batch_size, 345 disable_non_trt_optimizers=self._disable_non_trt_optimizers) 346 else: 347 rewriter_cfg = rewriter_config_pb2.RewriterConfig() 348 if self._disable_non_trt_optimizers: 349 trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg) 350 351 config = config_pb2.ConfigProto( 352 gpu_options=self._GetGPUOptions(), 353 graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg)) 354 return config 355 356 def _GetFeedNames(self): 357 params = self._GetParamsCached() 358 # Construct the feeds tensor names by appending :0 to the node names. 359 return [spec.name + ":0" for spec in params.input_specs] 360 361 def _GetFetchNames(self): 362 params = self._GetParamsCached() 363 # Construct the fetches tensor names by appending :0 to the node names. 364 return [spec.name + ":0" for spec in params.output_specs] 365 366 def _GetFeedDict(self, inputs_data): 367 return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)} 368 369 def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2): 370 """Run given graphdef multiple times using TF 1.x runtime.""" 371 params = self._GetParamsCached() 372 fetches = self._GetFetchNames() 373 g = ops.Graph() 374 with g.as_default(): 375 with self.session(graph=g, config=config, use_gpu=True) as sess: 376 loader.load(sess, [tag_constants.SERVING], saved_model_dir) 377 vals = [] 378 # Run for each input(s) shape 379 for expected_shapes, current_input_data in zip( 380 params.expected_output_dims, inputs_data): 381 val = None 382 for _ in range(num_runs): 383 new_val = sess.run(fetches, self._GetFeedDict(current_input_data)) 384 self.assertEqual(len(expected_shapes), len(new_val)) 385 for expected_shape, actual_val in zip(expected_shapes, new_val): 386 self.assertEqual(list(expected_shape), list(actual_val.shape)) 387 if val is not None: 388 # Some ops may have nondeterministic output. E.g. Conv2D may use 389 # winograd algorithm. So we set atol/rtol be larger than 1.e-06. 390 self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) 391 val = new_val 392 vals.append(val) 393 return vals 394 395 def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2): 396 """Run given graphdef multiple times using TF 2.0 runtime.""" 397 params = self._GetParamsCached() 398 root = load.load(saved_model_dir) 399 func = root.signatures[ 400 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 401 results = [] 402 for expected_shapes, current_input_data in zip(params.expected_output_dims, 403 inputs_data): 404 val = None 405 for _ in range(num_runs): 406 feed_dict = { 407 params.input_specs[i].name: current_input_data[i] 408 for i in range(len(params.input_specs)) 409 } 410 new_val = func(**feed_dict) 411 assert isinstance(new_val, dict) 412 # The key of the output map is always like output_i. 413 new_val = [new_val[key] for key in sorted(new_val)] 414 # Each element is an eager Tensor, and accessing individual elements is 415 # very expensive, so we convert them to a numpy array first. 416 new_val = [v.numpy() for v in new_val] 417 self.assertEqual(len(expected_shapes), len(new_val)) 418 for expected_shape, actual_val in zip(expected_shapes, new_val): 419 self.assertEqual(list(expected_shape), list(actual_val.shape)) 420 if val is not None: 421 # Some ops may have nondeterministic output. E.g. Conv2D may use 422 # winograd algorithm. So we set atol/rtol be larger than 1.e-06. 423 self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) 424 val = new_val 425 results.append(val) 426 427 return results 428 429 def _RunGraph(self, 430 run_params, 431 saved_model_dir, 432 inputs_data, 433 graph_state, 434 num_runs=2): 435 params = self._GetParamsCached() 436 for data in inputs_data: 437 assert len(params.input_specs) == len(data) 438 439 if run_params.is_v2: 440 results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state, 441 num_runs) 442 gc.collect() # Force GC to destroy the TRT engine cache. 443 return results 444 445 # The default config for tf.session is None. Create a config with 446 # TensorRTOptimizer enabled to support convert_online for inference. 447 config = None 448 # TODO(b/170220818): use the default session config to run inferenence 449 # graphs for the offline conversion case after fixing the bug. 450 if graph_state == GraphState.INFERENCE: 451 config = self._GetConfigProto(run_params, GraphState.INFERENCE) 452 return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs) 453 454 def _CreateConverter(self, run_params, saved_model_dir, conversion_params): 455 """Returns a TrtGraphConverter.""" 456 if run_params.is_v2: 457 converter_v2 = trt_convert.TrtGraphConverterV2( 458 input_saved_model_dir=saved_model_dir, 459 conversion_params=conversion_params, 460 use_dynamic_shape=not self._use_implicit_batch, 461 dynamic_shape_profile_strategy=self._profile_strategy) 462 if self._disable_non_trt_optimizers: 463 converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access 464 return converter_v2 465 466 converter_v1 = trt_convert.TrtGraphConverter( 467 input_saved_model_dir=saved_model_dir, 468 max_batch_size=self.GetMaxBatchSize(run_params), 469 max_workspace_size_bytes=conversion_params.max_workspace_size_bytes, 470 precision_mode=conversion_params.precision_mode, 471 minimum_segment_size=conversion_params.minimum_segment_size, 472 is_dynamic_op=run_params.dynamic_engine, 473 maximum_cached_engines=conversion_params.maximum_cached_engines, 474 use_calibration=conversion_params.use_calibration) 475 if self._disable_non_trt_optimizers: 476 converter_v1._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access 477 return converter_v1 478 479 def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data): 480 """Return trt converted graphdef in INT8 mode.""" 481 conversion_params = self.GetConversionParams(run_params) 482 logging.info(conversion_params) 483 assert conversion_params.precision_mode == "INT8" 484 assert run_params.dynamic_engine 485 assert conversion_params.maximum_cached_engines == 1 486 assert conversion_params.use_calibration 487 488 # We only support calibrating single engine. 489 # TODO(aaroey): fix this. 490 assert len(inputs_data) == 1 491 492 converter = self._CreateConverter(run_params, saved_model_dir, 493 conversion_params) 494 int8_gdef = converter.convert() 495 self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef, 496 GraphState.CALIBRATE) 497 498 converter.calibrate( 499 fetch_names=self._GetFetchNames(), 500 num_runs=5, 501 feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0])) 502 trt_saved_model_dir = self._GetSavedModelDir(run_params, 503 GraphState.CALIBRATE) 504 converter.save(trt_saved_model_dir) 505 return trt_saved_model_dir 506 507 def _ShouldConverterBuild(self, run_params): 508 return True 509 510 def _GetInferGraph(self, run_params, saved_model_dir): 511 """Return trt converted graphdef.""" 512 conversion_params = self.GetConversionParams(run_params) 513 logging.info(conversion_params) 514 515 converter = self._CreateConverter(run_params, saved_model_dir, 516 conversion_params) 517 converter.convert() 518 519 if not self._use_implicit_batch and self._ShouldConverterBuild(run_params): 520 logging.info("Using build mode") 521 522 def _BuildInputFn(): 523 for shapes in self._GetParamsCached().input_dims: 524 yield [np.zeros(x).astype(np.float32) for x in shapes] 525 526 converter.build(input_fn=_BuildInputFn) 527 528 trt_saved_model_dir = self._GetSavedModelDir(run_params, 529 GraphState.INFERENCE) 530 converter.save(trt_saved_model_dir) 531 return trt_saved_model_dir 532 533 def _GetGraphStateLabel(self, graph_state): 534 if graph_state == GraphState.ORIGINAL: 535 return "Original" 536 elif graph_state == GraphState.CALIBRATE: 537 return "CalibEngine" 538 elif graph_state == GraphState.INFERENCE: 539 return "InferEngine" 540 else: 541 return "UnknownState" 542 543 def _WriteGraph(self, run_params, gdef, graph_state): 544 temp_dir = os.getenv("TRT_TEST_TMPDIR") 545 if not temp_dir: 546 return 547 548 graph_name = ( 549 self.__class__.__name__ + "_" + run_params.test_name + "_" + 550 self._GetGraphStateLabel(graph_state) + ".pbtxt") 551 logging.info("Writing graph to %s/%s", temp_dir, graph_name) 552 graph_io.write_graph(gdef, temp_dir, graph_name) 553 554 # Removes the prefix(s) of function name(s). 555 # The input value can be a string or a sequence of string. 556 def _Canonicalize(self, value): 557 if isinstance(value, str): 558 return self._ToString(value.split("/")[-1]) 559 elif isinstance(value, collections.abc.Iterable): 560 return set(self._Canonicalize(nm) for nm in value) 561 else: 562 raise TypeError( 563 "'_Canonicalize' can only be used on strings or sequence of strings!") 564 565 # Removes the graph sequence number prefix from the name(s) only if the 566 # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts 567 # such a prefix exists. 568 # The input value can be a string or a sequence of string. 569 def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix): 570 if isinstance(value, str): 571 match = re.search(r"TRTEngineOp_\d+_", value) 572 has_prefix = match and value.startswith(match.group(0)) 573 assert (not expecting_prefix) or has_prefix 574 if has_prefix: 575 parts = value.split("_", maxsplit=2) 576 assert len(parts) == 3 577 return parts[0] + "_" + parts[2] 578 return value 579 elif isinstance(value, collections.abc.Iterable): 580 return set( 581 self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix) 582 for nm in value) 583 else: 584 raise TypeError( 585 "'_RemoveGraphSequenceNumberImpl' can only be used on strings " 586 "or sequence of strings!") 587 588 def _RemoveGraphSequenceNumber(self, name): 589 return self._RemoveGraphSequenceNumberImpl(name, True) 590 591 def _MayRemoveGraphSequenceNumber(self, name): 592 return self._RemoveGraphSequenceNumberImpl(name, False) 593 594 def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef): 595 old_to_new_node_map = { 596 self._ToString(node.name): self._ToString(node.name) 597 for node in original_gdef.node 598 } 599 for engine_name, node_names in expected_engines.items(): 600 for node_name in node_names: 601 old_to_new_node_map[node_name] = engine_name 602 name_to_node_map = { 603 self._ToString(node.name): node for node in original_gdef.node 604 } 605 606 def _InputName(inp): 607 inp = self._ToString(inp) 608 prefix = "" 609 if inp[0] == "^": 610 prefix = "^" 611 inp = inp[1:] 612 parts = inp.split(":") 613 if len(parts) > 1 and parts[-1].isdigit(): 614 inp = inp[:-len(parts[-1]) - 1] 615 return (prefix, inp) 616 617 # Compute the expected mapping from each node to its input nodes. 618 expected_input_map = {} 619 removed_const_nodes = set([ 620 self._ToString(node.name) 621 for node in original_gdef.node 622 if node.op == "Const" 623 ]) 624 for node in original_gdef.node: 625 name_str = self._ToString(node.name) 626 target_node_name = old_to_new_node_map[name_str] 627 is_engine_op = (target_node_name != name_str) 628 if target_node_name not in expected_input_map: 629 expected_input_map[target_node_name] = set() 630 input_set = expected_input_map[target_node_name] 631 for inp in node.input: 632 (prefix, inp_name) = _InputName(inp) 633 mapped_input = old_to_new_node_map[inp_name] 634 # Add the input only if it's outside the segment (note that it could be 635 # in a different engine). 636 if not is_engine_op or (mapped_input != target_node_name and 637 name_to_node_map[inp_name].op != "Const"): 638 input_set.add(prefix + mapped_input) 639 if mapped_input in removed_const_nodes: 640 removed_const_nodes.remove(mapped_input) 641 # Remove const nodes that have no outputs. 642 expected_input_map = { 643 k: v 644 for k, v in expected_input_map.items() 645 if k not in removed_const_nodes 646 } 647 648 # Compute the actual mapping from each node to its input nodes. If a cast 649 # op doesn't exist in the original graph, we replace the use of the cast op 650 # with the input of the op. This allows the verification to handle the case 651 # where the TF-TRT bridge splits a cast op into a chain of two cast ops. 652 new_cast_op_name_to_node_map = { 653 node.name: node 654 for node in converted_gdef.node 655 if (node.name not in old_to_new_node_map and node.op == "Cast") 656 } 657 actual_input_map = {} 658 for node in converted_gdef.node: 659 name_str = node.name 660 # Only nodes from the original graph or TRTEngineOp nodes are added as 661 # keys to the map. 662 if node.op == "TRTEngineOp": 663 name_str = self._RemoveGraphSequenceNumber(name_str) 664 elif name_str not in old_to_new_node_map: 665 continue 666 actual_input_map[name_str] = set() 667 input_set = actual_input_map[name_str] 668 for inp in node.input: 669 (prefix, node_name) = _InputName(inp) 670 node_name = self._MayRemoveGraphSequenceNumber(node_name) 671 if node_name in new_cast_op_name_to_node_map: 672 (prefix, node_name) = _InputName( 673 new_cast_op_name_to_node_map[node_name].input[0]) 674 input_set.add(prefix + node_name) 675 676 self.assertEqual( 677 expected_input_map, 678 actual_input_map, 679 msg="\nexpected:\n%s\nvs actual:\n%s" % 680 (sorted(expected_input_map.items()), sorted(actual_input_map.items()))) 681 682 def _VerifyMaxBatchSizeAnnotations( 683 self, 684 expected_engines, 685 original_gdef, 686 converted_gdef, 687 default_max_batch_size, 688 expected_max_batch_sizes=None, 689 ): 690 """Verifies the max batch size annotations in the original and converted GraphDef. 691 692 Args: 693 expected_engines: A sequence of engines names. 694 original_gdef: GraphDef. The graph def before TensorRT conversion. 695 converted_gdef: GraphDef. The graph def after TensorRT conversion. 696 default_max_batch_size: The default maximum batch size to use if no node 697 inside a segment is annoted with a customized max batch size. This value 698 is None when the graph is converted to TF-TRT with dynamic engines. 699 expected_max_batch_sizes: Optional. A sequence of max batch sizes for all 700 the engines. `None` if does not check enforce max batch sizes. 701 """ 702 if isinstance(expected_max_batch_sizes, collections.abc.Collection): 703 self.assertEqual(len(expected_max_batch_sizes), len(expected_engines)) 704 else: 705 self.assertIsNone( 706 expected_max_batch_sizes, 707 "'expected_max_batch_sizes' shall only be a sequence " 708 "of integers or `None`.") 709 710 def _ChainAllNodes(graph_def): 711 return itertools.chain( 712 graph_def.node, 713 itertools.chain( 714 *[func.node_def for func in graph_def.library.function])) 715 716 old_name_to_node_map = { 717 self._ToString(node.name): node 718 for node in _ChainAllNodes(original_gdef) 719 } 720 new_name_to_func_map = { 721 self._ToString(func.signature.name): func 722 for func in converted_gdef.library.function 723 } 724 725 def _DetectStaticBatchSize(node_def): 726 """Returns the static batch size of an operation or None. 727 728 It is incorrect to use the output shapes to find the batch size of an 729 operation, as the segmenter actually uses the input shapes. However, it is 730 a simplication and works for most of the cases for the test purposes. 731 732 Args: 733 node_def: `tf.NodeDef`. The target node for analysis. 734 735 Returns: 736 If all the outputs of the node have the same static batch size, returns 737 the int value for the batch size. Otherwise returns None. 738 """ 739 shapes = node_def.attr["_output_shapes"].list.shape 740 batch_size = set( 741 list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes) 742 if len(batch_size) == 1 and list(batch_size)[0] >= 1: 743 return list(batch_size)[0] 744 return None 745 746 name_to_engines_map = {} 747 actual_max_batch_sizes = [] 748 for node in _ChainAllNodes(converted_gdef): 749 if node.op == "TRTEngineOp": 750 engine = node 751 engine_name = self._RemoveGraphSequenceNumber( 752 self._Canonicalize(self._ToString(engine.name))) 753 self.assertIn(engine_name, expected_engines) 754 name_to_engines_map[engine_name] = engine 755 # The input nodes shall not have the conflicting annotation (no 756 # annotation or the same annotation) with the maximum batch size 757 # annotation. If the engine has maximum batch size annotation as the 758 # non-default maximum batch size, then at least one input node shall 759 # have the same annotation to be the source. 760 self.assertIn("max_batch_size", node.attr) 761 engine_max_batch_size = node.attr["max_batch_size"].i 762 self.assertIsInstance(engine_max_batch_size, int) 763 actual_max_batch_sizes.append(engine_max_batch_size) 764 seg_func = node.attr["segment_func"].func 765 self.assertIsNotNone(seg_func) 766 self.assertIn(seg_func.name, new_name_to_func_map) 767 seg_func_def = new_name_to_func_map[seg_func.name] 768 logging.info("Segment function name: %s. Including %d nodes.", 769 seg_func.name, len(seg_func_def.node_def)) 770 node_max_batch_size_all_none = True 771 # Use the native segment to search for replaced nodes 772 for alternative_node in seg_func_def.node_def: 773 node_name = self._Canonicalize(self._ToString(alternative_node.name)) 774 if node_name not in old_name_to_node_map: 775 continue 776 original_node = old_name_to_node_map[node_name] 777 node_max_batch_size = None 778 if "_tftrt_op_max_batch_size" in original_node.attr: 779 node_max_batch_size = original_node.attr[ 780 "_tftrt_op_max_batch_size"].i 781 elif (original_node.op != "Const" and 782 alternative_node.op != "Const" and 783 "_output_shapes" in original_node.attr): 784 node_max_batch_size = _DetectStaticBatchSize(original_node) 785 logging.info( 786 "'{%s}(%s)'s max batch size annotation is %s. " 787 "'{%s}'s max batch size is %s.", node_name, original_node.op, 788 str(node_max_batch_size), engine_name, str(engine_max_batch_size)) 789 node_max_batch_size_all_none &= node_max_batch_size is None 790 self.assertTrue(engine_max_batch_size == node_max_batch_size or 791 node_max_batch_size is None) 792 logging.info("'{%s}'s max batch size is %d.", engine_name, 793 engine_max_batch_size) 794 self.assertTrue(default_max_batch_size is None or 795 engine_max_batch_size == default_max_batch_size or 796 not node_max_batch_size_all_none) 797 798 self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys())) 799 if expected_max_batch_sizes is not None: 800 self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes) 801 802 def _GetGraphDef(self, run_params, gdef_or_saved_model_dir): 803 if isinstance(gdef_or_saved_model_dir, str): 804 if run_params.is_v2: 805 root = load.load(gdef_or_saved_model_dir) 806 func = root.signatures[ 807 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 808 gdef = func.graph.as_graph_def() 809 # Manually unref the loaded saved model and force GC to destroy the TRT 810 # engine cache after load(). There is currently a reference cycle in 2.0 811 # which prevents auto deletion of the resource. 812 # TODO(laigd): fix this. 813 del func 814 del root 815 gc.collect() 816 return gdef 817 return saved_model_utils.get_meta_graph_def( 818 gdef_or_saved_model_dir, tag_constants.SERVING).graph_def 819 assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef) 820 return gdef_or_saved_model_dir 821 822 def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify, 823 graph_state): 824 expected_engines = self.ExpectedEnginesToBuild(run_params) 825 num_engines = 0 826 functions = [f.signature.name for f in gdef_to_verify.library.function] 827 for node in gdef_to_verify.node: 828 if node.op == "TRTEngineOp": 829 logging.info("Found TRTEngineOp: " + node.name) 830 num_engines += 1 831 segment_funcdef_name = node.attr["segment_func"].func.name 832 function_name = node.name + "_native_segment" 833 is_dynamic_engine = not node.attr["static_engine"].b 834 self.assertNotEmpty(segment_funcdef_name, node.name) 835 self.assertIn(function_name, functions) 836 if (not IsQuantizationWithCalibration(run_params) and 837 not is_dynamic_engine): 838 self.assertTrue(len(node.attr["serialized_segment"].s), node.name) 839 self.assertIn( 840 self._RemoveGraphSequenceNumber(node.name), expected_engines) 841 if IsQuantizationWithoutCalibration(run_params): 842 # TODO(bixia): Refine this check by inspecting nodes in the engine. 843 if self._ToBytes("INT8") != node.attr["precision_mode"].s: 844 self.assertEqual( 845 self._ToBytes("FP16"), node.attr["precision_mode"].s, node.name) 846 else: 847 self.assertEqual( 848 self._ToBytes(run_params.precision_mode), 849 node.attr["precision_mode"].s, node.name) 850 851 self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, 852 node.name) 853 self.assertEqual(node.attr["use_calibration"].b, 854 run_params.use_calibration, node.name) 855 856 has_calibration_data = len(node.attr["calibration_data"].s) 857 if (IsQuantizationWithCalibration(run_params) and 858 graph_state == GraphState.INFERENCE): 859 self.assertTrue(has_calibration_data, node.name) 860 else: 861 self.assertFalse(has_calibration_data, node.name) 862 if graph_state == GraphState.ORIGINAL: 863 self.assertEqual(0, num_engines) 864 else: 865 self.assertEqual(num_engines, len(expected_engines)) 866 if isinstance(expected_engines, dict): 867 self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify) 868 self._VerifyMaxBatchSizeAnnotations( 869 expected_engines=expected_engines, 870 original_gdef=original_gdef, 871 converted_gdef=gdef_to_verify, 872 expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), 873 default_max_batch_size=self.GetMaxBatchSize(run_params)) 874 875 def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify, 876 graph_state): 877 if graph_state == GraphState.ORIGINAL: 878 return 879 expected_engines = self.ExpectedEnginesToBuild(run_params) 880 all_op_names = [node.name for node in gdef_to_verify.node] 881 trt_op_names = [] 882 for func in gdef_to_verify.library.function: 883 if not re.search(r"TRTEngineOp_\d+_\d+_native_segment", 884 func.signature.name): 885 for node in func.node_def: 886 all_op_names.append(node.name) 887 if node.op == "TRTEngineOp": 888 trt_op_names.append(node.name) 889 if not self._use_implicit_batch: 890 self.assertEqual( 891 self._ToString(node.attr["profile_strategy"].s).lower(), 892 self._profile_strategy.lower()) 893 894 all_op_names = self._Canonicalize(all_op_names) 895 trt_op_names = self._RemoveGraphSequenceNumber( 896 self._Canonicalize(trt_op_names)) 897 898 if isinstance(expected_engines, dict): 899 # For simplicity we don't verify the connections inside the engine in 900 # 2.0, but we still make sure that the converted ops are gone from the 901 # graph. 902 unexpected_names = set(nest.flatten(expected_engines.values())) 903 self.assertEmpty( 904 [name for name in unexpected_names if name in all_op_names]) 905 expected_engines = set(expected_engines.keys()) 906 907 self.assertEqual(set(expected_engines), trt_op_names) 908 909 def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir, 910 gdef_or_saved_model_dir_to_verify, graph_state): 911 original_gdef = self._GetGraphDef(run_params, 912 original_gdef_or_saved_model_dir) 913 gdef_to_verify = self._GetGraphDef(run_params, 914 gdef_or_saved_model_dir_to_verify) 915 self._WriteGraph(run_params, gdef_to_verify, graph_state) 916 if run_params.is_v2: 917 self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify, 918 graph_state) 919 else: 920 self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify, 921 graph_state) 922 923 def _GetSavedModelDir(self, run_params, graph_state): 924 test_tmpdir = os.getenv("TRT_TEST_TMPDIR") 925 if test_tmpdir: 926 saved_model_dir = os.path.join( 927 test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name + 928 "_" + self._GetGraphStateLabel(graph_state)) 929 try: 930 # For TF 1.x we need to make sure the output directory doesn't exist 931 # before exporting the saved model. 932 shutil.rmtree(saved_model_dir) 933 except OSError as e: 934 if e.errno != errno.ENOENT: 935 raise 936 return saved_model_dir 937 return tempfile.mkdtemp(dir=self.get_temp_dir()) 938 939 def _MakeSavedModelV1(self, run_params): 940 """Write the saved model as an input for testing.""" 941 params = self._GetParamsCached() 942 g = ops.Graph() 943 with g.as_default(): 944 inputs = [] 945 for spec in params.input_specs: 946 inp = array_ops.placeholder( 947 dtype=spec.dtype, shape=spec.shape, name=spec.name) 948 inputs.append(inp) 949 outputs = params.graph_fn(*inputs) 950 if not isinstance(outputs, list) and not isinstance(outputs, tuple): 951 outputs = [outputs] 952 953 signature_def = signature_def_utils.build_signature_def( 954 inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs}, 955 outputs={out.op.name: utils.build_tensor_info(out) for out in outputs}, 956 method_name=signature_constants.PREDICT_METHOD_NAME) 957 958 saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) 959 saved_model_builder = builder.SavedModelBuilder(saved_model_dir) 960 with self.session( 961 graph=g, config=self._GetConfigProto(run_params, 962 GraphState.ORIGINAL)) as sess: 963 saved_model_builder.add_meta_graph_and_variables( 964 sess, [tag_constants.SERVING], 965 signature_def_map={ 966 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 967 signature_def 968 }) 969 saved_model_builder.save() 970 return saved_model_dir 971 972 def _MakeSavedModelV2(self, run_params): 973 params = self._GetParamsCached() 974 root = tracking.AutoTrackable() 975 root.run = def_function.function( 976 params.graph_fn, input_signature=params.input_specs) 977 saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) 978 logging.info("Saving input SavedModel to %s", saved_model_dir) 979 save.save(root, saved_model_dir, 980 {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run}) 981 return saved_model_dir 982 983 def _MakeSavedModel(self, run_params): 984 if run_params.is_v2: 985 return self._MakeSavedModelV2(run_params) 986 return self._MakeSavedModelV1(run_params) 987 988 def RunTest(self, run_params): 989 with trace.Trace(run_params.test_name): 990 should_run, reason_for_skipping = self.ShouldRunTest(run_params) 991 if not should_run: 992 return self.skipTest(reason_for_skipping) 993 994 saved_model_dir = self._MakeSavedModel(run_params) 995 996 np.random.seed(12345) # Fix the seed so the test is deterministic. 997 inputs_data = [] 998 input_specs = self._GetParamsCached().input_specs 999 for dim_list in self._GetParamsCached().input_dims: 1000 assert len(input_specs) == len(dim_list) 1001 current_input_data = [] 1002 for spec, np_shape in zip(input_specs, dim_list): 1003 np_dtype = spec.dtype.as_numpy_dtype() 1004 # Multiply the input by some constant to avoid all zeros input for 1005 # integer types. 1006 scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0 1007 # TODO(laigd): add debug options. E.g. we can set the input data to be 1008 # continuous natural numbers: 1009 # seq = np.arange(np.prod(np_shape)) 1010 # seq.resize(np_shape) 1011 # current_inputs_data.append(scale * seq.astype(np_dtype)) 1012 data = (scale * np.random.random_sample(np_shape)).astype(np_dtype) 1013 if run_params.is_v2: 1014 with ops.device("/GPU:0"): 1015 data = ops.convert_to_tensor(data) 1016 current_input_data.append(data) 1017 inputs_data.append(current_input_data) 1018 1019 # Verify the original graph. 1020 self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir, 1021 GraphState.ORIGINAL) 1022 1023 # Run the original graph without TensorRT to get the reference result. 1024 logging.info("Running original graph w/o TensorRT\n") 1025 ref_result = self._RunGraph( 1026 run_params, 1027 saved_model_dir, 1028 inputs_data, 1029 GraphState.ORIGINAL, 1030 num_runs=1) 1031 1032 # Run calibration if necessary. 1033 if IsQuantizationWithCalibration(run_params): 1034 infer_saved_model_dir = self._GetCalibratedInferGraph( 1035 run_params, saved_model_dir, inputs_data) 1036 self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, 1037 GraphState.INFERENCE) 1038 elif not run_params.convert_online: 1039 infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir) 1040 self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, 1041 GraphState.INFERENCE) 1042 else: 1043 infer_saved_model_dir = saved_model_dir 1044 1045 # Run the inference graph, either using the converted graph or the 1046 # original graph with convert_online == True. 1047 logging.info("Running final inference graph\n") 1048 result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data, 1049 GraphState.INFERENCE) 1050 self.assertAllClose( 1051 ref_result, 1052 result, 1053 atol=self.ExpectedAbsoluteTolerance(run_params), 1054 rtol=self.ExpectedRelativeTolerance(run_params)) 1055 1056 def testIdempotence(self): 1057 # Test that applying tensorrt optimizer or offline conversion tools multiple 1058 # times to the same graph will result in same graph. 1059 # 1060 # TODO(aaroey): implement this. 1061 pass 1062 1063 1064def _GetTestConfigsV1(): 1065 """Returns the config combinations to run the test.""" 1066 convert_online, convert_offline = True, False 1067 dynamic_engine, static_engine = True, False 1068 use_calibration, no_calibration = True, False 1069 1070 # Add all possible test cases and let the derived test class to decide 1071 # whether to run specific ones with ShouldRunTest(). 1072 # 1073 # Note: INT8 without calibration behaves like FP32/FP16. 1074 opts = list( 1075 itertools.product([FP32, FP16, INT8], [convert_online, convert_offline], 1076 [dynamic_engine, static_engine], [no_calibration])) 1077 # We always run calibration with offline tool. 1078 # TODO(aaroey): static calibration engine is not supported yet. 1079 opts.append((INT8, convert_offline, dynamic_engine, use_calibration)) 1080 return opts 1081 1082 1083def _GetTestConfigsV2(): 1084 """Returns the config combinations to run the test.""" 1085 convert_offline = False 1086 # TODO(laigd): add support for static_engine. 1087 dynamic_engine = True 1088 # TODO(laigd): add support for calibration. 1089 no_calibration = False 1090 1091 # Add all possible test cases and let the derived test class to decide 1092 # whether to run specific ones with ShouldRunTest(). 1093 # 1094 # Note: 1095 # - In TF2.0 the conversion always produce dynamic engine, and we don't test 1096 # the offline mode here. 1097 # - For simplicity we don't test online conversion which requires setting the 1098 # Grappler config in default eager context. 1099 # - INT8 without calibration behaves like FP32/FP16. 1100 opts = list( 1101 itertools.product([FP32, FP16, INT8], [convert_offline], [dynamic_engine], 1102 [no_calibration])) 1103 # We always run calibration with offline tool. 1104 # TODO(aaroey): INT8+calibration is not supported yet in V2. 1105 # opts.append((INT8, convert_offline, dynamic_engine, use_calibration)) 1106 return opts 1107 1108 1109def _GetTest(run_params): 1110 """Gets a single test method based on the parameters.""" 1111 1112 def _Test(self): 1113 logging.info( 1114 "Running test %s with parameters: convert_online=%s, " 1115 "precision_mode=%s, dynamic_engine=%s", run_params.test_name, 1116 run_params.convert_online, run_params.precision_mode, 1117 run_params.dynamic_engine) 1118 self.RunTest(run_params) 1119 1120 return _Test 1121 1122 1123def _AddTestsFor(test_class, is_v2): 1124 """Adds test methods to TfTrtIntegrationTestBase for specific TF version.""" 1125 opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1() 1126 for (precision_mode, convert_online, dynamic_engine, use_calibration) in opts: 1127 conversion = "OnlineConversion" if convert_online else "OfflineConversion" 1128 engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" 1129 calibration_type = "UseCalibration" if use_calibration else "NoCalibration" 1130 test_name = "%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt", 1131 conversion, engine_type, precision_mode, 1132 calibration_type) 1133 run_params = RunParams( 1134 convert_online=convert_online, 1135 precision_mode=precision_mode, 1136 dynamic_engine=dynamic_engine, 1137 test_name=test_name, 1138 use_calibration=use_calibration, 1139 is_v2=is_v2) 1140 if is_v2: 1141 setattr(test_class, test_name, 1142 test_util.run_v2_only(_GetTest(run_params))) 1143 else: 1144 setattr(test_class, test_name, 1145 test_util.run_v1_only("", _GetTest(run_params))) 1146 1147 1148def _AddTests(test_class): 1149 """Adds test methods to TfTrtIntegrationTestBase.""" 1150 _AddTestsFor(test_class, is_v2=False) 1151 _AddTestsFor(test_class, is_v2=True) 1152 1153 1154if is_tensorrt_enabled(): 1155 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" 1156 _AddTests(TfTrtIntegrationTestBase) 1157