• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gc
22import os
23import re
24import tempfile
25
26from absl.testing import parameterized
27import numpy as np
28
29from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
30from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance  # pylint: disable=g-importing-member
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.protobuf import config_pb2
33from tensorflow.python.compiler.tensorrt import trt_convert
34from tensorflow.python.eager import def_function
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import graph_util
38from tensorflow.python.framework import importer
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.framework import tensor_spec
42from tensorflow.python.framework import test_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import gen_resource_variable_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import test
47from tensorflow.python.saved_model import builder
48from tensorflow.python.saved_model import load
49from tensorflow.python.saved_model import loader
50from tensorflow.python.saved_model import loader_impl
51from tensorflow.python.saved_model import save
52from tensorflow.python.saved_model import signature_constants
53from tensorflow.python.saved_model import signature_def_utils
54from tensorflow.python.saved_model import tag_constants
55from tensorflow.python.saved_model import utils
56from tensorflow.python.tools import saved_model_utils
57from tensorflow.python.training.tracking import tracking
58from tensorflow.python.util.lazy_loader import LazyLoader
59
60_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
61
62gen_trt_ops = LazyLoader(
63    "gen_trt_ops", globals(),
64    "tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
65
66
67class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
68  """Class to test Tensorflow-TensorRT integration python API."""
69
70  # Use a small max_workspace_size for tests so they don't consume too much GPU
71  # memory.
72  _TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
73
74  def mkdtemp(self):
75    return tempfile.mkdtemp(dir=self.get_temp_dir())
76
77  def testTRTEngineInstanceAvailable(self):
78    # test if we can access the TRTEngineInstance protobuf
79    assert hasattr(TRTEngineInstance(), "serialized_engine")
80
81  def _GetConfigProto(self, rewriter_config=None):
82    """Get ConfigProto for session creation."""
83    config = config_pb2.ConfigProto(
84        gpu_options=config_pb2.GPUOptions(allow_growth=True))
85    if rewriter_config:
86      config.graph_options.rewrite_options.CopyFrom(rewriter_config)
87    return config
88
89  @classmethod
90  def _GetGraph(cls, inp1, inp2, var):
91    """Get the graph for testing."""
92    # The graph computes: inp1^2 + inp1*var + inp1 + inp2 + var
93    add = inp1 + var
94    mul = inp1 * add
95    add = mul + add
96    add = add + inp2
97    out = array_ops.identity(add, name="output")
98    return out
99
100  def _GetModelForV2(self):
101
102    class SimpleModel(tracking.AutoTrackable):
103
104      def __init__(self):
105        self.v = None
106
107      @def_function.function(input_signature=[
108          tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
109          tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
110      ])
111      def run(self, inp1, inp2):
112        if self.v is None:
113          self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
114        return TrtConvertTest._GetGraph(inp1, inp2, self.v)
115
116    return SimpleModel()
117
118  def _GetGraphForV1(self, device):
119
120    def _GraphFn():
121      inp1 = array_ops.placeholder(
122          dtype=dtypes.float32, shape=[None, 1, 1], name="input1")
123      inp2 = array_ops.placeholder(
124          dtype=dtypes.float32, shape=[None, 1, 1], name="input2")
125      var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
126      out = TrtConvertTest._GetGraph(inp1, inp2, var)
127      return g, var, inp1, inp2, out
128
129    g = ops.Graph()
130    with g.as_default():
131      if device:
132        with g.device(device):
133          return _GraphFn()
134      return _GraphFn()
135
136  def _GetGraphDefForV1(self, device):
137    """Get the graph def for testing."""
138    g, var, _, _, _ = self._GetGraphForV1(device)
139    with self.session(graph=g, config=self._GetConfigProto()) as sess:
140      sess.run(var.initializer)
141      graph_def = graph_util.convert_variables_to_constants(
142          sess, g.as_graph_def(add_shapes=True), ["output"])
143    node_name_to_op = {node.name: node.op for node in graph_def.node}
144    self.assertEqual(
145        {
146            "v1": "Const",
147            "add/ReadVariableOp": "Identity",
148            "input1": "Placeholder",
149            "input2": "Placeholder",
150            "add": "AddV2",
151            "mul": "Mul",
152            "add_1": "AddV2",
153            "add_2": "AddV2",
154            "output": "Identity"
155        }, node_name_to_op)
156    return graph_def
157
158  def _WriteInputSavedModelForV1(self, input_saved_model_dir, device):
159    """Write the saved model as an input for testing."""
160    g, var, inp1, inp2, out = self._GetGraphForV1(device)
161    signature_def = signature_def_utils.build_signature_def(
162        inputs={
163            "myinput1": utils.build_tensor_info(inp1),
164            "myinput2": utils.build_tensor_info(inp2)
165        },
166        outputs={"myoutput": utils.build_tensor_info(out)},
167        method_name=signature_constants.PREDICT_METHOD_NAME)
168    saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
169    with self.session(graph=g, config=self._GetConfigProto()) as sess:
170      sess.run(var.initializer)
171      saved_model_builder.add_meta_graph_and_variables(
172          sess, [tag_constants.SERVING],
173          signature_def_map={_SAVED_MODEL_SIGNATURE_KEY: signature_def})
174    saved_model_builder.save()
175
176  def _ConvertGraphV1(self,
177                      output_saved_model_dir=None,
178                      need_calibration=False,
179                      max_batch_size=1,
180                      minimum_segment_size=3,
181                      is_dynamic_op=False,
182                      maximum_cached_engines=1,
183                      device=None):
184    """Helper method to convert a GraphDef or SavedModel using TF-TRT."""
185    input_saved_model_dir = None
186    if output_saved_model_dir:
187      input_saved_model_dir = self.mkdtemp()
188      self._WriteInputSavedModelForV1(input_saved_model_dir, device)
189
190    # Calibration requires dynamic_op.
191    if need_calibration:
192      is_dynamic_op = True
193
194    # For dynamic_op, the converter requires the unused max_batch_size=None.
195    if is_dynamic_op:
196      max_batch_size = None
197
198    converter = trt_convert.TrtGraphConverter(
199        input_saved_model_dir=input_saved_model_dir,
200        input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
201        input_graph_def=None
202        if input_saved_model_dir else self._GetGraphDefForV1(device),
203        nodes_denylist=None if input_saved_model_dir else ["output"],
204        max_batch_size=max_batch_size,
205        max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
206        precision_mode=(trt_convert.TrtPrecisionMode.INT8 if need_calibration
207                        else trt_convert.TrtPrecisionMode.FP32),
208        minimum_segment_size=minimum_segment_size,
209        is_dynamic_op=is_dynamic_op,
210        maximum_cached_engines=maximum_cached_engines)
211    output_graph_def = converter.convert()
212
213    if need_calibration:
214
215      class CalibrationData(object):
216
217        def __init__(self):
218          self._data = 0
219
220        def next(self):
221          self._data += 1
222          return {"input1:0": [[[self._data]]], "input2:0": [[[self._data]]]}
223
224      output_graph_def = converter.calibrate(
225          fetch_names=["output:0"],
226          num_runs=10,
227          feed_dict_fn=CalibrationData().next)
228
229    if output_saved_model_dir is not None:
230      converter.save(output_saved_model_dir=output_saved_model_dir)
231    return output_graph_def
232
233  # Remove the graph sequence number prefix from the name only if the name has
234  # a prefix TRTEngineOp_n_.
235  def _MayRemoveGraphSequenceNumber(self, name):
236    prefix = re.search(r"TRTEngineOp_\d+_", name)
237    if prefix and name.startswith(prefix.group(0)):
238      parts = name.split("_", maxsplit=2)
239      assert len(parts) == 3
240      return parts[0] + "_" + parts[2]
241    return name
242
243  # Return the unique TRTEngineOp in the given graph def.
244  def _GetUniqueTRTEngineOp(self, graph_def):
245    trt_engine_nodes = [
246        node for node in graph_def.node if node.op == "TRTEngineOp"
247    ]
248    assert len(trt_engine_nodes) == 1
249    return trt_engine_nodes[0]
250
251  def _TestTrtGraphConverter(self,
252                             device,
253                             output_saved_model_dir=None,
254                             need_calibration=False,
255                             is_dynamic_op=False):
256    """General method to test trt_convert.TrtGraphConverter()."""
257    output_graph_def = self._ConvertGraphV1(
258        output_saved_model_dir=output_saved_model_dir,
259        need_calibration=need_calibration,
260        is_dynamic_op=is_dynamic_op,
261        device=device)
262    graph_defs_to_verify = [output_graph_def]
263
264    if output_saved_model_dir:
265      saved_model_graph_def = saved_model_utils.get_meta_graph_def(
266          output_saved_model_dir, tag_constants.SERVING).graph_def
267      self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
268      graph_defs_to_verify.append(saved_model_graph_def)
269
270    for graph_def in graph_defs_to_verify:
271      node_name_to_op = {
272          self._MayRemoveGraphSequenceNumber(node.name): node.op
273          for node in graph_def.node
274      }
275      if device is not None and device.startswith("/CPU:"):
276        self.assertEqual(
277            {
278                "add": "AddV2",
279                "add/ReadVariableOp": "Const",
280                "add_1": "AddV2",
281                "add_2": "AddV2",
282                "input1": "Placeholder",
283                "input2": "Placeholder",
284                "mul": "Mul",
285                "output": "Identity"
286            }, node_name_to_op)
287      else:
288        self.assertEqual(
289            {
290                "input1": "Placeholder",
291                "input2": "Placeholder",
292                "TRTEngineOp_0": "TRTEngineOp",
293                "output": "Identity"
294            }, node_name_to_op)
295
296      if need_calibration:
297        trt_engine_nodes = [
298            node for node in graph_def.node if node.op == "TRTEngineOp"
299        ]
300        if device is not None and device.startswith("/CPU:"):
301          self.assertEmpty(trt_engine_nodes)
302          return
303
304        self.assertNotEmpty(trt_engine_nodes)
305        for node in trt_engine_nodes:
306          self.assertTrue(len(node.attr["calibration_data"].s))
307        # Run the calibrated graph.
308        # TODO(laigd): consider having some input where the answer is different.
309        with ops.Graph().as_default():
310          importer.import_graph_def(graph_def, name="")
311          with self.session(config=self._GetConfigProto()) as sess:
312            for test_data in range(10):
313              self.assertEqual((test_data + 1.0)**2 + test_data,
314                               sess.run(
315                                   "output:0",
316                                   feed_dict={
317                                       "input1:0": [[[test_data]]],
318                                       "input2:0": [[[test_data]]]
319                                   }))
320
321  @parameterized.named_parameters([
322      ("NoDeviceAssignment", None),
323      ("GPU", "/GPU:0"),
324      ("CPU", "/CPU:0"),
325  ])
326  @test_util.deprecated_graph_mode_only
327  def testTrtGraphConverter_OfflineConversion(self, device):
328    """Test case for trt_convert.TrtGraphConverter()."""
329    if not is_tensorrt_enabled():
330      return
331
332    for need_calibration in [False, True]:
333      # Use GraphDef as input.
334      self._TestTrtGraphConverter(device)
335
336      # Use SavedModel as input.
337      self._TestTrtGraphConverter(
338          device,
339          output_saved_model_dir=self.mkdtemp(),
340          need_calibration=need_calibration)
341
342  @parameterized.named_parameters([
343      ("NoDeviceAssignment", None),
344      ("GPU", "/device:GPU:0"),
345      ("CPU", "/device:CPU:0"),
346  ])
347  @test_util.deprecated_graph_mode_only
348  def testTrtGraphConverter_OnlineConversion(self, device):
349    """Test case for TF-TRT conversion using Grappler directly."""
350    if not is_tensorrt_enabled():
351      return
352
353    conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
354        precision_mode=trt_convert.TrtPrecisionMode.FP32)
355    config = self._GetConfigProto(
356        rewriter_config=trt_convert.get_tensorrt_rewriter_config(
357            conversion_params,
358            is_dynamic_op=False,
359            max_batch_size=1,
360            is_v2=False))
361
362    with ops.Graph().as_default():
363      # Online conversion requires a frozen graph, so we reuse inp1 as the var
364      # argument.
365      inp1 = array_ops.placeholder(
366          dtype=dtypes.float32, shape=[None, 1, 1], name="input1")
367      inp2 = array_ops.placeholder(
368          dtype=dtypes.float32, shape=[None, 1, 1], name="input2")
369      if device:
370        with ops.device(device):
371          TrtConvertTest._GetGraph(inp1, inp2, inp1)
372      else:
373        TrtConvertTest._GetGraph(inp1, inp2, inp1)
374      with self.session(config=config) as sess:
375        self._TestRun(sess, batch_size=1)
376
377  def _CreateConverterV2(
378      self,
379      input_saved_model_dir,
380      input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
381      max_workspace_size_bytes=10 << 20,  # Use a smaller workspace.
382      precision_mode=trt_convert.TrtPrecisionMode.FP32,
383      maximum_cached_engines=2):
384    return trt_convert.TrtGraphConverterV2(
385        input_saved_model_dir=input_saved_model_dir,
386        input_saved_model_signature_key=input_saved_model_signature_key,
387        conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
388            max_workspace_size_bytes=max_workspace_size_bytes,
389            precision_mode=precision_mode,
390            maximum_cached_engines=maximum_cached_engines))
391
392  def _CheckTrtOps(self, concrete_func, check_fn=None):
393    graph_def = concrete_func.graph.as_graph_def()
394    trt_op_names = []
395    for node in graph_def.node:
396      if node.op == "TRTEngineOp":
397        trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
398        if check_fn:
399          check_fn(node)
400    for func in graph_def.library.function:
401      for node in func.node_def:
402        if node.op == "TRTEngineOp":
403          trt_op_names.append(self._MayRemoveGraphSequenceNumber(node.name))
404          if check_fn:
405            check_fn(node)
406    self.assertEqual(1, len(trt_op_names))
407    self.assertIn("TRTEngineOp_0", trt_op_names[0])
408
409  def _RandomInput(self, shape, dtype=np.float32):
410    inp1 = np.random.random_sample(shape).astype(dtype)
411    inp2 = np.random.random_sample(shape).astype(dtype)
412    return inp1, inp2
413
414  @test_util.run_v2_only
415  def testTrtGraphConverter_DynamicConversion_v2(self):
416    """Test case for trt_convert.TrtGraphConverter()."""
417    if not is_tensorrt_enabled():
418      return
419
420    np_input1, np_input2 = self._RandomInput([4, 1, 1])
421
422    # Create a model and save it.
423    input_saved_model_dir = self.mkdtemp()
424    root = self._GetModelForV2()
425    expected_output = root.run(np_input1, np_input2)
426    save.save(root, input_saved_model_dir,
427              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
428
429    # Run TRT conversion.
430    converter = self._CreateConverterV2(input_saved_model_dir)
431    converter.convert()
432
433    # Verify the converted GraphDef and ConcreteFunction.
434    self._CheckTrtOps(converter._converted_func)  # pylint: disable=protected-access
435
436    trt_engine_name = self._GetUniqueTRTEngineOp(
437        converter._converted_graph_def).name
438
439    # Save the converted model without any TRT engine cache.
440    output_saved_model_dir = self.mkdtemp()
441    converter.save(output_saved_model_dir)
442    unexpected_asset_file = os.path.join(
443        output_saved_model_dir,
444        "assets/trt-serialized-engine." + trt_engine_name)
445    self.assertFalse(os.path.exists(unexpected_asset_file))
446
447    # Run the converted function to populate the engine cache.
448    def _InputFn():
449      yield np_input1, np_input2
450
451    converter.build(input_fn=_InputFn)
452
453    # Save the converted model again with serialized engine cache.
454    output_saved_model_dir = self.mkdtemp()
455    converter.save(output_saved_model_dir)
456    expected_asset_file = os.path.join(
457        output_saved_model_dir,
458        "assets/trt-serialized-engine." + trt_engine_name)
459    self.assertTrue(os.path.exists(expected_asset_file))
460    self.assertTrue(os.path.getsize(expected_asset_file))
461
462    del converter
463    gc.collect()  # Force GC to destroy the TRT engine cache.
464
465    # Load and verify the converted model.
466    #
467    # TODO(laigd): the name of the new input_signature of the
468    # `root_with_trt.run` function is empty string (originally was None),
469    # investigate why.
470    root_with_trt = load.load(output_saved_model_dir)
471    # TODO(laigd): `root_with_trt.run` is still using the original graph without
472    # trt. Consider changing that.
473    # self._CheckTrtOps(root_with_trt.run.get_concrete_function())
474    converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
475    self._CheckTrtOps(converted_signature)
476    output_with_trt = converted_signature(
477        inp1=ops.convert_to_tensor(np_input1),
478        inp2=ops.convert_to_tensor(np_input2))
479    # The output of running the converted signature is a dict due to
480    # compatibility reasons with V1 SavedModel signature mechanism.
481    self.assertAllClose(
482        expected_output,
483        list(output_with_trt.values())[0],
484        atol=1e-6,
485        rtol=1e-6)
486
487    del root_with_trt
488    gc.collect()  # Force GC to destroy the TRT engine cache.
489
490  @test_util.run_v2_only
491  def testTrtGraphConverter_Int8Conversion_v2(self):
492    if not is_tensorrt_enabled():
493      return
494
495    np_input1, np_input2 = self._RandomInput([4, 1, 1])
496
497    # Create a model and save it.
498    input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
499    root = self._GetModelForV2()
500    expected_output = root.run(np_input1, np_input2)
501    save.save(root, input_saved_model_dir,
502              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
503
504    # Run TRT conversion.
505    converter = self._CreateConverterV2(
506        input_saved_model_dir,
507        precision_mode=trt_convert.TrtPrecisionMode.INT8,
508        maximum_cached_engines=3)
509
510    # Convert and perform INT8 calibration
511    def _CalibrationInputFn():
512      yield np_input1, np_input2
513
514    converter.convert(calibration_input_fn=_CalibrationInputFn)
515
516    trt_engine_name = self._GetUniqueTRTEngineOp(
517        converter._converted_graph_def).name
518
519    def _CheckFn(node):
520      self.assertTrue(len(node.attr["calibration_data"].s), node.name)
521
522    # Verify the converted GraphDef.
523    self._CheckTrtOps(converter._converted_func, _CheckFn)  # pylint: disable=protected-access
524
525    # Build another engine with different batch size.
526    def _InputFn():
527      yield self._RandomInput([5, 1, 1])
528
529    converter.build(input_fn=_InputFn)
530
531    # Save the converted model.
532    # TODO(laigd): check that it should contain two engines.
533    output_saved_model_dir = self.mkdtemp()
534    converter.save(output_saved_model_dir)
535    expected_asset_file = os.path.join(
536        output_saved_model_dir,
537        "assets/trt-serialized-engine." + trt_engine_name)
538    self.assertTrue(os.path.exists(expected_asset_file))
539    self.assertTrue(os.path.getsize(expected_asset_file))
540
541    del converter
542    gc.collect()  # Force GC to destroy the TRT engine cache.
543
544    # Load and verify the converted model.
545    root_with_trt = load.load(output_saved_model_dir)
546    converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
547    self._CheckTrtOps(converted_signature, _CheckFn)
548    output_with_trt = converted_signature(
549        inp1=ops.convert_to_tensor(np_input1),
550        inp2=ops.convert_to_tensor(np_input2))
551    self.assertEqual(1, len(output_with_trt))
552    # The output of running the converted signature is a dict due to
553    # compatibility reasons with V1 SavedModel signature mechanism.
554    self.assertAllClose(
555        expected_output,
556        list(output_with_trt.values())[0],
557        atol=1e-6,
558        rtol=1e-6)
559
560    # Run with an input of different batch size. It should build a new engine
561    # using calibration table.
562    # TODO(laigd): check that it should contain three engines.
563    np_input1, np_input2 = self._RandomInput([6, 1, 1])
564    converted_signature(
565        inp1=ops.convert_to_tensor(np_input1),
566        inp2=ops.convert_to_tensor(np_input2))
567
568    del root_with_trt
569    gc.collect()  # Force GC to destroy the TRT engine cache.
570
571  @test_util.run_v2_only
572  def testTrtGraphConverter_DestroyEngineCache(self):
573    """Test case for trt_convert.TrtGraphConverter()."""
574    if not is_tensorrt_enabled():
575      return
576
577    np_input1, np_input2 = self._RandomInput([4, 1, 1])
578
579    # Create a model and save it.
580    input_saved_model_dir = self.mkdtemp()
581    root = self._GetModelForV2()
582    save.save(root, input_saved_model_dir,
583              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
584
585    # Run TRT conversion.
586    converter = self._CreateConverterV2(input_saved_model_dir)
587    converter.convert()
588
589    trt_engine_name = self._GetUniqueTRTEngineOp(
590        converter._converted_graph_def).name
591
592    def _InputFn():
593      yield np_input1, np_input2
594
595    converter.build(input_fn=_InputFn)  # Populate the TRT engine cache.
596    output_saved_model_dir = self.mkdtemp()
597    converter.save(output_saved_model_dir)
598
599    def _DestroyCache():
600      with ops.device("GPU:0"):
601        handle = gen_trt_ops.create_trt_resource_handle(
602            resource_name=trt_engine_name)
603        gen_resource_variable_ops.destroy_resource_op(
604            handle, ignore_lookup_error=False)
605
606    with self.assertRaisesRegex(errors.NotFoundError,
607                                r"Resource .* does not exist."):
608      _DestroyCache()
609
610    # Load the converted model and make sure the engine cache is populated by
611    # default.
612    root = load.load(output_saved_model_dir)
613    _DestroyCache()
614    with self.assertRaisesRegex(errors.NotFoundError,
615                                r"Resource .* does not exist."):
616      _DestroyCache()
617
618    # Load the converted model again and make sure the engine cache is destroyed
619    # when the model goes out of scope.
620    root = load.load(output_saved_model_dir)
621    del root
622    gc.collect()  # Force GC to destroy the TRT engine cache.
623    with self.assertRaisesRegex(errors.NotFoundError,
624                                r"Resource .* does not exist."):
625      _DestroyCache()
626
627  def _CompareSavedModel(self, model_class):
628    signature_key = "serving_default"
629
630    def _GetModelPaths(model_class):
631      input_saved_model_dir = self.mkdtemp()
632      root = model_class()
633      save.save(root, input_saved_model_dir)
634
635      converter = self._CreateConverterV2(
636          input_saved_model_dir, input_saved_model_signature_key=signature_key)
637      converter.convert()
638      output_saved_model_dir = self.mkdtemp()
639      converter.save(output_saved_model_dir)
640      return input_saved_model_dir, output_saved_model_dir
641
642    def _GetSignatureDef(export_dir):
643      saved_model_proto = loader_impl.parse_saved_model(export_dir)
644      self.assertEqual(1, len(saved_model_proto.meta_graphs))
645      meta_graph = saved_model_proto.meta_graphs[0]
646      self.assertIn(signature_key, meta_graph.signature_def)
647      return meta_graph.signature_def[signature_key]
648
649    def _CompareSignatureDef(original_def, converted_def, is_input):
650      endpoints = original_def.inputs if is_input else original_def.outputs
651      converted_endpoints = (
652          converted_def.inputs if is_input else converted_def.outputs)
653      self.assertEqual(set(endpoints.keys()), set(converted_endpoints.keys()))
654      for key in endpoints:
655        original_input = endpoints[key]
656        converted_input = converted_endpoints[key]
657        self.assertEqual(original_input.name, converted_input.name)
658        self.assertEqual(original_input.dtype, converted_input.dtype)
659        self.assertEqual(
660            tensor_shape.TensorShape(original_input.tensor_shape).as_list(),
661            tensor_shape.TensorShape(converted_input.tensor_shape).as_list())
662
663    def _GetStructuredOutputs(export_dir):
664      root = load.load(export_dir)
665      return root.signatures[signature_key].structured_outputs
666
667    saved_model_path, converted_saved_model_path = _GetModelPaths(model_class)
668    original_def = _GetSignatureDef(saved_model_path)
669    converted_def = _GetSignatureDef(converted_saved_model_path)
670    self.assertEqual(original_def.method_name, converted_def.method_name)
671    _CompareSignatureDef(original_def, converted_def, True)
672    _CompareSignatureDef(original_def, converted_def, False)
673
674    self.assertEqual(
675        _GetStructuredOutputs(saved_model_path),
676        _GetStructuredOutputs(converted_saved_model_path))
677
678  @test_util.run_v2_only
679  def testRetainSignatureInfo_NoInputs(self):
680    if not is_tensorrt_enabled():
681      return
682
683    class _Model(tracking.AutoTrackable):
684
685      @def_function.function(input_signature=[])
686      def run(self):
687        return array_ops.constant(1.0)
688
689    self._CompareSavedModel(_Model)
690
691  @test_util.run_v2_only
692  def testRetainSignatureInfo_OneInput(self):
693    if not is_tensorrt_enabled():
694      return
695
696    class _Model(tracking.AutoTrackable):
697
698      @def_function.function(input_signature=[
699          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32)
700      ])
701      def run(self, inp):
702        return inp + inp * inp
703
704    self._CompareSavedModel(_Model)
705
706  @test_util.run_v2_only
707  def testRetainSignatureInfo_TwoInputs(self):
708    if not is_tensorrt_enabled():
709      return
710
711    class _Model(tracking.AutoTrackable):
712
713      @def_function.function(input_signature=[
714          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32),
715          tensor_spec.TensorSpec(shape=[None, 2], dtype=dtypes.float32)
716      ])
717      def run(self, inp1, inp2):
718        return inp1 + inp2 * inp2
719
720    self._CompareSavedModel(_Model)
721
722  @test_util.run_v2_only
723  def testRetainSignatureInfo_OneOutputSignatureKey(self):
724    if not is_tensorrt_enabled():
725      return
726
727    class _Model(tracking.AutoTrackable):
728
729      @def_function.function(input_signature=[])
730      def run(self):
731        return {"my_output": array_ops.constant(1.0)}
732
733    self._CompareSavedModel(_Model)
734
735  @test_util.run_v2_only
736  def testRetainSignatureInfo_TwoOutputSignatureKeys(self):
737    if not is_tensorrt_enabled():
738      return
739
740    class _Model(tracking.AutoTrackable):
741
742      @def_function.function(input_signature=[
743          tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.float32)
744      ])
745      def run(self, inp):
746        # Here the keys are not ordered lexicographically on purpose.
747        return {
748            "output_b": array_ops.constant(1.0),
749            "output_a": inp + inp * inp
750        }
751
752    self._CompareSavedModel(_Model)
753
754  def _TestRun(self, sess, batch_size):
755    result = sess.run(
756        "output:0",
757        feed_dict={
758            "input1:0": [[[1.0]]] * batch_size,
759            "input2:0": [[[1.0]]] * batch_size
760        })
761    self.assertAllEqual([[[5.0]]] * batch_size, result)
762
763  @test_util.deprecated_graph_mode_only
764  def testTrtGraphConverter_MinimumSegmentSize(self):
765    if not is_tensorrt_enabled():
766      return
767    output_graph_def = self._ConvertGraphV1(minimum_segment_size=7)
768    node_name_to_op = {node.name: node.op for node in output_graph_def.node}
769    self.assertEqual(
770        {
771            "add/ReadVariableOp": "Const",
772            "input1": "Placeholder",
773            "input2": "Placeholder",
774            "add": "AddV2",
775            "mul": "Mul",
776            "add_1": "AddV2",
777            "add_2": "AddV2",
778            "output": "Identity"
779        }, node_name_to_op)
780
781  @test_util.deprecated_graph_mode_only
782  def testTrtGraphConverter_DynamicOp(self):
783    if not is_tensorrt_enabled():
784      return
785
786    output_saved_model_dir = self.mkdtemp()
787    output_graph_def = self._ConvertGraphV1(
788        output_saved_model_dir=output_saved_model_dir,
789        is_dynamic_op=True,
790        maximum_cached_engines=2)
791
792    # Test the output GraphDef.
793    with ops.Graph().as_default():
794      importer.import_graph_def(output_graph_def, name="")
795      with self.session(config=self._GetConfigProto()) as sess:
796        # Run with batch size 1, a new engine is created and cached.
797        self._TestRun(sess, 1)
798        # Run with batch size 2, a new engine is created and cached.
799        self._TestRun(sess, 2)
800        # Run with batch size 3, since the number of cached engines has reached
801        # the max, it should evict an old engine and create a new one.
802        self._TestRun(sess, 3)
803
804    # Test the output SavedModel
805    with ops.Graph().as_default():
806      with self.session(config=self._GetConfigProto()) as sess:
807        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
808        # Run with batch size 1, a new engine is created and cached.
809        self._TestRun(sess, 1)
810        # Run with batch size 2, a new engine is created and cached.
811        self._TestRun(sess, 2)
812        # Run with batch size 3, since the number of cached engines has reached
813        # the max, it should evict an old engine and create a new one.
814        self._TestRun(sess, 3)
815
816  @test_util.deprecated_graph_mode_only
817  def testTrtGraphConverter_StaticOp(self):
818    if not is_tensorrt_enabled():
819      return
820
821    output_saved_model_dir = self.mkdtemp()
822    output_graph_def = self._ConvertGraphV1(
823        output_saved_model_dir=output_saved_model_dir, maximum_cached_engines=1)
824
825    # Test the output GraphDef.
826    with ops.Graph().as_default():
827      importer.import_graph_def(output_graph_def, name="")
828      with self.session(config=self._GetConfigProto()) as sess:
829        # Run with batch size 1, the default engine embedded in the graphdef
830        # will be used.
831        self._TestRun(sess, 1)
832        # Run with batch size 2, which exceed the max_batch_size, it should try
833        # to fall back to TF function.
834        self._TestRun(sess, 2)
835
836    # Test the output SavedModel
837    with ops.Graph().as_default():
838      with self.session(config=self._GetConfigProto()) as sess:
839        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
840        # Run with batch size 1, the default engine embedded in the graphdef
841        # will be used.
842        self._TestRun(sess, 1)
843        # Run with batch size 2, which exceed the max_batch_size, it should try
844        # to fall back to TF function.
845        self._TestRun(sess, 2)
846
847  @test_util.run_v2_only
848  def testTrtGraphConverter_AllowEngineNativeSegmentExecution(self):
849    if not is_tensorrt_enabled():
850      return
851
852    np_input1, np_input2 = self._RandomInput([4, 1, 1])
853
854    # Create a model and save it.
855    input_saved_model_dir = self.mkdtemp()
856    root = self._GetModelForV2()
857    save.save(root, input_saved_model_dir,
858              {_SAVED_MODEL_SIGNATURE_KEY: root.run})
859
860    def _InputFn():
861      yield np_input1, np_input2
862
863    # Run TRT conversion and request an unreasonably large workspace.
864    converter = self._CreateConverterV2(
865        input_saved_model_dir, max_workspace_size_bytes=10 << 40)
866    converter.convert()
867
868    os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
869    with self.assertRaisesRegex(
870        errors.AbortedError,
871        r"User disallowed engine native segment execution"):
872      converter.build(input_fn=_InputFn)
873
874    os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
875    converter.build(input_fn=_InputFn)
876
877  @test_util.run_v2_only
878  def testBackwardCompatibility(self):
879    """Load and execute a model that was saved in TF2.0."""
880    if not is_tensorrt_enabled():
881      return
882
883    model_dir = test.test_src_dir_path(
884        "python/compiler/tensorrt/test/testdata/tftrt_2.0_saved_model")
885    saved_model_loaded = load.load(model_dir, tags=[tag_constants.SERVING])
886    graph_func = saved_model_loaded.signatures[
887        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
888
889    np_input1 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
890    np_input2 = ops.convert_to_tensor(np.ones([4, 1, 1]).astype(np.float32))
891    output = graph_func(input1=np_input1, input2=np_input2)["output_0"]
892
893    self.assertEqual(output.shape, (4, 1, 1))
894    self.assertAllClose(
895        np.asarray([5.0, 5.0, 5.0, 5.0]).reshape([4, 1, 1]), output)
896
897
898if __name__ == "__main__":
899  test.main()
900