1"""Repository rule for TensorRT configuration. 2 3`tensorrt_configure` depends on the following environment variables: 4 5 * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version. 6 * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library. 7""" 8 9load( 10 "//third_party/gpus:cuda_configure.bzl", 11 "find_cuda_config", 12 "lib_name", 13 "make_copy_files_rule", 14) 15load( 16 "//third_party/remote_config:common.bzl", 17 "config_repo_label", 18 "get_cpu_value", 19 "get_host_environ", 20) 21 22_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" 23_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH" 24_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO" 25_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION" 26_TF_NEED_TENSORRT = "TF_NEED_TENSORRT" 27 28_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"] 29_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"] 30_TF_TENSORRT_HEADERS_V6 = [ 31 "NvInfer.h", 32 "NvUtils.h", 33 "NvInferPlugin.h", 34 "NvInferVersion.h", 35 "NvInferRuntime.h", 36 "NvInferRuntimeCommon.h", 37 "NvInferPluginUtils.h", 38] 39_TF_TENSORRT_HEADERS_V8 = [ 40 "NvInfer.h", 41 "NvInferLegacyDims.h", 42 "NvInferImpl.h", 43 "NvUtils.h", 44 "NvInferPlugin.h", 45 "NvInferVersion.h", 46 "NvInferRuntime.h", 47 "NvInferRuntimeCommon.h", 48 "NvInferPluginUtils.h", 49] 50 51_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" 52_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" 53_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH" 54 55def _at_least_version(actual_version, required_version): 56 actual = [int(v) for v in actual_version.split(".")] 57 required = [int(v) for v in required_version.split(".")] 58 return actual >= required 59 60def _get_tensorrt_headers(tensorrt_version): 61 if _at_least_version(tensorrt_version, "8"): 62 return _TF_TENSORRT_HEADERS_V8 63 if _at_least_version(tensorrt_version, "6"): 64 return _TF_TENSORRT_HEADERS_V6 65 return _TF_TENSORRT_HEADERS 66 67def _tpl_path(repository_ctx, filename): 68 return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename)) 69 70def _tpl(repository_ctx, tpl, substitutions): 71 repository_ctx.template( 72 tpl, 73 _tpl_path(repository_ctx, tpl), 74 substitutions, 75 ) 76 77def _create_dummy_repository(repository_ctx): 78 """Create a dummy TensorRT repository.""" 79 _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"}) 80 _tpl(repository_ctx, "BUILD", { 81 "%{copy_rules}": "", 82 "\":tensorrt_include\"": "", 83 "\":tensorrt_lib\"": "", 84 }) 85 _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", { 86 "%{tensorrt_version}": "", 87 }) 88 89 # Copy license file in non-remote build. 90 repository_ctx.template( 91 "LICENSE", 92 Label("//third_party/tensorrt:LICENSE"), 93 {}, 94 ) 95 96 # Set up tensorrt_config.py, which is used by gen_build_info to provide 97 # build environment info to the API 98 _tpl( 99 repository_ctx, 100 "tensorrt/tensorrt_config.py", 101 _py_tmpl_dict({}), 102 ) 103 104def enable_tensorrt(repository_ctx): 105 """Returns whether to build with TensorRT support.""" 106 return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False)) 107 108def _get_tensorrt_static_path(repository_ctx): 109 """Returns the path for TensorRT static libraries.""" 110 return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None) 111 112def _create_local_tensorrt_repository(repository_ctx): 113 # Resolve all labels before doing any real work. Resolving causes the 114 # function to be restarted with all previous state being lost. This 115 # can easily lead to a O(n^2) runtime in the number of labels. 116 # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 117 find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64")) 118 tpl_paths = { 119 "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"), 120 "BUILD": _tpl_path(repository_ctx, "BUILD"), 121 "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"), 122 "tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"), 123 } 124 125 config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"]) 126 trt_version = config["tensorrt_version"] 127 cpu_value = get_cpu_value(repository_ctx) 128 129 # Copy the library and header files. 130 libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS] 131 132 library_dir = config["tensorrt_library_dir"] + "/" 133 headers = _get_tensorrt_headers(trt_version) 134 include_dir = config["tensorrt_include_dir"] + "/" 135 copy_rules = [ 136 make_copy_files_rule( 137 repository_ctx, 138 name = "tensorrt_lib", 139 srcs = [library_dir + library for library in libraries], 140 outs = ["tensorrt/lib/" + library for library in libraries], 141 ), 142 make_copy_files_rule( 143 repository_ctx, 144 name = "tensorrt_include", 145 srcs = [include_dir + header for header in headers], 146 outs = ["tensorrt/include/" + header for header in headers], 147 ), 148 ] 149 150 tensorrt_static_path = _get_tensorrt_static_path(repository_ctx) 151 if tensorrt_static_path: 152 tensorrt_static_path = tensorrt_static_path + "/" 153 if _at_least_version(trt_version, "8"): 154 raw_static_library_names = _TF_TENSORRT_LIBS 155 else: 156 raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"] 157 static_library_names = ["%s_static" % name for name in raw_static_library_names] 158 static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in static_library_names] 159 if tensorrt_static_path != None: 160 copy_rules = copy_rules + [ 161 make_copy_files_rule( 162 repository_ctx, 163 name = "tensorrt_static_lib", 164 srcs = [tensorrt_static_path + library for library in static_libraries], 165 outs = ["tensorrt/lib/" + library for library in static_libraries], 166 ), 167 ] 168 169 # Set up config file. 170 repository_ctx.template( 171 "build_defs.bzl", 172 tpl_paths["build_defs.bzl"], 173 {"%{if_tensorrt}": "if_true"}, 174 ) 175 176 # Set up BUILD file. 177 repository_ctx.template( 178 "BUILD", 179 tpl_paths["BUILD"], 180 {"%{copy_rules}": "\n".join(copy_rules)}, 181 ) 182 183 # Copy license file in non-remote build. 184 repository_ctx.template( 185 "LICENSE", 186 Label("//third_party/tensorrt:LICENSE"), 187 {}, 188 ) 189 190 # Set up tensorrt_config.h, which is used by 191 # tensorflow/stream_executor/dso_loader.cc. 192 repository_ctx.template( 193 "tensorrt/include/tensorrt_config.h", 194 tpl_paths["tensorrt/include/tensorrt_config.h"], 195 {"%{tensorrt_version}": trt_version}, 196 ) 197 198 # Set up tensorrt_config.py, which is used by gen_build_info to provide 199 # build environment info to the API 200 repository_ctx.template( 201 "tensorrt/tensorrt_config.py", 202 tpl_paths["tensorrt/tensorrt_config.py"], 203 _py_tmpl_dict({ 204 "tensorrt_version": trt_version, 205 }), 206 ) 207 208def _py_tmpl_dict(d): 209 return {"%{tensorrt_config}": str(d)} 210 211def _tensorrt_configure_impl(repository_ctx): 212 """Implementation of the tensorrt_configure repository rule.""" 213 214 if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None: 215 # Forward to the pre-configured remote repository. 216 remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] 217 repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {}) 218 repository_ctx.template( 219 "build_defs.bzl", 220 config_repo_label(remote_config_repo, ":build_defs.bzl"), 221 {}, 222 ) 223 repository_ctx.template( 224 "tensorrt/include/tensorrt_config.h", 225 config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"), 226 {}, 227 ) 228 repository_ctx.template( 229 "tensorrt/tensorrt_config.py", 230 config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"), 231 {}, 232 ) 233 repository_ctx.template( 234 "LICENSE", 235 config_repo_label(remote_config_repo, ":LICENSE"), 236 {}, 237 ) 238 return 239 240 if not enable_tensorrt(repository_ctx): 241 _create_dummy_repository(repository_ctx) 242 return 243 244 _create_local_tensorrt_repository(repository_ctx) 245 246_ENVIRONS = [ 247 _TENSORRT_INSTALL_PATH, 248 _TF_TENSORRT_VERSION, 249 _TF_NEED_TENSORRT, 250 _TF_TENSORRT_STATIC_PATH, 251 "TF_CUDA_PATHS", 252] 253 254remote_tensorrt_configure = repository_rule( 255 implementation = _create_local_tensorrt_repository, 256 environ = _ENVIRONS, 257 remotable = True, 258 attrs = { 259 "environ": attr.string_dict(), 260 }, 261) 262 263tensorrt_configure = repository_rule( 264 implementation = _tensorrt_configure_impl, 265 environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO], 266) 267"""Detects and configures the local CUDA toolchain. 268 269Add the following to your WORKSPACE FILE: 270 271```python 272tensorrt_configure(name = "local_config_tensorrt") 273``` 274 275Args: 276 name: A unique name for this workspace rule. 277""" 278