1"""Repository rule for CUDA autoconfiguration. 2 3`cuda_configure` depends on the following environment variables: 4 5 * `TF_NEED_CUDA`: Whether to enable building with CUDA. 6 * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path 7 * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler. 8 * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for 9 both host and device code compilation if TF_CUDA_CLANG is 1. 10 * `TF_SYSROOT`: The sysroot to use when compiling. 11 * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang 12 compiler and use it to build tensorflow. When this option is set 13 CLANG_CUDA_COMPILER_PATH is ignored. 14 * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is 15 `/usr/local/cuda,usr/`. 16 * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is 17 `/usr/local/cuda`. 18 * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then 19 use the system default. 20 * `TF_CUDNN_VERSION`: The version of the cuDNN library. 21 * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is 22 `/usr/local/cuda`. 23 * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is 24 `3.5,5.2`. 25 * `PYTHON_BIN_PATH`: The python binary path 26""" 27 28load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") 29load( 30 "@bazel_tools//tools/cpp:lib_cc_configure.bzl", 31 "escape_string", 32 "get_env_var", 33) 34load( 35 "@bazel_tools//tools/cpp:windows_cc_configure.bzl", 36 "find_msvc_tool", 37 "find_vc_path", 38 "setup_vc_env_vars", 39) 40 41_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" 42_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" 43_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" 44_TF_SYSROOT = "TF_SYSROOT" 45_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH" 46_TF_CUDA_VERSION = "TF_CUDA_VERSION" 47_TF_CUDNN_VERSION = "TF_CUDNN_VERSION" 48_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH" 49_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" 50_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" 51_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" 52_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" 53 54_DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"] 55 56def to_list_of_strings(elements): 57 """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. 58 59 This is to be used to put a list of strings into the bzl file templates 60 so it gets interpreted as list of strings in Starlark. 61 62 Args: 63 elements: list of string elements 64 65 Returns: 66 single string of elements wrapped in quotes separated by a comma.""" 67 quoted_strings = ["\"" + element + "\"" for element in elements] 68 return ", ".join(quoted_strings) 69 70def verify_build_defines(params): 71 """Verify all variables that crosstool/BUILD.tpl expects are substituted. 72 73 Args: 74 params: dict of variables that will be passed to the BUILD.tpl template. 75 """ 76 missing = [] 77 for param in [ 78 "cxx_builtin_include_directories", 79 "extra_no_canonical_prefixes_flags", 80 "host_compiler_path", 81 "host_compiler_prefix", 82 "host_compiler_warnings", 83 "linker_bin_path", 84 "compiler_deps", 85 "msvc_cl_path", 86 "msvc_env_include", 87 "msvc_env_lib", 88 "msvc_env_path", 89 "msvc_env_tmp", 90 "msvc_lib_path", 91 "msvc_link_path", 92 "msvc_ml_path", 93 "unfiltered_compile_flags", 94 "win_compiler_deps", 95 ]: 96 if ("%{" + param + "}") not in params: 97 missing.append(param) 98 99 if missing: 100 auto_configure_fail( 101 "BUILD.tpl template is missing these variables: " + 102 str(missing) + 103 ".\nWe only got: " + 104 str(params) + 105 ".", 106 ) 107 108def _get_python_bin(repository_ctx): 109 """Gets the python bin path.""" 110 python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) 111 if python_bin != None: 112 return python_bin 113 python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python" 114 python_bin_path = repository_ctx.which(python_bin_name) 115 if python_bin_path != None: 116 return str(python_bin_path) 117 auto_configure_fail( 118 "Cannot find python in PATH, please make sure " + 119 "python is installed and add its directory in PATH, or --define " + 120 "%s='/something/else'.\nPATH=%s" % ( 121 _PYTHON_BIN_PATH, 122 repository_ctx.os.environ.get("PATH", ""), 123 ), 124 ) 125 126def _get_nvcc_tmp_dir_for_windows(repository_ctx): 127 """Return the Windows tmp directory for nvcc to generate intermediate source files.""" 128 escaped_tmp_dir = escape_string( 129 get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( 130 "\\", 131 "\\\\", 132 ), 133 ) 134 return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" 135 136def _get_nvcc_tmp_dir_for_unix(repository_ctx): 137 """Return the UNIX tmp directory for nvcc to generate intermediate source files.""" 138 escaped_tmp_dir = escape_string( 139 get_env_var(repository_ctx, "TMPDIR", "/tmp"), 140 ) 141 return escaped_tmp_dir + "/nvcc_inter_files_tmp_dir" 142 143def _get_msvc_compiler(repository_ctx): 144 vc_path = find_vc_path(repository_ctx) 145 return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") 146 147def _get_win_cuda_defines(repository_ctx): 148 """Return CROSSTOOL defines for Windows""" 149 150 # If we are not on Windows, return fake vaules for Windows specific fields. 151 # This ensures the CROSSTOOL file parser is happy. 152 if not _is_windows(repository_ctx): 153 return { 154 "%{msvc_env_tmp}": "msvc_not_used", 155 "%{msvc_env_path}": "msvc_not_used", 156 "%{msvc_env_include}": "msvc_not_used", 157 "%{msvc_env_lib}": "msvc_not_used", 158 "%{msvc_cl_path}": "msvc_not_used", 159 "%{msvc_ml_path}": "msvc_not_used", 160 "%{msvc_link_path}": "msvc_not_used", 161 "%{msvc_lib_path}": "msvc_not_used", 162 } 163 164 vc_path = find_vc_path(repository_ctx) 165 if not vc_path: 166 auto_configure_fail( 167 "Visual C++ build tools not found on your machine." + 168 "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using", 169 ) 170 return {} 171 172 env = setup_vc_env_vars(repository_ctx, vc_path) 173 escaped_paths = escape_string(env["PATH"]) 174 escaped_include_paths = escape_string(env["INCLUDE"]) 175 escaped_lib_paths = escape_string(env["LIB"]) 176 escaped_tmp_dir = escape_string( 177 get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( 178 "\\", 179 "\\\\", 180 ), 181 ) 182 183 msvc_cl_path = _get_python_bin(repository_ctx) 184 msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( 185 "\\", 186 "/", 187 ) 188 msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( 189 "\\", 190 "/", 191 ) 192 msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( 193 "\\", 194 "/", 195 ) 196 197 # nvcc will generate some temporary source files under %{nvcc_tmp_dir} 198 # The generated files are guaranteed to have unique name, so they can share 199 # the same tmp directory 200 escaped_cxx_include_directories = [ 201 _get_nvcc_tmp_dir_for_windows(repository_ctx), 202 ] 203 for path in escaped_include_paths.split(";"): 204 if path: 205 escaped_cxx_include_directories.append(path) 206 207 return { 208 "%{msvc_env_tmp}": escaped_tmp_dir, 209 "%{msvc_env_path}": escaped_paths, 210 "%{msvc_env_include}": escaped_include_paths, 211 "%{msvc_env_lib}": escaped_lib_paths, 212 "%{msvc_cl_path}": msvc_cl_path, 213 "%{msvc_ml_path}": msvc_ml_path, 214 "%{msvc_link_path}": msvc_link_path, 215 "%{msvc_lib_path}": msvc_lib_path, 216 "%{cxx_builtin_include_directories}": to_list_of_strings( 217 escaped_cxx_include_directories, 218 ), 219 } 220 221# TODO(dzc): Once these functions have been factored out of Bazel's 222# cc_configure.bzl, load them from @bazel_tools instead. 223# BEGIN cc_configure common functions. 224def find_cc(repository_ctx): 225 """Find the C++ compiler.""" 226 if _is_windows(repository_ctx): 227 return _get_msvc_compiler(repository_ctx) 228 229 if _use_cuda_clang(repository_ctx): 230 target_cc_name = "clang" 231 cc_path_envvar = _CLANG_CUDA_COMPILER_PATH 232 if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG): 233 return "extra_tools/bin/clang" 234 else: 235 target_cc_name = "gcc" 236 cc_path_envvar = _GCC_HOST_COMPILER_PATH 237 cc_name = target_cc_name 238 239 if cc_path_envvar in repository_ctx.os.environ: 240 cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() 241 if cc_name_from_env: 242 cc_name = cc_name_from_env 243 if cc_name.startswith("/"): 244 # Absolute path, maybe we should make this supported by our which function. 245 return cc_name 246 cc = repository_ctx.which(cc_name) 247 if cc == None: 248 fail(("Cannot find {}, either correct your path or set the {}" + 249 " environment variable").format(target_cc_name, cc_path_envvar)) 250 return cc 251 252_INC_DIR_MARKER_BEGIN = "#include <...>" 253 254# OSX add " (framework directory)" at the end of line, strip it. 255_OSX_FRAMEWORK_SUFFIX = " (framework directory)" 256_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) 257 258def _cxx_inc_convert(path): 259 """Convert path returned by cc -E xc++ in a complete path.""" 260 path = path.strip() 261 if path.endswith(_OSX_FRAMEWORK_SUFFIX): 262 path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() 263 return path 264 265def _normalize_include_path(repository_ctx, path): 266 """Normalizes include paths before writing them to the crosstool. 267 268 If path points inside the 'crosstool' folder of the repository, a relative 269 path is returned. 270 If path points outside the 'crosstool' folder, an absolute path is returned. 271 """ 272 path = str(repository_ctx.path(path)) 273 crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) 274 275 if path.startswith(crosstool_folder): 276 # We drop the path to "$REPO/crosstool" and a trailing path separator. 277 return path[len(crosstool_folder) + 1:] 278 return path 279 280def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): 281 """Compute the list of default C or C++ include directories.""" 282 if lang_is_cpp: 283 lang = "c++" 284 else: 285 lang = "c" 286 sysroot = [] 287 if tf_sysroot: 288 sysroot += ["--sysroot", tf_sysroot] 289 result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"] + 290 sysroot) 291 index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN) 292 if index1 == -1: 293 return [] 294 index1 = result.stderr.find("\n", index1) 295 if index1 == -1: 296 return [] 297 index2 = result.stderr.rfind("\n ") 298 if index2 == -1 or index2 < index1: 299 return [] 300 index2 = result.stderr.find("\n", index2 + 1) 301 if index2 == -1: 302 inc_dirs = result.stderr[index1 + 1:] 303 else: 304 inc_dirs = result.stderr[index1 + 1:index2].strip() 305 306 return [ 307 _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) 308 for p in inc_dirs.split("\n") 309 ] 310 311def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): 312 """Compute the list of default C and C++ include directories.""" 313 314 # For some reason `clang -xc` sometimes returns include paths that are 315 # different from the ones from `clang -xc++`. (Symlink and a dir) 316 # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists 317 includes_cpp = _get_cxx_inc_directories_impl( 318 repository_ctx, 319 cc, 320 True, 321 tf_sysroot, 322 ) 323 includes_c = _get_cxx_inc_directories_impl( 324 repository_ctx, 325 cc, 326 False, 327 tf_sysroot, 328 ) 329 330 return includes_cpp + [ 331 inc 332 for inc in includes_c 333 if inc not in includes_cpp 334 ] 335 336def auto_configure_fail(msg): 337 """Output failure message when cuda configuration fails.""" 338 red = "\033[0;31m" 339 no_color = "\033[0m" 340 fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) 341 342# END cc_configure common functions (see TODO above). 343 344def _cuda_include_path(repository_ctx, cuda_config): 345 """Generates the Starlark string with cuda include directories. 346 347 Args: 348 repository_ctx: The repository context. 349 cc: The path to the gcc host compiler. 350 351 Returns: 352 A list of the gcc host compiler include directories. 353 """ 354 nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % ( 355 cuda_config.cuda_toolkit_path, 356 ".exe" if cuda_config.cpu_value == "Windows" else "", 357 )) 358 result = repository_ctx.execute([ 359 nvcc_path, 360 "-v", 361 "/dev/null", 362 "-o", 363 "/dev/null", 364 ]) 365 target_dir = "" 366 for one_line in result.stderr.splitlines(): 367 if one_line.startswith("#$ _TARGET_DIR_="): 368 target_dir = ( 369 cuda_config.cuda_toolkit_path + "/" + one_line.replace( 370 "#$ _TARGET_DIR_=", 371 "", 372 ) + "/include" 373 ) 374 inc_entries = [] 375 if target_dir != "": 376 inc_entries.append(target_dir) 377 inc_entries.append(cuda_config.cuda_toolkit_path + "/include") 378 return inc_entries 379 380def enable_cuda(repository_ctx): 381 """Returns whether to build with CUDA support.""" 382 return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False)) 383 384def matches_version(environ_version, detected_version): 385 """Checks whether the user-specified version matches the detected version. 386 387 This function performs a weak matching so that if the user specifies only 388 the 389 major or major and minor versions, the versions are still considered 390 matching 391 if the version parts match. To illustrate: 392 393 environ_version detected_version result 394 ----------------------------------------- 395 5.1.3 5.1.3 True 396 5.1 5.1.3 True 397 5 5.1 True 398 5.1.3 5.1 False 399 5.2.3 5.1.3 False 400 401 Args: 402 environ_version: The version specified by the user via environment 403 variables. 404 detected_version: The version autodetected from the CUDA installation on 405 the system. 406 Returns: True if user-specified version matches detected version and False 407 otherwise. 408 """ 409 environ_version_parts = environ_version.split(".") 410 detected_version_parts = detected_version.split(".") 411 if len(detected_version_parts) < len(environ_version_parts): 412 return False 413 for i, part in enumerate(detected_version_parts): 414 if i >= len(environ_version_parts): 415 break 416 if part != environ_version_parts[i]: 417 return False 418 return True 419 420_NVCC_VERSION_PREFIX = "Cuda compilation tools, release " 421 422_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" 423 424def compute_capabilities(repository_ctx): 425 """Returns a list of strings representing cuda compute capabilities.""" 426 if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ: 427 return _DEFAULT_CUDA_COMPUTE_CAPABILITIES 428 capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES] 429 capabilities = capabilities_str.split(",") 430 for capability in capabilities: 431 # Workaround for Skylark's lack of support for regex. This check should 432 # be equivalent to checking: 433 # if re.match("[0-9]+.[0-9]+", capability) == None: 434 parts = capability.split(".") 435 if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit(): 436 auto_configure_fail("Invalid compute capability: %s" % capability) 437 return capabilities 438 439def get_cpu_value(repository_ctx): 440 """Returns the name of the host operating system. 441 442 Args: 443 repository_ctx: The repository context. 444 445 Returns: 446 A string containing the name of the host operating system. 447 """ 448 os_name = repository_ctx.os.name.lower() 449 if os_name.startswith("mac os"): 450 return "Darwin" 451 if os_name.find("windows") != -1: 452 return "Windows" 453 result = repository_ctx.execute(["uname", "-s"]) 454 return result.stdout.strip() 455 456def _is_windows(repository_ctx): 457 """Returns true if the host operating system is windows.""" 458 return repository_ctx.os.name.lower().find("windows") >= 0 459 460def lib_name(base_name, cpu_value, version = None, static = False): 461 """Constructs the platform-specific name of a library. 462 463 Args: 464 base_name: The name of the library, such as "cudart" 465 cpu_value: The name of the host operating system. 466 version: The version of the library. 467 static: True the library is static or False if it is a shared object. 468 469 Returns: 470 The platform-specific name of the library. 471 """ 472 version = "" if not version else "." + version 473 if cpu_value in ("Linux", "FreeBSD"): 474 if static: 475 return "lib%s.a" % base_name 476 return "lib%s.so%s" % (base_name, version) 477 elif cpu_value == "Windows": 478 return "%s.lib" % base_name 479 elif cpu_value == "Darwin": 480 if static: 481 return "lib%s.a" % base_name 482 return "lib%s%s.dylib" % (base_name, version) 483 else: 484 auto_configure_fail("Invalid cpu_value: %s" % cpu_value) 485 486def find_lib(repository_ctx, paths, check_soname = True): 487 """ 488 Finds a library among a list of potential paths. 489 490 Args: 491 paths: List of paths to inspect. 492 493 Returns: 494 Returns the first path in paths that exist. 495 """ 496 objdump = repository_ctx.which("objdump") 497 mismatches = [] 498 for path in [repository_ctx.path(path) for path in paths]: 499 if not path.exists: 500 continue 501 if check_soname and objdump != None and not _is_windows(repository_ctx): 502 output = repository_ctx.execute([objdump, "-p", str(path)]).stdout 503 output = [line for line in output.splitlines() if "SONAME" in line] 504 sonames = [line.strip().split(" ")[-1] for line in output] 505 if not any([soname == path.basename for soname in sonames]): 506 mismatches.append(str(path)) 507 continue 508 return path 509 if mismatches: 510 auto_configure_fail( 511 "None of the libraries match their SONAME: " + ", ".join(mismatches), 512 ) 513 auto_configure_fail("No library found under: " + ", ".join(paths)) 514 515def _find_cuda_lib( 516 lib, 517 repository_ctx, 518 cpu_value, 519 basedir, 520 version, 521 static = False): 522 """Finds the given CUDA or cuDNN library on the system. 523 524 Args: 525 lib: The name of the library, such as "cudart" 526 repository_ctx: The repository context. 527 cpu_value: The name of the host operating system. 528 basedir: The install directory of CUDA or cuDNN. 529 version: The version of the library. 530 static: True if static library, False if shared object. 531 532 Returns: 533 Returns the path to the library. 534 """ 535 file_name = lib_name(lib, cpu_value, version, static) 536 return find_lib( 537 repository_ctx, 538 ["%s/%s" % (basedir, file_name)], 539 check_soname = version and not static, 540 ) 541 542def _find_libs(repository_ctx, cuda_config): 543 """Returns the CUDA and cuDNN libraries on the system. 544 545 Args: 546 repository_ctx: The repository context. 547 cuda_config: The CUDA config as returned by _get_cuda_config 548 549 Returns: 550 Map of library names to structs of filename and path. 551 """ 552 cpu_value = cuda_config.cpu_value 553 stub_dir = "" if _is_windows(repository_ctx) else "/stubs" 554 return { 555 "cuda": _find_cuda_lib( 556 "cuda", 557 repository_ctx, 558 cpu_value, 559 cuda_config.config["cuda_library_dir"] + stub_dir, 560 None, 561 ), 562 "cudart": _find_cuda_lib( 563 "cudart", 564 repository_ctx, 565 cpu_value, 566 cuda_config.config["cuda_library_dir"], 567 cuda_config.cuda_version, 568 ), 569 "cudart_static": _find_cuda_lib( 570 "cudart_static", 571 repository_ctx, 572 cpu_value, 573 cuda_config.config["cuda_library_dir"], 574 cuda_config.cuda_version, 575 static = True, 576 ), 577 "cublas": _find_cuda_lib( 578 "cublas", 579 repository_ctx, 580 cpu_value, 581 cuda_config.config["cublas_library_dir"], 582 cuda_config.cuda_lib_version, 583 ), 584 "cusolver": _find_cuda_lib( 585 "cusolver", 586 repository_ctx, 587 cpu_value, 588 cuda_config.config["cuda_library_dir"], 589 cuda_config.cuda_lib_version, 590 ), 591 "curand": _find_cuda_lib( 592 "curand", 593 repository_ctx, 594 cpu_value, 595 cuda_config.config["cuda_library_dir"], 596 cuda_config.cuda_lib_version, 597 ), 598 "cufft": _find_cuda_lib( 599 "cufft", 600 repository_ctx, 601 cpu_value, 602 cuda_config.config["cuda_library_dir"], 603 cuda_config.cuda_lib_version, 604 ), 605 "cudnn": _find_cuda_lib( 606 "cudnn", 607 repository_ctx, 608 cpu_value, 609 cuda_config.config["cudnn_library_dir"], 610 cuda_config.cudnn_version, 611 ), 612 "cupti": _find_cuda_lib( 613 "cupti", 614 repository_ctx, 615 cpu_value, 616 cuda_config.config["cupti_library_dir"], 617 cuda_config.cuda_version, 618 ), 619 "cusparse": _find_cuda_lib( 620 "cusparse", 621 repository_ctx, 622 cpu_value, 623 cuda_config.config["cuda_library_dir"], 624 cuda_config.cuda_lib_version, 625 ), 626 } 627 628def _cudart_static_linkopt(cpu_value): 629 """Returns additional platform-specific linkopts for cudart.""" 630 return "" if cpu_value == "Darwin" else "\"-lrt\"," 631 632# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, 633# and nccl_configure.bzl. 634def find_cuda_config(repository_ctx, cuda_libraries): 635 """Returns CUDA config dictionary from running find_cuda_config.py""" 636 exec_result = repository_ctx.execute([ 637 _get_python_bin(repository_ctx), 638 repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")), 639 ] + cuda_libraries) 640 if exec_result.return_code: 641 auto_configure_fail("Failed to run find_cuda_config.py: %s" % exec_result.stderr) 642 643 # Parse the dict from stdout. 644 return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) 645 646def _get_cuda_config(repository_ctx): 647 """Detects and returns information about the CUDA installation on the system. 648 649 Args: 650 repository_ctx: The repository context. 651 652 Returns: 653 A struct containing the following fields: 654 cuda_toolkit_path: The CUDA toolkit installation directory. 655 cudnn_install_basedir: The cuDNN installation directory. 656 cuda_version: The version of CUDA on the system. 657 cudnn_version: The version of cuDNN on the system. 658 compute_capabilities: A list of the system's CUDA compute capabilities. 659 cpu_value: The name of the host operating system. 660 """ 661 config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) 662 cpu_value = get_cpu_value(repository_ctx) 663 toolkit_path = config["cuda_toolkit_path"] 664 665 is_windows = _is_windows(repository_ctx) 666 cuda_version = config["cuda_version"].split(".") 667 cuda_major = cuda_version[0] 668 cuda_minor = cuda_version[1] 669 670 cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) 671 cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] 672 673 # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. 674 # It changed from 'x.y' to just 'x' in CUDA 10.1. 675 if (int(cuda_major), int(cuda_minor)) >= (10, 1): 676 cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major 677 else: 678 cuda_lib_version = cuda_version 679 680 return struct( 681 cuda_toolkit_path = toolkit_path, 682 cuda_version = cuda_version, 683 cudnn_version = cudnn_version, 684 cuda_lib_version = cuda_lib_version, 685 compute_capabilities = compute_capabilities(repository_ctx), 686 cpu_value = cpu_value, 687 config = config, 688 ) 689 690def _tpl(repository_ctx, tpl, substitutions = {}, out = None): 691 if not out: 692 out = tpl.replace(":", "/") 693 repository_ctx.template( 694 out, 695 Label("//third_party/gpus/%s.tpl" % tpl), 696 substitutions, 697 ) 698 699def _file(repository_ctx, label): 700 repository_ctx.template( 701 label.replace(":", "/"), 702 Label("//third_party/gpus/%s.tpl" % label), 703 {}, 704 ) 705 706_DUMMY_CROSSTOOL_BZL_FILE = """ 707def error_gpu_disabled(): 708 fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + 709 "to build with GPU support. Please re-run ./configure and enter 'Y' " + 710 "at the prompt to build with GPU support.") 711 712 native.genrule( 713 name = "error_gen_crosstool", 714 outs = ["CROSSTOOL"], 715 cmd = "echo 'Should not be run.' && exit 1", 716 ) 717 718 native.filegroup( 719 name = "crosstool", 720 srcs = [":CROSSTOOL"], 721 output_licenses = ["unencumbered"], 722 ) 723""" 724 725_DUMMY_CROSSTOOL_BUILD_FILE = """ 726load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") 727 728error_gpu_disabled() 729""" 730 731def _create_dummy_repository(repository_ctx): 732 cpu_value = get_cpu_value(repository_ctx) 733 734 # Set up BUILD file for cuda/. 735 _tpl( 736 repository_ctx, 737 "cuda:build_defs.bzl", 738 { 739 "%{cuda_is_configured}": "False", 740 "%{cuda_extra_copts}": "[]", 741 }, 742 ) 743 _tpl( 744 repository_ctx, 745 "cuda:BUILD", 746 { 747 "%{cuda_driver_lib}": lib_name("cuda", cpu_value), 748 "%{cudart_static_lib}": lib_name( 749 "cudart_static", 750 cpu_value, 751 static = True, 752 ), 753 "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), 754 "%{cudart_lib}": lib_name("cudart", cpu_value), 755 "%{cublas_lib}": lib_name("cublas", cpu_value), 756 "%{cusolver_lib}": lib_name("cusolver", cpu_value), 757 "%{cudnn_lib}": lib_name("cudnn", cpu_value), 758 "%{cufft_lib}": lib_name("cufft", cpu_value), 759 "%{curand_lib}": lib_name("curand", cpu_value), 760 "%{cupti_lib}": lib_name("cupti", cpu_value), 761 "%{cusparse_lib}": lib_name("cusparse", cpu_value), 762 "%{copy_rules}": """ 763filegroup(name="cuda-include") 764filegroup(name="cublas-include") 765filegroup(name="cudnn-include") 766""", 767 }, 768 ) 769 770 # Create dummy files for the CUDA toolkit since they are still required by 771 # tensorflow/core/platform/default/build_config:cuda. 772 repository_ctx.file("cuda/cuda/include/cuda.h") 773 repository_ctx.file("cuda/cuda/include/cublas.h") 774 repository_ctx.file("cuda/cuda/include/cudnn.h") 775 repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h") 776 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) 777 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) 778 repository_ctx.file( 779 "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), 780 ) 781 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) 782 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) 783 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) 784 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) 785 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value)) 786 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value)) 787 repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value)) 788 789 # Set up cuda_config.h, which is used by 790 # tensorflow/stream_executor/dso_loader.cc. 791 _tpl( 792 repository_ctx, 793 "cuda:cuda_config.h", 794 { 795 "%{cuda_version}": "", 796 "%{cuda_lib_version}": "", 797 "%{cudnn_version}": "", 798 "%{cuda_compute_capabilities}": ",".join([ 799 "CudaVersion(\"%s\")" % c 800 for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES 801 ]), 802 "%{cuda_toolkit_path}": "", 803 }, 804 "cuda/cuda/cuda_config.h", 805 ) 806 807 # If cuda_configure is not configured to build with GPU support, and the user 808 # attempts to build with --config=cuda, add a dummy build rule to intercept 809 # this and fail with an actionable error message. 810 repository_ctx.file( 811 "crosstool/error_gpu_disabled.bzl", 812 _DUMMY_CROSSTOOL_BZL_FILE, 813 ) 814 repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) 815 816def _execute( 817 repository_ctx, 818 cmdline, 819 error_msg = None, 820 error_details = None, 821 empty_stdout_fine = False): 822 """Executes an arbitrary shell command. 823 824 Args: 825 repository_ctx: the repository_ctx object 826 cmdline: list of strings, the command to execute 827 error_msg: string, a summary of the error if the command fails 828 error_details: string, details about the error or steps to fix it 829 empty_stdout_fine: bool, if True, an empty stdout result is fine, 830 otherwise it's an error 831 Return: the result of repository_ctx.execute(cmdline) 832 """ 833 result = repository_ctx.execute(cmdline) 834 if result.stderr or not (empty_stdout_fine or result.stdout): 835 auto_configure_fail( 836 "\n".join([ 837 error_msg.strip() if error_msg else "Repository command failed", 838 result.stderr.strip(), 839 error_details if error_details else "", 840 ]), 841 ) 842 return result 843 844def _norm_path(path): 845 """Returns a path with '/' and remove the trailing slash.""" 846 path = path.replace("\\", "/") 847 if path[-1] == "/": 848 path = path[:-1] 849 return path 850 851def make_copy_files_rule(repository_ctx, name, srcs, outs): 852 """Returns a rule to copy a set of files.""" 853 cmds = [] 854 855 # Copy files. 856 for src, out in zip(srcs, outs): 857 cmds.append('cp -f "%s" "$(location %s)"' % (src, out)) 858 outs = [(' "%s",' % out) for out in outs] 859 return """genrule( 860 name = "%s", 861 outs = [ 862%s 863 ], 864 cmd = \"""%s \""", 865)""" % (name, "\n".join(outs), " && \\\n".join(cmds)) 866 867def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): 868 """Returns a rule to recursively copy a directory.""" 869 src_dir = _norm_path(src_dir) 870 out_dir = _norm_path(out_dir) 871 outs = _read_dir(repository_ctx, src_dir) 872 outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] 873 874 # '@D' already contains the relative path for a single file, see 875 # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables 876 out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" 877 return """genrule( 878 name = "%s", 879 outs = [ 880%s 881 ], 882 cmd = \"""cp -rLf "%s/." "%s/" \""", 883)""" % (name, "\n".join(outs), src_dir, out_dir) 884 885def _read_dir(repository_ctx, src_dir): 886 """Returns a string with all files in a directory. 887 888 Finds all files inside a directory, traversing subfolders and following 889 symlinks. The returned string contains the full path of all files 890 separated by line breaks. 891 """ 892 if _is_windows(repository_ctx): 893 src_dir = src_dir.replace("/", "\\") 894 find_result = _execute( 895 repository_ctx, 896 ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], 897 empty_stdout_fine = True, 898 ) 899 900 # src_files will be used in genrule.outs where the paths must 901 # use forward slashes. 902 result = find_result.stdout.replace("\\", "/") 903 else: 904 find_result = _execute( 905 repository_ctx, 906 ["find", src_dir, "-follow", "-type", "f"], 907 empty_stdout_fine = True, 908 ) 909 result = find_result.stdout 910 return sorted(result.splitlines()) 911 912def _flag_enabled(repository_ctx, flag_name): 913 if flag_name in repository_ctx.os.environ: 914 value = repository_ctx.os.environ[flag_name].strip() 915 return value == "1" 916 return False 917 918def _use_cuda_clang(repository_ctx): 919 return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") 920 921def _tf_sysroot(repository_ctx): 922 if _TF_SYSROOT in repository_ctx.os.environ: 923 return repository_ctx.os.environ[_TF_SYSROOT] 924 return "" 925 926def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): 927 capability_flags = [ 928 "--cuda-gpu-arch=sm_" + cap.replace(".", "") 929 for cap in compute_capabilities 930 ] 931 932 # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc 933 # TODO(csigg): Make this consistent with cuda clang and pass unconditionally. 934 return "if_cuda_clang(%s)" % str(capability_flags) 935 936def _create_local_cuda_repository(repository_ctx): 937 """Creates the repository containing files set up to build with CUDA.""" 938 cuda_config = _get_cuda_config(repository_ctx) 939 940 cuda_include_path = cuda_config.config["cuda_include_dir"] 941 cublas_include_path = cuda_config.config["cublas_include_dir"] 942 cudnn_header_dir = cuda_config.config["cudnn_include_dir"] 943 cupti_header_dir = cuda_config.config["cupti_include_dir"] 944 nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"] 945 946 # Create genrule to copy files from the installed CUDA toolkit into execroot. 947 copy_rules = [ 948 make_copy_dir_rule( 949 repository_ctx, 950 name = "cuda-include", 951 src_dir = cuda_include_path, 952 out_dir = "cuda/include", 953 ), 954 make_copy_dir_rule( 955 repository_ctx, 956 name = "cuda-nvvm", 957 src_dir = nvvm_libdevice_dir, 958 out_dir = "cuda/nvvm/libdevice", 959 ), 960 make_copy_dir_rule( 961 repository_ctx, 962 name = "cuda-extras", 963 src_dir = cupti_header_dir, 964 out_dir = "cuda/extras/CUPTI/include", 965 ), 966 ] 967 968 copy_rules.append(make_copy_files_rule( 969 repository_ctx, 970 name = "cublas-include", 971 srcs = [ 972 cublas_include_path + "/cublas.h", 973 cublas_include_path + "/cublas_v2.h", 974 cublas_include_path + "/cublas_api.h", 975 ], 976 outs = [ 977 "cublas/include/cublas.h", 978 "cublas/include/cublas_v2.h", 979 "cublas/include/cublas_api.h", 980 ], 981 )) 982 983 cuda_libs = _find_libs(repository_ctx, cuda_config) 984 cuda_lib_srcs = [] 985 cuda_lib_outs = [] 986 for path in cuda_libs.values(): 987 cuda_lib_srcs.append(str(path)) 988 cuda_lib_outs.append("cuda/lib/" + path.basename) 989 copy_rules.append(make_copy_files_rule( 990 repository_ctx, 991 name = "cuda-lib", 992 srcs = cuda_lib_srcs, 993 outs = cuda_lib_outs, 994 )) 995 996 # copy files mentioned in third_party/nccl/build_defs.bzl.tpl 997 copy_rules.append(make_copy_files_rule( 998 repository_ctx, 999 name = "cuda-bin", 1000 srcs = [ 1001 cuda_config.cuda_toolkit_path + "/bin/" + "crt/link.stub", 1002 cuda_config.cuda_toolkit_path + "/bin/" + "nvlink", 1003 cuda_config.cuda_toolkit_path + "/bin/" + "fatbinary", 1004 cuda_config.cuda_toolkit_path + "/bin/" + "bin2c", 1005 ], 1006 outs = [ 1007 "cuda/bin/" + "crt/link.stub", 1008 "cuda/bin/" + "nvlink", 1009 "cuda/bin/" + "fatbinary", 1010 "cuda/bin/" + "bin2c", 1011 ], 1012 )) 1013 1014 copy_rules.append(make_copy_files_rule( 1015 repository_ctx, 1016 name = "cudnn-include", 1017 srcs = [cudnn_header_dir + "/cudnn.h"], 1018 outs = ["cudnn/include/cudnn.h"], 1019 )) 1020 1021 # Set up BUILD file for cuda/ 1022 _tpl( 1023 repository_ctx, 1024 "cuda:build_defs.bzl", 1025 { 1026 "%{cuda_is_configured}": "True", 1027 "%{cuda_extra_copts}": _compute_cuda_extra_copts( 1028 repository_ctx, 1029 cuda_config.compute_capabilities, 1030 ), 1031 }, 1032 ) 1033 _tpl( 1034 repository_ctx, 1035 "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD", 1036 { 1037 "%{cuda_driver_lib}": cuda_libs["cuda"].basename, 1038 "%{cudart_static_lib}": cuda_libs["cudart_static"].basename, 1039 "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), 1040 "%{cudart_lib}": cuda_libs["cudart"].basename, 1041 "%{cublas_lib}": cuda_libs["cublas"].basename, 1042 "%{cusolver_lib}": cuda_libs["cusolver"].basename, 1043 "%{cudnn_lib}": cuda_libs["cudnn"].basename, 1044 "%{cufft_lib}": cuda_libs["cufft"].basename, 1045 "%{curand_lib}": cuda_libs["curand"].basename, 1046 "%{cupti_lib}": cuda_libs["cupti"].basename, 1047 "%{cusparse_lib}": cuda_libs["cusparse"].basename, 1048 "%{copy_rules}": "\n".join(copy_rules), 1049 }, 1050 "cuda/BUILD", 1051 ) 1052 1053 is_cuda_clang = _use_cuda_clang(repository_ctx) 1054 tf_sysroot = _tf_sysroot(repository_ctx) 1055 1056 should_download_clang = is_cuda_clang and _flag_enabled( 1057 repository_ctx, 1058 _TF_DOWNLOAD_CLANG, 1059 ) 1060 if should_download_clang: 1061 download_clang(repository_ctx, "crosstool/extra_tools") 1062 1063 # Set up crosstool/ 1064 cc = find_cc(repository_ctx) 1065 cc_fullpath = cc if not should_download_clang else "crosstool/" + cc 1066 1067 host_compiler_includes = get_cxx_inc_directories( 1068 repository_ctx, 1069 cc_fullpath, 1070 tf_sysroot, 1071 ) 1072 cuda_defines = {} 1073 cuda_defines["%{builtin_sysroot}"] = tf_sysroot 1074 cuda_defines["%{cuda_toolkit_path}"] = "" 1075 if is_cuda_clang: 1076 cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"] 1077 1078 host_compiler_prefix = "/usr/bin" 1079 if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ: 1080 host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip() 1081 cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix 1082 1083 # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see 1084 # https://github.com/bazelbuild/bazel/issues/760). 1085 # However, this stops our custom clang toolchain from picking the provided 1086 # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded 1087 # toolchain. 1088 # TODO: when bazel stops adding '-B/usr/bin' by default, remove this 1089 # flag from the CROSSTOOL completely (see 1090 # https://github.com/bazelbuild/bazel/issues/5634) 1091 if should_download_clang: 1092 cuda_defines["%{linker_bin_path}"] = "" 1093 else: 1094 cuda_defines["%{linker_bin_path}"] = host_compiler_prefix 1095 1096 cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" 1097 cuda_defines["%{unfiltered_compile_flags}"] = "" 1098 if is_cuda_clang: 1099 cuda_defines["%{host_compiler_path}"] = str(cc) 1100 cuda_defines["%{host_compiler_warnings}"] = """ 1101 # Some parts of the codebase set -Werror and hit this warning, so 1102 # switch it off for now. 1103 "-Wno-invalid-partial-specialization" 1104 """ 1105 cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes) 1106 cuda_defines["%{compiler_deps}"] = ":empty" 1107 cuda_defines["%{win_compiler_deps}"] = ":empty" 1108 repository_ctx.file( 1109 "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", 1110 "", 1111 ) 1112 repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") 1113 else: 1114 cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" 1115 cuda_defines["%{host_compiler_warnings}"] = "" 1116 1117 # nvcc has the system include paths built in and will automatically 1118 # search them; we cannot work around that, so we add the relevant cuda 1119 # system paths to the allowed compiler specific include paths. 1120 cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( 1121 host_compiler_includes + _cuda_include_path( 1122 repository_ctx, 1123 cuda_config, 1124 ) + [cupti_header_dir, cudnn_header_dir], 1125 ) 1126 1127 # For gcc, do not canonicalize system header paths; some versions of gcc 1128 # pick the shortest possible path for system includes when creating the 1129 # .d file - given that includes that are prefixed with "../" multiple 1130 # time quickly grow longer than the root of the tree, this can lead to 1131 # bazel's header check failing. 1132 cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" 1133 1134 nvcc_path = str( 1135 repository_ctx.path("%s/nvcc%s" % ( 1136 cuda_config.config["cuda_binary_dir"], 1137 ".exe" if _is_windows(repository_ctx) else "", 1138 )), 1139 ) 1140 cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" 1141 cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files" 1142 1143 wrapper_defines = { 1144 "%{cpu_compiler}": str(cc), 1145 "%{cuda_version}": cuda_config.cuda_version, 1146 "%{nvcc_path}": nvcc_path, 1147 "%{gcc_host_compiler_path}": str(cc), 1148 "%{cuda_compute_capabilities}": ", ".join( 1149 ["\"%s\"" % c for c in cuda_config.compute_capabilities], 1150 ), 1151 "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), 1152 } 1153 _tpl( 1154 repository_ctx, 1155 "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc", 1156 wrapper_defines, 1157 ) 1158 _tpl( 1159 repository_ctx, 1160 "crosstool:windows/msvc_wrapper_for_nvcc.py", 1161 wrapper_defines, 1162 ) 1163 1164 cuda_defines.update(_get_win_cuda_defines(repository_ctx)) 1165 1166 verify_build_defines(cuda_defines) 1167 1168 # Only expand template variables in the BUILD file 1169 _tpl(repository_ctx, "crosstool:BUILD", cuda_defines) 1170 1171 # No templating of cc_toolchain_config - use attributes and templatize the 1172 # BUILD file. 1173 _file(repository_ctx, "crosstool:cc_toolchain_config.bzl") 1174 1175 # Set up cuda_config.h, which is used by 1176 # tensorflow/stream_executor/dso_loader.cc. 1177 _tpl( 1178 repository_ctx, 1179 "cuda:cuda_config.h", 1180 { 1181 "%{cuda_version}": cuda_config.cuda_version, 1182 "%{cuda_lib_version}": cuda_config.cuda_lib_version, 1183 "%{cudnn_version}": cuda_config.cudnn_version, 1184 "%{cuda_compute_capabilities}": ", ".join([ 1185 "CudaVersion(\"%s\")" % c 1186 for c in cuda_config.compute_capabilities 1187 ]), 1188 "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, 1189 }, 1190 "cuda/cuda/cuda_config.h", 1191 ) 1192 1193def _create_remote_cuda_repository(repository_ctx, remote_config_repo): 1194 """Creates pointers to a remotely configured repo set up to build with CUDA.""" 1195 _tpl( 1196 repository_ctx, 1197 "cuda:build_defs.bzl", 1198 { 1199 "%{cuda_is_configured}": "True", 1200 "%{cuda_extra_copts}": _compute_cuda_extra_copts( 1201 repository_ctx, 1202 compute_capabilities(repository_ctx), 1203 ), 1204 }, 1205 ) 1206 repository_ctx.template( 1207 "cuda/BUILD", 1208 Label(remote_config_repo + "/cuda:BUILD"), 1209 {}, 1210 ) 1211 repository_ctx.template( 1212 "cuda/build_defs.bzl", 1213 Label(remote_config_repo + "/cuda:build_defs.bzl"), 1214 {}, 1215 ) 1216 repository_ctx.template( 1217 "cuda/cuda/cuda_config.h", 1218 Label(remote_config_repo + "/cuda:cuda/cuda_config.h"), 1219 {}, 1220 ) 1221 1222def _cuda_autoconf_impl(repository_ctx): 1223 """Implementation of the cuda_autoconf repository rule.""" 1224 if not enable_cuda(repository_ctx): 1225 _create_dummy_repository(repository_ctx) 1226 elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: 1227 if (_TF_CUDA_VERSION not in repository_ctx.os.environ or 1228 _TF_CUDNN_VERSION not in repository_ctx.os.environ): 1229 auto_configure_fail("%s and %s must also be set if %s is specified" % 1230 (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO)) 1231 _create_remote_cuda_repository( 1232 repository_ctx, 1233 repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO], 1234 ) 1235 else: 1236 _create_local_cuda_repository(repository_ctx) 1237 1238cuda_configure = repository_rule( 1239 implementation = _cuda_autoconf_impl, 1240 environ = [ 1241 _GCC_HOST_COMPILER_PATH, 1242 _GCC_HOST_COMPILER_PREFIX, 1243 _CLANG_CUDA_COMPILER_PATH, 1244 "TF_NEED_CUDA", 1245 "TF_CUDA_CLANG", 1246 _TF_DOWNLOAD_CLANG, 1247 _CUDA_TOOLKIT_PATH, 1248 _CUDNN_INSTALL_PATH, 1249 _TF_CUDA_VERSION, 1250 _TF_CUDNN_VERSION, 1251 _TF_CUDA_COMPUTE_CAPABILITIES, 1252 _TF_CUDA_CONFIG_REPO, 1253 "NVVMIR_LIBRARY_DIR", 1254 _PYTHON_BIN_PATH, 1255 "TMP", 1256 "TMPDIR", 1257 "TF_CUDA_PATHS", 1258 ], 1259) 1260 1261"""Detects and configures the local CUDA toolchain. 1262 1263Add the following to your WORKSPACE FILE: 1264 1265```python 1266cuda_configure(name = "local_config_cuda") 1267``` 1268 1269Args: 1270 name: A unique name for this workspace rule. 1271""" 1272