1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import errno 23import gc 24import itertools 25import os 26import re 27import shutil 28import tempfile 29import warnings 30 31import numpy as np 32import six 33 34from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt_version 35from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled 36from tensorflow.core.framework import graph_pb2 37from tensorflow.core.protobuf import config_pb2 38from tensorflow.core.protobuf import rewriter_config_pb2 39from tensorflow.python.compiler.tensorrt import trt_convert 40from tensorflow.python.compiler.tensorrt import utils as trt_utils 41from tensorflow.python.eager import def_function 42from tensorflow.python.framework import graph_io 43from tensorflow.python.framework import ops 44from tensorflow.python.framework import tensor_spec 45from tensorflow.python.framework import test_util 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.profiler import trace 50from tensorflow.python.saved_model import builder 51from tensorflow.python.saved_model import load 52from tensorflow.python.saved_model import loader 53from tensorflow.python.saved_model import save 54from tensorflow.python.saved_model import signature_constants 55from tensorflow.python.saved_model import signature_def_utils 56from tensorflow.python.saved_model import tag_constants 57from tensorflow.python.saved_model import utils 58from tensorflow.python.tools import saved_model_utils 59from tensorflow.python.training.tracking import tracking 60from tensorflow.python.util import nest 61 62TfTrtIntegrationTestParams = collections.namedtuple( 63 "TfTrtIntegrationTestParams", 64 [ 65 # A function that creates the TF graph for testing. 66 "graph_fn", 67 # A list of specifications for input tensors. 68 "input_specs", 69 # A list of specifications for output tensors. 70 "output_specs", 71 # A list of list of input shapes. Each shape must match the 72 # corresponding element in `input_specs`. 73 "input_dims", 74 # A list of list of expected output shapes. Each shape must match the 75 # corresponding element in `output_specs`. 76 "expected_output_dims" 77 ]) 78 79RunParams = collections.namedtuple( 80 "RunParams", 81 [ 82 # Whether to run the conversion online with RewriterConfig, or offline 83 # with TrtGraphConverter. 84 "convert_online", 85 "precision_mode", 86 "dynamic_engine", 87 "use_calibration", 88 "test_name", 89 # Is this test for TF 2.0? 90 "is_v2", 91 ]) 92 93FP32 = "FP32" 94FP16 = "FP16" 95INT8 = "INT8" 96PRECISION_MODES = [FP32, FP16, INT8] 97 98 99def IsQuantizationMode(mode): 100 return mode == "INT8" 101 102 103def IsQuantizationWithCalibration(params): 104 return IsQuantizationMode(params.precision_mode) and params.use_calibration 105 106 107def IsTensorRTVersionGreaterEqual(major, minor=0, patch=0): 108 ver = get_linked_tensorrt_version() 109 return ver[0] > major or (ver[0] == major and ver[1] > minor) or ( 110 ver[0] == major and ver[1] == minor and ver[2] >= patch) 111 112 113class GraphState(object): 114 ORIGINAL = 0 115 CALIBRATE = 1 116 INFERENCE = 2 117 118 119class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): 120 """Class to test Tensorflow-TensorRT integration.""" 121 122 @property 123 def trt_incompatible_op(self): 124 return math_ops.erf 125 126 @property 127 def precision_modes(self): 128 return ["FP32", "FP16", "INT8"] 129 130 # str is bytes in py2, but unicode in py3. 131 def _ToUnicode(self, s): 132 if six.PY2: 133 if isinstance(s, unicode): 134 return s 135 return s.decode("utf-8") 136 else: 137 if isinstance(s, str): 138 return s 139 return s.decode("utf-8") 140 141 def _ToBytes(self, s): 142 if six.PY2: 143 if isinstance(s, unicode): 144 return s.encode("utf-8") 145 return s 146 else: 147 if isinstance(s, str): 148 return s.encode("utf-8") 149 return s 150 151 def _ToString(self, s): 152 if six.PY2: 153 if isinstance(s, unicode): 154 return s.encode("utf-8") 155 return s 156 else: 157 if isinstance(s, str): 158 return s 159 return s.decode("utf-8") 160 161 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 162 super(TfTrtIntegrationTestBase, self).__init__(methodName) 163 self._trt_test_params = None 164 self._disable_non_trt_optimizers = False 165 self._use_implicit_batch = True 166 167 def setUp(self): 168 """Setup method.""" 169 super(TfTrtIntegrationTestBase, self).setUp() 170 warnings.simplefilter("always") 171 172 def _GetTensorSpec(self, shape, mask, dtype, name): 173 # Set dimension i to None if mask[i] == False 174 assert len(shape) == len(mask) 175 new_shape = [s if m else None for s, m in zip(shape, mask)] 176 return tensor_spec.TensorSpec(new_shape, dtype, name) 177 178 def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes): 179 """Build test parameters. 180 181 The input_shapes and output_shapes arguments are known (static) shapes that 182 can be used to generate test data. To define the model, we also specify 183 corresponding input/output TensoSpecs. These are defined using the shape 184 arguments. For each input tensor we define: 185 186 input_spec = [None] + input_shape[1:] 187 188 and similarly for output shapes. This means that we leave the first (batch) 189 dimension unknown, the rest is just copied from the shapes arg. 190 191 Args: 192 graph_fn: The function to build the graph. 193 dtype: The element type. 194 input_shapes: The input shapes. 195 output_shapes: The output shapes. 196 197 Returns: 198 The test parameters. 199 """ 200 201 input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes] 202 output_mask = [ 203 [False] + [True] * (len(shape) - 1) for shape in output_shapes 204 ] 205 206 return self.BuildParamsWithMask(graph_fn, dtype, input_shapes, 207 output_shapes, input_mask, output_mask, [], 208 []) 209 210 def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes, 211 input_mask, output_mask, extra_inputs, extra_outputs): 212 """Build test parameters with static or dynamic input shapes. 213 214 To define dynamic shapes give a boolean mask that describes which 215 dimensions to treat as known. The values in input_mask are interpreted the 216 following way: 217 - True: known dim (use the corresponding value from input_shapes) 218 - False: unknown dim (replace the corresponding value from input_shapes 219 with None) 220 For example, to define the first two dimension with unknown size use 221 input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]]. 222 223 Args: 224 graph_fn: The function to build the graph. 225 dtype: The element type. 226 input_shapes: The input shapes. 227 output_shapes: The output shapes. 228 input_mask: The input shape masks. 229 output_mask: the output shape masks. 230 extra_inputs: list of additional input shapes 231 extra_outputs: list of additional outputs shapes 232 233 Returns: 234 The test parameters. 235 """ 236 237 def _ValidateShapes(shapes): 238 # Make sure all the shapes are fully specified. 239 for shape in shapes: 240 assert all(shape) 241 242 _ValidateShapes(input_shapes) 243 _ValidateShapes(output_shapes) 244 245 assert len(input_mask) == len(input_shapes) 246 assert len(output_mask) == len(output_shapes) 247 for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs): 248 assert len(input_shapes) == len(extra_in_shape) 249 assert len(output_shapes) == len(extra_out_shape) 250 251 return TfTrtIntegrationTestParams( 252 graph_fn=graph_fn, 253 input_specs=[ 254 self._GetTensorSpec(shape, mask, dtype, "input_%d" % i) 255 for i, (shape, mask) in enumerate(zip(input_shapes, input_mask)) 256 ], 257 output_specs=[ 258 self._GetTensorSpec(shape, mask, dtype, "output_%d" % i) 259 for i, (shape, mask) in enumerate(zip(output_shapes, output_mask)) 260 ], 261 input_dims=[input_shapes] + extra_inputs, 262 expected_output_dims=[output_shapes] + extra_outputs) 263 264 def DisableNonTrtOptimizers(self): 265 self._disable_non_trt_optimizers = True 266 267 def DisableImplicitBatchMode(self): 268 self._use_implicit_batch = False 269 270 def GetParams(self): 271 """Returns a TfTrtIntegrationTestParams for the test.""" 272 raise NotImplementedError() 273 274 def GetConversionParams(self, run_params): 275 """Returns a TrtConversionParams for test.""" 276 conversion_params = trt_convert.TrtConversionParams( 277 # We use the minimum of all the batch sizes, so when multiple different 278 # input shapes are provided it'll always create new engines in the 279 # cache, and we can therefore test the cache behavior. 280 max_workspace_size_bytes=1 << 25, 281 precision_mode=run_params.precision_mode, 282 minimum_segment_size=2, 283 maximum_cached_engines=1, 284 use_calibration=run_params.use_calibration) 285 return conversion_params 286 287 def GetMaxBatchSize(self, run_params): 288 """Returns the max_batch_size that the converter should use for tests.""" 289 if run_params.dynamic_engine: 290 return None 291 batch_list = [] 292 for dims_list in self._GetParamsCached().input_dims: 293 assert dims_list 294 # Each list of shapes should have same batch size. 295 input_batches = [dims[0] for dims in dims_list] 296 assert max(input_batches) == min(input_batches) 297 batch_list.append(input_batches[0]) 298 return max(batch_list) 299 300 def ShouldRunTest(self, run_params): 301 """Whether to run the test.""" 302 # Ensure use_calibration=True in case of INT8 precision 303 return (run_params.use_calibration or not IsQuantizationMode( 304 run_params.precision_mode)), "test either calibration or non-INT8" 305 306 def ExpectedEnginesToBuild(self, run_params): 307 """Returns the expected engines to build, implemented by subclass.""" 308 raise NotImplementedError() 309 310 def ExpectedMaxBatchSizes(self, run_params): 311 """Returns the expected maximum batch sizes of the build engines.""" 312 return None 313 314 def ExpectedAbsoluteTolerance(self, run_params): 315 """The absolute tolerance to compare floating point results.""" 316 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 317 318 def ExpectedRelativeTolerance(self, run_params): 319 """The relative tolerance to compare floating point results.""" 320 return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02 321 322 def _GetParamsCached(self): 323 if self._trt_test_params is None: 324 self._trt_test_params = self.GetParams() 325 return self._trt_test_params 326 327 def _GetGPUOptions(self): 328 gpu_options = config_pb2.GPUOptions() 329 gpu_options.allow_growth = True 330 return gpu_options 331 332 def _GetConfigProto(self, run_params, graph_state): 333 """Get config proto based on specific settings.""" 334 conversion_params = self.GetConversionParams(run_params) 335 max_batch_size = self.GetMaxBatchSize(run_params) 336 337 if graph_state == GraphState.INFERENCE and run_params.convert_online: 338 rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( 339 conversion_params, 340 is_dynamic_op=run_params.dynamic_engine, 341 max_batch_size=max_batch_size, 342 disable_non_trt_optimizers=self._disable_non_trt_optimizers) 343 else: 344 rewriter_cfg = rewriter_config_pb2.RewriterConfig() 345 if self._disable_non_trt_optimizers: 346 trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg) 347 348 config = config_pb2.ConfigProto( 349 gpu_options=self._GetGPUOptions(), 350 graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg)) 351 return config 352 353 def _GetFeedNames(self): 354 params = self._GetParamsCached() 355 # Construct the feeds tensor names by appending :0 to the node names. 356 return [spec.name + ":0" for spec in params.input_specs] 357 358 def _GetFetchNames(self): 359 params = self._GetParamsCached() 360 # Construct the fetches tensor names by appending :0 to the node names. 361 return [spec.name + ":0" for spec in params.output_specs] 362 363 def _GetFeedDict(self, inputs_data): 364 return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)} 365 366 def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2): 367 """Run given graphdef multiple times using TF 1.x runtime.""" 368 params = self._GetParamsCached() 369 fetches = self._GetFetchNames() 370 g = ops.Graph() 371 with g.as_default(): 372 with self.session(graph=g, config=config, use_gpu=True) as sess: 373 loader.load(sess, [tag_constants.SERVING], saved_model_dir) 374 vals = [] 375 # Run for each input(s) shape 376 for expected_shapes, current_input_data in zip( 377 params.expected_output_dims, inputs_data): 378 val = None 379 for _ in range(num_runs): 380 new_val = sess.run(fetches, self._GetFeedDict(current_input_data)) 381 self.assertEqual(len(expected_shapes), len(new_val)) 382 for expected_shape, actual_val in zip(expected_shapes, new_val): 383 self.assertEqual(list(expected_shape), list(actual_val.shape)) 384 if val is not None: 385 # Some ops may have nondeterministic output. E.g. Conv2D may use 386 # winograd algorithm. So we set atol/rtol be larger than 1.e-06. 387 self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) 388 val = new_val 389 vals.append(val) 390 return vals 391 392 def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2): 393 """Run given graphdef multiple times using TF 2.0 runtime.""" 394 params = self._GetParamsCached() 395 root = load.load(saved_model_dir) 396 func = root.signatures[ 397 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 398 results = [] 399 for expected_shapes, current_input_data in zip(params.expected_output_dims, 400 inputs_data): 401 val = None 402 for _ in range(num_runs): 403 feed_dict = { 404 params.input_specs[i].name: current_input_data[i] 405 for i in range(len(params.input_specs)) 406 } 407 new_val = func(**feed_dict) 408 assert isinstance(new_val, dict) 409 # The key of the output map is always like output_i. 410 new_val = [new_val[key] for key in sorted(new_val)] 411 # Each element is an eager Tensor, and accessing individual elements is 412 # very expensive, so we convert them to a numpy array first. 413 new_val = [v.numpy() for v in new_val] 414 self.assertEqual(len(expected_shapes), len(new_val)) 415 for expected_shape, actual_val in zip(expected_shapes, new_val): 416 self.assertEqual(list(expected_shape), list(actual_val.shape)) 417 if val is not None: 418 # Some ops may have nondeterministic output. E.g. Conv2D may use 419 # winograd algorithm. So we set atol/rtol be larger than 1.e-06. 420 self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05) 421 val = new_val 422 results.append(val) 423 424 return results 425 426 def _RunGraph(self, 427 run_params, 428 saved_model_dir, 429 inputs_data, 430 graph_state, 431 num_runs=2): 432 params = self._GetParamsCached() 433 for data in inputs_data: 434 assert len(params.input_specs) == len(data) 435 436 if run_params.is_v2: 437 results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state, 438 num_runs) 439 gc.collect() # Force GC to destroy the TRT engine cache. 440 return results 441 442 # The default config for tf.session is None. Create a config with 443 # TensorRTOptimizer enabled to support convert_online for inference. 444 config = None 445 # TODO(b/170220818): use the default session config to run inferenence 446 # graphs for the offline conversion case after fixing the bug. 447 if graph_state == GraphState.INFERENCE: 448 config = self._GetConfigProto(run_params, GraphState.INFERENCE) 449 return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs) 450 451 def _CreateConverter(self, run_params, saved_model_dir, conversion_params): 452 """Returns a TrtGraphConverter.""" 453 if run_params.is_v2: 454 converter_v2 = trt_convert.TrtGraphConverterV2( 455 input_saved_model_dir=saved_model_dir, 456 conversion_params=conversion_params) 457 if self._disable_non_trt_optimizers: 458 converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access 459 if not self._use_implicit_batch: 460 converter_v2._test_only_use_implicit_batch = False # pylint: disable=protected-access 461 return converter_v2 462 463 converter_v1 = trt_convert.TrtGraphConverter( 464 input_saved_model_dir=saved_model_dir, 465 max_batch_size=self.GetMaxBatchSize(run_params), 466 max_workspace_size_bytes=conversion_params.max_workspace_size_bytes, 467 precision_mode=conversion_params.precision_mode, 468 minimum_segment_size=conversion_params.minimum_segment_size, 469 is_dynamic_op=run_params.dynamic_engine, 470 maximum_cached_engines=conversion_params.maximum_cached_engines, 471 use_calibration=conversion_params.use_calibration) 472 if self._disable_non_trt_optimizers: 473 converter_v1._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access 474 return converter_v1 475 476 def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data): 477 """Return trt converted graphdef in INT8 mode.""" 478 conversion_params = self.GetConversionParams(run_params) 479 logging.info(conversion_params) 480 assert conversion_params.precision_mode == "INT8" 481 assert run_params.dynamic_engine 482 assert conversion_params.maximum_cached_engines == 1 483 assert conversion_params.use_calibration 484 485 # We only support calibrating single engine. 486 # TODO(aaroey): fix this. 487 assert len(inputs_data) == 1 488 489 converter = self._CreateConverter(run_params, saved_model_dir, 490 conversion_params) 491 int8_gdef = converter.convert() 492 self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef, 493 GraphState.CALIBRATE) 494 495 converter.calibrate( 496 fetch_names=self._GetFetchNames(), 497 num_runs=5, 498 feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0])) 499 trt_saved_model_dir = self._GetSavedModelDir(run_params, 500 GraphState.CALIBRATE) 501 converter.save(trt_saved_model_dir) 502 return trt_saved_model_dir 503 504 def _GetInferGraph(self, run_params, saved_model_dir): 505 """Return trt converted graphdef.""" 506 conversion_params = self.GetConversionParams(run_params) 507 logging.info(conversion_params) 508 509 converter = self._CreateConverter(run_params, saved_model_dir, 510 conversion_params) 511 converter.convert() 512 513 if not self._use_implicit_batch: 514 logging.info("Using build mode") 515 516 def _BuildInputFn(): 517 for shapes in self._GetParamsCached().input_dims: 518 yield [np.zeros(x).astype(np.float32) for x in shapes] 519 520 converter.build(input_fn=_BuildInputFn) 521 522 trt_saved_model_dir = self._GetSavedModelDir(run_params, 523 GraphState.INFERENCE) 524 converter.save(trt_saved_model_dir) 525 return trt_saved_model_dir 526 527 def _GetGraphStateLabel(self, graph_state): 528 if graph_state == GraphState.ORIGINAL: 529 return "Original" 530 elif graph_state == GraphState.CALIBRATE: 531 return "CalibEngine" 532 elif graph_state == GraphState.INFERENCE: 533 return "InferEngine" 534 else: 535 return "UnknownState" 536 537 def _WriteGraph(self, run_params, gdef, graph_state): 538 temp_dir = os.getenv("TRT_TEST_TMPDIR") 539 if not temp_dir: 540 return 541 542 graph_name = ( 543 self.__class__.__name__ + "_" + run_params.test_name + "_" + 544 self._GetGraphStateLabel(graph_state) + ".pbtxt") 545 logging.info("Writing graph to %s/%s", temp_dir, graph_name) 546 graph_io.write_graph(gdef, temp_dir, graph_name) 547 548 # Removes the prefix(s) of function name(s). 549 # The input value can be a string or a sequence of string. 550 def _Canonicalize(self, value): 551 if isinstance(value, str): 552 return self._ToString(value.split("/")[-1]) 553 elif isinstance(value, collections.abc.Iterable): 554 return set(self._Canonicalize(nm) for nm in value) 555 else: 556 raise TypeError( 557 "'_Canonicalize' can only be used on strings or sequence of strings!") 558 559 # Removes the graph sequence number prefix from the name(s) only if the 560 # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts 561 # such a prefix exists. 562 # The input value can be a string or a sequence of string. 563 def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix): 564 if isinstance(value, str): 565 match = re.search(r"TRTEngineOp_\d+_", value) 566 has_prefix = match and value.startswith(match.group(0)) 567 assert (not expecting_prefix) or has_prefix 568 if has_prefix: 569 parts = value.split("_", maxsplit=2) 570 assert len(parts) == 3 571 return parts[0] + "_" + parts[2] 572 return value 573 elif isinstance(value, collections.abc.Iterable): 574 return set( 575 self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix) 576 for nm in value) 577 else: 578 raise TypeError( 579 "'_RemoveGraphSequenceNumberImpl' can only be used on strings " 580 "or sequence of strings!") 581 582 def _RemoveGraphSequenceNumber(self, name): 583 return self._RemoveGraphSequenceNumberImpl(name, True) 584 585 def _MayRemoveGraphSequenceNumber(self, name): 586 return self._RemoveGraphSequenceNumberImpl(name, False) 587 588 def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef): 589 old_to_new_node_map = { 590 self._ToString(node.name): self._ToString(node.name) 591 for node in original_gdef.node 592 } 593 for engine_name, node_names in expected_engines.items(): 594 for node_name in node_names: 595 old_to_new_node_map[node_name] = engine_name 596 name_to_node_map = { 597 self._ToString(node.name): node for node in original_gdef.node 598 } 599 600 def _InputName(inp): 601 inp = self._ToString(inp) 602 prefix = "" 603 if inp[0] == "^": 604 prefix = "^" 605 inp = inp[1:] 606 parts = inp.split(":") 607 if len(parts) > 1 and parts[-1].isdigit(): 608 inp = inp[:-len(parts[-1]) - 1] 609 return (prefix, inp) 610 611 # Compute the expected mapping from each node to its input nodes. 612 expected_input_map = {} 613 removed_const_nodes = set([ 614 self._ToString(node.name) 615 for node in original_gdef.node 616 if node.op == "Const" 617 ]) 618 for node in original_gdef.node: 619 name_str = self._ToString(node.name) 620 target_node_name = old_to_new_node_map[name_str] 621 is_engine_op = (target_node_name != name_str) 622 if target_node_name not in expected_input_map: 623 expected_input_map[target_node_name] = set() 624 input_set = expected_input_map[target_node_name] 625 for inp in node.input: 626 (prefix, inp_name) = _InputName(inp) 627 mapped_input = old_to_new_node_map[inp_name] 628 # Add the input only if it's outside the segment (note that it could be 629 # in a different engine). 630 if not is_engine_op or (mapped_input != target_node_name and 631 name_to_node_map[inp_name].op != "Const"): 632 input_set.add(prefix + mapped_input) 633 if mapped_input in removed_const_nodes: 634 removed_const_nodes.remove(mapped_input) 635 # Remove const nodes that have no outputs. 636 expected_input_map = { 637 k: v 638 for k, v in expected_input_map.items() 639 if k not in removed_const_nodes 640 } 641 642 # Compute the actual mapping from each node to its input nodes. If a cast 643 # op doesn't exist in the original graph, we replace the use of the cast op 644 # with the input of the op. This allows the verification to handle the case 645 # where the TF-TRT bridge splits a cast op into a chain of two cast ops. 646 new_cast_op_name_to_node_map = { 647 node.name: node 648 for node in converted_gdef.node 649 if (node.name not in old_to_new_node_map and node.op == "Cast") 650 } 651 actual_input_map = {} 652 for node in converted_gdef.node: 653 name_str = node.name 654 # Only nodes from the original graph or TRTEngineOp nodes are added as 655 # keys to the map. 656 if node.op == "TRTEngineOp": 657 name_str = self._RemoveGraphSequenceNumber(name_str) 658 elif name_str not in old_to_new_node_map: 659 continue 660 actual_input_map[name_str] = set() 661 input_set = actual_input_map[name_str] 662 for inp in node.input: 663 (prefix, node_name) = _InputName(inp) 664 node_name = self._MayRemoveGraphSequenceNumber(node_name) 665 if node_name in new_cast_op_name_to_node_map: 666 (prefix, node_name) = _InputName( 667 new_cast_op_name_to_node_map[node_name].input[0]) 668 input_set.add(prefix + node_name) 669 670 self.assertEqual( 671 expected_input_map, 672 actual_input_map, 673 msg="\nexpected:\n%s\nvs actual:\n%s" % 674 (sorted(expected_input_map.items()), sorted(actual_input_map.items()))) 675 676 def _VerifyMaxBatchSizeAnnotations( 677 self, 678 expected_engines, 679 original_gdef, 680 converted_gdef, 681 default_max_batch_size, 682 expected_max_batch_sizes=None, 683 ): 684 """Verifies the max batch size annotations in the original and converted GraphDef. 685 686 Args: 687 expected_engines: A sequence of engines names. 688 original_gdef: GraphDef. The graph def before TensorRT conversion. 689 converted_gdef: GraphDef. The graph def after TensorRT conversion. 690 default_max_batch_size: The default maximum batch size to use if no node 691 inside a segment is annoted with a customized max batch size. This value 692 is None when the graph is converted to TF-TRT with dynamic engines. 693 expected_max_batch_sizes: Optional. A sequence of max batch sizes for all 694 the engines. `None` if does not check enforce max batch sizes. 695 """ 696 if isinstance(expected_max_batch_sizes, collections.abc.Collection): 697 self.assertEqual(len(expected_max_batch_sizes), len(expected_engines)) 698 else: 699 self.assertIsNone( 700 expected_max_batch_sizes, 701 "'expected_max_batch_sizes' shall only be a sequence " 702 "of integers or `None`.") 703 704 def _ChainAllNodes(graph_def): 705 return itertools.chain( 706 graph_def.node, 707 itertools.chain( 708 *[func.node_def for func in graph_def.library.function])) 709 710 old_name_to_node_map = { 711 self._ToString(node.name): node 712 for node in _ChainAllNodes(original_gdef) 713 } 714 new_name_to_func_map = { 715 self._ToString(func.signature.name): func 716 for func in converted_gdef.library.function 717 } 718 719 def _DetectStaticBatchSize(node_def): 720 """Returns the static batch size of an operation or None. 721 722 It is incorrect to use the output shapes to find the batch size of an 723 operation, as the segmenter actually uses the input shapes. However, it is 724 a simplication and works for most of the cases for the test purposes. 725 726 Args: 727 node_def: `tf.NodeDef`. The target node for analysis. 728 729 Returns: 730 If all the outputs of the node have the same static batch size, returns 731 the int value for the batch size. Otherwise returns None. 732 """ 733 shapes = node_def.attr["_output_shapes"].list.shape 734 batch_size = set( 735 list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes) 736 if len(batch_size) == 1 and list(batch_size)[0] >= 1: 737 return list(batch_size)[0] 738 return None 739 740 name_to_engines_map = {} 741 actual_max_batch_sizes = [] 742 for node in _ChainAllNodes(converted_gdef): 743 if node.op == "TRTEngineOp": 744 engine = node 745 engine_name = self._RemoveGraphSequenceNumber( 746 self._Canonicalize(self._ToString(engine.name))) 747 self.assertIn(engine_name, expected_engines) 748 name_to_engines_map[engine_name] = engine 749 # The input nodes shall not have the conflicting annotation (no 750 # annotation or the same annotation) with the maximum batch size 751 # annotation. If the engine has maximum batch size annotation as the 752 # non-default maximum batch size, then at least one input node shall 753 # have the same annotation to be the source. 754 self.assertIn("max_batch_size", node.attr) 755 engine_max_batch_size = node.attr["max_batch_size"].i 756 self.assertIsInstance(engine_max_batch_size, int) 757 actual_max_batch_sizes.append(engine_max_batch_size) 758 seg_func = node.attr["segment_func"].func 759 self.assertIsNotNone(seg_func) 760 self.assertIn(seg_func.name, new_name_to_func_map) 761 seg_func_def = new_name_to_func_map[seg_func.name] 762 logging.info("Segment function name: %s. Including %d nodes.", 763 seg_func.name, len(seg_func_def.node_def)) 764 node_max_batch_size_all_none = True 765 # Use the native segment to search for replaced nodes 766 for alternative_node in seg_func_def.node_def: 767 node_name = self._Canonicalize(self._ToString(alternative_node.name)) 768 if node_name not in old_name_to_node_map: 769 continue 770 original_node = old_name_to_node_map[node_name] 771 node_max_batch_size = None 772 if "_tftrt_op_max_batch_size" in original_node.attr: 773 node_max_batch_size = original_node.attr[ 774 "_tftrt_op_max_batch_size"].i 775 elif (original_node.op != "Const" and 776 alternative_node.op != "Const" and 777 "_output_shapes" in original_node.attr): 778 node_max_batch_size = _DetectStaticBatchSize(original_node) 779 logging.info( 780 "'{%s}(%s)'s max batch size annotation is %s. " 781 "'{%s}'s max batch size is %s.", node_name, original_node.op, 782 str(node_max_batch_size), engine_name, str(engine_max_batch_size)) 783 node_max_batch_size_all_none &= node_max_batch_size is None 784 self.assertTrue(engine_max_batch_size == node_max_batch_size or 785 node_max_batch_size is None) 786 logging.info("'{%s}'s max batch size is %d.", engine_name, 787 engine_max_batch_size) 788 self.assertTrue(default_max_batch_size is None or 789 engine_max_batch_size == default_max_batch_size or 790 not node_max_batch_size_all_none) 791 792 self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys())) 793 if expected_max_batch_sizes is not None: 794 self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes) 795 796 def _GetGraphDef(self, run_params, gdef_or_saved_model_dir): 797 if isinstance(gdef_or_saved_model_dir, str): 798 if run_params.is_v2: 799 root = load.load(gdef_or_saved_model_dir) 800 func = root.signatures[ 801 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 802 gdef = func.graph.as_graph_def() 803 # Manually unref the loaded saved model and force GC to destroy the TRT 804 # engine cache after load(). There is currently a reference cycle in 2.0 805 # which prevents auto deletion of the resource. 806 # TODO(laigd): fix this. 807 del func 808 del root 809 gc.collect() 810 return gdef 811 return saved_model_utils.get_meta_graph_def( 812 gdef_or_saved_model_dir, tag_constants.SERVING).graph_def 813 assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef) 814 return gdef_or_saved_model_dir 815 816 def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify, 817 graph_state): 818 expected_engines = self.ExpectedEnginesToBuild(run_params) 819 num_engines = 0 820 functions = [f.signature.name for f in gdef_to_verify.library.function] 821 for node in gdef_to_verify.node: 822 if node.op == "TRTEngineOp": 823 logging.info("Found TRTEngineOp: " + node.name) 824 num_engines += 1 825 segment_funcdef_name = node.attr["segment_func"].func.name 826 function_name = node.name + "_native_segment" 827 is_dynamic_engine = not node.attr["static_engine"].b 828 self.assertNotEmpty(segment_funcdef_name, node.name) 829 self.assertIn(function_name, functions) 830 if not IsQuantizationWithCalibration and not is_dynamic_engine: 831 self.assertTrue(len(node.attr["serialized_segment"].s), node.name) 832 self.assertIn( 833 self._RemoveGraphSequenceNumber(node.name), expected_engines) 834 self.assertEqual( 835 self._ToBytes(run_params.precision_mode), 836 node.attr["precision_mode"].s, node.name) 837 838 self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, 839 node.name) 840 self.assertEqual(node.attr["use_calibration"].b, 841 run_params.use_calibration, node.name) 842 843 has_calibration_data = len(node.attr["calibration_data"].s) 844 if (IsQuantizationWithCalibration(run_params) and 845 graph_state == GraphState.INFERENCE): 846 self.assertTrue(has_calibration_data, node.name) 847 else: 848 self.assertFalse(has_calibration_data, node.name) 849 if graph_state == GraphState.ORIGINAL: 850 self.assertEqual(0, num_engines) 851 else: 852 self.assertEqual(num_engines, len(expected_engines)) 853 if isinstance(expected_engines, dict): 854 self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify) 855 self._VerifyMaxBatchSizeAnnotations( 856 expected_engines=expected_engines, 857 original_gdef=original_gdef, 858 converted_gdef=gdef_to_verify, 859 expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params), 860 default_max_batch_size=self.GetMaxBatchSize(run_params)) 861 862 def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify, 863 graph_state): 864 if graph_state == GraphState.ORIGINAL: 865 return 866 expected_engines = self.ExpectedEnginesToBuild(run_params) 867 all_op_names = [node.name for node in gdef_to_verify.node] 868 trt_op_names = [ 869 node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp" 870 ] 871 for func in gdef_to_verify.library.function: 872 if not re.search(r"TRTEngineOp_\d+_\d+_native_segment", 873 func.signature.name): 874 for node in func.node_def: 875 all_op_names.append(node.name) 876 if node.op == "TRTEngineOp": 877 trt_op_names.append(node.name) 878 879 all_op_names = self._Canonicalize(all_op_names) 880 trt_op_names = self._RemoveGraphSequenceNumber( 881 self._Canonicalize(trt_op_names)) 882 883 if isinstance(expected_engines, dict): 884 # For simplicity we don't verify the connections inside the engine in 885 # 2.0, but we still make sure that the converted ops are gone from the 886 # graph. 887 unexpected_names = set(nest.flatten(expected_engines.values())) 888 self.assertEmpty( 889 [name for name in unexpected_names if name in all_op_names]) 890 expected_engines = set(expected_engines.keys()) 891 892 self.assertEqual(set(expected_engines), trt_op_names) 893 894 def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir, 895 gdef_or_saved_model_dir_to_verify, graph_state): 896 original_gdef = self._GetGraphDef(run_params, 897 original_gdef_or_saved_model_dir) 898 gdef_to_verify = self._GetGraphDef(run_params, 899 gdef_or_saved_model_dir_to_verify) 900 self._WriteGraph(run_params, gdef_to_verify, graph_state) 901 if run_params.is_v2: 902 self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify, 903 graph_state) 904 else: 905 self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify, 906 graph_state) 907 908 def _GetSavedModelDir(self, run_params, graph_state): 909 test_tmpdir = os.getenv("TRT_TEST_TMPDIR") 910 if test_tmpdir: 911 saved_model_dir = os.path.join( 912 test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name + 913 "_" + self._GetGraphStateLabel(graph_state)) 914 try: 915 # For TF 1.x we need to make sure the output directory doesn't exist 916 # before exporting the saved model. 917 shutil.rmtree(saved_model_dir) 918 except OSError as e: 919 if e.errno != errno.ENOENT: 920 raise 921 return saved_model_dir 922 return tempfile.mkdtemp(dir=self.get_temp_dir()) 923 924 def _MakeSavedModelV1(self, run_params): 925 """Write the saved model as an input for testing.""" 926 params = self._GetParamsCached() 927 g = ops.Graph() 928 with g.as_default(): 929 inputs = [] 930 for spec in params.input_specs: 931 inp = array_ops.placeholder( 932 dtype=spec.dtype, shape=spec.shape, name=spec.name) 933 inputs.append(inp) 934 outputs = params.graph_fn(*inputs) 935 if not isinstance(outputs, list) and not isinstance(outputs, tuple): 936 outputs = [outputs] 937 938 signature_def = signature_def_utils.build_signature_def( 939 inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs}, 940 outputs={out.op.name: utils.build_tensor_info(out) for out in outputs}, 941 method_name=signature_constants.PREDICT_METHOD_NAME) 942 943 saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) 944 saved_model_builder = builder.SavedModelBuilder(saved_model_dir) 945 with self.session( 946 graph=g, config=self._GetConfigProto(run_params, 947 GraphState.ORIGINAL)) as sess: 948 saved_model_builder.add_meta_graph_and_variables( 949 sess, [tag_constants.SERVING], 950 signature_def_map={ 951 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 952 signature_def 953 }) 954 saved_model_builder.save() 955 return saved_model_dir 956 957 def _MakeSavedModelV2(self, run_params): 958 params = self._GetParamsCached() 959 root = tracking.AutoTrackable() 960 root.run = def_function.function( 961 params.graph_fn, input_signature=params.input_specs) 962 saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL) 963 logging.info("Saving input SavedModel to %s", saved_model_dir) 964 save.save(root, saved_model_dir, 965 {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run}) 966 return saved_model_dir 967 968 def _MakeSavedModel(self, run_params): 969 if run_params.is_v2: 970 return self._MakeSavedModelV2(run_params) 971 return self._MakeSavedModelV1(run_params) 972 973 def RunTest(self, run_params): 974 with trace.Trace(run_params.test_name): 975 should_run, reason_for_skipping = self.ShouldRunTest(run_params) 976 if not should_run: 977 return self.skipTest(reason_for_skipping) 978 979 saved_model_dir = self._MakeSavedModel(run_params) 980 981 np.random.seed(12345) # Fix the seed so the test is deterministic. 982 inputs_data = [] 983 input_specs = self._GetParamsCached().input_specs 984 for dim_list in self._GetParamsCached().input_dims: 985 assert len(input_specs) == len(dim_list) 986 current_input_data = [] 987 for spec, np_shape in zip(input_specs, dim_list): 988 np_dtype = spec.dtype.as_numpy_dtype() 989 # Multiply the input by some constant to avoid all zeros input for 990 # integer types. 991 scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0 992 # TODO(laigd): add debug options. E.g. we can set the input data to be 993 # continuous natural numbers: 994 # seq = np.arange(np.prod(np_shape)) 995 # seq.resize(np_shape) 996 # current_inputs_data.append(scale * seq.astype(np_dtype)) 997 data = (scale * np.random.random_sample(np_shape)).astype(np_dtype) 998 if run_params.is_v2: 999 with ops.device("/GPU:0"): 1000 data = ops.convert_to_tensor(data) 1001 current_input_data.append(data) 1002 inputs_data.append(current_input_data) 1003 1004 # Verify the original graph. 1005 self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir, 1006 GraphState.ORIGINAL) 1007 1008 # Run the original graph without TensorRT to get the reference result. 1009 logging.info("Running original graph w/o TensorRT\n") 1010 ref_result = self._RunGraph( 1011 run_params, 1012 saved_model_dir, 1013 inputs_data, 1014 GraphState.ORIGINAL, 1015 num_runs=1) 1016 1017 # Run calibration if necessary. 1018 if IsQuantizationWithCalibration(run_params): 1019 infer_saved_model_dir = self._GetCalibratedInferGraph( 1020 run_params, saved_model_dir, inputs_data) 1021 self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, 1022 GraphState.INFERENCE) 1023 elif not run_params.convert_online: 1024 infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir) 1025 self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, 1026 GraphState.INFERENCE) 1027 else: 1028 infer_saved_model_dir = saved_model_dir 1029 1030 # Run the inference graph, either using the converted graph or the 1031 # original graph with convert_online == True. 1032 logging.info("Running final inference graph\n") 1033 result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data, 1034 GraphState.INFERENCE) 1035 self.assertAllClose( 1036 ref_result, 1037 result, 1038 atol=self.ExpectedAbsoluteTolerance(run_params), 1039 rtol=self.ExpectedRelativeTolerance(run_params)) 1040 1041 def testIdempotence(self): 1042 # Test that applying tensorrt optimizer or offline conversion tools multiple 1043 # times to the same graph will result in same graph. 1044 # 1045 # TODO(aaroey): implement this. 1046 pass 1047 1048 1049def _GetTestConfigsV1(): 1050 """Returns the config combinations to run the test.""" 1051 convert_online, convert_offline = True, False 1052 dynamic_engine, static_engine = True, False 1053 use_calibration, no_calibration = True, False 1054 1055 # Add all possible test cases and let the derived test class to decide 1056 # whether to run specific ones with ShouldRunTest(). 1057 # 1058 # Note: INT8 without calibration behaves like FP32/FP16. 1059 opts = list( 1060 itertools.product([FP32, FP16, INT8], [convert_online, convert_offline], 1061 [dynamic_engine, static_engine], [no_calibration])) 1062 # We always run calibration with offline tool. 1063 # TODO(aaroey): static calibration engine is not supported yet. 1064 opts.append((INT8, convert_offline, dynamic_engine, use_calibration)) 1065 return opts 1066 1067 1068def _GetTestConfigsV2(): 1069 """Returns the config combinations to run the test.""" 1070 convert_offline = False 1071 # TODO(laigd): add support for static_engine. 1072 dynamic_engine = True 1073 # TODO(laigd): add support for calibration. 1074 no_calibration = False 1075 1076 # Add all possible test cases and let the derived test class to decide 1077 # whether to run specific ones with ShouldRunTest(). 1078 # 1079 # Note: 1080 # - In TF2.0 the conversion always produce dynamic engine, and we don't test 1081 # the offline mode here. 1082 # - For simplicity we don't test online conversion which requires setting the 1083 # Grappler config in default eager context. 1084 # - INT8 without calibration behaves like FP32/FP16. 1085 opts = list( 1086 itertools.product([FP32, FP16, INT8], [convert_offline], [dynamic_engine], 1087 [no_calibration])) 1088 # We always run calibration with offline tool. 1089 # TODO(aaroey): INT8+calibration is not supported yet in V2. 1090 # opts.append((INT8, convert_offline, dynamic_engine, use_calibration)) 1091 return opts 1092 1093 1094def _GetTest(run_params): 1095 """Gets a single test method based on the parameters.""" 1096 1097 def _Test(self): 1098 logging.info( 1099 "Running test %s with parameters: convert_online=%s, " 1100 "precision_mode=%s, dynamic_engine=%s", run_params.test_name, 1101 run_params.convert_online, run_params.precision_mode, 1102 run_params.dynamic_engine) 1103 self.RunTest(run_params) 1104 1105 return _Test 1106 1107 1108def _AddTestsFor(test_class, is_v2): 1109 """Adds test methods to TfTrtIntegrationTestBase for specific TF version.""" 1110 opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1() 1111 for (precision_mode, convert_online, dynamic_engine, use_calibration) in opts: 1112 conversion = "OnlineConversion" if convert_online else "OfflineConversion" 1113 engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" 1114 calibration_type = "UseCalibration" if use_calibration else "NoCalibration" 1115 test_name = "%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt", 1116 conversion, engine_type, precision_mode, 1117 calibration_type) 1118 run_params = RunParams( 1119 convert_online=convert_online, 1120 precision_mode=precision_mode, 1121 dynamic_engine=dynamic_engine, 1122 test_name=test_name, 1123 use_calibration=use_calibration, 1124 is_v2=is_v2) 1125 if is_v2: 1126 setattr(test_class, test_name, 1127 test_util.run_v2_only(_GetTest(run_params))) 1128 else: 1129 setattr(test_class, test_name, 1130 test_util.run_v1_only("", _GetTest(run_params))) 1131 1132 1133def _AddTests(test_class): 1134 """Adds test methods to TfTrtIntegrationTestBase.""" 1135 _AddTestsFor(test_class, is_v2=False) 1136 _AddTestsFor(test_class, is_v2=True) 1137 1138 1139if is_tensorrt_enabled(): 1140 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" 1141 _AddTests(TfTrtIntegrationTestBase) 1142