• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper of TRTEngineOp."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23import platform
24from tensorflow.python.framework import errors
25
26_tf_trt_so = None
27_module_lock = threading.Lock()
28
29
30def load_trt_ops():
31  """Load TF-TRT op libraries so if it hasn't been loaded already."""
32  global _tf_trt_so
33
34  if platform.system() == "Windows":
35    raise RuntimeError("Windows platforms are not supported")
36
37  with _module_lock:
38    if _tf_trt_so:
39      return
40
41    try:
42      # pylint: disable=g-import-not-at-top,unused-variable
43      # This will call register_op_list() in
44      # tensorflow/python/framework/op_def_registry.py, but it doesn't register
45      # the op or the op kernel in C++ runtime.
46      from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
47      # pylint: enable=g-import-not-at-top,unused-variable
48    except ImportError as e:
49      print("**** Failed to import TF-TRT ops. This is because the binary was "
50            "not built with CUDA or TensorRT enabled. ****")
51      raise e
52
53    try:
54      # pylint: disable=g-import-not-at-top
55      from tensorflow.python.framework import load_library
56      from tensorflow.python.platform import resource_loader
57      # pylint: enable=g-import-not-at-top
58
59      # Loading the shared object will cause registration of the op and the op
60      # kernel if we link TF-TRT dynamically.
61      _tf_trt_so = load_library.load_op_library(
62          resource_loader.get_path_to_datafile("libtftrt.so"))
63    except errors.NotFoundError as e:
64      no_trt_message = (
65          "**** Failed to initialize TensorRT. This is either because the "
66          "TensorRT installation path is not in LD_LIBRARY_PATH, or because "
67          "you do not have it installed. If not installed, please go to "
68          "https://developer.nvidia.com/tensorrt to download and install "
69          "TensorRT ****")
70      print(no_trt_message)
71      raise e
72