• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to test TF-TensorRT integration."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import errno
23import gc
24import itertools
25import os
26import re
27import shutil
28import tempfile
29import warnings
30
31import numpy as np
32import six
33
34from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import get_linked_tensorrt_version
35from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
36from tensorflow.core.framework import graph_pb2
37from tensorflow.core.protobuf import config_pb2
38from tensorflow.core.protobuf import rewriter_config_pb2
39from tensorflow.python.compiler.tensorrt import trt_convert
40from tensorflow.python.compiler.tensorrt import utils as trt_utils
41from tensorflow.python.eager import def_function
42from tensorflow.python.framework import graph_io
43from tensorflow.python.framework import ops
44from tensorflow.python.framework import tensor_spec
45from tensorflow.python.framework import test_util
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.profiler import trace
50from tensorflow.python.saved_model import builder
51from tensorflow.python.saved_model import load
52from tensorflow.python.saved_model import loader
53from tensorflow.python.saved_model import save
54from tensorflow.python.saved_model import signature_constants
55from tensorflow.python.saved_model import signature_def_utils
56from tensorflow.python.saved_model import tag_constants
57from tensorflow.python.saved_model import utils
58from tensorflow.python.tools import saved_model_utils
59from tensorflow.python.training.tracking import tracking
60from tensorflow.python.util import nest
61
62TfTrtIntegrationTestParams = collections.namedtuple(
63    "TfTrtIntegrationTestParams",
64    [
65        # A function that creates the TF graph for testing.
66        "graph_fn",
67        # A list of specifications for input tensors.
68        "input_specs",
69        # A list of specifications for output tensors.
70        "output_specs",
71        # A list of list of input shapes. Each shape must match the
72        # corresponding element in `input_specs`.
73        "input_dims",
74        # A list of list of expected output shapes. Each shape must match the
75        # corresponding element in `output_specs`.
76        "expected_output_dims"
77    ])
78
79RunParams = collections.namedtuple(
80    "RunParams",
81    [
82        # Whether to run the conversion online with RewriterConfig, or offline
83        # with TrtGraphConverter.
84        "convert_online",
85        "precision_mode",
86        "dynamic_engine",
87        "use_calibration",
88        "test_name",
89        # Is this test for TF 2.0?
90        "is_v2",
91    ])
92
93FP32 = "FP32"
94FP16 = "FP16"
95INT8 = "INT8"
96PRECISION_MODES = [FP32, FP16, INT8]
97
98
99def IsQuantizationMode(mode):
100  return mode == "INT8"
101
102
103def IsQuantizationWithCalibration(params):
104  return IsQuantizationMode(params.precision_mode) and params.use_calibration
105
106
107def IsTensorRTVersionGreaterEqual(major, minor=0, patch=0):
108  ver = get_linked_tensorrt_version()
109  return ver[0] > major or (ver[0] == major and ver[1] > minor) or (
110      ver[0] == major and ver[1] == minor and ver[2] >= patch)
111
112
113class GraphState(object):
114  ORIGINAL = 0
115  CALIBRATE = 1
116  INFERENCE = 2
117
118
119class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
120  """Class to test Tensorflow-TensorRT integration."""
121
122  @property
123  def trt_incompatible_op(self):
124    return math_ops.erf
125
126  @property
127  def precision_modes(self):
128    return ["FP32", "FP16", "INT8"]
129
130  # str is bytes in py2, but unicode in py3.
131  def _ToUnicode(self, s):
132    if six.PY2:
133      if isinstance(s, unicode):
134        return s
135      return s.decode("utf-8")
136    else:
137      if isinstance(s, str):
138        return s
139      return s.decode("utf-8")
140
141  def _ToBytes(self, s):
142    if six.PY2:
143      if isinstance(s, unicode):
144        return s.encode("utf-8")
145      return s
146    else:
147      if isinstance(s, str):
148        return s.encode("utf-8")
149      return s
150
151  def _ToString(self, s):
152    if six.PY2:
153      if isinstance(s, unicode):
154        return s.encode("utf-8")
155      return s
156    else:
157      if isinstance(s, str):
158        return s
159      return s.decode("utf-8")
160
161  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
162    super(TfTrtIntegrationTestBase, self).__init__(methodName)
163    self._trt_test_params = None
164    self._disable_non_trt_optimizers = False
165    self._use_implicit_batch = True
166
167  def setUp(self):
168    """Setup method."""
169    super(TfTrtIntegrationTestBase, self).setUp()
170    warnings.simplefilter("always")
171
172  def _GetTensorSpec(self, shape, mask, dtype, name):
173    # Set dimension i to None if mask[i] == False
174    assert len(shape) == len(mask)
175    new_shape = [s if m else None for s, m in zip(shape, mask)]
176    return tensor_spec.TensorSpec(new_shape, dtype, name)
177
178  def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes):
179    """Build test parameters.
180
181    The input_shapes and output_shapes arguments are known (static) shapes that
182    can be used to generate test data. To define the model, we also specify
183    corresponding input/output TensoSpecs. These are defined using the shape
184    arguments. For each input tensor we define:
185
186    input_spec = [None] + input_shape[1:]
187
188    and similarly for output shapes. This means that we leave the first (batch)
189    dimension unknown, the rest is just copied from the shapes arg.
190
191    Args:
192      graph_fn: The function to build the graph.
193      dtype: The element type.
194      input_shapes: The input shapes.
195      output_shapes: The output shapes.
196
197    Returns:
198      The test parameters.
199    """
200
201    input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes]
202    output_mask = [
203        [False] + [True] * (len(shape) - 1) for shape in output_shapes
204    ]
205
206    return self.BuildParamsWithMask(graph_fn, dtype, input_shapes,
207                                    output_shapes, input_mask, output_mask, [],
208                                    [])
209
210  def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes,
211                          input_mask, output_mask, extra_inputs, extra_outputs):
212    """Build test parameters with static or dynamic input shapes.
213
214    To define dynamic shapes give a boolean mask that describes which
215    dimensions to treat as known. The values in input_mask are interpreted the
216    following way:
217    - True: known dim (use the corresponding value from input_shapes)
218    - False: unknown dim (replace the corresponding value from input_shapes
219             with None)
220    For example, to define the first two dimension with unknown size use
221    input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]].
222
223    Args:
224      graph_fn: The function to build the graph.
225      dtype: The element type.
226      input_shapes: The input shapes.
227      output_shapes: The output shapes.
228      input_mask: The input shape masks.
229      output_mask: the output shape masks.
230      extra_inputs: list of additional input shapes
231      extra_outputs: list of additional outputs shapes
232
233    Returns:
234      The test parameters.
235    """
236
237    def _ValidateShapes(shapes):
238      # Make sure all the shapes are fully specified.
239      for shape in shapes:
240        assert all(shape)
241
242    _ValidateShapes(input_shapes)
243    _ValidateShapes(output_shapes)
244
245    assert len(input_mask) == len(input_shapes)
246    assert len(output_mask) == len(output_shapes)
247    for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs):
248      assert len(input_shapes) == len(extra_in_shape)
249      assert len(output_shapes) == len(extra_out_shape)
250
251    return TfTrtIntegrationTestParams(
252        graph_fn=graph_fn,
253        input_specs=[
254            self._GetTensorSpec(shape, mask, dtype, "input_%d" % i)
255            for i, (shape, mask) in enumerate(zip(input_shapes, input_mask))
256        ],
257        output_specs=[
258            self._GetTensorSpec(shape, mask, dtype, "output_%d" % i)
259            for i, (shape, mask) in enumerate(zip(output_shapes, output_mask))
260        ],
261        input_dims=[input_shapes] + extra_inputs,
262        expected_output_dims=[output_shapes] + extra_outputs)
263
264  def DisableNonTrtOptimizers(self):
265    self._disable_non_trt_optimizers = True
266
267  def DisableImplicitBatchMode(self):
268    self._use_implicit_batch = False
269
270  def GetParams(self):
271    """Returns a TfTrtIntegrationTestParams for the test."""
272    raise NotImplementedError()
273
274  def GetConversionParams(self, run_params):
275    """Returns a TrtConversionParams for test."""
276    conversion_params = trt_convert.TrtConversionParams(
277        # We use the minimum of all the batch sizes, so when multiple different
278        # input shapes are provided it'll always create new engines in the
279        # cache, and we can therefore test the cache behavior.
280        max_workspace_size_bytes=1 << 25,
281        precision_mode=run_params.precision_mode,
282        minimum_segment_size=2,
283        maximum_cached_engines=1,
284        use_calibration=run_params.use_calibration)
285    return conversion_params
286
287  def GetMaxBatchSize(self, run_params):
288    """Returns the max_batch_size that the converter should use for tests."""
289    if run_params.dynamic_engine:
290      return None
291    batch_list = []
292    for dims_list in self._GetParamsCached().input_dims:
293      assert dims_list
294      # Each list of shapes should have same batch size.
295      input_batches = [dims[0] for dims in dims_list]
296      assert max(input_batches) == min(input_batches)
297      batch_list.append(input_batches[0])
298    return max(batch_list)
299
300  def ShouldRunTest(self, run_params):
301    """Whether to run the test."""
302    # Ensure use_calibration=True in case of INT8 precision
303    return (run_params.use_calibration or not IsQuantizationMode(
304        run_params.precision_mode)), "test either calibration or non-INT8"
305
306  def ExpectedEnginesToBuild(self, run_params):
307    """Returns the expected engines to build, implemented by subclass."""
308    raise NotImplementedError()
309
310  def ExpectedMaxBatchSizes(self, run_params):
311    """Returns the expected maximum batch sizes of the build engines."""
312    return None
313
314  def ExpectedAbsoluteTolerance(self, run_params):
315    """The absolute tolerance to compare floating point results."""
316    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
317
318  def ExpectedRelativeTolerance(self, run_params):
319    """The relative tolerance to compare floating point results."""
320    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
321
322  def _GetParamsCached(self):
323    if self._trt_test_params is None:
324      self._trt_test_params = self.GetParams()
325    return self._trt_test_params
326
327  def _GetGPUOptions(self):
328    gpu_options = config_pb2.GPUOptions()
329    gpu_options.allow_growth = True
330    return gpu_options
331
332  def _GetConfigProto(self, run_params, graph_state):
333    """Get config proto based on specific settings."""
334    conversion_params = self.GetConversionParams(run_params)
335    max_batch_size = self.GetMaxBatchSize(run_params)
336
337    if graph_state == GraphState.INFERENCE and run_params.convert_online:
338      rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
339          conversion_params,
340          is_dynamic_op=run_params.dynamic_engine,
341          max_batch_size=max_batch_size,
342          disable_non_trt_optimizers=self._disable_non_trt_optimizers)
343    else:
344      rewriter_cfg = rewriter_config_pb2.RewriterConfig()
345      if self._disable_non_trt_optimizers:
346        trt_utils.disable_non_trt_optimizers_in_rewriter_config(rewriter_cfg)
347
348    config = config_pb2.ConfigProto(
349        gpu_options=self._GetGPUOptions(),
350        graph_options=config_pb2.GraphOptions(rewrite_options=rewriter_cfg))
351    return config
352
353  def _GetFeedNames(self):
354    params = self._GetParamsCached()
355    # Construct the feeds tensor names by appending :0 to the node names.
356    return [spec.name + ":0" for spec in params.input_specs]
357
358  def _GetFetchNames(self):
359    params = self._GetParamsCached()
360    # Construct the fetches tensor names by appending :0 to the node names.
361    return [spec.name + ":0" for spec in params.output_specs]
362
363  def _GetFeedDict(self, inputs_data):
364    return {name: data for name, data in zip(self._GetFeedNames(), inputs_data)}
365
366  def _RunGraphV1(self, saved_model_dir, inputs_data, config, num_runs=2):
367    """Run given graphdef multiple times using TF 1.x runtime."""
368    params = self._GetParamsCached()
369    fetches = self._GetFetchNames()
370    g = ops.Graph()
371    with g.as_default():
372      with self.session(graph=g, config=config, use_gpu=True) as sess:
373        loader.load(sess, [tag_constants.SERVING], saved_model_dir)
374        vals = []
375        # Run for each input(s) shape
376        for expected_shapes, current_input_data in zip(
377            params.expected_output_dims, inputs_data):
378          val = None
379          for _ in range(num_runs):
380            new_val = sess.run(fetches, self._GetFeedDict(current_input_data))
381            self.assertEqual(len(expected_shapes), len(new_val))
382            for expected_shape, actual_val in zip(expected_shapes, new_val):
383              self.assertEqual(list(expected_shape), list(actual_val.shape))
384            if val is not None:
385              # Some ops may have nondeterministic output. E.g. Conv2D may use
386              # winograd algorithm. So we set atol/rtol be larger than 1.e-06.
387              self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05)
388            val = new_val
389          vals.append(val)
390        return vals
391
392  def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2):
393    """Run given graphdef multiple times using TF 2.0 runtime."""
394    params = self._GetParamsCached()
395    root = load.load(saved_model_dir)
396    func = root.signatures[
397        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
398    results = []
399    for expected_shapes, current_input_data in zip(params.expected_output_dims,
400                                                   inputs_data):
401      val = None
402      for _ in range(num_runs):
403        feed_dict = {
404            params.input_specs[i].name: current_input_data[i]
405            for i in range(len(params.input_specs))
406        }
407        new_val = func(**feed_dict)
408        assert isinstance(new_val, dict)
409        # The key of the output map is always like output_i.
410        new_val = [new_val[key] for key in sorted(new_val)]
411        # Each element is an eager Tensor, and accessing individual elements is
412        # very expensive, so we convert them to a numpy array first.
413        new_val = [v.numpy() for v in new_val]
414        self.assertEqual(len(expected_shapes), len(new_val))
415        for expected_shape, actual_val in zip(expected_shapes, new_val):
416          self.assertEqual(list(expected_shape), list(actual_val.shape))
417        if val is not None:
418          # Some ops may have nondeterministic output. E.g. Conv2D may use
419          # winograd algorithm. So we set atol/rtol be larger than 1.e-06.
420          self.assertAllClose(val, new_val, atol=1.e-05, rtol=1.e-05)
421        val = new_val
422      results.append(val)
423
424    return results
425
426  def _RunGraph(self,
427                run_params,
428                saved_model_dir,
429                inputs_data,
430                graph_state,
431                num_runs=2):
432    params = self._GetParamsCached()
433    for data in inputs_data:
434      assert len(params.input_specs) == len(data)
435
436    if run_params.is_v2:
437      results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state,
438                                 num_runs)
439      gc.collect()  # Force GC to destroy the TRT engine cache.
440      return results
441
442    # The default config for tf.session is None. Create a config with
443    # TensorRTOptimizer enabled to support convert_online for inference.
444    config = None
445    # TODO(b/170220818): use the default session config to run inferenence
446    #   graphs for the offline conversion case after fixing the bug.
447    if graph_state == GraphState.INFERENCE:
448      config = self._GetConfigProto(run_params, GraphState.INFERENCE)
449    return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs)
450
451  def _CreateConverter(self, run_params, saved_model_dir, conversion_params):
452    """Returns a TrtGraphConverter."""
453    if run_params.is_v2:
454      converter_v2 = trt_convert.TrtGraphConverterV2(
455          input_saved_model_dir=saved_model_dir,
456          conversion_params=conversion_params)
457      if self._disable_non_trt_optimizers:
458        converter_v2._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
459      if not self._use_implicit_batch:
460        converter_v2._test_only_use_implicit_batch = False  # pylint: disable=protected-access
461      return converter_v2
462
463    converter_v1 = trt_convert.TrtGraphConverter(
464        input_saved_model_dir=saved_model_dir,
465        max_batch_size=self.GetMaxBatchSize(run_params),
466        max_workspace_size_bytes=conversion_params.max_workspace_size_bytes,
467        precision_mode=conversion_params.precision_mode,
468        minimum_segment_size=conversion_params.minimum_segment_size,
469        is_dynamic_op=run_params.dynamic_engine,
470        maximum_cached_engines=conversion_params.maximum_cached_engines,
471        use_calibration=conversion_params.use_calibration)
472    if self._disable_non_trt_optimizers:
473      converter_v1._test_only_disable_non_trt_optimizers = True  # pylint: disable=protected-access
474    return converter_v1
475
476  def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
477    """Return trt converted graphdef in INT8 mode."""
478    conversion_params = self.GetConversionParams(run_params)
479    logging.info(conversion_params)
480    assert conversion_params.precision_mode == "INT8"
481    assert run_params.dynamic_engine
482    assert conversion_params.maximum_cached_engines == 1
483    assert conversion_params.use_calibration
484
485    # We only support calibrating single engine.
486    # TODO(aaroey): fix this.
487    assert len(inputs_data) == 1
488
489    converter = self._CreateConverter(run_params, saved_model_dir,
490                                      conversion_params)
491    int8_gdef = converter.convert()
492    self._VerifyGraphDef(run_params, saved_model_dir, int8_gdef,
493                         GraphState.CALIBRATE)
494
495    converter.calibrate(
496        fetch_names=self._GetFetchNames(),
497        num_runs=5,
498        feed_dict_fn=lambda: self._GetFeedDict(inputs_data[0]))
499    trt_saved_model_dir = self._GetSavedModelDir(run_params,
500                                                 GraphState.CALIBRATE)
501    converter.save(trt_saved_model_dir)
502    return trt_saved_model_dir
503
504  def _GetInferGraph(self, run_params, saved_model_dir):
505    """Return trt converted graphdef."""
506    conversion_params = self.GetConversionParams(run_params)
507    logging.info(conversion_params)
508
509    converter = self._CreateConverter(run_params, saved_model_dir,
510                                      conversion_params)
511    converter.convert()
512
513    if not self._use_implicit_batch:
514      logging.info("Using build mode")
515
516      def _BuildInputFn():
517        for shapes in self._GetParamsCached().input_dims:
518          yield [np.zeros(x).astype(np.float32) for x in shapes]
519
520      converter.build(input_fn=_BuildInputFn)
521
522    trt_saved_model_dir = self._GetSavedModelDir(run_params,
523                                                 GraphState.INFERENCE)
524    converter.save(trt_saved_model_dir)
525    return trt_saved_model_dir
526
527  def _GetGraphStateLabel(self, graph_state):
528    if graph_state == GraphState.ORIGINAL:
529      return "Original"
530    elif graph_state == GraphState.CALIBRATE:
531      return "CalibEngine"
532    elif graph_state == GraphState.INFERENCE:
533      return "InferEngine"
534    else:
535      return "UnknownState"
536
537  def _WriteGraph(self, run_params, gdef, graph_state):
538    temp_dir = os.getenv("TRT_TEST_TMPDIR")
539    if not temp_dir:
540      return
541
542    graph_name = (
543        self.__class__.__name__ + "_" + run_params.test_name + "_" +
544        self._GetGraphStateLabel(graph_state) + ".pbtxt")
545    logging.info("Writing graph to %s/%s", temp_dir, graph_name)
546    graph_io.write_graph(gdef, temp_dir, graph_name)
547
548  # Removes the prefix(s) of function name(s).
549  # The input value can be a string or a sequence of string.
550  def _Canonicalize(self, value):
551    if isinstance(value, str):
552      return self._ToString(value.split("/")[-1])
553    elif isinstance(value, collections.abc.Iterable):
554      return set(self._Canonicalize(nm) for nm in value)
555    else:
556      raise TypeError(
557          "'_Canonicalize' can only be used on strings or sequence of strings!")
558
559  # Removes the graph sequence number prefix from the name(s) only if the
560  # name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts
561  # such a prefix exists.
562  # The input value can be a string or a sequence of string.
563  def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix):
564    if isinstance(value, str):
565      match = re.search(r"TRTEngineOp_\d+_", value)
566      has_prefix = match and value.startswith(match.group(0))
567      assert (not expecting_prefix) or has_prefix
568      if has_prefix:
569        parts = value.split("_", maxsplit=2)
570        assert len(parts) == 3
571        return parts[0] + "_" + parts[2]
572      return value
573    elif isinstance(value, collections.abc.Iterable):
574      return set(
575          self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix)
576          for nm in value)
577    else:
578      raise TypeError(
579          "'_RemoveGraphSequenceNumberImpl' can only be used on strings "
580          "or sequence of strings!")
581
582  def _RemoveGraphSequenceNumber(self, name):
583    return self._RemoveGraphSequenceNumberImpl(name, True)
584
585  def _MayRemoveGraphSequenceNumber(self, name):
586    return self._RemoveGraphSequenceNumberImpl(name, False)
587
588  def _VerifyConnections(self, expected_engines, original_gdef, converted_gdef):
589    old_to_new_node_map = {
590        self._ToString(node.name): self._ToString(node.name)
591        for node in original_gdef.node
592    }
593    for engine_name, node_names in expected_engines.items():
594      for node_name in node_names:
595        old_to_new_node_map[node_name] = engine_name
596    name_to_node_map = {
597        self._ToString(node.name): node for node in original_gdef.node
598    }
599
600    def _InputName(inp):
601      inp = self._ToString(inp)
602      prefix = ""
603      if inp[0] == "^":
604        prefix = "^"
605        inp = inp[1:]
606      parts = inp.split(":")
607      if len(parts) > 1 and parts[-1].isdigit():
608        inp = inp[:-len(parts[-1]) - 1]
609      return (prefix, inp)
610
611    # Compute the expected mapping from each node to its input nodes.
612    expected_input_map = {}
613    removed_const_nodes = set([
614        self._ToString(node.name)
615        for node in original_gdef.node
616        if node.op == "Const"
617    ])
618    for node in original_gdef.node:
619      name_str = self._ToString(node.name)
620      target_node_name = old_to_new_node_map[name_str]
621      is_engine_op = (target_node_name != name_str)
622      if target_node_name not in expected_input_map:
623        expected_input_map[target_node_name] = set()
624      input_set = expected_input_map[target_node_name]
625      for inp in node.input:
626        (prefix, inp_name) = _InputName(inp)
627        mapped_input = old_to_new_node_map[inp_name]
628        # Add the input only if it's outside the segment (note that it could be
629        # in a different engine).
630        if not is_engine_op or (mapped_input != target_node_name and
631                                name_to_node_map[inp_name].op != "Const"):
632          input_set.add(prefix + mapped_input)
633          if mapped_input in removed_const_nodes:
634            removed_const_nodes.remove(mapped_input)
635    # Remove const nodes that have no outputs.
636    expected_input_map = {
637        k: v
638        for k, v in expected_input_map.items()
639        if k not in removed_const_nodes
640    }
641
642    # Compute the actual mapping from each node to its input nodes. If a cast
643    # op doesn't exist in the original graph, we replace the use of the cast op
644    # with the input of the op. This allows the verification to handle the case
645    # where the TF-TRT bridge splits a cast op into a chain of two cast ops.
646    new_cast_op_name_to_node_map = {
647        node.name: node
648        for node in converted_gdef.node
649        if (node.name not in old_to_new_node_map and node.op == "Cast")
650    }
651    actual_input_map = {}
652    for node in converted_gdef.node:
653      name_str = node.name
654      # Only nodes from the original graph or TRTEngineOp nodes are added as
655      # keys to the map.
656      if node.op == "TRTEngineOp":
657        name_str = self._RemoveGraphSequenceNumber(name_str)
658      elif name_str not in old_to_new_node_map:
659        continue
660      actual_input_map[name_str] = set()
661      input_set = actual_input_map[name_str]
662      for inp in node.input:
663        (prefix, node_name) = _InputName(inp)
664        node_name = self._MayRemoveGraphSequenceNumber(node_name)
665        if node_name in new_cast_op_name_to_node_map:
666          (prefix, node_name) = _InputName(
667              new_cast_op_name_to_node_map[node_name].input[0])
668        input_set.add(prefix + node_name)
669
670    self.assertEqual(
671        expected_input_map,
672        actual_input_map,
673        msg="\nexpected:\n%s\nvs actual:\n%s" %
674        (sorted(expected_input_map.items()), sorted(actual_input_map.items())))
675
676  def _VerifyMaxBatchSizeAnnotations(
677      self,
678      expected_engines,
679      original_gdef,
680      converted_gdef,
681      default_max_batch_size,
682      expected_max_batch_sizes=None,
683  ):
684    """Verifies the max batch size annotations in the original and converted GraphDef.
685
686    Args:
687      expected_engines: A sequence of engines names.
688      original_gdef: GraphDef. The graph def before TensorRT conversion.
689      converted_gdef: GraphDef. The graph def after TensorRT conversion.
690      default_max_batch_size: The default maximum batch size to use if no node
691        inside a segment is annoted with a customized max batch size. This value
692        is None when the graph is converted to TF-TRT with dynamic engines.
693      expected_max_batch_sizes: Optional. A sequence of max batch sizes for all
694        the engines. `None` if does not check enforce max batch sizes.
695    """
696    if isinstance(expected_max_batch_sizes, collections.abc.Collection):
697      self.assertEqual(len(expected_max_batch_sizes), len(expected_engines))
698    else:
699      self.assertIsNone(
700          expected_max_batch_sizes,
701          "'expected_max_batch_sizes' shall only be a sequence "
702          "of integers or `None`.")
703
704    def _ChainAllNodes(graph_def):
705      return itertools.chain(
706          graph_def.node,
707          itertools.chain(
708              *[func.node_def for func in graph_def.library.function]))
709
710    old_name_to_node_map = {
711        self._ToString(node.name): node
712        for node in _ChainAllNodes(original_gdef)
713    }
714    new_name_to_func_map = {
715        self._ToString(func.signature.name): func
716        for func in converted_gdef.library.function
717    }
718
719    def _DetectStaticBatchSize(node_def):
720      """Returns the static batch size of an operation or None.
721
722      It is incorrect to use the output shapes to find the batch size of an
723      operation, as the segmenter actually uses the input shapes. However, it is
724      a simplication and works for most of the cases for the test purposes.
725
726      Args:
727        node_def: `tf.NodeDef`. The target node for analysis.
728
729      Returns:
730        If all the outputs of the node have the same static batch size, returns
731        the int value for the batch size. Otherwise returns None.
732      """
733      shapes = node_def.attr["_output_shapes"].list.shape
734      batch_size = set(
735          list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes)
736      if len(batch_size) == 1 and list(batch_size)[0] >= 1:
737        return list(batch_size)[0]
738      return None
739
740    name_to_engines_map = {}
741    actual_max_batch_sizes = []
742    for node in _ChainAllNodes(converted_gdef):
743      if node.op == "TRTEngineOp":
744        engine = node
745        engine_name = self._RemoveGraphSequenceNumber(
746            self._Canonicalize(self._ToString(engine.name)))
747        self.assertIn(engine_name, expected_engines)
748        name_to_engines_map[engine_name] = engine
749        # The input nodes shall not have the conflicting annotation (no
750        # annotation or the same annotation) with the maximum batch size
751        # annotation. If the engine has maximum batch size annotation as the
752        # non-default maximum batch size, then at least one input node shall
753        # have the same annotation to be the source.
754        self.assertIn("max_batch_size", node.attr)
755        engine_max_batch_size = node.attr["max_batch_size"].i
756        self.assertIsInstance(engine_max_batch_size, int)
757        actual_max_batch_sizes.append(engine_max_batch_size)
758        seg_func = node.attr["segment_func"].func
759        self.assertIsNotNone(seg_func)
760        self.assertIn(seg_func.name, new_name_to_func_map)
761        seg_func_def = new_name_to_func_map[seg_func.name]
762        logging.info("Segment function name: %s. Including %d nodes.",
763                     seg_func.name, len(seg_func_def.node_def))
764        node_max_batch_size_all_none = True
765        # Use the native segment to search for replaced nodes
766        for alternative_node in seg_func_def.node_def:
767          node_name = self._Canonicalize(self._ToString(alternative_node.name))
768          if node_name not in old_name_to_node_map:
769            continue
770          original_node = old_name_to_node_map[node_name]
771          node_max_batch_size = None
772          if "_tftrt_op_max_batch_size" in original_node.attr:
773            node_max_batch_size = original_node.attr[
774                "_tftrt_op_max_batch_size"].i
775          elif (original_node.op != "Const" and
776                alternative_node.op != "Const" and
777                "_output_shapes" in original_node.attr):
778            node_max_batch_size = _DetectStaticBatchSize(original_node)
779          logging.info(
780              "'{%s}(%s)'s max batch size annotation is %s. "
781              "'{%s}'s max batch size is %s.", node_name, original_node.op,
782              str(node_max_batch_size), engine_name, str(engine_max_batch_size))
783          node_max_batch_size_all_none &= node_max_batch_size is None
784          self.assertTrue(engine_max_batch_size == node_max_batch_size or
785                          node_max_batch_size is None)
786        logging.info("'{%s}'s max batch size is %d.", engine_name,
787                     engine_max_batch_size)
788        self.assertTrue(default_max_batch_size is None or
789                        engine_max_batch_size == default_max_batch_size or
790                        not node_max_batch_size_all_none)
791
792    self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys()))
793    if expected_max_batch_sizes is not None:
794      self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes)
795
796  def _GetGraphDef(self, run_params, gdef_or_saved_model_dir):
797    if isinstance(gdef_or_saved_model_dir, str):
798      if run_params.is_v2:
799        root = load.load(gdef_or_saved_model_dir)
800        func = root.signatures[
801            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
802        gdef = func.graph.as_graph_def()
803        # Manually unref the loaded saved model and force GC to destroy the TRT
804        # engine cache after load(). There is currently a reference cycle in 2.0
805        # which prevents auto deletion of the resource.
806        # TODO(laigd): fix this.
807        del func
808        del root
809        gc.collect()
810        return gdef
811      return saved_model_utils.get_meta_graph_def(
812          gdef_or_saved_model_dir, tag_constants.SERVING).graph_def
813    assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef)
814    return gdef_or_saved_model_dir
815
816  def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify,
817                        graph_state):
818    expected_engines = self.ExpectedEnginesToBuild(run_params)
819    num_engines = 0
820    functions = [f.signature.name for f in gdef_to_verify.library.function]
821    for node in gdef_to_verify.node:
822      if node.op == "TRTEngineOp":
823        logging.info("Found TRTEngineOp: " + node.name)
824        num_engines += 1
825        segment_funcdef_name = node.attr["segment_func"].func.name
826        function_name = node.name + "_native_segment"
827        is_dynamic_engine = not node.attr["static_engine"].b
828        self.assertNotEmpty(segment_funcdef_name, node.name)
829        self.assertIn(function_name, functions)
830        if not IsQuantizationWithCalibration and not is_dynamic_engine:
831          self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
832        self.assertIn(
833            self._RemoveGraphSequenceNumber(node.name), expected_engines)
834        self.assertEqual(
835            self._ToBytes(run_params.precision_mode),
836            node.attr["precision_mode"].s, node.name)
837
838        self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
839                         node.name)
840        self.assertEqual(node.attr["use_calibration"].b,
841                         run_params.use_calibration, node.name)
842
843        has_calibration_data = len(node.attr["calibration_data"].s)
844        if (IsQuantizationWithCalibration(run_params) and
845            graph_state == GraphState.INFERENCE):
846          self.assertTrue(has_calibration_data, node.name)
847        else:
848          self.assertFalse(has_calibration_data, node.name)
849    if graph_state == GraphState.ORIGINAL:
850      self.assertEqual(0, num_engines)
851    else:
852      self.assertEqual(num_engines, len(expected_engines))
853      if isinstance(expected_engines, dict):
854        self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify)
855      self._VerifyMaxBatchSizeAnnotations(
856          expected_engines=expected_engines,
857          original_gdef=original_gdef,
858          converted_gdef=gdef_to_verify,
859          expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
860          default_max_batch_size=self.GetMaxBatchSize(run_params))
861
862  def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify,
863                        graph_state):
864    if graph_state == GraphState.ORIGINAL:
865      return
866    expected_engines = self.ExpectedEnginesToBuild(run_params)
867    all_op_names = [node.name for node in gdef_to_verify.node]
868    trt_op_names = [
869        node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
870    ]
871    for func in gdef_to_verify.library.function:
872      if not re.search(r"TRTEngineOp_\d+_\d+_native_segment",
873                       func.signature.name):
874        for node in func.node_def:
875          all_op_names.append(node.name)
876          if node.op == "TRTEngineOp":
877            trt_op_names.append(node.name)
878
879    all_op_names = self._Canonicalize(all_op_names)
880    trt_op_names = self._RemoveGraphSequenceNumber(
881        self._Canonicalize(trt_op_names))
882
883    if isinstance(expected_engines, dict):
884      # For simplicity we don't verify the connections inside the engine in
885      # 2.0, but we still make sure that the converted ops are gone from the
886      # graph.
887      unexpected_names = set(nest.flatten(expected_engines.values()))
888      self.assertEmpty(
889          [name for name in unexpected_names if name in all_op_names])
890      expected_engines = set(expected_engines.keys())
891
892    self.assertEqual(set(expected_engines), trt_op_names)
893
894  def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir,
895                      gdef_or_saved_model_dir_to_verify, graph_state):
896    original_gdef = self._GetGraphDef(run_params,
897                                      original_gdef_or_saved_model_dir)
898    gdef_to_verify = self._GetGraphDef(run_params,
899                                       gdef_or_saved_model_dir_to_verify)
900    self._WriteGraph(run_params, gdef_to_verify, graph_state)
901    if run_params.is_v2:
902      self._VerifyGraphDefV2(run_params, original_gdef, gdef_to_verify,
903                             graph_state)
904    else:
905      self._VerifyGraphDefV1(run_params, original_gdef, gdef_to_verify,
906                             graph_state)
907
908  def _GetSavedModelDir(self, run_params, graph_state):
909    test_tmpdir = os.getenv("TRT_TEST_TMPDIR")
910    if test_tmpdir:
911      saved_model_dir = os.path.join(
912          test_tmpdir, self.__class__.__name__ + "_" + run_params.test_name +
913          "_" + self._GetGraphStateLabel(graph_state))
914      try:
915        # For TF 1.x we need to make sure the output directory doesn't exist
916        # before exporting the saved model.
917        shutil.rmtree(saved_model_dir)
918      except OSError as e:
919        if e.errno != errno.ENOENT:
920          raise
921      return saved_model_dir
922    return tempfile.mkdtemp(dir=self.get_temp_dir())
923
924  def _MakeSavedModelV1(self, run_params):
925    """Write the saved model as an input for testing."""
926    params = self._GetParamsCached()
927    g = ops.Graph()
928    with g.as_default():
929      inputs = []
930      for spec in params.input_specs:
931        inp = array_ops.placeholder(
932            dtype=spec.dtype, shape=spec.shape, name=spec.name)
933        inputs.append(inp)
934      outputs = params.graph_fn(*inputs)
935      if not isinstance(outputs, list) and not isinstance(outputs, tuple):
936        outputs = [outputs]
937
938    signature_def = signature_def_utils.build_signature_def(
939        inputs={inp.op.name: utils.build_tensor_info(inp) for inp in inputs},
940        outputs={out.op.name: utils.build_tensor_info(out) for out in outputs},
941        method_name=signature_constants.PREDICT_METHOD_NAME)
942
943    saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL)
944    saved_model_builder = builder.SavedModelBuilder(saved_model_dir)
945    with self.session(
946        graph=g, config=self._GetConfigProto(run_params,
947                                             GraphState.ORIGINAL)) as sess:
948      saved_model_builder.add_meta_graph_and_variables(
949          sess, [tag_constants.SERVING],
950          signature_def_map={
951              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
952                  signature_def
953          })
954    saved_model_builder.save()
955    return saved_model_dir
956
957  def _MakeSavedModelV2(self, run_params):
958    params = self._GetParamsCached()
959    root = tracking.AutoTrackable()
960    root.run = def_function.function(
961        params.graph_fn, input_signature=params.input_specs)
962    saved_model_dir = self._GetSavedModelDir(run_params, GraphState.ORIGINAL)
963    logging.info("Saving input SavedModel to %s", saved_model_dir)
964    save.save(root, saved_model_dir,
965              {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: root.run})
966    return saved_model_dir
967
968  def _MakeSavedModel(self, run_params):
969    if run_params.is_v2:
970      return self._MakeSavedModelV2(run_params)
971    return self._MakeSavedModelV1(run_params)
972
973  def RunTest(self, run_params):
974    with trace.Trace(run_params.test_name):
975      should_run, reason_for_skipping = self.ShouldRunTest(run_params)
976      if not should_run:
977        return self.skipTest(reason_for_skipping)
978
979      saved_model_dir = self._MakeSavedModel(run_params)
980
981      np.random.seed(12345)  # Fix the seed so the test is deterministic.
982      inputs_data = []
983      input_specs = self._GetParamsCached().input_specs
984      for dim_list in self._GetParamsCached().input_dims:
985        assert len(input_specs) == len(dim_list)
986        current_input_data = []
987        for spec, np_shape in zip(input_specs, dim_list):
988          np_dtype = spec.dtype.as_numpy_dtype()
989          # Multiply the input by some constant to avoid all zeros input for
990          # integer types.
991          scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0
992          # TODO(laigd): add debug options. E.g. we can set the input data to be
993          # continuous natural numbers:
994          # seq = np.arange(np.prod(np_shape))
995          # seq.resize(np_shape)
996          # current_inputs_data.append(scale * seq.astype(np_dtype))
997          data = (scale * np.random.random_sample(np_shape)).astype(np_dtype)
998          if run_params.is_v2:
999            with ops.device("/GPU:0"):
1000              data = ops.convert_to_tensor(data)
1001          current_input_data.append(data)
1002        inputs_data.append(current_input_data)
1003
1004      # Verify the original graph.
1005      self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir,
1006                           GraphState.ORIGINAL)
1007
1008      # Run the original graph without TensorRT to get the reference result.
1009      logging.info("Running original graph w/o TensorRT\n")
1010      ref_result = self._RunGraph(
1011          run_params,
1012          saved_model_dir,
1013          inputs_data,
1014          GraphState.ORIGINAL,
1015          num_runs=1)
1016
1017      # Run calibration if necessary.
1018      if IsQuantizationWithCalibration(run_params):
1019        infer_saved_model_dir = self._GetCalibratedInferGraph(
1020            run_params, saved_model_dir, inputs_data)
1021        self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir,
1022                             GraphState.INFERENCE)
1023      elif not run_params.convert_online:
1024        infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir)
1025        self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir,
1026                             GraphState.INFERENCE)
1027      else:
1028        infer_saved_model_dir = saved_model_dir
1029
1030      # Run the inference graph, either using the converted graph or the
1031      # original graph with convert_online == True.
1032      logging.info("Running final inference graph\n")
1033      result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data,
1034                              GraphState.INFERENCE)
1035      self.assertAllClose(
1036          ref_result,
1037          result,
1038          atol=self.ExpectedAbsoluteTolerance(run_params),
1039          rtol=self.ExpectedRelativeTolerance(run_params))
1040
1041  def testIdempotence(self):
1042    # Test that applying tensorrt optimizer or offline conversion tools multiple
1043    # times to the same graph will result in same graph.
1044    #
1045    # TODO(aaroey): implement this.
1046    pass
1047
1048
1049def _GetTestConfigsV1():
1050  """Returns the config combinations to run the test."""
1051  convert_online, convert_offline = True, False
1052  dynamic_engine, static_engine = True, False
1053  use_calibration, no_calibration = True, False
1054
1055  # Add all possible test cases and let the derived test class to decide
1056  # whether to run specific ones with ShouldRunTest().
1057  #
1058  # Note: INT8 without calibration behaves like FP32/FP16.
1059  opts = list(
1060      itertools.product([FP32, FP16, INT8], [convert_online, convert_offline],
1061                        [dynamic_engine, static_engine], [no_calibration]))
1062  # We always run calibration with offline tool.
1063  # TODO(aaroey): static calibration engine is not supported yet.
1064  opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
1065  return opts
1066
1067
1068def _GetTestConfigsV2():
1069  """Returns the config combinations to run the test."""
1070  convert_offline = False
1071  # TODO(laigd): add support for static_engine.
1072  dynamic_engine = True
1073  # TODO(laigd): add support for calibration.
1074  no_calibration = False
1075
1076  # Add all possible test cases and let the derived test class to decide
1077  # whether to run specific ones with ShouldRunTest().
1078  #
1079  # Note:
1080  # - In TF2.0 the conversion always produce dynamic engine, and we don't test
1081  #   the offline mode here.
1082  # - For simplicity we don't test online conversion which requires setting the
1083  #   Grappler config in default eager context.
1084  # - INT8 without calibration behaves like FP32/FP16.
1085  opts = list(
1086      itertools.product([FP32, FP16, INT8], [convert_offline], [dynamic_engine],
1087                        [no_calibration]))
1088  # We always run calibration with offline tool.
1089  # TODO(aaroey): INT8+calibration is not supported yet in V2.
1090  # opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
1091  return opts
1092
1093
1094def _GetTest(run_params):
1095  """Gets a single test method based on the parameters."""
1096
1097  def _Test(self):
1098    logging.info(
1099        "Running test %s with parameters: convert_online=%s, "
1100        "precision_mode=%s, dynamic_engine=%s", run_params.test_name,
1101        run_params.convert_online, run_params.precision_mode,
1102        run_params.dynamic_engine)
1103    self.RunTest(run_params)
1104
1105  return _Test
1106
1107
1108def _AddTestsFor(test_class, is_v2):
1109  """Adds test methods to TfTrtIntegrationTestBase for specific TF version."""
1110  opts = _GetTestConfigsV2() if is_v2 else _GetTestConfigsV1()
1111  for (precision_mode, convert_online, dynamic_engine, use_calibration) in opts:
1112    conversion = "OnlineConversion" if convert_online else "OfflineConversion"
1113    engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine"
1114    calibration_type = "UseCalibration" if use_calibration else "NoCalibration"
1115    test_name = "%s_%s_%s_%s_%s" % ("testTfTrtV2" if is_v2 else "testTfTrt",
1116                                    conversion, engine_type, precision_mode,
1117                                    calibration_type)
1118    run_params = RunParams(
1119        convert_online=convert_online,
1120        precision_mode=precision_mode,
1121        dynamic_engine=dynamic_engine,
1122        test_name=test_name,
1123        use_calibration=use_calibration,
1124        is_v2=is_v2)
1125    if is_v2:
1126      setattr(test_class, test_name,
1127              test_util.run_v2_only(_GetTest(run_params)))
1128    else:
1129      setattr(test_class, test_name,
1130              test_util.run_v1_only("", _GetTest(run_params)))
1131
1132
1133def _AddTests(test_class):
1134  """Adds test methods to TfTrtIntegrationTestBase."""
1135  _AddTestsFor(test_class, is_v2=False)
1136  _AddTestsFor(test_class, is_v2=True)
1137
1138
1139if is_tensorrt_enabled():
1140  os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
1141  _AddTests(TfTrtIntegrationTestBase)
1142