• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Repository rule for CUDA autoconfiguration.
2
3`cuda_configure` depends on the following environment variables:
4
5  * `TF_NEED_CUDA`: Whether to enable building with CUDA.
6  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
7  * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
8  * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
9    both host and device code compilation if TF_CUDA_CLANG is 1.
10  * `TF_SYSROOT`: The sysroot to use when compiling.
11  * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang
12    compiler and use it to build tensorflow. When this option is set
13    CLANG_CUDA_COMPILER_PATH is ignored.
14  * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
15    `/usr/local/cuda,usr/`.
16  * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
17    `/usr/local/cuda`.
18  * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
19    use the system default.
20  * `TF_CUDNN_VERSION`: The version of the cuDNN library.
21  * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
22    `/usr/local/cuda`.
23  * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
24    `3.5,5.2`.
25  * `PYTHON_BIN_PATH`: The python binary path
26"""
27
28load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
29load(
30    "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
31    "escape_string",
32    "get_env_var",
33)
34load(
35    "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
36    "find_msvc_tool",
37    "find_vc_path",
38    "setup_vc_env_vars",
39)
40
41_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
42_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
43_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH"
44_TF_SYSROOT = "TF_SYSROOT"
45_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
46_TF_CUDA_VERSION = "TF_CUDA_VERSION"
47_TF_CUDNN_VERSION = "TF_CUDNN_VERSION"
48_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
49_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
50_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
51_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
52_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
53
54_DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
55
56def to_list_of_strings(elements):
57    """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
58
59    This is to be used to put a list of strings into the bzl file templates
60    so it gets interpreted as list of strings in Starlark.
61
62    Args:
63      elements: list of string elements
64
65    Returns:
66      single string of elements wrapped in quotes separated by a comma."""
67    quoted_strings = ["\"" + element + "\"" for element in elements]
68    return ", ".join(quoted_strings)
69
70def verify_build_defines(params):
71    """Verify all variables that crosstool/BUILD.tpl expects are substituted.
72
73    Args:
74      params: dict of variables that will be passed to the BUILD.tpl template.
75    """
76    missing = []
77    for param in [
78        "cxx_builtin_include_directories",
79        "extra_no_canonical_prefixes_flags",
80        "host_compiler_path",
81        "host_compiler_prefix",
82        "host_compiler_warnings",
83        "linker_bin_path",
84        "compiler_deps",
85        "msvc_cl_path",
86        "msvc_env_include",
87        "msvc_env_lib",
88        "msvc_env_path",
89        "msvc_env_tmp",
90        "msvc_lib_path",
91        "msvc_link_path",
92        "msvc_ml_path",
93        "unfiltered_compile_flags",
94        "win_compiler_deps",
95    ]:
96        if ("%{" + param + "}") not in params:
97            missing.append(param)
98
99    if missing:
100        auto_configure_fail(
101            "BUILD.tpl template is missing these variables: " +
102            str(missing) +
103            ".\nWe only got: " +
104            str(params) +
105            ".",
106        )
107
108def _get_python_bin(repository_ctx):
109    """Gets the python bin path."""
110    python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
111    if python_bin != None:
112        return python_bin
113    python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python"
114    python_bin_path = repository_ctx.which(python_bin_name)
115    if python_bin_path != None:
116        return str(python_bin_path)
117    auto_configure_fail(
118        "Cannot find python in PATH, please make sure " +
119        "python is installed and add its directory in PATH, or --define " +
120        "%s='/something/else'.\nPATH=%s" % (
121            _PYTHON_BIN_PATH,
122            repository_ctx.os.environ.get("PATH", ""),
123        ),
124    )
125
126def _get_nvcc_tmp_dir_for_windows(repository_ctx):
127    """Return the Windows tmp directory for nvcc to generate intermediate source files."""
128    escaped_tmp_dir = escape_string(
129        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
130            "\\",
131            "\\\\",
132        ),
133    )
134    return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
135
136def _get_nvcc_tmp_dir_for_unix(repository_ctx):
137    """Return the UNIX tmp directory for nvcc to generate intermediate source files."""
138    escaped_tmp_dir = escape_string(
139        get_env_var(repository_ctx, "TMPDIR", "/tmp"),
140    )
141    return escaped_tmp_dir + "/nvcc_inter_files_tmp_dir"
142
143def _get_msvc_compiler(repository_ctx):
144    vc_path = find_vc_path(repository_ctx)
145    return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
146
147def _get_win_cuda_defines(repository_ctx):
148    """Return CROSSTOOL defines for Windows"""
149
150    # If we are not on Windows, return fake vaules for Windows specific fields.
151    # This ensures the CROSSTOOL file parser is happy.
152    if not _is_windows(repository_ctx):
153        return {
154            "%{msvc_env_tmp}": "msvc_not_used",
155            "%{msvc_env_path}": "msvc_not_used",
156            "%{msvc_env_include}": "msvc_not_used",
157            "%{msvc_env_lib}": "msvc_not_used",
158            "%{msvc_cl_path}": "msvc_not_used",
159            "%{msvc_ml_path}": "msvc_not_used",
160            "%{msvc_link_path}": "msvc_not_used",
161            "%{msvc_lib_path}": "msvc_not_used",
162        }
163
164    vc_path = find_vc_path(repository_ctx)
165    if not vc_path:
166        auto_configure_fail(
167            "Visual C++ build tools not found on your machine." +
168            "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using",
169        )
170        return {}
171
172    env = setup_vc_env_vars(repository_ctx, vc_path)
173    escaped_paths = escape_string(env["PATH"])
174    escaped_include_paths = escape_string(env["INCLUDE"])
175    escaped_lib_paths = escape_string(env["LIB"])
176    escaped_tmp_dir = escape_string(
177        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
178            "\\",
179            "\\\\",
180        ),
181    )
182
183    msvc_cl_path = _get_python_bin(repository_ctx)
184    msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
185        "\\",
186        "/",
187    )
188    msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
189        "\\",
190        "/",
191    )
192    msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
193        "\\",
194        "/",
195    )
196
197    # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
198    # The generated files are guaranteed to have unique name, so they can share
199    # the same tmp directory
200    escaped_cxx_include_directories = [
201        _get_nvcc_tmp_dir_for_windows(repository_ctx),
202    ]
203    for path in escaped_include_paths.split(";"):
204        if path:
205            escaped_cxx_include_directories.append(path)
206
207    return {
208        "%{msvc_env_tmp}": escaped_tmp_dir,
209        "%{msvc_env_path}": escaped_paths,
210        "%{msvc_env_include}": escaped_include_paths,
211        "%{msvc_env_lib}": escaped_lib_paths,
212        "%{msvc_cl_path}": msvc_cl_path,
213        "%{msvc_ml_path}": msvc_ml_path,
214        "%{msvc_link_path}": msvc_link_path,
215        "%{msvc_lib_path}": msvc_lib_path,
216        "%{cxx_builtin_include_directories}": to_list_of_strings(
217            escaped_cxx_include_directories,
218        ),
219    }
220
221# TODO(dzc): Once these functions have been factored out of Bazel's
222# cc_configure.bzl, load them from @bazel_tools instead.
223# BEGIN cc_configure common functions.
224def find_cc(repository_ctx):
225    """Find the C++ compiler."""
226    if _is_windows(repository_ctx):
227        return _get_msvc_compiler(repository_ctx)
228
229    if _use_cuda_clang(repository_ctx):
230        target_cc_name = "clang"
231        cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
232        if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
233            return "extra_tools/bin/clang"
234    else:
235        target_cc_name = "gcc"
236        cc_path_envvar = _GCC_HOST_COMPILER_PATH
237    cc_name = target_cc_name
238
239    if cc_path_envvar in repository_ctx.os.environ:
240        cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
241        if cc_name_from_env:
242            cc_name = cc_name_from_env
243    if cc_name.startswith("/"):
244        # Absolute path, maybe we should make this supported by our which function.
245        return cc_name
246    cc = repository_ctx.which(cc_name)
247    if cc == None:
248        fail(("Cannot find {}, either correct your path or set the {}" +
249              " environment variable").format(target_cc_name, cc_path_envvar))
250    return cc
251
252_INC_DIR_MARKER_BEGIN = "#include <...>"
253
254# OSX add " (framework directory)" at the end of line, strip it.
255_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
256_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
257
258def _cxx_inc_convert(path):
259    """Convert path returned by cc -E xc++ in a complete path."""
260    path = path.strip()
261    if path.endswith(_OSX_FRAMEWORK_SUFFIX):
262        path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
263    return path
264
265def _normalize_include_path(repository_ctx, path):
266    """Normalizes include paths before writing them to the crosstool.
267
268      If path points inside the 'crosstool' folder of the repository, a relative
269      path is returned.
270      If path points outside the 'crosstool' folder, an absolute path is returned.
271      """
272    path = str(repository_ctx.path(path))
273    crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
274
275    if path.startswith(crosstool_folder):
276        # We drop the path to "$REPO/crosstool" and a trailing path separator.
277        return path[len(crosstool_folder) + 1:]
278    return path
279
280def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot):
281    """Compute the list of default C or C++ include directories."""
282    if lang_is_cpp:
283        lang = "c++"
284    else:
285        lang = "c"
286    sysroot = []
287    if tf_sysroot:
288        sysroot += ["--sysroot", tf_sysroot]
289    result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"] +
290                                    sysroot)
291    index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
292    if index1 == -1:
293        return []
294    index1 = result.stderr.find("\n", index1)
295    if index1 == -1:
296        return []
297    index2 = result.stderr.rfind("\n ")
298    if index2 == -1 or index2 < index1:
299        return []
300    index2 = result.stderr.find("\n", index2 + 1)
301    if index2 == -1:
302        inc_dirs = result.stderr[index1 + 1:]
303    else:
304        inc_dirs = result.stderr[index1 + 1:index2].strip()
305
306    return [
307        _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
308        for p in inc_dirs.split("\n")
309    ]
310
311def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot):
312    """Compute the list of default C and C++ include directories."""
313
314    # For some reason `clang -xc` sometimes returns include paths that are
315    # different from the ones from `clang -xc++`. (Symlink and a dir)
316    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
317    includes_cpp = _get_cxx_inc_directories_impl(
318        repository_ctx,
319        cc,
320        True,
321        tf_sysroot,
322    )
323    includes_c = _get_cxx_inc_directories_impl(
324        repository_ctx,
325        cc,
326        False,
327        tf_sysroot,
328    )
329
330    return includes_cpp + [
331        inc
332        for inc in includes_c
333        if inc not in includes_cpp
334    ]
335
336def auto_configure_fail(msg):
337    """Output failure message when cuda configuration fails."""
338    red = "\033[0;31m"
339    no_color = "\033[0m"
340    fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
341
342# END cc_configure common functions (see TODO above).
343
344def _cuda_include_path(repository_ctx, cuda_config):
345    """Generates the Starlark string with cuda include directories.
346
347      Args:
348        repository_ctx: The repository context.
349        cc: The path to the gcc host compiler.
350
351      Returns:
352        A list of the gcc host compiler include directories.
353      """
354    nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
355        cuda_config.cuda_toolkit_path,
356        ".exe" if cuda_config.cpu_value == "Windows" else "",
357    ))
358    result = repository_ctx.execute([
359        nvcc_path,
360        "-v",
361        "/dev/null",
362        "-o",
363        "/dev/null",
364    ])
365    target_dir = ""
366    for one_line in result.stderr.splitlines():
367        if one_line.startswith("#$ _TARGET_DIR_="):
368            target_dir = (
369                cuda_config.cuda_toolkit_path + "/" + one_line.replace(
370                    "#$ _TARGET_DIR_=",
371                    "",
372                ) + "/include"
373            )
374    inc_entries = []
375    if target_dir != "":
376        inc_entries.append(target_dir)
377    inc_entries.append(cuda_config.cuda_toolkit_path + "/include")
378    return inc_entries
379
380def enable_cuda(repository_ctx):
381    """Returns whether to build with CUDA support."""
382    return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False))
383
384def matches_version(environ_version, detected_version):
385    """Checks whether the user-specified version matches the detected version.
386
387      This function performs a weak matching so that if the user specifies only
388      the
389      major or major and minor versions, the versions are still considered
390      matching
391      if the version parts match. To illustrate:
392
393          environ_version  detected_version  result
394          -----------------------------------------
395          5.1.3            5.1.3             True
396          5.1              5.1.3             True
397          5                5.1               True
398          5.1.3            5.1               False
399          5.2.3            5.1.3             False
400
401      Args:
402        environ_version: The version specified by the user via environment
403          variables.
404        detected_version: The version autodetected from the CUDA installation on
405          the system.
406      Returns: True if user-specified version matches detected version and False
407        otherwise.
408    """
409    environ_version_parts = environ_version.split(".")
410    detected_version_parts = detected_version.split(".")
411    if len(detected_version_parts) < len(environ_version_parts):
412        return False
413    for i, part in enumerate(detected_version_parts):
414        if i >= len(environ_version_parts):
415            break
416        if part != environ_version_parts[i]:
417            return False
418    return True
419
420_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
421
422_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
423
424def compute_capabilities(repository_ctx):
425    """Returns a list of strings representing cuda compute capabilities."""
426    if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
427        return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
428    capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
429    capabilities = capabilities_str.split(",")
430    for capability in capabilities:
431        # Workaround for Skylark's lack of support for regex. This check should
432        # be equivalent to checking:
433        #     if re.match("[0-9]+.[0-9]+", capability) == None:
434        parts = capability.split(".")
435        if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
436            auto_configure_fail("Invalid compute capability: %s" % capability)
437    return capabilities
438
439def get_cpu_value(repository_ctx):
440    """Returns the name of the host operating system.
441
442      Args:
443        repository_ctx: The repository context.
444
445      Returns:
446        A string containing the name of the host operating system.
447      """
448    os_name = repository_ctx.os.name.lower()
449    if os_name.startswith("mac os"):
450        return "Darwin"
451    if os_name.find("windows") != -1:
452        return "Windows"
453    result = repository_ctx.execute(["uname", "-s"])
454    return result.stdout.strip()
455
456def _is_windows(repository_ctx):
457    """Returns true if the host operating system is windows."""
458    return repository_ctx.os.name.lower().find("windows") >= 0
459
460def lib_name(base_name, cpu_value, version = None, static = False):
461    """Constructs the platform-specific name of a library.
462
463      Args:
464        base_name: The name of the library, such as "cudart"
465        cpu_value: The name of the host operating system.
466        version: The version of the library.
467        static: True the library is static or False if it is a shared object.
468
469      Returns:
470        The platform-specific name of the library.
471      """
472    version = "" if not version else "." + version
473    if cpu_value in ("Linux", "FreeBSD"):
474        if static:
475            return "lib%s.a" % base_name
476        return "lib%s.so%s" % (base_name, version)
477    elif cpu_value == "Windows":
478        return "%s.lib" % base_name
479    elif cpu_value == "Darwin":
480        if static:
481            return "lib%s.a" % base_name
482        return "lib%s%s.dylib" % (base_name, version)
483    else:
484        auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
485
486def find_lib(repository_ctx, paths, check_soname = True):
487    """
488      Finds a library among a list of potential paths.
489
490      Args:
491        paths: List of paths to inspect.
492
493      Returns:
494        Returns the first path in paths that exist.
495    """
496    objdump = repository_ctx.which("objdump")
497    mismatches = []
498    for path in [repository_ctx.path(path) for path in paths]:
499        if not path.exists:
500            continue
501        if check_soname and objdump != None and not _is_windows(repository_ctx):
502            output = repository_ctx.execute([objdump, "-p", str(path)]).stdout
503            output = [line for line in output.splitlines() if "SONAME" in line]
504            sonames = [line.strip().split(" ")[-1] for line in output]
505            if not any([soname == path.basename for soname in sonames]):
506                mismatches.append(str(path))
507                continue
508        return path
509    if mismatches:
510        auto_configure_fail(
511            "None of the libraries match their SONAME: " + ", ".join(mismatches),
512        )
513    auto_configure_fail("No library found under: " + ", ".join(paths))
514
515def _find_cuda_lib(
516        lib,
517        repository_ctx,
518        cpu_value,
519        basedir,
520        version,
521        static = False):
522    """Finds the given CUDA or cuDNN library on the system.
523
524      Args:
525        lib: The name of the library, such as "cudart"
526        repository_ctx: The repository context.
527        cpu_value: The name of the host operating system.
528        basedir: The install directory of CUDA or cuDNN.
529        version: The version of the library.
530        static: True if static library, False if shared object.
531
532      Returns:
533        Returns the path to the library.
534      """
535    file_name = lib_name(lib, cpu_value, version, static)
536    return find_lib(
537        repository_ctx,
538        ["%s/%s" % (basedir, file_name)],
539        check_soname = version and not static,
540    )
541
542def _find_libs(repository_ctx, cuda_config):
543    """Returns the CUDA and cuDNN libraries on the system.
544
545      Args:
546        repository_ctx: The repository context.
547        cuda_config: The CUDA config as returned by _get_cuda_config
548
549      Returns:
550        Map of library names to structs of filename and path.
551      """
552    cpu_value = cuda_config.cpu_value
553    stub_dir = "" if _is_windows(repository_ctx) else "/stubs"
554    return {
555        "cuda": _find_cuda_lib(
556            "cuda",
557            repository_ctx,
558            cpu_value,
559            cuda_config.config["cuda_library_dir"] + stub_dir,
560            None,
561        ),
562        "cudart": _find_cuda_lib(
563            "cudart",
564            repository_ctx,
565            cpu_value,
566            cuda_config.config["cuda_library_dir"],
567            cuda_config.cuda_version,
568        ),
569        "cudart_static": _find_cuda_lib(
570            "cudart_static",
571            repository_ctx,
572            cpu_value,
573            cuda_config.config["cuda_library_dir"],
574            cuda_config.cuda_version,
575            static = True,
576        ),
577        "cublas": _find_cuda_lib(
578            "cublas",
579            repository_ctx,
580            cpu_value,
581            cuda_config.config["cublas_library_dir"],
582            cuda_config.cuda_lib_version,
583        ),
584        "cusolver": _find_cuda_lib(
585            "cusolver",
586            repository_ctx,
587            cpu_value,
588            cuda_config.config["cuda_library_dir"],
589            cuda_config.cuda_lib_version,
590        ),
591        "curand": _find_cuda_lib(
592            "curand",
593            repository_ctx,
594            cpu_value,
595            cuda_config.config["cuda_library_dir"],
596            cuda_config.cuda_lib_version,
597        ),
598        "cufft": _find_cuda_lib(
599            "cufft",
600            repository_ctx,
601            cpu_value,
602            cuda_config.config["cuda_library_dir"],
603            cuda_config.cuda_lib_version,
604        ),
605        "cudnn": _find_cuda_lib(
606            "cudnn",
607            repository_ctx,
608            cpu_value,
609            cuda_config.config["cudnn_library_dir"],
610            cuda_config.cudnn_version,
611        ),
612        "cupti": _find_cuda_lib(
613            "cupti",
614            repository_ctx,
615            cpu_value,
616            cuda_config.config["cupti_library_dir"],
617            cuda_config.cuda_version,
618        ),
619        "cusparse": _find_cuda_lib(
620            "cusparse",
621            repository_ctx,
622            cpu_value,
623            cuda_config.config["cuda_library_dir"],
624            cuda_config.cuda_lib_version,
625        ),
626    }
627
628def _cudart_static_linkopt(cpu_value):
629    """Returns additional platform-specific linkopts for cudart."""
630    return "" if cpu_value == "Darwin" else "\"-lrt\","
631
632# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
633# and nccl_configure.bzl.
634def find_cuda_config(repository_ctx, cuda_libraries):
635    """Returns CUDA config dictionary from running find_cuda_config.py"""
636    exec_result = repository_ctx.execute([
637        _get_python_bin(repository_ctx),
638        repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")),
639    ] + cuda_libraries)
640    if exec_result.return_code:
641        auto_configure_fail("Failed to run find_cuda_config.py: %s" % exec_result.stderr)
642
643    # Parse the dict from stdout.
644    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
645
646def _get_cuda_config(repository_ctx):
647    """Detects and returns information about the CUDA installation on the system.
648
649      Args:
650        repository_ctx: The repository context.
651
652      Returns:
653        A struct containing the following fields:
654          cuda_toolkit_path: The CUDA toolkit installation directory.
655          cudnn_install_basedir: The cuDNN installation directory.
656          cuda_version: The version of CUDA on the system.
657          cudnn_version: The version of cuDNN on the system.
658          compute_capabilities: A list of the system's CUDA compute capabilities.
659          cpu_value: The name of the host operating system.
660      """
661    config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
662    cpu_value = get_cpu_value(repository_ctx)
663    toolkit_path = config["cuda_toolkit_path"]
664
665    is_windows = _is_windows(repository_ctx)
666    cuda_version = config["cuda_version"].split(".")
667    cuda_major = cuda_version[0]
668    cuda_minor = cuda_version[1]
669
670    cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
671    cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
672
673    # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
674    # It changed from 'x.y' to just 'x' in CUDA 10.1.
675    if (int(cuda_major), int(cuda_minor)) >= (10, 1):
676        cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
677    else:
678        cuda_lib_version = cuda_version
679
680    return struct(
681        cuda_toolkit_path = toolkit_path,
682        cuda_version = cuda_version,
683        cudnn_version = cudnn_version,
684        cuda_lib_version = cuda_lib_version,
685        compute_capabilities = compute_capabilities(repository_ctx),
686        cpu_value = cpu_value,
687        config = config,
688    )
689
690def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
691    if not out:
692        out = tpl.replace(":", "/")
693    repository_ctx.template(
694        out,
695        Label("//third_party/gpus/%s.tpl" % tpl),
696        substitutions,
697    )
698
699def _file(repository_ctx, label):
700    repository_ctx.template(
701        label.replace(":", "/"),
702        Label("//third_party/gpus/%s.tpl" % label),
703        {},
704    )
705
706_DUMMY_CROSSTOOL_BZL_FILE = """
707def error_gpu_disabled():
708  fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
709       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
710       "at the prompt to build with GPU support.")
711
712  native.genrule(
713      name = "error_gen_crosstool",
714      outs = ["CROSSTOOL"],
715      cmd = "echo 'Should not be run.' && exit 1",
716  )
717
718  native.filegroup(
719      name = "crosstool",
720      srcs = [":CROSSTOOL"],
721      output_licenses = ["unencumbered"],
722  )
723"""
724
725_DUMMY_CROSSTOOL_BUILD_FILE = """
726load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
727
728error_gpu_disabled()
729"""
730
731def _create_dummy_repository(repository_ctx):
732    cpu_value = get_cpu_value(repository_ctx)
733
734    # Set up BUILD file for cuda/.
735    _tpl(
736        repository_ctx,
737        "cuda:build_defs.bzl",
738        {
739            "%{cuda_is_configured}": "False",
740            "%{cuda_extra_copts}": "[]",
741        },
742    )
743    _tpl(
744        repository_ctx,
745        "cuda:BUILD",
746        {
747            "%{cuda_driver_lib}": lib_name("cuda", cpu_value),
748            "%{cudart_static_lib}": lib_name(
749                "cudart_static",
750                cpu_value,
751                static = True,
752            ),
753            "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
754            "%{cudart_lib}": lib_name("cudart", cpu_value),
755            "%{cublas_lib}": lib_name("cublas", cpu_value),
756            "%{cusolver_lib}": lib_name("cusolver", cpu_value),
757            "%{cudnn_lib}": lib_name("cudnn", cpu_value),
758            "%{cufft_lib}": lib_name("cufft", cpu_value),
759            "%{curand_lib}": lib_name("curand", cpu_value),
760            "%{cupti_lib}": lib_name("cupti", cpu_value),
761            "%{cusparse_lib}": lib_name("cusparse", cpu_value),
762            "%{copy_rules}": """
763filegroup(name="cuda-include")
764filegroup(name="cublas-include")
765filegroup(name="cudnn-include")
766""",
767        },
768    )
769
770    # Create dummy files for the CUDA toolkit since they are still required by
771    # tensorflow/core/platform/default/build_config:cuda.
772    repository_ctx.file("cuda/cuda/include/cuda.h")
773    repository_ctx.file("cuda/cuda/include/cublas.h")
774    repository_ctx.file("cuda/cuda/include/cudnn.h")
775    repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h")
776    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value))
777    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value))
778    repository_ctx.file(
779        "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
780    )
781    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
782    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
783    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
784    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
785    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value))
786    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value))
787    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value))
788
789    # Set up cuda_config.h, which is used by
790    # tensorflow/stream_executor/dso_loader.cc.
791    _tpl(
792        repository_ctx,
793        "cuda:cuda_config.h",
794        {
795            "%{cuda_version}": "",
796            "%{cuda_lib_version}": "",
797            "%{cudnn_version}": "",
798            "%{cuda_compute_capabilities}": ",".join([
799                "CudaVersion(\"%s\")" % c
800                for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
801            ]),
802            "%{cuda_toolkit_path}": "",
803        },
804        "cuda/cuda/cuda_config.h",
805    )
806
807    # If cuda_configure is not configured to build with GPU support, and the user
808    # attempts to build with --config=cuda, add a dummy build rule to intercept
809    # this and fail with an actionable error message.
810    repository_ctx.file(
811        "crosstool/error_gpu_disabled.bzl",
812        _DUMMY_CROSSTOOL_BZL_FILE,
813    )
814    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
815
816def _execute(
817        repository_ctx,
818        cmdline,
819        error_msg = None,
820        error_details = None,
821        empty_stdout_fine = False):
822    """Executes an arbitrary shell command.
823
824      Args:
825        repository_ctx: the repository_ctx object
826        cmdline: list of strings, the command to execute
827        error_msg: string, a summary of the error if the command fails
828        error_details: string, details about the error or steps to fix it
829        empty_stdout_fine: bool, if True, an empty stdout result is fine,
830          otherwise it's an error
831      Return: the result of repository_ctx.execute(cmdline)
832    """
833    result = repository_ctx.execute(cmdline)
834    if result.stderr or not (empty_stdout_fine or result.stdout):
835        auto_configure_fail(
836            "\n".join([
837                error_msg.strip() if error_msg else "Repository command failed",
838                result.stderr.strip(),
839                error_details if error_details else "",
840            ]),
841        )
842    return result
843
844def _norm_path(path):
845    """Returns a path with '/' and remove the trailing slash."""
846    path = path.replace("\\", "/")
847    if path[-1] == "/":
848        path = path[:-1]
849    return path
850
851def make_copy_files_rule(repository_ctx, name, srcs, outs):
852    """Returns a rule to copy a set of files."""
853    cmds = []
854
855    # Copy files.
856    for src, out in zip(srcs, outs):
857        cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
858    outs = [('        "%s",' % out) for out in outs]
859    return """genrule(
860    name = "%s",
861    outs = [
862%s
863    ],
864    cmd = \"""%s \""",
865)""" % (name, "\n".join(outs), " && \\\n".join(cmds))
866
867def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
868    """Returns a rule to recursively copy a directory."""
869    src_dir = _norm_path(src_dir)
870    out_dir = _norm_path(out_dir)
871    outs = _read_dir(repository_ctx, src_dir)
872    outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]
873
874    # '@D' already contains the relative path for a single file, see
875    # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
876    out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
877    return """genrule(
878    name = "%s",
879    outs = [
880%s
881    ],
882    cmd = \"""cp -rLf "%s/." "%s/" \""",
883)""" % (name, "\n".join(outs), src_dir, out_dir)
884
885def _read_dir(repository_ctx, src_dir):
886    """Returns a string with all files in a directory.
887
888      Finds all files inside a directory, traversing subfolders and following
889      symlinks. The returned string contains the full path of all files
890      separated by line breaks.
891      """
892    if _is_windows(repository_ctx):
893        src_dir = src_dir.replace("/", "\\")
894        find_result = _execute(
895            repository_ctx,
896            ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
897            empty_stdout_fine = True,
898        )
899
900        # src_files will be used in genrule.outs where the paths must
901        # use forward slashes.
902        result = find_result.stdout.replace("\\", "/")
903    else:
904        find_result = _execute(
905            repository_ctx,
906            ["find", src_dir, "-follow", "-type", "f"],
907            empty_stdout_fine = True,
908        )
909        result = find_result.stdout
910    return sorted(result.splitlines())
911
912def _flag_enabled(repository_ctx, flag_name):
913    if flag_name in repository_ctx.os.environ:
914        value = repository_ctx.os.environ[flag_name].strip()
915        return value == "1"
916    return False
917
918def _use_cuda_clang(repository_ctx):
919    return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
920
921def _tf_sysroot(repository_ctx):
922    if _TF_SYSROOT in repository_ctx.os.environ:
923        return repository_ctx.os.environ[_TF_SYSROOT]
924    return ""
925
926def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
927    capability_flags = [
928        "--cuda-gpu-arch=sm_" + cap.replace(".", "")
929        for cap in compute_capabilities
930    ]
931
932    # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
933    # TODO(csigg): Make this consistent with cuda clang and pass unconditionally.
934    return "if_cuda_clang(%s)" % str(capability_flags)
935
936def _create_local_cuda_repository(repository_ctx):
937    """Creates the repository containing files set up to build with CUDA."""
938    cuda_config = _get_cuda_config(repository_ctx)
939
940    cuda_include_path = cuda_config.config["cuda_include_dir"]
941    cublas_include_path = cuda_config.config["cublas_include_dir"]
942    cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
943    cupti_header_dir = cuda_config.config["cupti_include_dir"]
944    nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]
945
946    # Create genrule to copy files from the installed CUDA toolkit into execroot.
947    copy_rules = [
948        make_copy_dir_rule(
949            repository_ctx,
950            name = "cuda-include",
951            src_dir = cuda_include_path,
952            out_dir = "cuda/include",
953        ),
954        make_copy_dir_rule(
955            repository_ctx,
956            name = "cuda-nvvm",
957            src_dir = nvvm_libdevice_dir,
958            out_dir = "cuda/nvvm/libdevice",
959        ),
960        make_copy_dir_rule(
961            repository_ctx,
962            name = "cuda-extras",
963            src_dir = cupti_header_dir,
964            out_dir = "cuda/extras/CUPTI/include",
965        ),
966    ]
967
968    copy_rules.append(make_copy_files_rule(
969        repository_ctx,
970        name = "cublas-include",
971        srcs = [
972            cublas_include_path + "/cublas.h",
973            cublas_include_path + "/cublas_v2.h",
974            cublas_include_path + "/cublas_api.h",
975        ],
976        outs = [
977            "cublas/include/cublas.h",
978            "cublas/include/cublas_v2.h",
979            "cublas/include/cublas_api.h",
980        ],
981    ))
982
983    cuda_libs = _find_libs(repository_ctx, cuda_config)
984    cuda_lib_srcs = []
985    cuda_lib_outs = []
986    for path in cuda_libs.values():
987        cuda_lib_srcs.append(str(path))
988        cuda_lib_outs.append("cuda/lib/" + path.basename)
989    copy_rules.append(make_copy_files_rule(
990        repository_ctx,
991        name = "cuda-lib",
992        srcs = cuda_lib_srcs,
993        outs = cuda_lib_outs,
994    ))
995
996    # copy files mentioned in third_party/nccl/build_defs.bzl.tpl
997    copy_rules.append(make_copy_files_rule(
998        repository_ctx,
999        name = "cuda-bin",
1000        srcs = [
1001            cuda_config.cuda_toolkit_path + "/bin/" + "crt/link.stub",
1002            cuda_config.cuda_toolkit_path + "/bin/" + "nvlink",
1003            cuda_config.cuda_toolkit_path + "/bin/" + "fatbinary",
1004            cuda_config.cuda_toolkit_path + "/bin/" + "bin2c",
1005        ],
1006        outs = [
1007            "cuda/bin/" + "crt/link.stub",
1008            "cuda/bin/" + "nvlink",
1009            "cuda/bin/" + "fatbinary",
1010            "cuda/bin/" + "bin2c",
1011        ],
1012    ))
1013
1014    copy_rules.append(make_copy_files_rule(
1015        repository_ctx,
1016        name = "cudnn-include",
1017        srcs = [cudnn_header_dir + "/cudnn.h"],
1018        outs = ["cudnn/include/cudnn.h"],
1019    ))
1020
1021    # Set up BUILD file for cuda/
1022    _tpl(
1023        repository_ctx,
1024        "cuda:build_defs.bzl",
1025        {
1026            "%{cuda_is_configured}": "True",
1027            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1028                repository_ctx,
1029                cuda_config.compute_capabilities,
1030            ),
1031        },
1032    )
1033    _tpl(
1034        repository_ctx,
1035        "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
1036        {
1037            "%{cuda_driver_lib}": cuda_libs["cuda"].basename,
1038            "%{cudart_static_lib}": cuda_libs["cudart_static"].basename,
1039            "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
1040            "%{cudart_lib}": cuda_libs["cudart"].basename,
1041            "%{cublas_lib}": cuda_libs["cublas"].basename,
1042            "%{cusolver_lib}": cuda_libs["cusolver"].basename,
1043            "%{cudnn_lib}": cuda_libs["cudnn"].basename,
1044            "%{cufft_lib}": cuda_libs["cufft"].basename,
1045            "%{curand_lib}": cuda_libs["curand"].basename,
1046            "%{cupti_lib}": cuda_libs["cupti"].basename,
1047            "%{cusparse_lib}": cuda_libs["cusparse"].basename,
1048            "%{copy_rules}": "\n".join(copy_rules),
1049        },
1050        "cuda/BUILD",
1051    )
1052
1053    is_cuda_clang = _use_cuda_clang(repository_ctx)
1054    tf_sysroot = _tf_sysroot(repository_ctx)
1055
1056    should_download_clang = is_cuda_clang and _flag_enabled(
1057        repository_ctx,
1058        _TF_DOWNLOAD_CLANG,
1059    )
1060    if should_download_clang:
1061        download_clang(repository_ctx, "crosstool/extra_tools")
1062
1063    # Set up crosstool/
1064    cc = find_cc(repository_ctx)
1065    cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
1066
1067    host_compiler_includes = get_cxx_inc_directories(
1068        repository_ctx,
1069        cc_fullpath,
1070        tf_sysroot,
1071    )
1072    cuda_defines = {}
1073    cuda_defines["%{builtin_sysroot}"] = tf_sysroot
1074    cuda_defines["%{cuda_toolkit_path}"] = ""
1075    if is_cuda_clang:
1076        cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
1077
1078    host_compiler_prefix = "/usr/bin"
1079    if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
1080        host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
1081    cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
1082
1083    # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
1084    # https://github.com/bazelbuild/bazel/issues/760).
1085    # However, this stops our custom clang toolchain from picking the provided
1086    # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
1087    # toolchain.
1088    # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
1089    #       flag from the CROSSTOOL completely (see
1090    #       https://github.com/bazelbuild/bazel/issues/5634)
1091    if should_download_clang:
1092        cuda_defines["%{linker_bin_path}"] = ""
1093    else:
1094        cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
1095
1096    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
1097    cuda_defines["%{unfiltered_compile_flags}"] = ""
1098    if is_cuda_clang:
1099        cuda_defines["%{host_compiler_path}"] = str(cc)
1100        cuda_defines["%{host_compiler_warnings}"] = """
1101        # Some parts of the codebase set -Werror and hit this warning, so
1102        # switch it off for now.
1103        "-Wno-invalid-partial-specialization"
1104    """
1105        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes)
1106        cuda_defines["%{compiler_deps}"] = ":empty"
1107        cuda_defines["%{win_compiler_deps}"] = ":empty"
1108        repository_ctx.file(
1109            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1110            "",
1111        )
1112        repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
1113    else:
1114        cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
1115        cuda_defines["%{host_compiler_warnings}"] = ""
1116
1117        # nvcc has the system include paths built in and will automatically
1118        # search them; we cannot work around that, so we add the relevant cuda
1119        # system paths to the allowed compiler specific include paths.
1120        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
1121            host_compiler_includes + _cuda_include_path(
1122                repository_ctx,
1123                cuda_config,
1124            ) + [cupti_header_dir, cudnn_header_dir],
1125        )
1126
1127        # For gcc, do not canonicalize system header paths; some versions of gcc
1128        # pick the shortest possible path for system includes when creating the
1129        # .d file - given that includes that are prefixed with "../" multiple
1130        # time quickly grow longer than the root of the tree, this can lead to
1131        # bazel's header check failing.
1132        cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
1133
1134        nvcc_path = str(
1135            repository_ctx.path("%s/nvcc%s" % (
1136                cuda_config.config["cuda_binary_dir"],
1137                ".exe" if _is_windows(repository_ctx) else "",
1138            )),
1139        )
1140        cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc"
1141        cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files"
1142
1143        wrapper_defines = {
1144            "%{cpu_compiler}": str(cc),
1145            "%{cuda_version}": cuda_config.cuda_version,
1146            "%{nvcc_path}": nvcc_path,
1147            "%{gcc_host_compiler_path}": str(cc),
1148            "%{cuda_compute_capabilities}": ", ".join(
1149                ["\"%s\"" % c for c in cuda_config.compute_capabilities],
1150            ),
1151            "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
1152        }
1153        _tpl(
1154            repository_ctx,
1155            "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
1156            wrapper_defines,
1157        )
1158        _tpl(
1159            repository_ctx,
1160            "crosstool:windows/msvc_wrapper_for_nvcc.py",
1161            wrapper_defines,
1162        )
1163
1164    cuda_defines.update(_get_win_cuda_defines(repository_ctx))
1165
1166    verify_build_defines(cuda_defines)
1167
1168    # Only expand template variables in the BUILD file
1169    _tpl(repository_ctx, "crosstool:BUILD", cuda_defines)
1170
1171    # No templating of cc_toolchain_config - use attributes and templatize the
1172    # BUILD file.
1173    _file(repository_ctx, "crosstool:cc_toolchain_config.bzl")
1174
1175    # Set up cuda_config.h, which is used by
1176    # tensorflow/stream_executor/dso_loader.cc.
1177    _tpl(
1178        repository_ctx,
1179        "cuda:cuda_config.h",
1180        {
1181            "%{cuda_version}": cuda_config.cuda_version,
1182            "%{cuda_lib_version}": cuda_config.cuda_lib_version,
1183            "%{cudnn_version}": cuda_config.cudnn_version,
1184            "%{cuda_compute_capabilities}": ", ".join([
1185                "CudaVersion(\"%s\")" % c
1186                for c in cuda_config.compute_capabilities
1187            ]),
1188            "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
1189        },
1190        "cuda/cuda/cuda_config.h",
1191    )
1192
1193def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
1194    """Creates pointers to a remotely configured repo set up to build with CUDA."""
1195    _tpl(
1196        repository_ctx,
1197        "cuda:build_defs.bzl",
1198        {
1199            "%{cuda_is_configured}": "True",
1200            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1201                repository_ctx,
1202                compute_capabilities(repository_ctx),
1203            ),
1204        },
1205    )
1206    repository_ctx.template(
1207        "cuda/BUILD",
1208        Label(remote_config_repo + "/cuda:BUILD"),
1209        {},
1210    )
1211    repository_ctx.template(
1212        "cuda/build_defs.bzl",
1213        Label(remote_config_repo + "/cuda:build_defs.bzl"),
1214        {},
1215    )
1216    repository_ctx.template(
1217        "cuda/cuda/cuda_config.h",
1218        Label(remote_config_repo + "/cuda:cuda/cuda_config.h"),
1219        {},
1220    )
1221
1222def _cuda_autoconf_impl(repository_ctx):
1223    """Implementation of the cuda_autoconf repository rule."""
1224    if not enable_cuda(repository_ctx):
1225        _create_dummy_repository(repository_ctx)
1226    elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
1227        if (_TF_CUDA_VERSION not in repository_ctx.os.environ or
1228            _TF_CUDNN_VERSION not in repository_ctx.os.environ):
1229            auto_configure_fail("%s and %s must also be set if %s is specified" %
1230                                (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
1231        _create_remote_cuda_repository(
1232            repository_ctx,
1233            repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
1234        )
1235    else:
1236        _create_local_cuda_repository(repository_ctx)
1237
1238cuda_configure = repository_rule(
1239    implementation = _cuda_autoconf_impl,
1240    environ = [
1241        _GCC_HOST_COMPILER_PATH,
1242        _GCC_HOST_COMPILER_PREFIX,
1243        _CLANG_CUDA_COMPILER_PATH,
1244        "TF_NEED_CUDA",
1245        "TF_CUDA_CLANG",
1246        _TF_DOWNLOAD_CLANG,
1247        _CUDA_TOOLKIT_PATH,
1248        _CUDNN_INSTALL_PATH,
1249        _TF_CUDA_VERSION,
1250        _TF_CUDNN_VERSION,
1251        _TF_CUDA_COMPUTE_CAPABILITIES,
1252        _TF_CUDA_CONFIG_REPO,
1253        "NVVMIR_LIBRARY_DIR",
1254        _PYTHON_BIN_PATH,
1255        "TMP",
1256        "TMPDIR",
1257        "TF_CUDA_PATHS",
1258    ],
1259)
1260
1261"""Detects and configures the local CUDA toolchain.
1262
1263Add the following to your WORKSPACE FILE:
1264
1265```python
1266cuda_configure(name = "local_config_cuda")
1267```
1268
1269Args:
1270  name: A unique name for this workspace rule.
1271"""
1272