• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.python.compiler.tensorrt import trt_convert
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import graph_util
33from tensorflow.python.framework import importer
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import test_util
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import gen_nn_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import test
41from tensorflow.python.saved_model import builder
42from tensorflow.python.saved_model import loader
43from tensorflow.python.saved_model import signature_constants
44from tensorflow.python.saved_model import signature_def_utils
45from tensorflow.python.saved_model import tag_constants
46from tensorflow.python.saved_model import utils
47from tensorflow.python.tools import saved_model_utils
48from tensorflow.python.saved_model import load
49from tensorflow.python.saved_model import save
50from tensorflow.python.training.tracking import tracking
51
52
53class TrtConvertTest(test_util.TensorFlowTestCase):
54  """Class to test Tensorflow-TensorRT integration python API."""
55
56  # Use a small max_workspace_size for tests so they don't consume too much GPU
57  # memory.
58  _TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
59
60  def testGetTensorrtRewriterConfig(self):
61    """Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
62    if not is_tensorrt_enabled():
63      return
64    rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
65        rewriter_config_template=None,
66        max_batch_size=128,
67        max_workspace_size_bytes=1234,
68        precision_mode="INT8",
69        minimum_segment_size=10,
70        is_dynamic_op=True,
71        maximum_cached_engines=2,
72        cached_engine_batches=[1, 128])
73    self.assertEqual(["constfold", "layout", "constfold"],
74                     rewriter_cfg.optimizers)
75    self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
76                     rewriter_cfg.meta_optimizer_iterations)
77    trt_optimizer = None
78    for optimizer in rewriter_cfg.custom_optimizers:
79      if optimizer.name == "TensorRTOptimizer":
80        self.assertTrue(trt_optimizer is None)
81        trt_optimizer = optimizer
82    self.assertTrue(trt_optimizer is not None)
83    for key in [
84        "minimum_segment_size", "max_batch_size", "is_dynamic_op",
85        "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines",
86        "cached_engine_batches"
87    ]:
88      self.assertTrue(key in trt_optimizer.parameter_map)
89    self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
90    self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
91    self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
92    self.assertEqual(1234,
93                     trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
94    self.assertEqual(
95        trt_convert._to_bytes("INT8"),
96        trt_optimizer.parameter_map["precision_mode"].s)
97    self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
98    self.assertEqual(
99        [1, 128], trt_optimizer.parameter_map["cached_engine_batches"].list.i)
100
101  def _GetConfigProto(self):
102    """Get ConfigProto for session creation."""
103    config = config_pb2.ConfigProto(
104        gpu_options=config_pb2.GPUOptions(allow_growth=True))
105    return config
106
107  def _GetGraph(self):
108    """Get the graph for testing."""
109    # The graph computes (input+1)^2, it looks like:
110    #
111    # input (Placeholder)  v1 (Variable)
112    #               |   \ /
113    #                \   +
114    #                 \ / \
115    #                  *   |
116    #                   \ /
117    #                    +
118    #                    |
119    #                 output (Identity)
120    g = ops.Graph()
121    with g.as_default():
122      with g.device("/GPU:0"):
123        inp = array_ops.placeholder(
124            dtype=dtypes.float32, shape=[None, 1, 1], name="input")
125        var = variables.VariableV1([[[1.0]]],
126                                   dtype=dtypes.float32,
127                                   name="v1",
128                                   use_resource=False)
129        add = inp + var.value()
130        mul = inp * add
131        add = mul + add
132        out = array_ops.identity(add, name="output")
133    return g, var, inp, out
134
135  def _GetGraphDef(self):
136    """Get the graph def for testing."""
137    g, var, _, _ = self._GetGraph()
138    with self.session(graph=g, config=self._GetConfigProto()) as sess:
139      sess.run(var.initializer)
140      graph_def = graph_util.convert_variables_to_constants(
141          sess, g.as_graph_def(add_shapes=True), ["output"])
142    node_name_to_op = {node.name: node.op for node in graph_def.node}
143    self.assertEqual({
144        "v1": "Const",
145        "v1/read": "Identity",
146        "input": "Placeholder",
147        "add": "Add",
148        "mul": "Mul",
149        "add_1": "Add",
150        "output": "Identity"
151    }, node_name_to_op)
152    return graph_def
153
154  def _WriteInputSavedModel(self, input_saved_model_dir):
155    """Write the saved model as an input for testing."""
156    g, var, inp, out = self._GetGraph()
157    signature_def = signature_def_utils.build_signature_def(
158        inputs={"myinput": utils.build_tensor_info(inp)},
159        outputs={"myoutput": utils.build_tensor_info(out)},
160        method_name=signature_constants.PREDICT_METHOD_NAME)
161    saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
162    with self.session(graph=g, config=self._GetConfigProto()) as sess:
163      sess.run(var.initializer)
164      saved_model_builder.add_meta_graph_and_variables(
165          sess, [tag_constants.SERVING],
166          signature_def_map={"mypredict": signature_def})
167    saved_model_builder.save()
168
169  def _ConvertGraph(self,
170                    input_saved_model_dir=None,
171                    output_saved_model_dir=None,
172                    need_calibration=False,
173                    max_batch_size=1,
174                    minimum_segment_size=3,
175                    is_dynamic_op=False,
176                    maximum_cached_engines=1,
177                    use_function_backup=False):
178    """Helper method to convert a GraphDef or SavedModel using TF-TRT."""
179    converter = trt_convert.TrtGraphConverter(
180        input_saved_model_dir=input_saved_model_dir,
181        input_saved_model_signature_key="mypredict",
182        input_graph_def=None if input_saved_model_dir else self._GetGraphDef(),
183        nodes_blacklist=["output"],
184        session_config=self._GetConfigProto(),
185        max_batch_size=max_batch_size,
186        max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
187        precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration
188                        else trt_convert.TrtPrecisionMode.FP32),
189        minimum_segment_size=minimum_segment_size,
190        is_dynamic_op=is_dynamic_op,
191        maximum_cached_engines=maximum_cached_engines,
192        use_function_backup=use_function_backup)
193    conversion_result = converter.convert()
194
195    if context.executing_eagerly():
196      output_graph_def = conversion_result.graph.as_graph_def()
197    else:
198      output_graph_def = conversion_result
199
200      if need_calibration:
201
202        class CalibrationData(object):
203
204          def __init__(self):
205            self._data = 0
206
207          def next(self):
208            self._data += 1
209            return {"input:0": [[[self._data]]]}
210
211        output_graph_def = converter.calibrate(
212            fetch_names=["output:0"],
213            num_runs=10,
214            feed_dict_fn=CalibrationData().next)
215
216    if output_saved_model_dir is not None:
217      converter.save(output_saved_model_dir=output_saved_model_dir)
218    return output_graph_def
219
220  def _TestTrtGraphConverter(self,
221                             input_saved_model_dir=None,
222                             output_saved_model_dir=None,
223                             need_calibration=False,
224                             is_dynamic_op=False):
225    """General method to test trt_convert.TrtGraphConverter()."""
226    output_graph_def = self._ConvertGraph(
227        input_saved_model_dir=input_saved_model_dir,
228        output_saved_model_dir=output_saved_model_dir,
229        need_calibration=need_calibration,
230        is_dynamic_op=is_dynamic_op,
231        use_function_backup=need_calibration)
232    graph_defs_to_verify = [output_graph_def]
233
234    if output_saved_model_dir:
235      if context.executing_eagerly():
236        root = load.load(output_saved_model_dir)
237        saved_model_graph_def = root.signatures[
238            signature_constants
239            .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
240      else:
241        saved_model_graph_def = saved_model_utils.get_meta_graph_def(
242            output_saved_model_dir, tag_constants.SERVING).graph_def
243      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
244      graph_defs_to_verify.append(saved_model_graph_def)
245
246    for graph_def in graph_defs_to_verify:
247      node_name_to_op = {node.name: node.op for node in graph_def.node}
248      if context.executing_eagerly():
249        # In V2 the actual graph could be inside a function.
250        for func in graph_def.library.function:
251          node_name_to_op.update({node.name: node.op for node in func.node_def})
252        self.assertIn("TRTEngineOp_0", node_name_to_op)
253        self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
254      else:
255        self.assertEqual({
256            "input": "Placeholder",
257            "TRTEngineOp_0": "TRTEngineOp",
258            "output": "Identity"
259        }, node_name_to_op)
260
261      if need_calibration:
262        trt_engine_nodes = [
263            node for node in graph_def.node if node.op == "TRTEngineOp"
264        ]
265        self.assertNotEmpty(trt_engine_nodes)
266        for node in trt_engine_nodes:
267          self.assertTrue(len(node.attr["calibration_data"].s))
268        # Run the calibrated graph.
269        # TODO(laigd): consider having some input where the answer is different.
270        with ops.Graph().as_default():
271          importer.import_graph_def(graph_def, name="")
272          with self.session(config=self._GetConfigProto()) as sess:
273            for test_data in range(10):
274              self.assertEqual((test_data + 1.0)**2,
275                               sess.run(
276                                   "output:0",
277                                   feed_dict={"input:0": [[[test_data]]]}))
278
279  @test_util.deprecated_graph_mode_only
280  def testTrtGraphConverter_BasicConversion(self):
281    """Test case for trt_convert.TrtGraphConverter()."""
282    if not is_tensorrt_enabled():
283      return
284
285    tmp_dir = self.get_temp_dir()
286    input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
287    self._WriteInputSavedModel(input_saved_model_dir)
288
289    for need_calibration in [False, True]:
290      # Use GraphDef as input.
291      self._TestTrtGraphConverter()
292
293      # Use SavedModel as input.
294      output_saved_model_dir = os.path.join(
295          tmp_dir, "out_dir1%s" % ("_int8" if need_calibration else ""))
296      self._TestTrtGraphConverter(
297          input_saved_model_dir=input_saved_model_dir,
298          output_saved_model_dir=output_saved_model_dir,
299          need_calibration=need_calibration)
300
301  @test_util.run_v2_only
302  def testTrtGraphConverter_BasicConversion_v2(self):
303    """Test case for trt_convert.TrtGraphConverter()."""
304    if not is_tensorrt_enabled():
305      return
306
307    # TODO(laigd): we need to use ops like conv2d so Grappler can infer the
308    # shapes (at least rank) of the tensors, so we're able to build an TRT
309    # engine in dynamic mode. Currently shape information is not propagate from
310    # ConcreteFunction to GraphDef, need to investigate and fix it.
311    class SimpleModel(tracking.AutoTrackable):
312
313      def __init__(self):
314        self.v = None
315
316      @def_function.function(input_signature=[
317          tensor_spec.TensorSpec(shape=[None, 24, 24, 2], dtype=dtypes.float32)
318      ])
319      def run(self, inp):
320        if self.v is None:
321          self.v = variables.Variable([[[[1., 0.5, 4., 6., 0.5, 1.],
322                                         [1., 0.5, 1., 1., 0.5, 1.]]]])
323        conv = gen_nn_ops.conv2d(
324            input=inp, filter=self.v, strides=[1, 2, 2, 1], padding="SAME")
325        identity = array_ops.identity(conv)
326        return identity
327
328    tmp_dir = self.get_temp_dir()
329    input_saved_model_dir = os.path.join(tmp_dir, "in_dir1_v2")
330    root = SimpleModel()
331    save.save(root, input_saved_model_dir)
332
333    # Convert the SavedModel and verify the result.
334    output_saved_model_dir = os.path.join(tmp_dir, "out_dir1_v2")
335    self._TestTrtGraphConverter(
336        input_saved_model_dir=input_saved_model_dir,
337        output_saved_model_dir=output_saved_model_dir,
338        is_dynamic_op=True)
339
340  def _TestRun(self,
341               sess,
342               batch_size,
343               use_function_backup=False,
344               expect_engine_is_run=True):
345    try:
346      result = sess.run(
347          "output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
348      self.assertAllEqual([[[4.0]]] * batch_size, result)
349    except errors.OpError as e:
350      # This should happen only when fallback path is disabled and TRT engine
351      # fails to run.
352      self.assertTrue(not use_function_backup and not expect_engine_is_run)
353      self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e))
354
355  @test_util.deprecated_graph_mode_only
356  def testTrtGraphConverter_MinimumSegmentSize(self):
357    if not is_tensorrt_enabled():
358      return
359    output_graph_def = self._ConvertGraph(minimum_segment_size=5)
360    node_name_to_op = {node.name: node.op for node in output_graph_def.node}
361    self.assertEqual({
362        "v1/read": "Const",
363        "input": "Placeholder",
364        "add": "Add",
365        "mul": "Mul",
366        "add_1": "Add",
367        "output": "Identity"
368    }, node_name_to_op)
369
370  @test_util.deprecated_graph_mode_only
371  def testTrtGraphConverter_DynamicOp(self):
372    if not is_tensorrt_enabled():
373      return
374
375    tmp_dir = self.get_temp_dir()
376    input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
377    output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
378    self._WriteInputSavedModel(input_saved_model_dir)
379    output_graph_def = self._ConvertGraph(
380        input_saved_model_dir=input_saved_model_dir,
381        output_saved_model_dir=output_saved_model_dir,
382        is_dynamic_op=True,
383        maximum_cached_engines=2,
384        use_function_backup=False)  # Disallow fallback.
385
386    # Test the output GraphDef.
387    with ops.Graph().as_default():
388      importer.import_graph_def(output_graph_def, name="")
389      with self.session(config=self._GetConfigProto()) as sess:
390        # Run with batch size 1, a new engine is created and cached.
391        self._TestRun(sess, 1)
392        # Run with batch size 2, a new engine is created and cached.
393        self._TestRun(sess, 2)
394        # Run with batch size 3, since the number of cached engines has reached
395        # the max, it should evict an old engine and create a new one.
396        self._TestRun(sess, 3)
397
398    # Test the output SavedModel
399    with ops.Graph().as_default():
400      with self.session(config=self._GetConfigProto()) as sess:
401        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
402        # Run with batch size 1, a new engine is created and cached.
403        self._TestRun(sess, 1)
404        # Run with batch size 2, a new engine is created and cached.
405        self._TestRun(sess, 2)
406        # Run with batch size 3, since the number of cached engines has reached
407        # the max, it should evict an old engine and create a new one.
408        self._TestRun(sess, 3)
409
410  def _TestStaticOp(self, use_function_backup):
411    if not is_tensorrt_enabled():
412      return
413
414    tmp_dir = self.get_temp_dir()
415    input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
416    output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
417    self._WriteInputSavedModel(input_saved_model_dir)
418    output_graph_def = self._ConvertGraph(
419        input_saved_model_dir=input_saved_model_dir,
420        output_saved_model_dir=output_saved_model_dir,
421        maximum_cached_engines=2,  # This is noop, added just for testing.
422        use_function_backup=use_function_backup)
423
424    # Test the output GraphDef.
425    with ops.Graph().as_default():
426      importer.import_graph_def(output_graph_def, name="")
427      with self.session(config=self._GetConfigProto()) as sess:
428        # Run with batch size 1, the default engine embedded in the graphdef
429        # will be used.
430        self._TestRun(
431            sess,
432            1,
433            use_function_backup=use_function_backup,
434            expect_engine_is_run=True)
435        # Run with batch size 2, which exceed the max_batch_size, it should try
436        # to fall back to TF function.
437        self._TestRun(
438            sess,
439            2,
440            use_function_backup=use_function_backup,
441            expect_engine_is_run=False)
442
443    # Test the output SavedModel
444    with ops.Graph().as_default():
445      with self.session(config=self._GetConfigProto()) as sess:
446        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
447        # Run with batch size 1, the default engine embedded in the graphdef
448        # will be used.
449        self._TestRun(
450            sess,
451            1,
452            use_function_backup=use_function_backup,
453            expect_engine_is_run=True)
454        # Run with batch size 2, which exceed the max_batch_size, it should try
455        # to fall back to TF function.
456        self._TestRun(
457            sess,
458            2,
459            use_function_backup=use_function_backup,
460            expect_engine_is_run=False)
461
462  @test_util.deprecated_graph_mode_only
463  def testTrtGraphConverter_StaticOp_NoFallback(self):
464    self._TestStaticOp(use_function_backup=False)
465
466  @test_util.deprecated_graph_mode_only
467  def testTrtGraphConverter_StaticOp_WithFallback(self):
468    self._TestStaticOp(use_function_backup=True)
469
470
471if __name__ == "__main__":
472  test.main()
473