1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Loads, converts, calibrates, and runs sample models.""" 16 17import abc 18import collections 19import functools 20import itertools 21import tempfile 22import time 23from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Union 24 25from absl import logging 26import numpy as np 27 28from tensorflow.core.framework import graph_pb2 29from tensorflow.core.framework import tensor_shape_pb2 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import meta_graph_pb2 32from tensorflow.python.client import session 33from tensorflow.python.compiler.tensorrt import trt_convert as trt 34from tensorflow.python.framework import convert_to_constants 35from tensorflow.python.framework import dtypes as tf_dtypes 36from tensorflow.python.framework import importer 37from tensorflow.python.framework import ops as framework_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import random_ops 40from tensorflow.python.saved_model import load as saved_model_load 41from tensorflow.python.saved_model import loader as saved_model_loader 42from tensorflow.python.saved_model import signature_constants 43from tensorflow.python.saved_model import tag_constants 44 45 46# pylint: disable=bad-whitespace 47### Helper Functions 48def _remove_graph_sequence_number(name: str) -> str: 49 return name.split(":")[0] 50 51 52def _get_concrete_tensor_shape( 53 tensor_shape: tensor_shape_pb2.TensorShapeProto, 54 batch_size: Optional[int] = None) -> Sequence[int]: 55 """Gets a concrete tensor shape without dynamic dimensions.""" 56 if tensor_shape.unknown_rank: 57 raise ValueError("Cannot generates random tensors for unknown rank!") 58 shape = [dim.size for dim in tensor_shape.dim] 59 if not shape: 60 raise ValueError("The tensor cannot have a rank of 0!") 61 if shape[0] < 0: 62 if batch_size is None or batch_size <= 0: 63 raise ValueError("Must provide a valid batch size " 64 "as the tensor has a dynamic batch size!") 65 shape[0] = batch_size 66 if any(filter(lambda x: x < 0, shape)): 67 raise ValueError("Cannot have dynamic dimensions except for batch size!") 68 return shape 69 70 71def _generate_random_tensor_ops(shape: Sequence[int], dtype: tf_dtypes.DType, 72 name: str) -> framework_ops.Tensor: 73 # Need to generate a random tensor in float32/int32 and cast to a different 74 # datatype as random_ops doesn't suppprt all the datatypes. 75 random_dtype = tf_dtypes.float32 if dtype.is_floating else tf_dtypes.int32 76 # tf.bool doesn't have `max` attribute 77 dtype_max = 1 if dtype == tf_dtypes.bool else dtype.max 78 return math_ops.cast( 79 random_ops.random_uniform( 80 shape=shape, 81 dtype=random_dtype, 82 # Limits maximum value as 255 to simulate pixel values, avoid 83 # generating large numbers and casuing overflows. 84 maxval=min(dtype_max, random_dtype.max, 255)), 85 dtype=dtype, 86 name=name) 87 88 89def _generate_random_tensor_v1(tensor_info: meta_graph_pb2.TensorInfo, 90 batch_size: Optional[int] = None) -> np.ndarray: 91 """Generates a random tensor based on the data type and tensor shape.""" 92 dtype = tf_dtypes.as_dtype(tensor_info.dtype) 93 shape = _get_concrete_tensor_shape(tensor_info.tensor_shape, batch_size) 94 with framework_ops.Graph().as_default() as graph, session.Session( 95 graph=graph): 96 return _generate_random_tensor_ops( 97 shape=shape, 98 dtype=dtype, 99 name=_remove_graph_sequence_number(tensor_info.name)).eval() 100 101 102def _generate_random_tensor_v2( 103 tensor: framework_ops.Tensor, 104 batch_size: Optional[int] = None) -> framework_ops.Tensor: 105 """Generates a random tensor based on the data type and tensor shape.""" 106 shape = _get_concrete_tensor_shape(tensor.shape.as_proto(), batch_size) 107 return _generate_random_tensor_ops( 108 shape=shape, dtype=tensor.dtype, name=tensor.name) 109 110 111# Models are repeatedly loaded for different TensorRT conversion settings. 112# Using cache can reduce I/O. 113@functools.lru_cache() 114def load_meta_graph( 115 saved_model_dir: str, saved_model_tags: str, 116 saved_model_signature_key: str) -> meta_graph_pb2.MetaGraphDef: 117 """Loads a `tf.MetaGraphDef` in TF1.""" 118 with framework_ops.Graph().as_default() as graph, session.Session( 119 graph=graph) as sess: 120 meta_graph = saved_model_loader.load( 121 sess=sess, 122 export_dir=saved_model_dir, 123 tags=saved_model_tags, 124 ) 125 output_node_names = [ 126 _remove_graph_sequence_number(tensor.name) for tensor in 127 meta_graph.signature_def[saved_model_signature_key].outputs.values() 128 ] 129 graph_def = ( 130 convert_to_constants.convert_variables_to_constants_from_session_graph( 131 sess, meta_graph.graph_def, output_node_names)) 132 meta_graph.graph_def.CopyFrom(graph_def) 133 return meta_graph 134 135 136@functools.lru_cache() 137def load_graph_func(saved_model_dir: str, saved_model_tags: str, 138 saved_model_signature_key: str): 139 """Loads a graph function in TF2.""" 140 imported = saved_model_load.load( 141 export_dir=saved_model_dir, tags=saved_model_tags) 142 graph_func = imported.signatures[saved_model_signature_key] 143 return convert_to_constants.convert_variables_to_constants_v2(graph_func) 144 145 146### Test Classes 147class ModelConfig( 148 collections.namedtuple("ModelConfig", [ 149 "saved_model_dir", "saved_model_tags", "saved_model_signature_key", 150 "default_batch_size" 151 ])): 152 """Configurations for test models.""" 153 154 def __new__(cls, 155 saved_model_dir: str, 156 saved_model_tags: Sequence[str] = (tag_constants.SERVING,), 157 saved_model_signature_key: str = ( 158 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY), 159 default_batch_size: int = 1): 160 return super(ModelConfig, 161 cls).__new__(cls, saved_model_dir, saved_model_tags, 162 saved_model_signature_key, default_batch_size) 163 164 165class TestResult( 166 collections.namedtuple("TestResult", [ 167 "model_config", "enable_gpu", "output_names", "output_tensors", 168 "model_latency", "trt_convert_params" 169 ])): 170 """Configuration and results for a single model testing.""" 171 172 def __new__(cls, 173 model_config: ModelConfig, 174 enable_gpu: bool, 175 output_names: Sequence[str], 176 output_tensors: Sequence[np.ndarray], 177 model_latency: List[float], 178 trt_convert_params: trt.TrtConversionParams = None): 179 return super(TestResult, 180 cls).__new__(cls, model_config, enable_gpu, output_names, 181 output_tensors, model_latency, trt_convert_params) 182 183 184class TestResultCollection( 185 collections.namedtuple("TestResultCollection", [ 186 "test_name", "model_config", "cpu_base_result", "gpu_base_result", 187 "trt_results" 188 ])): 189 """Configuration and results for a series of model testing.""" 190 191 def __new__(cls, 192 test_name: str, 193 model_config: ModelConfig, 194 cpu_base_result: TestResult, 195 gpu_base_result: TestResult, 196 trt_results: Sequence[TestResult] = tuple()): 197 return super(TestResultCollection, 198 cls).__new__(cls, test_name, model_config, cpu_base_result, 199 gpu_base_result, trt_results) 200 201 @property 202 def results(self) -> Iterable[TestResult]: 203 return filter( 204 lambda x: x is not None, 205 itertools.chain([self.cpu_base_result, self.gpu_base_result], 206 self.trt_results)) 207 208 209class _ModelHandlerBase(metaclass=abc.ABCMeta): 210 """Base class for running a model.""" 211 212 def __init__(self, model_config: ModelConfig): 213 self._model_config = model_config 214 215 def __str__(self) -> str: 216 return str(self._model_config) 217 218 def __repr__(self) -> str: 219 return "{}({})".format(self.__class__.__name__, str(self)) 220 221 @property 222 def model_config(self) -> ModelConfig: 223 return self._model_config 224 225 @property 226 def input_tensort_names(self) -> Sequence[str]: 227 """Names of input tensors.""" 228 229 @property 230 def output_tensor_names(self) -> Sequence[str]: 231 """Names of output tensors.""" 232 233 @abc.abstractmethod 234 def generate_random_inputs( 235 self, 236 batch_size: Optional[int] = None 237 ) -> Mapping[str, Union[np.ndarray, framework_ops.Tensor]]: 238 """Generates mapping from names to input tensors.""" 239 240 @abc.abstractmethod 241 def run(self, 242 inputs=None, 243 warmup_iterations: int = 10, 244 benchmark_iterations: int = 100, 245 enable_gpu: bool = True) -> TestResult: 246 """Runs the model with provided or randomly generated input tensors. 247 248 Args: 249 inputs: Mapping from names to input ndarrays in TF1, or a sequence of 250 tensors in TF2. If `None`, ramdomly generated inputs will be used 251 instead. 252 warmup_iterations: Number of inferences to warm up the runtime. 253 benchmark_iterations: Number of inferences to measure the latency. 254 enable_gpu: Whether it is allowed to use GPU or not. 255 256 Returns: 257 `TestResult` summarizing latency and numerics information. 258 """ 259 260 261class ModelHandlerV1(_ModelHandlerBase): 262 """Runs a model in TF1.""" 263 264 @property 265 def meta_graph(self) -> meta_graph_pb2.MetaGraphDef: 266 return load_meta_graph( 267 saved_model_dir=self.model_config.saved_model_dir, 268 saved_model_tags=self.model_config.saved_model_tags, 269 saved_model_signature_key=self.model_config.saved_model_signature_key) 270 271 @property 272 def input_tensor_info(self) -> Mapping[str, meta_graph_pb2.TensorInfo]: 273 return self.meta_graph.signature_def[ 274 self.model_config.saved_model_signature_key].inputs 275 276 @property 277 def output_tensor_info(self) -> Mapping[str, meta_graph_pb2.TensorInfo]: 278 return self.meta_graph.signature_def[ 279 self.model_config.saved_model_signature_key].outputs 280 281 @property 282 def input_tensort_names(self) -> Sequence[str]: 283 return [info.name for info in self.input_tensor_info.values()] 284 285 @property 286 def output_tensor_names(self) -> Sequence[str]: 287 return [info.name for info in self.output_tensor_info.values()] 288 289 def generate_random_inputs(self, 290 batch_size: Optional[int] = None 291 ) -> Mapping[str, np.ndarray]: 292 batch_size = batch_size or self.model_config.default_batch_size 293 return { 294 tensor_info.name: _generate_random_tensor_v1(tensor_info, batch_size) 295 for tensor_info in self.input_tensor_info.values() 296 } 297 298 def run(self, 299 inputs: Optional[Mapping[str, np.ndarray]] = None, 300 warmup_iterations=10, 301 benchmark_iterations=100, 302 enable_gpu=True) -> TestResult: 303 inputs = inputs or self.generate_random_inputs() 304 config_proto = None 305 if not enable_gpu: 306 config_proto = config_pb2.ConfigProto(device_count={"CPU": 1, "GPU": 0}) 307 logging.info("Running model inference!") 308 with framework_ops.Graph().as_default(): 309 with session.Session(config=config_proto) as sess: 310 importer.import_graph_def(self.meta_graph.graph_def, name="") 311 try: 312 output_tensor_names = self.output_tensor_names 313 for _ in range(warmup_iterations): 314 sess.run(fetches=output_tensor_names, feed_dict=inputs) 315 latency = [] 316 for _ in range(benchmark_iterations): 317 before = time.time() 318 outputs = sess.run(fetches=output_tensor_names, feed_dict=inputs) 319 latency.append(time.time() - before) 320 except Exception as exc: 321 raise RuntimeError("Failed to run model inference! " 322 "Model information: {}".format(str(self))) from exc 323 return TestResult( 324 model_config=self.model_config, 325 enable_gpu=enable_gpu, 326 model_latency=latency, 327 output_names=self.output_tensor_names, 328 output_tensors=outputs) 329 330 331class ModelHandlerV2(_ModelHandlerBase): 332 """Runs a model in TF2.""" 333 334 @property 335 def graph_func(self): 336 graph_func = load_graph_func( 337 saved_model_dir=self.model_config.saved_model_dir, 338 saved_model_tags=self.model_config.saved_model_tags, 339 saved_model_signature_key=self.model_config.saved_model_signature_key) 340 return convert_to_constants.convert_variables_to_constants_v2(graph_func) 341 342 @property 343 def input_tensor_names(self): 344 return [tensor.name for tensor in self.graph_func.inputs] 345 346 @property 347 def output_tensor_names(self): 348 return [tensor.name for tensor in self.graph_func.outputs] 349 350 def generate_random_inputs(self, 351 batch_size: Optional[int] = None 352 ) -> Sequence[framework_ops.Tensor]: 353 batch_size = batch_size or self.model_config.default_batch_size 354 return [ 355 _generate_random_tensor_v2(tensor, batch_size) 356 for tensor in self.graph_func.inputs 357 ] 358 359 def run(self, 360 inputs: Optional[Sequence[framework_ops.Tensor]] = None, 361 warmup_iterations=10, 362 benchmark_iterations=100, 363 enable_gpu=True) -> TestResult: 364 inputs = inputs or self.generate_random_inputs() 365 try: 366 device = "/device:gpu:0" if enable_gpu else "/device:cpu:0" 367 with framework_ops.device(device): 368 for _ in range(warmup_iterations): 369 self.graph_func(*inputs) 370 latency = [] 371 for _ in range(benchmark_iterations): 372 before = time.time() 373 outputs = self.graph_func(*inputs) 374 latency.append(time.time() - before) 375 except Exception as exc: 376 raise RuntimeError("Failed to run model inference! " 377 "Model information: {}".format(str(self))) from exc 378 return TestResult( 379 model_config=self.model_config, 380 enable_gpu=enable_gpu, 381 model_latency=latency, 382 output_names=self.output_tensor_names, 383 output_tensors=outputs) 384 385 386class _TrtModelHandlerBase(_ModelHandlerBase): 387 """Base class for converting and running a model.""" 388 389 def __init__( 390 self, 391 model_config: ModelConfig, 392 trt_convert_params: trt.TrtConversionParams, 393 ): 394 super(_TrtModelHandlerBase, self).__init__(model_config) 395 self._trt_convert_params = trt_convert_params 396 397 self._converter = self._create_converter(trt_convert_params) 398 self._conversion_is_saved = False 399 400 @abc.abstractmethod 401 def _create_converter(self, trt_convert_params: trt.TrtConversionParams): 402 """Creates a converter for the corresponding TF version.""" 403 404 @abc.abstractmethod 405 def _check_conversion(self, conversion_output): 406 """Checks if conversion output has any TensorRT engines.""" 407 408 def _check_contains_trt_engine(self, graph_def: graph_pb2.GraphDef): 409 if "TRTEngineOp" not in [node.op for node in graph_def.node]: 410 raise RuntimeError("Failed to convert to TensorRT! " 411 "Model Information: {}".format(str(self))) 412 413 def __str__(self) -> str: 414 base = super(_TrtModelHandlerBase, self).__str__() 415 return "{}, TrtConversionParams: {}".format(base, 416 str(self._trt_convert_params)) 417 418 @property 419 def trt_convert_params(self) -> trt.TrtConversionParams: 420 return self._trt_convert_params 421 422 @abc.abstractmethod 423 def convert(self, 424 calibration_inputs: Optional[Mapping[str, np.ndarray]] = None, 425 num_runs=1) -> None: 426 """Converts the model with TensorRT and calibrates if using INT8 precision mode. 427 428 Args: 429 calibration_inputs: Mapping from input names to ndarrays in TF1. Or a 430 sequence of tensors in TF2. Used as calibration data. 431 num_runs: Number of calibration runs. 432 """ 433 434 def save(self, 435 output_saved_model_dir: Optional[str] = None, 436 overwrite=True) -> None: 437 """Saves a TensorRT converted model.""" 438 if self._conversion_is_saved and not overwrite: 439 return 440 output_saved_model_dir = output_saved_model_dir or tempfile.mkdtemp() 441 logging.info("Saving TensorRT model to %s!", output_saved_model_dir) 442 self._converter.save(output_saved_model_dir) 443 self._model_config = self.model_config._replace( 444 saved_model_dir=output_saved_model_dir) 445 self._conversion_is_saved = True 446 447 448class TrtModelHandlerV1(_TrtModelHandlerBase, ModelHandlerV1): 449 """Converts a TF1 model with TensorRT and runs the converted model.""" 450 451 def _create_converter(self, trt_convert_params: trt.TrtConversionParams): 452 conversion_nodes_denylist = self.output_tensor_names 453 return trt.TrtGraphConverter( 454 input_saved_model_dir=self.model_config.saved_model_dir, 455 input_saved_model_tags=self.model_config.saved_model_tags, 456 input_saved_model_signature_key=( 457 self.model_config.saved_model_signature_key), 458 nodes_denylist=conversion_nodes_denylist, 459 max_workspace_size_bytes=trt_convert_params.max_workspace_size_bytes, 460 precision_mode=trt_convert_params.precision_mode, 461 minimum_segment_size=trt_convert_params.minimum_segment_size, 462 maximum_cached_engines=trt_convert_params.maximum_cached_engines, 463 use_calibration=trt_convert_params.use_calibration, 464 max_batch_size=self.model_config.default_batch_size, 465 is_dynamic_op=False, 466 ) 467 468 _check_conversion = _TrtModelHandlerBase._check_contains_trt_engine 469 470 def convert(self, 471 calibration_inputs: Optional[Mapping[str, np.ndarray]] = None, 472 num_runs=1) -> None: 473 logging.info("Converting with TensorRT!") 474 self._check_conversion(self._converter.convert()) 475 476 if (self.trt_convert_params.precision_mode == trt.TrtPrecisionMode.INT8 and 477 self.trt_convert_params.use_calibration): 478 logging.info("Calibrating with TensorRT!") 479 if not calibration_inputs: 480 raise ValueError("Must provide calibration data " 481 "when using TensorRT calibration!") 482 try: 483 self._converter.calibrate( 484 fetch_names=self.output_tensor_names, 485 num_runs=num_runs, 486 feed_dict_fn=lambda: calibration_inputs) 487 except Exception as exc: 488 raise RuntimeError("Failed to calibrate! " 489 "Model Information: {}".format(str(self))) from exc 490 491 def run(self, 492 inputs: Optional[Mapping[str, np.ndarray]] = None, 493 warmup_iterations=10, 494 benchmark_iterations=100) -> TestResult: 495 self.save(overwrite=False) 496 self._check_conversion(self.meta_graph.graph_def) 497 logging.info("Running with TensorRT!") 498 test_result = ModelHandlerV1.run( 499 self, inputs, warmup_iterations, benchmark_iterations, enable_gpu=True) 500 return test_result._replace(trt_convert_params=self._trt_convert_params) 501 502 503class TrtModelHandlerV2(_TrtModelHandlerBase, ModelHandlerV2): 504 """Converts a TF2 model with TensorRT and runs the converted model.""" 505 506 def _create_converter(self, trt_convert_params: trt.TrtConversionParams): 507 return trt.TrtGraphConverterV2( 508 input_saved_model_dir=self.model_config.saved_model_dir, 509 input_saved_model_tags=self.model_config.saved_model_tags, 510 input_saved_model_signature_key=( 511 self.model_config.saved_model_signature_key), 512 conversion_params=trt_convert_params) 513 514 def _check_conversion(self, graph_func): 515 graph_def = graph_func.graph.as_graph_def() 516 self._check_contains_trt_engine(graph_def) 517 518 def convert(self, 519 calibration_inputs: Optional[Sequence[ 520 framework_ops.Tensor]] = None, 521 num_runs=1) -> None: 522 logging.info("Converting with TensorRT!") 523 524 calibration_input_fn = None 525 if (self.trt_convert_params.precision_mode == trt.TrtPrecisionMode.INT8 and 526 self.trt_convert_params.use_calibration): 527 logging.info("Calibrating with TensorRT at the same time!") 528 if not calibration_inputs: 529 raise ValueError("Must provide calibration data " 530 "when using TensorRT calibration!") 531 532 def gets_calibration_input(): 533 for _ in range(num_runs): 534 yield calibration_inputs 535 536 calibration_input_fn = gets_calibration_input 537 538 self._check_conversion(self._converter.convert(calibration_input_fn)) 539 540 def run(self, 541 inputs: Optional[Sequence[framework_ops.Tensor]] = None, 542 warmup_iterations=10, 543 benchmark_iterations=100) -> TestResult: 544 self.save(overwrite=False) 545 self._check_conversion(self.graph_func) 546 logging.info("Running with TensorRT!") 547 test_result = ModelHandlerV2.run( 548 self, inputs, warmup_iterations, benchmark_iterations, enable_gpu=True) 549 return test_result._replace(trt_convert_params=self._trt_convert_params) 550 551 552class _ModelHandlerManagerBase(metaclass=abc.ABCMeta): 553 """Manages a series of ModelHandlers for aggregrated testing/benchmarking.""" 554 555 def __init__( 556 self, name: str, model_config: ModelConfig, 557 default_trt_convert_params: trt.TrtConversionParams, 558 trt_convert_params_updater: Callable[[trt.TrtConversionParams], 559 Iterable[trt.TrtConversionParams]]): 560 self._ori_model = self.model_handler_cls(model_config) 561 self._trt_models = [] 562 for trt_convert_params in trt_convert_params_updater( 563 default_trt_convert_params): 564 trt_model = self.trt_model_handler_cls( 565 model_config, trt_convert_params=trt_convert_params) 566 self._trt_models.append(trt_model) 567 568 self._name = name 569 self._result_collection = None 570 571 def __str__(self) -> str: 572 return "Input Model: {}".format(str(self._ori_model)) 573 574 def __repr__(self) -> str: 575 return "{}({})".format(self.__class__.__name__, str(self)) 576 577 @property 578 @classmethod 579 @abc.abstractmethod 580 def model_handler_cls(cls): 581 """The modle handler class. ModelHandleV1/ModelHandlerV2.""" 582 583 @property 584 @classmethod 585 @abc.abstractmethod 586 def trt_model_handler_cls(cls): 587 """The TensorRTmodle handler class. TrtModelHandleV1/TrtModelHandlerV2.""" 588 589 @property 590 def name(self) -> str: 591 return self._name 592 593 @property 594 def model_config(self) -> ModelConfig: 595 return self._ori_model.model_config 596 597 def generate_random_inputs(self, batch_size: Optional[int] = None): 598 return self._ori_model.generate_random_inputs(batch_size) 599 600 def convert(self, calibration_inputs=None, num_runs=1) -> None: 601 """Converts models with TensorRT and calibrates if using INT8 precision mode. 602 603 Args: 604 calibration_inputs: Mapping from input names to ndarrays in TF1. Or a 605 sequence of tensors in TF2. Used as calibration data. 606 num_runs: Number of calibration runs. 607 """ 608 for trt_model in self._trt_models: 609 trt_model.convert(calibration_inputs, num_runs) 610 611 def run(self, 612 inputs=None, 613 warmup_iterations: int = 10, 614 benchmark_iterations: int = 100) -> TestResultCollection: 615 """Runs model inference with provided or randomly generated input tensors. 616 617 Args: 618 inputs: Mapping from names to input ndarrays in TF1. Or a sequence of 619 tensors in TF2. If `None`, ramdomly generated input tensors will be used 620 instead. 621 warmup_iterations: Number of inferences to warm up the runtime. 622 benchmark_iterations: Number of inferences to measure the latency. 623 624 Returns: 625 `TestResultCollection` summarizing latency and numerics information for 626 different TensorRT conversion settings. 627 """ 628 inputs = inputs or self.generate_random_inputs() 629 630 def run_model(model, **kwargs): 631 return model.run(inputs, warmup_iterations, benchmark_iterations, 632 **kwargs) 633 634 # Some models include operations that can only run on GPU. 635 try: 636 cpu_base_result = run_model(self._ori_model, enable_gpu=False) 637 except RuntimeError as err: 638 logging.info("%s cannot run on CPU. Reason: %s.", 639 self._ori_model.model_config, err) 640 cpu_base_result = None 641 gpu_base_result = run_model(self._ori_model, enable_gpu=True) 642 trt_results = list(map(run_model, self._trt_models)) 643 644 return TestResultCollection( 645 test_name=self._name, 646 model_config=self.model_config, 647 cpu_base_result=cpu_base_result, 648 gpu_base_result=gpu_base_result, 649 trt_results=trt_results) 650 651 652class ModelHandlerManagerV1(_ModelHandlerManagerBase): 653 """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF1.""" 654 655 model_handler_cls = ModelHandlerV1 656 trt_model_handler_cls = TrtModelHandlerV1 657 658 659class ModelHandlerManagerV2(_ModelHandlerManagerBase): 660 """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF2.""" 661 662 model_handler_cls = ModelHandlerV2 663 trt_model_handler_cls = TrtModelHandlerV2 664