1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gc 22import os 23import re 24import tempfile 25 26from absl.testing import parameterized 27import numpy as np 28 29from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled 30from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member 31from tensorflow.core.framework import graph_pb2 32from tensorflow.core.protobuf import config_pb2 33from tensorflow.python.compiler.tensorrt import trt_convert 34from tensorflow.python.eager import def_function 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import graph_util 38from tensorflow.python.framework import importer 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.framework import tensor_spec 42from tensorflow.python.framework import test_util 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import gen_resource_variable_ops 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import test 47from tensorflow.python.saved_model import builder 48from tensorflow.python.saved_model import load 49from tensorflow.python.saved_model import loader 50from tensorflow.python.saved_model import loader_impl 51from tensorflow.python.saved_model import save 52from tensorflow.python.saved_model import signature_constants 53from tensorflow.python.saved_model import signature_def_utils 54from tensorflow.python.saved_model import tag_constants 55from tensorflow.python.saved_model import utils 56from tensorflow.python.tools import saved_model_utils 57from tensorflow.python.training.tracking import tracking 58from tensorflow.python.util.lazy_loader import LazyLoader 59 60_SAVED_MODEL_SIGNATURE_KEY = "mypredict" 61 62gen_trt_ops = LazyLoader( 63 "gen_trt_ops", globals(), 64 "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops") 65 66 67class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): 68 """Class to test Tensorflow-TensorRT integration python API.""" 69 70 # Use a small max_workspace_size for tests so they don't consume too much GPU 71 # memory. 72 _TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20 73 74 def mkdtemp(self): 75 return tempfile.mkdtemp(dir=self.get_temp_dir()) 76 77 def testTRTEngineInstanceAvailable(self): 78 # test if we can access the TRTEngineInstance protobuf 79 assert hasattr(TRTEngineInstance(), "serialized_engine") 80 81 def _GetConfigProto(self, rewriter_config=None): 82 """Get ConfigProto for session creation.""" 83 config = config_pb2.ConfigProto( 84 gpu_options=config_pb2.GPUOptions(allow_growth=True)) 85 if rewriter_config: 86 config.graph_options.rewrite_options.CopyFrom(rewriter_config) 87 return config 88 89 @classmethod 90 def _GetGraph(cls, inp1, inp2, var): 91 """Get the graph for testing.""" 92 # The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var 93 add = inp1 + var 94 mul = inp1 * add 95 add = mul + add 96 add = add + inp2 97 out = array_ops.identity(add, name="output") 98 return out 99 100 def _GetModelForV2(self): 101 102 class SimpleModel(tracking.AutoTrackable): 103 104 def __init__(self): 105 self.v = None 106 107 @def_function.function(input_signature=[ 108 tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32), 109 tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32) 110 ]) 111 def run(self, inp1, inp2): 112 if self.v is None: 113 self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32) 114 return TrtConvertTest._GetGraph(inp1, inp2, self.v) 115 116 return SimpleModel() 117 118 def _GetGraphForV1(self, device): 119 120 def _GraphFn(): 121 inp1 = array_ops.placeholder( 122 dtype=dtypes.float32, shape=[None, 1, 1], name="input1") 123 inp2 = array_ops.placeholder( 124 dtype=dtypes.float32, shape=[None, 1, 1], name="input2") 125 var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1") 126 out = TrtConvertTest._GetGraph(inp1, inp2, var) 127 return g, var, inp1, inp2, out 128 129 g = ops.Graph() 130 with g.as_default(): 131 if device: 132 with g.device(device): 133 return _GraphFn() 134 return _GraphFn() 135 136 def _GetGraphDefForV1(self, device): 137 """Get the graph def for testing.""" 138 g, var, _, _, _ = self._GetGraphForV1(device) 139 with self.session(graph=g, config=self._GetConfigProto()) as sess: 140 sess.run(var.initializer) 141 graph_def = graph_util.convert_variables_to_constants( 142 sess, g.as_graph_def(add_shapes=True), ["output"]) 143 node_name_to_op = {node.name: node.op for node in graph_def.node} 144 self.assertEqual( 145 { 146 "v1": "Const", 147 "add/ReadVariableOp": "Identity", 148 "input1": "Placeholder", 149 "input2": "Placeholder", 150 "add": "AddV2", 151 "mul": "Mul", 152 "add_1": "AddV2", 153 "add_2": "AddV2", 154 "output": "Identity" 155 }, node_name_to_op) 156 return graph_def 157 158 def _WriteInputSavedModelForV1(self, input_saved_model_dir, device): 159 """Write the saved model as an input for testing.""" 160 g, var, inp1, inp2, out = self._GetGraphForV1(device) 161 signature_def = signature_def_utils.build_signature_def( 162 inputs={ 163 "myinput1": utils.build_tensor_info(inp1), 164 "myinput2": utils.build_tensor_info(inp2) 165 }, 166 outputs={"myoutput": utils.build_tensor_info(out)}, 167 method_name=signature_constants.PREDICT_METHOD_NAME) 168 saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) 169 with self.session(graph=g, config=self._GetConfigProto()) as sess: 170 sess.run(var.initializer) 171 saved_model_builder.add_meta_graph_and_variables( 172 sess, [tag_constants.SERVING], 173 signature_def_map={_SAVED_MODEL_SIGNATURE_KEY: signature_def}) 174 saved_model_builder.save() 175 176 def _ConvertGraphV1(self, 177 output_saved_model_dir=None, 178 need_calibration=False, 179 max_batch_size=1, 180 minimum_segment_size=3, 181 is_dynamic_op=False, 182 maximum_cached_engines=1, 183 device=None): 184 """Helper method to convert a GraphDef or SavedModel using TF-TRT.""" 185 input_saved_model_dir = None 186 if output_saved_model_dir: 187 input_saved_model_dir = self.mkdtemp() 188 self._WriteInputSavedModelForV1(input_saved_model_dir, device) 189 190 # Calibration requires dynamic_op. 191 if need_calibration: 192 is_dynamic_op = True 193 194 # For dynamic_op, the converter requires the unused max_batch_size=None. 195 if is_dynamic_op: 196 max_batch_size = None 197 198 converter = trt_convert.TrtGraphConverter( 199 input_saved_model_dir=input_saved_model_dir, 200 input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, 201 input_graph_def=None 202 if input_saved_model_dir else self._GetGraphDefForV1(device), 203 nodes_denylist=None if input_saved_model_dir else ["output"], 204 max_batch_size=max_batch_size, 205 max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, 206 precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration 207 else trt_convert.TrtPrecisionMode.FP32), 208 minimum_segment_size=minimum_segment_size, 209 is_dynamic_op=is_dynamic_op, 210 maximum_cached_engines=maximum_cached_engines) 211 output_graph_def = converter.convert() 212 213 if need_calibration: 214 215 class CalibrationData(object): 216 217 def __init__(self): 218 self._data = 0 219 220 def next(self): 221 self._data += 1 222 return {"input1:0": [[[self._data]]], "input2:0": [[[self._data]]]} 223 224 output_graph_def = converter.calibrate( 225 fetch_names=["output:0"], 226 num_runs=10, 227 feed_dict_fn=CalibrationData().next) 228 229 if output_saved_model_dir is not None: 230 converter.save(output_saved_model_dir=output_saved_model_dir) 231 return output_graph_def 232 233 # Remove the graph sequence number prefix from the name only if the name has 234 # a prefix TRTEngineOp_n_. 235 def _MayRemoveGraphSequenceNumber(self, name): 236 prefix = re.search(r"TRTEngineOp_\d+_", name) 237 if prefix and name.startswith(prefix.group(0)): 238 parts = name.split("_", maxsplit=2) 239 assert len(parts) == 3 240 return parts[0] + "_" + parts[2] 241 return name 242 243 # Return the unique TRTEngineOp in the given graph def. 244 def _GetUniqueTRTEngineOp(self, graph_def): 245 trt_engine_nodes = [ 246 node for node in graph_def.node if node.op == "TRTEngineOp" 247 ] 248 assert len(trt_engine_nodes) == 1 249 return trt_engine_nodes[0] 250 251 def _TestTrtGraphConverter(self, 252 device, 253 output_saved_model_dir=None, 254 need_calibration=False, 255 is_dynamic_op=False): 256 """General method to test trt_convert.TrtGraphConverter().""" 257 output_graph_def = self._ConvertGraphV1( 258 output_saved_model_dir=output_saved_model_dir, 259 need_calibration=need_calibration, 260 is_dynamic_op=is_dynamic_op, 261 device=device) 262 graph_defs_to_verify = [output_graph_def] 263 264 if output_saved_model_dir: 265 saved_model_graph_def = saved_model_utils.get_meta_graph_def( 266 output_saved_model_dir, tag_constants.SERVING).graph_def 267 self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef) 268 graph_defs_to_verify.append(saved_model_graph_def) 269 270 for graph_def in graph_defs_to_verify: 271 node_name_to_op = { 272 self._MayRemoveGraphSequenceNumber(node.name): node.op 273 for node in graph_def.node 274 } 275 if device is not None and device.startswith("/CPU:"): 276 self.assertEqual( 277 { 278 "add": "AddV2", 279 "add/ReadVariableOp": "Const", 280 "add_1": "AddV2", 281 "add_2": "AddV2", 282 "input1": "Placeholder", 283 "input2": "Placeholder", 284 "mul": "Mul", 285 "output": "Identity" 286 }, node_name_to_op) 287 else: 288 self.assertEqual( 289 { 290 "input1": "Placeholder", 291 "input2": "Placeholder", 292 "TRTEngineOp_0": "TRTEngineOp", 293 "output": "Identity" 294 }, node_name_to_op) 295 296 if need_calibration: 297 trt_engine_nodes = [ 298 node for node in graph_def.node if node.op == "TRTEngineOp" 299 ] 300 if device is not None and device.startswith("/CPU:"): 301 self.assertEmpty(trt_engine_nodes) 302 return 303 304 self.assertNotEmpty(trt_engine_nodes) 305 for node in trt_engine_nodes: 306 self.assertTrue(len(node.attr["calibration_data"].s)) 307 # Run the calibrated graph. 308 # TODO(laigd): consider having some input where the answer is different. 309 with ops.Graph().as_default(): 310 importer.import_graph_def(graph_def, name="") 311 with self.session(config=self._GetConfigProto()) as sess: 312 for test_data in range(10): 313 self.assertEqual((test_data + 1.0)**2 + test_data, 314 sess.run( 315 "output:0", 316 feed_dict={ 317 "input1:0": [[[test_data]]], 318 "input2:0": [[[test_data]]] 319 })) 320 321 @parameterized.named_parameters([ 322 ("NoDeviceAssignment", None), 323 ("GPU", "/GPU:0"), 324 ("CPU", "/CPU:0"), 325 ]) 326 @test_util.deprecated_graph_mode_only 327 def testTrtGraphConverter_OfflineConversion(self, device): 328 """Test case for trt_convert.TrtGraphConverter().""" 329 if not is_tensorrt_enabled(): 330 return 331 332 for need_calibration in [False, True]: 333 # Use GraphDef as input. 334 self._TestTrtGraphConverter(device) 335 336 # Use SavedModel as input. 337 self._TestTrtGraphConverter( 338 device, 339 output_saved_model_dir=self.mkdtemp(), 340 need_calibration=need_calibration) 341 342 @parameterized.named_parameters([ 343 ("NoDeviceAssignment", None), 344 ("GPU", "/device:GPU:0"), 345 ("CPU", "/device:CPU:0"), 346 ]) 347 @test_util.deprecated_graph_mode_only 348 def testTrtGraphConverter_OnlineConversion(self, device): 349 """Test case for TF-TRT conversion using Grappler directly.""" 350 if not is_tensorrt_enabled(): 351 return 352 353 conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( 354 precision_mode=trt_convert.TrtPrecisionMode.FP32) 355 config = self._GetConfigProto( 356 rewriter_config=trt_convert.get_tensorrt_rewriter_config( 357 conversion_params, 358 is_dynamic_op=False, 359 max_batch_size=1, 360 is_v2=False)) 361 362 with ops.Graph().as_default(): 363 # Online conversion requires a frozen graph, so we reuse inp1 as the var 364 # argument. 365 inp1 = array_ops.placeholder( 366 dtype=dtypes.float32, shape=[None, 1, 1], name="input1") 367 inp2 = array_ops.placeholder( 368 dtype=dtypes.float32, shape=[None, 1, 1], name="input2") 369 if device: 370 with ops.device(device): 371 TrtConvertTest._GetGraph(inp1, inp2, inp1) 372 else: 373 TrtConvertTest._GetGraph(inp1, inp2, inp1) 374 with self.session(config=config) as sess: 375 self._TestRun(sess, batch_size=1) 376 377 def _CreateConverterV2( 378 self, 379 input_saved_model_dir, 380 input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, 381 max_workspace_size_bytes=10 << 20, # Use a smaller workspace. 382 precision_mode=trt_convert.TrtPrecisionMode.FP32, 383 maximum_cached_engines=2): 384 return trt_convert.TrtGraphConverterV2( 385 input_saved_model_dir=input_saved_model_dir, 386 input_saved_model_signature_key=input_saved_model_signature_key, 387 conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( 388 max_workspace_size_bytes=max_workspace_size_bytes, 389 precision_mode=precision_mode, 390 maximum_cached_engines=maximum_cached_engines)) 391 392 def _CheckTrtOps(self, concrete_func, check_fn=None): 393 graph_def = concrete_func.graph.as_graph_def() 394 trt_op_names = [] 395 for node in graph_def.node: 396 if node.op == "TRTEngineOp": 397 trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) 398 if check_fn: 399 check_fn(node) 400 for func in graph_def.library.function: 401 for node in func.node_def: 402 if node.op == "TRTEngineOp": 403 trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name)) 404 if check_fn: 405 check_fn(node) 406 self.assertEqual(1, len(trt_op_names)) 407 self.assertIn("TRTEngineOp_0", trt_op_names[0]) 408 409 def _RandomInput(self, shape, dtype=np.float32): 410 inp1 = np.random.random_sample(shape).astype(dtype) 411 inp2 = np.random.random_sample(shape).astype(dtype) 412 return inp1, inp2 413 414 @test_util.run_v2_only 415 def testTrtGraphConverter_DynamicConversion_v2(self): 416 """Test case for trt_convert.TrtGraphConverter().""" 417 if not is_tensorrt_enabled(): 418 return 419 420 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 421 422 # Create a model and save it. 423 input_saved_model_dir = self.mkdtemp() 424 root = self._GetModelForV2() 425 expected_output = root.run(np_input1, np_input2) 426 save.save(root, input_saved_model_dir, 427 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 428 429 # Run TRT conversion. 430 converter = self._CreateConverterV2(input_saved_model_dir) 431 converter.convert() 432 433 # Verify the converted GraphDef and ConcreteFunction. 434 self._CheckTrtOps(converter._converted_func) # pylint: disable=protected-access 435 436 trt_engine_name = self._GetUniqueTRTEngineOp( 437 converter._converted_graph_def).name 438 439 # Save the converted model without any TRT engine cache. 440 output_saved_model_dir = self.mkdtemp() 441 converter.save(output_saved_model_dir) 442 unexpected_asset_file = os.path.join( 443 output_saved_model_dir, 444 "assets/trt-serialized-engine." + trt_engine_name) 445 self.assertFalse(os.path.exists(unexpected_asset_file)) 446 447 # Run the converted function to populate the engine cache. 448 def _InputFn(): 449 yield np_input1, np_input2 450 451 converter.build(input_fn=_InputFn) 452 453 # Save the converted model again with serialized engine cache. 454 output_saved_model_dir = self.mkdtemp() 455 converter.save(output_saved_model_dir) 456 expected_asset_file = os.path.join( 457 output_saved_model_dir, 458 "assets/trt-serialized-engine." + trt_engine_name) 459 self.assertTrue(os.path.exists(expected_asset_file)) 460 self.assertTrue(os.path.getsize(expected_asset_file)) 461 462 del converter 463 gc.collect() # Force GC to destroy the TRT engine cache. 464 465 # Load and verify the converted model. 466 # 467 # TODO(laigd): the name of the new input_signature of the 468 # `root_with_trt.run` function is empty string (originally was None), 469 # investigate why. 470 root_with_trt = load.load(output_saved_model_dir) 471 # TODO(laigd): `root_with_trt.run` is still using the original graph without 472 # trt. Consider changing that. 473 # self._CheckTrtOps(root_with_trt.run.get_concrete_function()) 474 converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] 475 self._CheckTrtOps(converted_signature) 476 output_with_trt = converted_signature( 477 inp1=ops.convert_to_tensor(np_input1), 478 inp2=ops.convert_to_tensor(np_input2)) 479 # The output of running the converted signature is a dict due to 480 # compatibility reasons with V1 SavedModel signature mechanism. 481 self.assertAllClose( 482 expected_output, 483 list(output_with_trt.values())[0], 484 atol=1e-6, 485 rtol=1e-6) 486 487 del root_with_trt 488 gc.collect() # Force GC to destroy the TRT engine cache. 489 490 @test_util.run_v2_only 491 def testTrtGraphConverter_Int8Conversion_v2(self): 492 if not is_tensorrt_enabled(): 493 return 494 495 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 496 497 # Create a model and save it. 498 input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) 499 root = self._GetModelForV2() 500 expected_output = root.run(np_input1, np_input2) 501 save.save(root, input_saved_model_dir, 502 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 503 504 # Run TRT conversion. 505 converter = self._CreateConverterV2( 506 input_saved_model_dir, 507 precision_mode=trt_convert.TrtPrecisionMode.INT8, 508 maximum_cached_engines=3) 509 510 # Convert and perform INT8 calibration 511 def _CalibrationInputFn(): 512 yield np_input1, np_input2 513 514 converter.convert(calibration_input_fn=_CalibrationInputFn) 515 516 trt_engine_name = self._GetUniqueTRTEngineOp( 517 converter._converted_graph_def).name 518 519 def _CheckFn(node): 520 self.assertTrue(len(node.attr["calibration_data"].s), node.name) 521 522 # Verify the converted GraphDef. 523 self._CheckTrtOps(converter._converted_func, _CheckFn) # pylint: disable=protected-access 524 525 # Build another engine with different batch size. 526 def _InputFn(): 527 yield self._RandomInput([5, 1, 1]) 528 529 converter.build(input_fn=_InputFn) 530 531 # Save the converted model. 532 # TODO(laigd): check that it should contain two engines. 533 output_saved_model_dir = self.mkdtemp() 534 converter.save(output_saved_model_dir) 535 expected_asset_file = os.path.join( 536 output_saved_model_dir, 537 "assets/trt-serialized-engine." + trt_engine_name) 538 self.assertTrue(os.path.exists(expected_asset_file)) 539 self.assertTrue(os.path.getsize(expected_asset_file)) 540 541 del converter 542 gc.collect() # Force GC to destroy the TRT engine cache. 543 544 # Load and verify the converted model. 545 root_with_trt = load.load(output_saved_model_dir) 546 converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] 547 self._CheckTrtOps(converted_signature, _CheckFn) 548 output_with_trt = converted_signature( 549 inp1=ops.convert_to_tensor(np_input1), 550 inp2=ops.convert_to_tensor(np_input2)) 551 self.assertEqual(1, len(output_with_trt)) 552 # The output of running the converted signature is a dict due to 553 # compatibility reasons with V1 SavedModel signature mechanism. 554 self.assertAllClose( 555 expected_output, 556 list(output_with_trt.values())[0], 557 atol=1e-6, 558 rtol=1e-6) 559 560 # Run with an input of different batch size. It should build a new engine 561 # using calibration table. 562 # TODO(laigd): check that it should contain three engines. 563 np_input1, np_input2 = self._RandomInput([6, 1, 1]) 564 converted_signature( 565 inp1=ops.convert_to_tensor(np_input1), 566 inp2=ops.convert_to_tensor(np_input2)) 567 568 del root_with_trt 569 gc.collect() # Force GC to destroy the TRT engine cache. 570 571 @test_util.run_v2_only 572 def testTrtGraphConverter_DestroyEngineCache(self): 573 """Test case for trt_convert.TrtGraphConverter().""" 574 if not is_tensorrt_enabled(): 575 return 576 577 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 578 579 # Create a model and save it. 580 input_saved_model_dir = self.mkdtemp() 581 root = self._GetModelForV2() 582 save.save(root, input_saved_model_dir, 583 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 584 585 # Run TRT conversion. 586 converter = self._CreateConverterV2(input_saved_model_dir) 587 converter.convert() 588 589 trt_engine_name = self._GetUniqueTRTEngineOp( 590 converter._converted_graph_def).name 591 592 def _InputFn(): 593 yield np_input1, np_input2 594 595 converter.build(input_fn=_InputFn) # Populate the TRT engine cache. 596 output_saved_model_dir = self.mkdtemp() 597 converter.save(output_saved_model_dir) 598 599 def _DestroyCache(): 600 with ops.device("GPU:0"): 601 handle = gen_trt_ops.create_trt_resource_handle( 602 resource_name=trt_engine_name) 603 gen_resource_variable_ops.destroy_resource_op( 604 handle, ignore_lookup_error=False) 605 606 with self.assertRaisesRegex(errors.NotFoundError, 607 r"Resource .* does not exist."): 608 _DestroyCache() 609 610 # Load the converted model and make sure the engine cache is populated by 611 # default. 612 root = load.load(output_saved_model_dir) 613 _DestroyCache() 614 with self.assertRaisesRegex(errors.NotFoundError, 615 r"Resource .* does not exist."): 616 _DestroyCache() 617 618 # Load the converted model again and make sure the engine cache is destroyed 619 # when the model goes out of scope. 620 root = load.load(output_saved_model_dir) 621 del root 622 gc.collect() # Force GC to destroy the TRT engine cache. 623 with self.assertRaisesRegex(errors.NotFoundError, 624 r"Resource .* does not exist."): 625 _DestroyCache() 626 627 def _CompareSavedModel(self, model_class): 628 signature_key = "serving_default" 629 630 def _GetModelPaths(model_class): 631 input_saved_model_dir = self.mkdtemp() 632 root = model_class() 633 save.save(root, input_saved_model_dir) 634 635 converter = self._CreateConverterV2( 636 input_saved_model_dir, input_saved_model_signature_key=signature_key) 637 converter.convert() 638 output_saved_model_dir = self.mkdtemp() 639 converter.save(output_saved_model_dir) 640 return input_saved_model_dir, output_saved_model_dir 641 642 def _GetSignatureDef(export_dir): 643 saved_model_proto = loader_impl.parse_saved_model(export_dir) 644 self.assertEqual(1, len(saved_model_proto.meta_graphs)) 645 meta_graph = saved_model_proto.meta_graphs[0] 646 self.assertIn(signature_key, meta_graph.signature_def) 647 return meta_graph.signature_def[signature_key] 648 649 def _CompareSignatureDef(original_def, converted_def, is_input): 650 endpoints = original_def.inputs if is_input else original_def.outputs 651 converted_endpoints = ( 652 converted_def.inputs if is_input else converted_def.outputs) 653 self.assertEqual(set(endpoints.keys()), set(converted_endpoints.keys())) 654 for key in endpoints: 655 original_input = endpoints[key] 656 converted_input = converted_endpoints[key] 657 self.assertEqual(original_input.name, converted_input.name) 658 self.assertEqual(original_input.dtype, converted_input.dtype) 659 self.assertEqual( 660 tensor_shape.TensorShape(original_input.tensor_shape).as_list(), 661 tensor_shape.TensorShape(converted_input.tensor_shape).as_list()) 662 663 def _GetStructuredOutputs(export_dir): 664 root = load.load(export_dir) 665 return root.signatures[signature_key].structured_outputs 666 667 saved_model_path, converted_saved_model_path = _GetModelPaths(model_class) 668 original_def = _GetSignatureDef(saved_model_path) 669 converted_def = _GetSignatureDef(converted_saved_model_path) 670 self.assertEqual(original_def.method_name, converted_def.method_name) 671 _CompareSignatureDef(original_def, converted_def, True) 672 _CompareSignatureDef(original_def, converted_def, False) 673 674 self.assertEqual( 675 _GetStructuredOutputs(saved_model_path), 676 _GetStructuredOutputs(converted_saved_model_path)) 677 678 @test_util.run_v2_only 679 def testRetainSignatureInfo_NoInputs(self): 680 if not is_tensorrt_enabled(): 681 return 682 683 class _Model(tracking.AutoTrackable): 684 685 @def_function.function(input_signature=[]) 686 def run(self): 687 return array_ops.constant(1.0) 688 689 self._CompareSavedModel(_Model) 690 691 @test_util.run_v2_only 692 def testRetainSignatureInfo_OneInput(self): 693 if not is_tensorrt_enabled(): 694 return 695 696 class _Model(tracking.AutoTrackable): 697 698 @def_function.function(input_signature=[ 699 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32) 700 ]) 701 def run(self, inp): 702 return inp + inp * inp 703 704 self._CompareSavedModel(_Model) 705 706 @test_util.run_v2_only 707 def testRetainSignatureInfo_TwoInputs(self): 708 if not is_tensorrt_enabled(): 709 return 710 711 class _Model(tracking.AutoTrackable): 712 713 @def_function.function(input_signature=[ 714 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32), 715 tensor_spec.TensorSpec(shape=[None, 2], dtype=dtypes.float32) 716 ]) 717 def run(self, inp1, inp2): 718 return inp1 + inp2 * inp2 719 720 self._CompareSavedModel(_Model) 721 722 @test_util.run_v2_only 723 def testRetainSignatureInfo_OneOutputSignatureKey(self): 724 if not is_tensorrt_enabled(): 725 return 726 727 class _Model(tracking.AutoTrackable): 728 729 @def_function.function(input_signature=[]) 730 def run(self): 731 return {"my_output": array_ops.constant(1.0)} 732 733 self._CompareSavedModel(_Model) 734 735 @test_util.run_v2_only 736 def testRetainSignatureInfo_TwoOutputSignatureKeys(self): 737 if not is_tensorrt_enabled(): 738 return 739 740 class _Model(tracking.AutoTrackable): 741 742 @def_function.function(input_signature=[ 743 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32) 744 ]) 745 def run(self, inp): 746 # Here the keys are not ordered lexicographically on purpose. 747 return { 748 "output_b": array_ops.constant(1.0), 749 "output_a": inp + inp * inp 750 } 751 752 self._CompareSavedModel(_Model) 753 754 def _TestRun(self, sess, batch_size): 755 result = sess.run( 756 "output:0", 757 feed_dict={ 758 "input1:0": [[[1.0]]] * batch_size, 759 "input2:0": [[[1.0]]] * batch_size 760 }) 761 self.assertAllEqual([[[5.0]]] * batch_size, result) 762 763 @test_util.deprecated_graph_mode_only 764 def testTrtGraphConverter_MinimumSegmentSize(self): 765 if not is_tensorrt_enabled(): 766 return 767 output_graph_def = self._ConvertGraphV1(minimum_segment_size=7) 768 node_name_to_op = {node.name: node.op for node in output_graph_def.node} 769 self.assertEqual( 770 { 771 "add/ReadVariableOp": "Const", 772 "input1": "Placeholder", 773 "input2": "Placeholder", 774 "add": "AddV2", 775 "mul": "Mul", 776 "add_1": "AddV2", 777 "add_2": "AddV2", 778 "output": "Identity" 779 }, node_name_to_op) 780 781 @test_util.deprecated_graph_mode_only 782 def testTrtGraphConverter_DynamicOp(self): 783 if not is_tensorrt_enabled(): 784 return 785 786 output_saved_model_dir = self.mkdtemp() 787 output_graph_def = self._ConvertGraphV1( 788 output_saved_model_dir=output_saved_model_dir, 789 is_dynamic_op=True, 790 maximum_cached_engines=2) 791 792 # Test the output GraphDef. 793 with ops.Graph().as_default(): 794 importer.import_graph_def(output_graph_def, name="") 795 with self.session(config=self._GetConfigProto()) as sess: 796 # Run with batch size 1, a new engine is created and cached. 797 self._TestRun(sess, 1) 798 # Run with batch size 2, a new engine is created and cached. 799 self._TestRun(sess, 2) 800 # Run with batch size 3, since the number of cached engines has reached 801 # the max, it should evict an old engine and create a new one. 802 self._TestRun(sess, 3) 803 804 # Test the output SavedModel 805 with ops.Graph().as_default(): 806 with self.session(config=self._GetConfigProto()) as sess: 807 loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) 808 # Run with batch size 1, a new engine is created and cached. 809 self._TestRun(sess, 1) 810 # Run with batch size 2, a new engine is created and cached. 811 self._TestRun(sess, 2) 812 # Run with batch size 3, since the number of cached engines has reached 813 # the max, it should evict an old engine and create a new one. 814 self._TestRun(sess, 3) 815 816 @test_util.deprecated_graph_mode_only 817 def testTrtGraphConverter_StaticOp(self): 818 if not is_tensorrt_enabled(): 819 return 820 821 output_saved_model_dir = self.mkdtemp() 822 output_graph_def = self._ConvertGraphV1( 823 output_saved_model_dir=output_saved_model_dir, maximum_cached_engines=1) 824 825 # Test the output GraphDef. 826 with ops.Graph().as_default(): 827 importer.import_graph_def(output_graph_def, name="") 828 with self.session(config=self._GetConfigProto()) as sess: 829 # Run with batch size 1, the default engine embedded in the graphdef 830 # will be used. 831 self._TestRun(sess, 1) 832 # Run with batch size 2, which exceed the max_batch_size, it should try 833 # to fall back to TF function. 834 self._TestRun(sess, 2) 835 836 # Test the output SavedModel 837 with ops.Graph().as_default(): 838 with self.session(config=self._GetConfigProto()) as sess: 839 loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) 840 # Run with batch size 1, the default engine embedded in the graphdef 841 # will be used. 842 self._TestRun(sess, 1) 843 # Run with batch size 2, which exceed the max_batch_size, it should try 844 # to fall back to TF function. 845 self._TestRun(sess, 2) 846 847 @test_util.run_v2_only 848 def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self): 849 if not is_tensorrt_enabled(): 850 return 851 852 np_input1, np_input2 = self._RandomInput([4, 1, 1]) 853 854 # Create a model and save it. 855 input_saved_model_dir = self.mkdtemp() 856 root = self._GetModelForV2() 857 save.save(root, input_saved_model_dir, 858 {_SAVED_MODEL_SIGNATURE_KEY: root.run}) 859 860 def _InputFn(): 861 yield np_input1, np_input2 862 863 # Run TRT conversion and request an unreasonably large workspace. 864 converter = self._CreateConverterV2( 865 input_saved_model_dir, max_workspace_size_bytes=10 << 40) 866 converter.convert() 867 868 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" 869 with self.assertRaisesRegex( 870 errors.AbortedError, 871 r"User disallowed engine native segment execution"): 872 converter.build(input_fn=_InputFn) 873 874 os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" 875 converter.build(input_fn=_InputFn) 876 877 @test_util.run_v2_only 878 def testBackwardCompatibility(self): 879 """Load and execute a model that was saved in TF2.0.""" 880 if not is_tensorrt_enabled(): 881 return 882 883 model_dir = test.test_src_dir_path( 884 "python/compiler/tensorrt/test/testdata/tftrt_2.0_saved_model") 885 saved_model_loaded = load.load(model_dir, tags=[tag_constants.SERVING]) 886 graph_func = saved_model_loaded.signatures[ 887 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 888 889 np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 890 np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32)) 891 output = graph_func(input1=np_input1, input2=np_input2)["output_0"] 892 893 self.assertEqual(output.shape, (4, 1, 1)) 894 self.assertAllClose( 895 np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]), output) 896 897 898if __name__ == "__main__": 899 test.main() 900