1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to test TF-TensorRT integration.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled 24from tensorflow.core.framework import graph_pb2 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.core.protobuf import rewriter_config_pb2 27from tensorflow.python.compiler.tensorrt import trt_convert 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import graph_util 33from tensorflow.python.framework import importer 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import test_util 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import gen_nn_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import test 41from tensorflow.python.saved_model import builder 42from tensorflow.python.saved_model import loader 43from tensorflow.python.saved_model import signature_constants 44from tensorflow.python.saved_model import signature_def_utils 45from tensorflow.python.saved_model import tag_constants 46from tensorflow.python.saved_model import utils 47from tensorflow.python.tools import saved_model_utils 48from tensorflow.python.saved_model import load 49from tensorflow.python.saved_model import save 50from tensorflow.python.training.tracking import tracking 51 52 53class TrtConvertTest(test_util.TensorFlowTestCase): 54 """Class to test Tensorflow-TensorRT integration python API.""" 55 56 # Use a small max_workspace_size for tests so they don't consume too much GPU 57 # memory. 58 _TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20 59 60 def testGetTensorrtRewriterConfig(self): 61 """Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" 62 if not is_tensorrt_enabled(): 63 return 64 rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( 65 rewriter_config_template=None, 66 max_batch_size=128, 67 max_workspace_size_bytes=1234, 68 precision_mode="INT8", 69 minimum_segment_size=10, 70 is_dynamic_op=True, 71 maximum_cached_engines=2, 72 cached_engine_batches=[1, 128]) 73 self.assertEqual(["constfold", "layout", "constfold"], 74 rewriter_cfg.optimizers) 75 self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, 76 rewriter_cfg.meta_optimizer_iterations) 77 trt_optimizer = None 78 for optimizer in rewriter_cfg.custom_optimizers: 79 if optimizer.name == "TensorRTOptimizer": 80 self.assertTrue(trt_optimizer is None) 81 trt_optimizer = optimizer 82 self.assertTrue(trt_optimizer is not None) 83 for key in [ 84 "minimum_segment_size", "max_batch_size", "is_dynamic_op", 85 "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines", 86 "cached_engine_batches" 87 ]: 88 self.assertTrue(key in trt_optimizer.parameter_map) 89 self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) 90 self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) 91 self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) 92 self.assertEqual(1234, 93 trt_optimizer.parameter_map["max_workspace_size_bytes"].i) 94 self.assertEqual( 95 trt_convert._to_bytes("INT8"), 96 trt_optimizer.parameter_map["precision_mode"].s) 97 self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) 98 self.assertEqual( 99 [1, 128], trt_optimizer.parameter_map["cached_engine_batches"].list.i) 100 101 def _GetConfigProto(self): 102 """Get ConfigProto for session creation.""" 103 config = config_pb2.ConfigProto( 104 gpu_options=config_pb2.GPUOptions(allow_growth=True)) 105 return config 106 107 def _GetGraph(self): 108 """Get the graph for testing.""" 109 # The graph computes (input+1)^2, it looks like: 110 # 111 # input (Placeholder) v1 (Variable) 112 # | \ / 113 # \ + 114 # \ / \ 115 # * | 116 # \ / 117 # + 118 # | 119 # output (Identity) 120 g = ops.Graph() 121 with g.as_default(): 122 with g.device("/GPU:0"): 123 inp = array_ops.placeholder( 124 dtype=dtypes.float32, shape=[None, 1, 1], name="input") 125 var = variables.VariableV1([[[1.0]]], 126 dtype=dtypes.float32, 127 name="v1", 128 use_resource=False) 129 add = inp + var.value() 130 mul = inp * add 131 add = mul + add 132 out = array_ops.identity(add, name="output") 133 return g, var, inp, out 134 135 def _GetGraphDef(self): 136 """Get the graph def for testing.""" 137 g, var, _, _ = self._GetGraph() 138 with self.session(graph=g, config=self._GetConfigProto()) as sess: 139 sess.run(var.initializer) 140 graph_def = graph_util.convert_variables_to_constants( 141 sess, g.as_graph_def(add_shapes=True), ["output"]) 142 node_name_to_op = {node.name: node.op for node in graph_def.node} 143 self.assertEqual({ 144 "v1": "Const", 145 "v1/read": "Identity", 146 "input": "Placeholder", 147 "add": "Add", 148 "mul": "Mul", 149 "add_1": "Add", 150 "output": "Identity" 151 }, node_name_to_op) 152 return graph_def 153 154 def _WriteInputSavedModel(self, input_saved_model_dir): 155 """Write the saved model as an input for testing.""" 156 g, var, inp, out = self._GetGraph() 157 signature_def = signature_def_utils.build_signature_def( 158 inputs={"myinput": utils.build_tensor_info(inp)}, 159 outputs={"myoutput": utils.build_tensor_info(out)}, 160 method_name=signature_constants.PREDICT_METHOD_NAME) 161 saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) 162 with self.session(graph=g, config=self._GetConfigProto()) as sess: 163 sess.run(var.initializer) 164 saved_model_builder.add_meta_graph_and_variables( 165 sess, [tag_constants.SERVING], 166 signature_def_map={"mypredict": signature_def}) 167 saved_model_builder.save() 168 169 def _ConvertGraph(self, 170 input_saved_model_dir=None, 171 output_saved_model_dir=None, 172 need_calibration=False, 173 max_batch_size=1, 174 minimum_segment_size=3, 175 is_dynamic_op=False, 176 maximum_cached_engines=1, 177 use_function_backup=False): 178 """Helper method to convert a GraphDef or SavedModel using TF-TRT.""" 179 converter = trt_convert.TrtGraphConverter( 180 input_saved_model_dir=input_saved_model_dir, 181 input_saved_model_signature_key="mypredict", 182 input_graph_def=None if input_saved_model_dir else self._GetGraphDef(), 183 nodes_blacklist=["output"], 184 session_config=self._GetConfigProto(), 185 max_batch_size=max_batch_size, 186 max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, 187 precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration 188 else trt_convert.TrtPrecisionMode.FP32), 189 minimum_segment_size=minimum_segment_size, 190 is_dynamic_op=is_dynamic_op, 191 maximum_cached_engines=maximum_cached_engines, 192 use_function_backup=use_function_backup) 193 conversion_result = converter.convert() 194 195 if context.executing_eagerly(): 196 output_graph_def = conversion_result.graph.as_graph_def() 197 else: 198 output_graph_def = conversion_result 199 200 if need_calibration: 201 202 class CalibrationData(object): 203 204 def __init__(self): 205 self._data = 0 206 207 def next(self): 208 self._data += 1 209 return {"input:0": [[[self._data]]]} 210 211 output_graph_def = converter.calibrate( 212 fetch_names=["output:0"], 213 num_runs=10, 214 feed_dict_fn=CalibrationData().next) 215 216 if output_saved_model_dir is not None: 217 converter.save(output_saved_model_dir=output_saved_model_dir) 218 return output_graph_def 219 220 def _TestTrtGraphConverter(self, 221 input_saved_model_dir=None, 222 output_saved_model_dir=None, 223 need_calibration=False, 224 is_dynamic_op=False): 225 """General method to test trt_convert.TrtGraphConverter().""" 226 output_graph_def = self._ConvertGraph( 227 input_saved_model_dir=input_saved_model_dir, 228 output_saved_model_dir=output_saved_model_dir, 229 need_calibration=need_calibration, 230 is_dynamic_op=is_dynamic_op, 231 use_function_backup=need_calibration) 232 graph_defs_to_verify = [output_graph_def] 233 234 if output_saved_model_dir: 235 if context.executing_eagerly(): 236 root = load.load(output_saved_model_dir) 237 saved_model_graph_def = root.signatures[ 238 signature_constants 239 .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def() 240 else: 241 saved_model_graph_def = saved_model_utils.get_meta_graph_def( 242 output_saved_model_dir, tag_constants.SERVING).graph_def 243 self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef)) 244 graph_defs_to_verify.append(saved_model_graph_def) 245 246 for graph_def in graph_defs_to_verify: 247 node_name_to_op = {node.name: node.op for node in graph_def.node} 248 if context.executing_eagerly(): 249 # In V2 the actual graph could be inside a function. 250 for func in graph_def.library.function: 251 node_name_to_op.update({node.name: node.op for node in func.node_def}) 252 self.assertIn("TRTEngineOp_0", node_name_to_op) 253 self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) 254 else: 255 self.assertEqual({ 256 "input": "Placeholder", 257 "TRTEngineOp_0": "TRTEngineOp", 258 "output": "Identity" 259 }, node_name_to_op) 260 261 if need_calibration: 262 trt_engine_nodes = [ 263 node for node in graph_def.node if node.op == "TRTEngineOp" 264 ] 265 self.assertNotEmpty(trt_engine_nodes) 266 for node in trt_engine_nodes: 267 self.assertTrue(len(node.attr["calibration_data"].s)) 268 # Run the calibrated graph. 269 # TODO(laigd): consider having some input where the answer is different. 270 with ops.Graph().as_default(): 271 importer.import_graph_def(graph_def, name="") 272 with self.session(config=self._GetConfigProto()) as sess: 273 for test_data in range(10): 274 self.assertEqual((test_data + 1.0)**2, 275 sess.run( 276 "output:0", 277 feed_dict={"input:0": [[[test_data]]]})) 278 279 @test_util.deprecated_graph_mode_only 280 def testTrtGraphConverter_BasicConversion(self): 281 """Test case for trt_convert.TrtGraphConverter().""" 282 if not is_tensorrt_enabled(): 283 return 284 285 tmp_dir = self.get_temp_dir() 286 input_saved_model_dir = os.path.join(tmp_dir, "in_dir1") 287 self._WriteInputSavedModel(input_saved_model_dir) 288 289 for need_calibration in [False, True]: 290 # Use GraphDef as input. 291 self._TestTrtGraphConverter() 292 293 # Use SavedModel as input. 294 output_saved_model_dir = os.path.join( 295 tmp_dir, "out_dir1%s" % ("_int8" if need_calibration else "")) 296 self._TestTrtGraphConverter( 297 input_saved_model_dir=input_saved_model_dir, 298 output_saved_model_dir=output_saved_model_dir, 299 need_calibration=need_calibration) 300 301 @test_util.run_v2_only 302 def testTrtGraphConverter_BasicConversion_v2(self): 303 """Test case for trt_convert.TrtGraphConverter().""" 304 if not is_tensorrt_enabled(): 305 return 306 307 # TODO(laigd): we need to use ops like conv2d so Grappler can infer the 308 # shapes (at least rank) of the tensors, so we're able to build an TRT 309 # engine in dynamic mode. Currently shape information is not propagate from 310 # ConcreteFunction to GraphDef, need to investigate and fix it. 311 class SimpleModel(tracking.AutoTrackable): 312 313 def __init__(self): 314 self.v = None 315 316 @def_function.function(input_signature=[ 317 tensor_spec.TensorSpec(shape=[None, 24, 24, 2], dtype=dtypes.float32) 318 ]) 319 def run(self, inp): 320 if self.v is None: 321 self.v = variables.Variable([[[[1., 0.5, 4., 6., 0.5, 1.], 322 [1., 0.5, 1., 1., 0.5, 1.]]]]) 323 conv = gen_nn_ops.conv2d( 324 input=inp, filter=self.v, strides=[1, 2, 2, 1], padding="SAME") 325 identity = array_ops.identity(conv) 326 return identity 327 328 tmp_dir = self.get_temp_dir() 329 input_saved_model_dir = os.path.join(tmp_dir, "in_dir1_v2") 330 root = SimpleModel() 331 save.save(root, input_saved_model_dir) 332 333 # Convert the SavedModel and verify the result. 334 output_saved_model_dir = os.path.join(tmp_dir, "out_dir1_v2") 335 self._TestTrtGraphConverter( 336 input_saved_model_dir=input_saved_model_dir, 337 output_saved_model_dir=output_saved_model_dir, 338 is_dynamic_op=True) 339 340 def _TestRun(self, 341 sess, 342 batch_size, 343 use_function_backup=False, 344 expect_engine_is_run=True): 345 try: 346 result = sess.run( 347 "output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) 348 self.assertAllEqual([[[4.0]]] * batch_size, result) 349 except errors.OpError as e: 350 # This should happen only when fallback path is disabled and TRT engine 351 # fails to run. 352 self.assertTrue(not use_function_backup and not expect_engine_is_run) 353 self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e)) 354 355 @test_util.deprecated_graph_mode_only 356 def testTrtGraphConverter_MinimumSegmentSize(self): 357 if not is_tensorrt_enabled(): 358 return 359 output_graph_def = self._ConvertGraph(minimum_segment_size=5) 360 node_name_to_op = {node.name: node.op for node in output_graph_def.node} 361 self.assertEqual({ 362 "v1/read": "Const", 363 "input": "Placeholder", 364 "add": "Add", 365 "mul": "Mul", 366 "add_1": "Add", 367 "output": "Identity" 368 }, node_name_to_op) 369 370 @test_util.deprecated_graph_mode_only 371 def testTrtGraphConverter_DynamicOp(self): 372 if not is_tensorrt_enabled(): 373 return 374 375 tmp_dir = self.get_temp_dir() 376 input_saved_model_dir = os.path.join(tmp_dir, "in_dir2") 377 output_saved_model_dir = os.path.join(tmp_dir, "out_dir2") 378 self._WriteInputSavedModel(input_saved_model_dir) 379 output_graph_def = self._ConvertGraph( 380 input_saved_model_dir=input_saved_model_dir, 381 output_saved_model_dir=output_saved_model_dir, 382 is_dynamic_op=True, 383 maximum_cached_engines=2, 384 use_function_backup=False) # Disallow fallback. 385 386 # Test the output GraphDef. 387 with ops.Graph().as_default(): 388 importer.import_graph_def(output_graph_def, name="") 389 with self.session(config=self._GetConfigProto()) as sess: 390 # Run with batch size 1, a new engine is created and cached. 391 self._TestRun(sess, 1) 392 # Run with batch size 2, a new engine is created and cached. 393 self._TestRun(sess, 2) 394 # Run with batch size 3, since the number of cached engines has reached 395 # the max, it should evict an old engine and create a new one. 396 self._TestRun(sess, 3) 397 398 # Test the output SavedModel 399 with ops.Graph().as_default(): 400 with self.session(config=self._GetConfigProto()) as sess: 401 loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) 402 # Run with batch size 1, a new engine is created and cached. 403 self._TestRun(sess, 1) 404 # Run with batch size 2, a new engine is created and cached. 405 self._TestRun(sess, 2) 406 # Run with batch size 3, since the number of cached engines has reached 407 # the max, it should evict an old engine and create a new one. 408 self._TestRun(sess, 3) 409 410 def _TestStaticOp(self, use_function_backup): 411 if not is_tensorrt_enabled(): 412 return 413 414 tmp_dir = self.get_temp_dir() 415 input_saved_model_dir = os.path.join(tmp_dir, "in_dir3") 416 output_saved_model_dir = os.path.join(tmp_dir, "out_dir3") 417 self._WriteInputSavedModel(input_saved_model_dir) 418 output_graph_def = self._ConvertGraph( 419 input_saved_model_dir=input_saved_model_dir, 420 output_saved_model_dir=output_saved_model_dir, 421 maximum_cached_engines=2, # This is noop, added just for testing. 422 use_function_backup=use_function_backup) 423 424 # Test the output GraphDef. 425 with ops.Graph().as_default(): 426 importer.import_graph_def(output_graph_def, name="") 427 with self.session(config=self._GetConfigProto()) as sess: 428 # Run with batch size 1, the default engine embedded in the graphdef 429 # will be used. 430 self._TestRun( 431 sess, 432 1, 433 use_function_backup=use_function_backup, 434 expect_engine_is_run=True) 435 # Run with batch size 2, which exceed the max_batch_size, it should try 436 # to fall back to TF function. 437 self._TestRun( 438 sess, 439 2, 440 use_function_backup=use_function_backup, 441 expect_engine_is_run=False) 442 443 # Test the output SavedModel 444 with ops.Graph().as_default(): 445 with self.session(config=self._GetConfigProto()) as sess: 446 loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) 447 # Run with batch size 1, the default engine embedded in the graphdef 448 # will be used. 449 self._TestRun( 450 sess, 451 1, 452 use_function_backup=use_function_backup, 453 expect_engine_is_run=True) 454 # Run with batch size 2, which exceed the max_batch_size, it should try 455 # to fall back to TF function. 456 self._TestRun( 457 sess, 458 2, 459 use_function_backup=use_function_backup, 460 expect_engine_is_run=False) 461 462 @test_util.deprecated_graph_mode_only 463 def testTrtGraphConverter_StaticOp_NoFallback(self): 464 self._TestStaticOp(use_function_backup=False) 465 466 @test_util.deprecated_graph_mode_only 467 def testTrtGraphConverter_StaticOp_WithFallback(self): 468 self._TestStaticOp(use_function_backup=True) 469 470 471if __name__ == "__main__": 472 test.main() 473