• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Repository rule for TensorRT configuration.
2
3`tensorrt_configure` depends on the following environment variables:
4
5  * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
6  * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
7"""
8
9load(
10    "//third_party/gpus:cuda_configure.bzl",
11    "find_cuda_config",
12    "lib_name",
13    "make_copy_files_rule",
14)
15load(
16    "//third_party/remote_config:common.bzl",
17    "config_repo_label",
18    "get_cpu_value",
19    "get_host_environ",
20)
21
22_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
23_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH"
24_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
25_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
26_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
27
28_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
29_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
30_TF_TENSORRT_HEADERS_V6 = [
31    "NvInfer.h",
32    "NvUtils.h",
33    "NvInferPlugin.h",
34    "NvInferVersion.h",
35    "NvInferRuntime.h",
36    "NvInferRuntimeCommon.h",
37    "NvInferPluginUtils.h",
38]
39_TF_TENSORRT_HEADERS_V8 = [
40    "NvInfer.h",
41    "NvInferLegacyDims.h",
42    "NvInferImpl.h",
43    "NvUtils.h",
44    "NvInferPlugin.h",
45    "NvInferVersion.h",
46    "NvInferRuntime.h",
47    "NvInferRuntimeCommon.h",
48    "NvInferPluginUtils.h",
49]
50
51_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
52_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
53_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
54
55def _at_least_version(actual_version, required_version):
56    actual = [int(v) for v in actual_version.split(".")]
57    required = [int(v) for v in required_version.split(".")]
58    return actual >= required
59
60def _get_tensorrt_headers(tensorrt_version):
61    if _at_least_version(tensorrt_version, "8"):
62        return _TF_TENSORRT_HEADERS_V8
63    if _at_least_version(tensorrt_version, "6"):
64        return _TF_TENSORRT_HEADERS_V6
65    return _TF_TENSORRT_HEADERS
66
67def _tpl_path(repository_ctx, filename):
68    return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename))
69
70def _tpl(repository_ctx, tpl, substitutions):
71    repository_ctx.template(
72        tpl,
73        _tpl_path(repository_ctx, tpl),
74        substitutions,
75    )
76
77def _create_dummy_repository(repository_ctx):
78    """Create a dummy TensorRT repository."""
79    _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"})
80    _tpl(repository_ctx, "BUILD", {
81        "%{copy_rules}": "",
82        "\":tensorrt_include\"": "",
83        "\":tensorrt_lib\"": "",
84    })
85    _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
86        "%{tensorrt_version}": "",
87    })
88
89    # Copy license file in non-remote build.
90    repository_ctx.template(
91        "LICENSE",
92        Label("//third_party/tensorrt:LICENSE"),
93        {},
94    )
95
96    # Set up tensorrt_config.py, which is used by gen_build_info to provide
97    # build environment info to the API
98    _tpl(
99        repository_ctx,
100        "tensorrt/tensorrt_config.py",
101        _py_tmpl_dict({}),
102    )
103
104def enable_tensorrt(repository_ctx):
105    """Returns whether to build with TensorRT support."""
106    return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
107
108def _get_tensorrt_static_path(repository_ctx):
109    """Returns the path for TensorRT static libraries."""
110    return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None)
111
112def _create_local_tensorrt_repository(repository_ctx):
113    # Resolve all labels before doing any real work. Resolving causes the
114    # function to be restarted with all previous state being lost. This
115    # can easily lead to a O(n^2) runtime in the number of labels.
116    # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
117    find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
118    tpl_paths = {
119        "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
120        "BUILD": _tpl_path(repository_ctx, "BUILD"),
121        "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
122        "tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"),
123    }
124
125    config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
126    trt_version = config["tensorrt_version"]
127    cpu_value = get_cpu_value(repository_ctx)
128
129    # Copy the library and header files.
130    libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
131
132    library_dir = config["tensorrt_library_dir"] + "/"
133    headers = _get_tensorrt_headers(trt_version)
134    include_dir = config["tensorrt_include_dir"] + "/"
135    copy_rules = [
136        make_copy_files_rule(
137            repository_ctx,
138            name = "tensorrt_lib",
139            srcs = [library_dir + library for library in libraries],
140            outs = ["tensorrt/lib/" + library for library in libraries],
141        ),
142        make_copy_files_rule(
143            repository_ctx,
144            name = "tensorrt_include",
145            srcs = [include_dir + header for header in headers],
146            outs = ["tensorrt/include/" + header for header in headers],
147        ),
148    ]
149
150    tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
151    if tensorrt_static_path:
152        tensorrt_static_path = tensorrt_static_path + "/"
153        if _at_least_version(trt_version, "8"):
154            raw_static_library_names = _TF_TENSORRT_LIBS
155        else:
156            raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"]
157        static_library_names = ["%s_static" % name for name in raw_static_library_names]
158        static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in static_library_names]
159        if tensorrt_static_path != None:
160            copy_rules = copy_rules + [
161                make_copy_files_rule(
162                    repository_ctx,
163                    name = "tensorrt_static_lib",
164                    srcs = [tensorrt_static_path + library for library in static_libraries],
165                    outs = ["tensorrt/lib/" + library for library in static_libraries],
166                ),
167            ]
168
169    # Set up config file.
170    repository_ctx.template(
171        "build_defs.bzl",
172        tpl_paths["build_defs.bzl"],
173        {"%{if_tensorrt}": "if_true"},
174    )
175
176    # Set up BUILD file.
177    repository_ctx.template(
178        "BUILD",
179        tpl_paths["BUILD"],
180        {"%{copy_rules}": "\n".join(copy_rules)},
181    )
182
183    # Copy license file in non-remote build.
184    repository_ctx.template(
185        "LICENSE",
186        Label("//third_party/tensorrt:LICENSE"),
187        {},
188    )
189
190    # Set up tensorrt_config.h, which is used by
191    # tensorflow/stream_executor/dso_loader.cc.
192    repository_ctx.template(
193        "tensorrt/include/tensorrt_config.h",
194        tpl_paths["tensorrt/include/tensorrt_config.h"],
195        {"%{tensorrt_version}": trt_version},
196    )
197
198    # Set up tensorrt_config.py, which is used by gen_build_info to provide
199    # build environment info to the API
200    repository_ctx.template(
201        "tensorrt/tensorrt_config.py",
202        tpl_paths["tensorrt/tensorrt_config.py"],
203        _py_tmpl_dict({
204            "tensorrt_version": trt_version,
205        }),
206    )
207
208def _py_tmpl_dict(d):
209    return {"%{tensorrt_config}": str(d)}
210
211def _tensorrt_configure_impl(repository_ctx):
212    """Implementation of the tensorrt_configure repository rule."""
213
214    if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
215        # Forward to the pre-configured remote repository.
216        remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
217        repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {})
218        repository_ctx.template(
219            "build_defs.bzl",
220            config_repo_label(remote_config_repo, ":build_defs.bzl"),
221            {},
222        )
223        repository_ctx.template(
224            "tensorrt/include/tensorrt_config.h",
225            config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
226            {},
227        )
228        repository_ctx.template(
229            "tensorrt/tensorrt_config.py",
230            config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"),
231            {},
232        )
233        repository_ctx.template(
234            "LICENSE",
235            config_repo_label(remote_config_repo, ":LICENSE"),
236            {},
237        )
238        return
239
240    if not enable_tensorrt(repository_ctx):
241        _create_dummy_repository(repository_ctx)
242        return
243
244    _create_local_tensorrt_repository(repository_ctx)
245
246_ENVIRONS = [
247    _TENSORRT_INSTALL_PATH,
248    _TF_TENSORRT_VERSION,
249    _TF_NEED_TENSORRT,
250    _TF_TENSORRT_STATIC_PATH,
251    "TF_CUDA_PATHS",
252]
253
254remote_tensorrt_configure = repository_rule(
255    implementation = _create_local_tensorrt_repository,
256    environ = _ENVIRONS,
257    remotable = True,
258    attrs = {
259        "environ": attr.string_dict(),
260    },
261)
262
263tensorrt_configure = repository_rule(
264    implementation = _tensorrt_configure_impl,
265    environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO],
266)
267"""Detects and configures the local CUDA toolchain.
268
269Add the following to your WORKSPACE FILE:
270
271```python
272tensorrt_configure(name = "local_config_tensorrt")
273```
274
275Args:
276  name: A unique name for this workspace rule.
277"""
278