• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loads, converts, calibrates, and runs sample models."""
16
17import abc
18import collections
19import functools
20import itertools
21import tempfile
22import time
23from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Union
24
25from absl import logging
26import numpy as np
27
28from tensorflow.core.framework import graph_pb2
29from tensorflow.core.framework import tensor_shape_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import meta_graph_pb2
32from tensorflow.python.client import session
33from tensorflow.python.compiler.tensorrt import trt_convert as trt
34from tensorflow.python.framework import convert_to_constants
35from tensorflow.python.framework import dtypes as tf_dtypes
36from tensorflow.python.framework import importer
37from tensorflow.python.framework import ops as framework_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import random_ops
40from tensorflow.python.saved_model import load as saved_model_load
41from tensorflow.python.saved_model import loader as saved_model_loader
42from tensorflow.python.saved_model import signature_constants
43from tensorflow.python.saved_model import tag_constants
44
45
46# pylint: disable=bad-whitespace
47### Helper Functions
48def _remove_graph_sequence_number(name: str) -> str:
49  return name.split(":")[0]
50
51
52def _get_concrete_tensor_shape(
53    tensor_shape: tensor_shape_pb2.TensorShapeProto,
54    batch_size: Optional[int] = None) -> Sequence[int]:
55  """Gets a concrete tensor shape without dynamic dimensions."""
56  if tensor_shape.unknown_rank:
57    raise ValueError("Cannot generates random tensors for unknown rank!")
58  shape = [dim.size for dim in tensor_shape.dim]
59  if not shape:
60    raise ValueError("The tensor cannot have a rank of 0!")
61  if shape[0] < 0:
62    if batch_size is None or batch_size <= 0:
63      raise ValueError("Must provide a valid batch size "
64                       "as the tensor has a dynamic batch size!")
65    shape[0] = batch_size
66  if any(filter(lambda x: x < 0, shape)):
67    raise ValueError("Cannot have dynamic dimensions except for batch size!")
68  return shape
69
70
71def _generate_random_tensor_ops(shape: Sequence[int], dtype: tf_dtypes.DType,
72                                name: str) -> framework_ops.Tensor:
73  # Need to generate a random tensor in float32/int32 and cast to a different
74  # datatype as random_ops doesn't suppprt all the datatypes.
75  random_dtype = tf_dtypes.float32 if dtype.is_floating else tf_dtypes.int32
76  # tf.bool doesn't have `max` attribute
77  dtype_max = 1 if dtype == tf_dtypes.bool else dtype.max
78  return math_ops.cast(
79      random_ops.random_uniform(
80          shape=shape,
81          dtype=random_dtype,
82          # Limits maximum value as 255 to simulate pixel values, avoid
83          # generating large numbers and casuing overflows.
84          maxval=min(dtype_max, random_dtype.max, 255)),
85      dtype=dtype,
86      name=name)
87
88
89def _generate_random_tensor_v1(tensor_info: meta_graph_pb2.TensorInfo,
90                               batch_size: Optional[int] = None) -> np.ndarray:
91  """Generates a random tensor based on the data type and tensor shape."""
92  dtype = tf_dtypes.as_dtype(tensor_info.dtype)
93  shape = _get_concrete_tensor_shape(tensor_info.tensor_shape, batch_size)
94  with framework_ops.Graph().as_default() as graph, session.Session(
95      graph=graph):
96    return _generate_random_tensor_ops(
97        shape=shape,
98        dtype=dtype,
99        name=_remove_graph_sequence_number(tensor_info.name)).eval()
100
101
102def _generate_random_tensor_v2(
103    tensor: framework_ops.Tensor,
104    batch_size: Optional[int] = None) -> framework_ops.Tensor:
105  """Generates a random tensor based on the data type and tensor shape."""
106  shape = _get_concrete_tensor_shape(tensor.shape.as_proto(), batch_size)
107  return _generate_random_tensor_ops(
108      shape=shape, dtype=tensor.dtype, name=tensor.name)
109
110
111# Models are repeatedly loaded for different TensorRT conversion settings.
112# Using cache can reduce I/O.
113@functools.lru_cache()
114def load_meta_graph(
115    saved_model_dir: str, saved_model_tags: str,
116    saved_model_signature_key: str) -> meta_graph_pb2.MetaGraphDef:
117  """Loads a `tf.MetaGraphDef` in TF1."""
118  with framework_ops.Graph().as_default() as graph, session.Session(
119      graph=graph) as sess:
120    meta_graph = saved_model_loader.load(
121        sess=sess,
122        export_dir=saved_model_dir,
123        tags=saved_model_tags,
124    )
125    output_node_names = [
126        _remove_graph_sequence_number(tensor.name) for tensor in
127        meta_graph.signature_def[saved_model_signature_key].outputs.values()
128    ]
129    graph_def = (
130        convert_to_constants.convert_variables_to_constants_from_session_graph(
131            sess, meta_graph.graph_def, output_node_names))
132    meta_graph.graph_def.CopyFrom(graph_def)
133  return meta_graph
134
135
136@functools.lru_cache()
137def load_graph_func(saved_model_dir: str, saved_model_tags: str,
138                    saved_model_signature_key: str):
139  """Loads a graph function in TF2."""
140  imported = saved_model_load.load(
141      export_dir=saved_model_dir, tags=saved_model_tags)
142  graph_func = imported.signatures[saved_model_signature_key]
143  return convert_to_constants.convert_variables_to_constants_v2(graph_func)
144
145
146### Test Classes
147class ModelConfig(
148    collections.namedtuple("ModelConfig", [
149        "saved_model_dir", "saved_model_tags", "saved_model_signature_key",
150        "default_batch_size"
151    ])):
152  """Configurations for test models."""
153
154  def __new__(cls,
155              saved_model_dir: str,
156              saved_model_tags: Sequence[str] = (tag_constants.SERVING,),
157              saved_model_signature_key: str = (
158                  signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY),
159              default_batch_size: int = 1):
160    return super(ModelConfig,
161                 cls).__new__(cls, saved_model_dir, saved_model_tags,
162                              saved_model_signature_key, default_batch_size)
163
164
165class TestResult(
166    collections.namedtuple("TestResult", [
167        "model_config", "enable_gpu", "output_names", "output_tensors",
168        "model_latency", "trt_convert_params"
169    ])):
170  """Configuration and results for a single model testing."""
171
172  def __new__(cls,
173              model_config: ModelConfig,
174              enable_gpu: bool,
175              output_names: Sequence[str],
176              output_tensors: Sequence[np.ndarray],
177              model_latency: List[float],
178              trt_convert_params: trt.TrtConversionParams = None):
179    return super(TestResult,
180                 cls).__new__(cls, model_config, enable_gpu, output_names,
181                              output_tensors, model_latency, trt_convert_params)
182
183
184class TestResultCollection(
185    collections.namedtuple("TestResultCollection", [
186        "test_name", "model_config", "cpu_base_result", "gpu_base_result",
187        "trt_results"
188    ])):
189  """Configuration and results for a series of model testing."""
190
191  def __new__(cls,
192              test_name: str,
193              model_config: ModelConfig,
194              cpu_base_result: TestResult,
195              gpu_base_result: TestResult,
196              trt_results: Sequence[TestResult] = tuple()):
197    return super(TestResultCollection,
198                 cls).__new__(cls, test_name, model_config, cpu_base_result,
199                              gpu_base_result, trt_results)
200
201  @property
202  def results(self) -> Iterable[TestResult]:
203    return filter(
204        lambda x: x is not None,
205        itertools.chain([self.cpu_base_result, self.gpu_base_result],
206                        self.trt_results))
207
208
209class _ModelHandlerBase(metaclass=abc.ABCMeta):
210  """Base class for running a model."""
211
212  def __init__(self, model_config: ModelConfig):
213    self._model_config = model_config
214
215  def __str__(self) -> str:
216    return str(self._model_config)
217
218  def __repr__(self) -> str:
219    return "{}({})".format(self.__class__.__name__, str(self))
220
221  @property
222  def model_config(self) -> ModelConfig:
223    return self._model_config
224
225  @property
226  def input_tensort_names(self) -> Sequence[str]:
227    """Names of input tensors."""
228
229  @property
230  def output_tensor_names(self) -> Sequence[str]:
231    """Names of output tensors."""
232
233  @abc.abstractmethod
234  def generate_random_inputs(
235      self,
236      batch_size: Optional[int] = None
237  ) -> Mapping[str, Union[np.ndarray, framework_ops.Tensor]]:
238    """Generates mapping from names to input tensors."""
239
240  @abc.abstractmethod
241  def run(self,
242          inputs=None,
243          warmup_iterations: int = 10,
244          benchmark_iterations: int = 100,
245          enable_gpu: bool = True) -> TestResult:
246    """Runs the model with provided or randomly generated input tensors.
247
248    Args:
249      inputs: Mapping from names to input ndarrays in TF1, or a sequence of
250        tensors in TF2. If `None`, ramdomly generated inputs will be used
251        instead.
252      warmup_iterations: Number of inferences to warm up the runtime.
253      benchmark_iterations: Number of inferences to measure the latency.
254      enable_gpu: Whether it is allowed to use GPU or not.
255
256    Returns:
257      `TestResult` summarizing latency and numerics information.
258    """
259
260
261class ModelHandlerV1(_ModelHandlerBase):
262  """Runs a model in TF1."""
263
264  @property
265  def meta_graph(self) -> meta_graph_pb2.MetaGraphDef:
266    return load_meta_graph(
267        saved_model_dir=self.model_config.saved_model_dir,
268        saved_model_tags=self.model_config.saved_model_tags,
269        saved_model_signature_key=self.model_config.saved_model_signature_key)
270
271  @property
272  def input_tensor_info(self) -> Mapping[str, meta_graph_pb2.TensorInfo]:
273    return self.meta_graph.signature_def[
274        self.model_config.saved_model_signature_key].inputs
275
276  @property
277  def output_tensor_info(self) -> Mapping[str, meta_graph_pb2.TensorInfo]:
278    return self.meta_graph.signature_def[
279        self.model_config.saved_model_signature_key].outputs
280
281  @property
282  def input_tensort_names(self) -> Sequence[str]:
283    return [info.name for info in self.input_tensor_info.values()]
284
285  @property
286  def output_tensor_names(self) -> Sequence[str]:
287    return [info.name for info in self.output_tensor_info.values()]
288
289  def generate_random_inputs(self,
290                             batch_size: Optional[int] = None
291                            ) -> Mapping[str, np.ndarray]:
292    batch_size = batch_size or self.model_config.default_batch_size
293    return {
294        tensor_info.name: _generate_random_tensor_v1(tensor_info, batch_size)
295        for tensor_info in self.input_tensor_info.values()
296    }
297
298  def run(self,
299          inputs: Optional[Mapping[str, np.ndarray]] = None,
300          warmup_iterations=10,
301          benchmark_iterations=100,
302          enable_gpu=True) -> TestResult:
303    inputs = inputs or self.generate_random_inputs()
304    config_proto = None
305    if not enable_gpu:
306      config_proto = config_pb2.ConfigProto(device_count={"CPU": 1, "GPU": 0})
307    logging.info("Running model inference!")
308    with framework_ops.Graph().as_default():
309      with session.Session(config=config_proto) as sess:
310        importer.import_graph_def(self.meta_graph.graph_def, name="")
311        try:
312          output_tensor_names = self.output_tensor_names
313          for _ in range(warmup_iterations):
314            sess.run(fetches=output_tensor_names, feed_dict=inputs)
315          latency = []
316          for _ in range(benchmark_iterations):
317            before = time.time()
318            outputs = sess.run(fetches=output_tensor_names, feed_dict=inputs)
319            latency.append(time.time() - before)
320        except Exception as exc:
321          raise RuntimeError("Failed to run model inference! "
322                             "Model information: {}".format(str(self))) from exc
323    return TestResult(
324        model_config=self.model_config,
325        enable_gpu=enable_gpu,
326        model_latency=latency,
327        output_names=self.output_tensor_names,
328        output_tensors=outputs)
329
330
331class ModelHandlerV2(_ModelHandlerBase):
332  """Runs a model in TF2."""
333
334  @property
335  def graph_func(self):
336    graph_func = load_graph_func(
337        saved_model_dir=self.model_config.saved_model_dir,
338        saved_model_tags=self.model_config.saved_model_tags,
339        saved_model_signature_key=self.model_config.saved_model_signature_key)
340    return convert_to_constants.convert_variables_to_constants_v2(graph_func)
341
342  @property
343  def input_tensor_names(self):
344    return [tensor.name for tensor in self.graph_func.inputs]
345
346  @property
347  def output_tensor_names(self):
348    return [tensor.name for tensor in self.graph_func.outputs]
349
350  def generate_random_inputs(self,
351                             batch_size: Optional[int] = None
352                            ) -> Sequence[framework_ops.Tensor]:
353    batch_size = batch_size or self.model_config.default_batch_size
354    return [
355        _generate_random_tensor_v2(tensor, batch_size)
356        for tensor in self.graph_func.inputs
357    ]
358
359  def run(self,
360          inputs: Optional[Sequence[framework_ops.Tensor]] = None,
361          warmup_iterations=10,
362          benchmark_iterations=100,
363          enable_gpu=True) -> TestResult:
364    inputs = inputs or self.generate_random_inputs()
365    try:
366      device = "/device:gpu:0" if enable_gpu else "/device:cpu:0"
367      with framework_ops.device(device):
368        for _ in range(warmup_iterations):
369          self.graph_func(*inputs)
370        latency = []
371        for _ in range(benchmark_iterations):
372          before = time.time()
373          outputs = self.graph_func(*inputs)
374          latency.append(time.time() - before)
375    except Exception as exc:
376      raise RuntimeError("Failed to run model inference! "
377                         "Model information: {}".format(str(self))) from exc
378    return TestResult(
379        model_config=self.model_config,
380        enable_gpu=enable_gpu,
381        model_latency=latency,
382        output_names=self.output_tensor_names,
383        output_tensors=outputs)
384
385
386class _TrtModelHandlerBase(_ModelHandlerBase):
387  """Base class for converting and running a model."""
388
389  def __init__(
390      self,
391      model_config: ModelConfig,
392      trt_convert_params: trt.TrtConversionParams,
393  ):
394    super(_TrtModelHandlerBase, self).__init__(model_config)
395    self._trt_convert_params = trt_convert_params
396
397    self._converter = self._create_converter(trt_convert_params)
398    self._conversion_is_saved = False
399
400  @abc.abstractmethod
401  def _create_converter(self, trt_convert_params: trt.TrtConversionParams):
402    """Creates a converter for the corresponding TF version."""
403
404  @abc.abstractmethod
405  def _check_conversion(self, conversion_output):
406    """Checks if conversion output has any TensorRT engines."""
407
408  def _check_contains_trt_engine(self, graph_def: graph_pb2.GraphDef):
409    if "TRTEngineOp" not in [node.op for node in graph_def.node]:
410      raise RuntimeError("Failed to convert to TensorRT! "
411                         "Model Information: {}".format(str(self)))
412
413  def __str__(self) -> str:
414    base = super(_TrtModelHandlerBase, self).__str__()
415    return "{}, TrtConversionParams: {}".format(base,
416                                                str(self._trt_convert_params))
417
418  @property
419  def trt_convert_params(self) -> trt.TrtConversionParams:
420    return self._trt_convert_params
421
422  @abc.abstractmethod
423  def convert(self,
424              calibration_inputs: Optional[Mapping[str, np.ndarray]] = None,
425              num_runs=1) -> None:
426    """Converts the model with TensorRT and calibrates if using INT8 precision mode.
427
428    Args:
429      calibration_inputs: Mapping from input names to ndarrays in TF1. Or a
430        sequence of tensors in TF2. Used as calibration data.
431      num_runs: Number of calibration runs.
432    """
433
434  def save(self,
435           output_saved_model_dir: Optional[str] = None,
436           overwrite=True) -> None:
437    """Saves a TensorRT converted model."""
438    if self._conversion_is_saved and not overwrite:
439      return
440    output_saved_model_dir = output_saved_model_dir or tempfile.mkdtemp()
441    logging.info("Saving TensorRT model to %s!", output_saved_model_dir)
442    self._converter.save(output_saved_model_dir)
443    self._model_config = self.model_config._replace(
444        saved_model_dir=output_saved_model_dir)
445    self._conversion_is_saved = True
446
447
448class TrtModelHandlerV1(_TrtModelHandlerBase, ModelHandlerV1):
449  """Converts a TF1 model with TensorRT and runs the converted model."""
450
451  def _create_converter(self, trt_convert_params: trt.TrtConversionParams):
452    conversion_nodes_denylist = self.output_tensor_names
453    return trt.TrtGraphConverter(
454        input_saved_model_dir=self.model_config.saved_model_dir,
455        input_saved_model_tags=self.model_config.saved_model_tags,
456        input_saved_model_signature_key=(
457            self.model_config.saved_model_signature_key),
458        nodes_denylist=conversion_nodes_denylist,
459        max_workspace_size_bytes=trt_convert_params.max_workspace_size_bytes,
460        precision_mode=trt_convert_params.precision_mode,
461        minimum_segment_size=trt_convert_params.minimum_segment_size,
462        maximum_cached_engines=trt_convert_params.maximum_cached_engines,
463        use_calibration=trt_convert_params.use_calibration,
464        max_batch_size=self.model_config.default_batch_size,
465        is_dynamic_op=False,
466    )
467
468  _check_conversion = _TrtModelHandlerBase._check_contains_trt_engine
469
470  def convert(self,
471              calibration_inputs: Optional[Mapping[str, np.ndarray]] = None,
472              num_runs=1) -> None:
473    logging.info("Converting with TensorRT!")
474    self._check_conversion(self._converter.convert())
475
476    if (self.trt_convert_params.precision_mode == trt.TrtPrecisionMode.INT8 and
477        self.trt_convert_params.use_calibration):
478      logging.info("Calibrating with TensorRT!")
479      if not calibration_inputs:
480        raise ValueError("Must provide calibration data "
481                         "when using TensorRT calibration!")
482      try:
483        self._converter.calibrate(
484            fetch_names=self.output_tensor_names,
485            num_runs=num_runs,
486            feed_dict_fn=lambda: calibration_inputs)
487      except Exception as exc:
488        raise RuntimeError("Failed to calibrate! "
489                           "Model Information: {}".format(str(self))) from exc
490
491  def run(self,
492          inputs: Optional[Mapping[str, np.ndarray]] = None,
493          warmup_iterations=10,
494          benchmark_iterations=100) -> TestResult:
495    self.save(overwrite=False)
496    self._check_conversion(self.meta_graph.graph_def)
497    logging.info("Running with TensorRT!")
498    test_result = ModelHandlerV1.run(
499        self, inputs, warmup_iterations, benchmark_iterations, enable_gpu=True)
500    return test_result._replace(trt_convert_params=self._trt_convert_params)
501
502
503class TrtModelHandlerV2(_TrtModelHandlerBase, ModelHandlerV2):
504  """Converts a TF2 model with TensorRT and runs the converted model."""
505
506  def _create_converter(self, trt_convert_params: trt.TrtConversionParams):
507    return trt.TrtGraphConverterV2(
508        input_saved_model_dir=self.model_config.saved_model_dir,
509        input_saved_model_tags=self.model_config.saved_model_tags,
510        input_saved_model_signature_key=(
511            self.model_config.saved_model_signature_key),
512        conversion_params=trt_convert_params)
513
514  def _check_conversion(self, graph_func):
515    graph_def = graph_func.graph.as_graph_def()
516    self._check_contains_trt_engine(graph_def)
517
518  def convert(self,
519              calibration_inputs: Optional[Sequence[
520                  framework_ops.Tensor]] = None,
521              num_runs=1) -> None:
522    logging.info("Converting with TensorRT!")
523
524    calibration_input_fn = None
525    if (self.trt_convert_params.precision_mode == trt.TrtPrecisionMode.INT8 and
526        self.trt_convert_params.use_calibration):
527      logging.info("Calibrating with TensorRT at the same time!")
528      if not calibration_inputs:
529        raise ValueError("Must provide calibration data "
530                         "when using TensorRT calibration!")
531
532      def gets_calibration_input():
533        for _ in range(num_runs):
534          yield calibration_inputs
535
536      calibration_input_fn = gets_calibration_input
537
538    self._check_conversion(self._converter.convert(calibration_input_fn))
539
540  def run(self,
541          inputs: Optional[Sequence[framework_ops.Tensor]] = None,
542          warmup_iterations=10,
543          benchmark_iterations=100) -> TestResult:
544    self.save(overwrite=False)
545    self._check_conversion(self.graph_func)
546    logging.info("Running with TensorRT!")
547    test_result = ModelHandlerV2.run(
548        self, inputs, warmup_iterations, benchmark_iterations, enable_gpu=True)
549    return test_result._replace(trt_convert_params=self._trt_convert_params)
550
551
552class _ModelHandlerManagerBase(metaclass=abc.ABCMeta):
553  """Manages a series of ModelHandlers for aggregrated testing/benchmarking."""
554
555  def __init__(
556      self, name: str, model_config: ModelConfig,
557      default_trt_convert_params: trt.TrtConversionParams,
558      trt_convert_params_updater: Callable[[trt.TrtConversionParams],
559                                           Iterable[trt.TrtConversionParams]]):
560    self._ori_model = self.model_handler_cls(model_config)
561    self._trt_models = []
562    for trt_convert_params in trt_convert_params_updater(
563        default_trt_convert_params):
564      trt_model = self.trt_model_handler_cls(
565          model_config, trt_convert_params=trt_convert_params)
566      self._trt_models.append(trt_model)
567
568    self._name = name
569    self._result_collection = None
570
571  def __str__(self) -> str:
572    return "Input Model: {}".format(str(self._ori_model))
573
574  def __repr__(self) -> str:
575    return "{}({})".format(self.__class__.__name__, str(self))
576
577  @property
578  @classmethod
579  @abc.abstractmethod
580  def model_handler_cls(cls):
581    """The modle handler class. ModelHandleV1/ModelHandlerV2."""
582
583  @property
584  @classmethod
585  @abc.abstractmethod
586  def trt_model_handler_cls(cls):
587    """The TensorRTmodle handler class. TrtModelHandleV1/TrtModelHandlerV2."""
588
589  @property
590  def name(self) -> str:
591    return self._name
592
593  @property
594  def model_config(self) -> ModelConfig:
595    return self._ori_model.model_config
596
597  def generate_random_inputs(self, batch_size: Optional[int] = None):
598    return self._ori_model.generate_random_inputs(batch_size)
599
600  def convert(self, calibration_inputs=None, num_runs=1) -> None:
601    """Converts models with TensorRT and calibrates if using INT8 precision mode.
602
603    Args:
604      calibration_inputs: Mapping from input names to ndarrays in TF1. Or a
605        sequence of tensors in TF2. Used as calibration data.
606      num_runs: Number of calibration runs.
607    """
608    for trt_model in self._trt_models:
609      trt_model.convert(calibration_inputs, num_runs)
610
611  def run(self,
612          inputs=None,
613          warmup_iterations: int = 10,
614          benchmark_iterations: int = 100) -> TestResultCollection:
615    """Runs model inference with provided or randomly generated input tensors.
616
617    Args:
618      inputs: Mapping from names to input ndarrays in TF1. Or a sequence of
619        tensors in TF2. If `None`, ramdomly generated input tensors will be used
620        instead.
621      warmup_iterations: Number of inferences to warm up the runtime.
622      benchmark_iterations: Number of inferences to measure the latency.
623
624    Returns:
625      `TestResultCollection` summarizing latency and numerics information for
626      different TensorRT conversion settings.
627    """
628    inputs = inputs or self.generate_random_inputs()
629
630    def run_model(model, **kwargs):
631      return model.run(inputs, warmup_iterations, benchmark_iterations,
632                       **kwargs)
633
634    # Some models include operations that can only run on GPU.
635    try:
636      cpu_base_result = run_model(self._ori_model, enable_gpu=False)
637    except RuntimeError as err:
638      logging.info("%s cannot run on CPU. Reason: %s.",
639                   self._ori_model.model_config, err)
640      cpu_base_result = None
641    gpu_base_result = run_model(self._ori_model, enable_gpu=True)
642    trt_results = list(map(run_model, self._trt_models))
643
644    return TestResultCollection(
645        test_name=self._name,
646        model_config=self.model_config,
647        cpu_base_result=cpu_base_result,
648        gpu_base_result=gpu_base_result,
649        trt_results=trt_results)
650
651
652class ModelHandlerManagerV1(_ModelHandlerManagerBase):
653  """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF1."""
654
655  model_handler_cls = ModelHandlerV1
656  trt_model_handler_cls = TrtModelHandlerV1
657
658
659class ModelHandlerManagerV2(_ModelHandlerManagerBase):
660  """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF2."""
661
662  model_handler_cls = ModelHandlerV2
663  trt_model_handler_cls = TrtModelHandlerV2
664