• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Repository rule for CUDA autoconfiguration.
2
3`cuda_configure` depends on the following environment variables:
4
5  * `TF_NEED_CUDA`: Whether to enable building with CUDA.
6  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
7  * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
8  * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
9    both host and device code compilation if TF_CUDA_CLANG is 1.
10  * `TF_SYSROOT`: The sysroot to use when compiling.
11  * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang
12    compiler and use it to build tensorflow. When this option is set
13    CLANG_CUDA_COMPILER_PATH is ignored.
14  * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
15    `/usr/local/cuda,usr/`.
16  * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
17    `/usr/local/cuda`.
18  * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
19    use the system default.
20  * `TF_CUDNN_VERSION`: The version of the cuDNN library.
21  * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
22    `/usr/local/cuda`.
23  * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
24    `3.5,5.2`.
25  * `PYTHON_BIN_PATH`: The python binary path
26"""
27
28load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
29load(
30    "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
31    "escape_string",
32    "get_env_var",
33)
34load(
35    "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
36    "find_msvc_tool",
37    "find_vc_path",
38    "setup_vc_env_vars",
39)
40load(
41    "//third_party/remote_config:common.bzl",
42    "config_repo_label",
43    "err_out",
44    "execute",
45    "get_bash_bin",
46    "get_cpu_value",
47    "get_host_environ",
48    "get_python_bin",
49    "is_windows",
50    "raw_exec",
51    "read_dir",
52    "realpath",
53    "which",
54)
55
56_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
57_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
58_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH"
59_TF_SYSROOT = "TF_SYSROOT"
60_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
61_TF_CUDA_VERSION = "TF_CUDA_VERSION"
62_TF_CUDNN_VERSION = "TF_CUDNN_VERSION"
63_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
64_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
65_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
66_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
67_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
68
69def to_list_of_strings(elements):
70    """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
71
72    This is to be used to put a list of strings into the bzl file templates
73    so it gets interpreted as list of strings in Starlark.
74
75    Args:
76      elements: list of string elements
77
78    Returns:
79      single string of elements wrapped in quotes separated by a comma."""
80    quoted_strings = ["\"" + element + "\"" for element in elements]
81    return ", ".join(quoted_strings)
82
83def verify_build_defines(params):
84    """Verify all variables that crosstool/BUILD.tpl expects are substituted.
85
86    Args:
87      params: dict of variables that will be passed to the BUILD.tpl template.
88    """
89    missing = []
90    for param in [
91        "cxx_builtin_include_directories",
92        "extra_no_canonical_prefixes_flags",
93        "host_compiler_path",
94        "host_compiler_prefix",
95        "host_compiler_warnings",
96        "linker_bin_path",
97        "compiler_deps",
98        "msvc_cl_path",
99        "msvc_env_include",
100        "msvc_env_lib",
101        "msvc_env_path",
102        "msvc_env_tmp",
103        "msvc_lib_path",
104        "msvc_link_path",
105        "msvc_ml_path",
106        "unfiltered_compile_flags",
107        "win_compiler_deps",
108    ]:
109        if ("%{" + param + "}") not in params:
110            missing.append(param)
111
112    if missing:
113        auto_configure_fail(
114            "BUILD.tpl template is missing these variables: " +
115            str(missing) +
116            ".\nWe only got: " +
117            str(params) +
118            ".",
119        )
120
121def _get_nvcc_tmp_dir_for_windows(repository_ctx):
122    """Return the Windows tmp directory for nvcc to generate intermediate source files."""
123    escaped_tmp_dir = escape_string(
124        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
125            "\\",
126            "\\\\",
127        ),
128    )
129    return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
130
131def _get_msvc_compiler(repository_ctx):
132    vc_path = find_vc_path(repository_ctx)
133    return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
134
135def _get_win_cuda_defines(repository_ctx):
136    """Return CROSSTOOL defines for Windows"""
137
138    # If we are not on Windows, return fake vaules for Windows specific fields.
139    # This ensures the CROSSTOOL file parser is happy.
140    if not is_windows(repository_ctx):
141        return {
142            "%{msvc_env_tmp}": "msvc_not_used",
143            "%{msvc_env_path}": "msvc_not_used",
144            "%{msvc_env_include}": "msvc_not_used",
145            "%{msvc_env_lib}": "msvc_not_used",
146            "%{msvc_cl_path}": "msvc_not_used",
147            "%{msvc_ml_path}": "msvc_not_used",
148            "%{msvc_link_path}": "msvc_not_used",
149            "%{msvc_lib_path}": "msvc_not_used",
150        }
151
152    vc_path = find_vc_path(repository_ctx)
153    if not vc_path:
154        auto_configure_fail(
155            "Visual C++ build tools not found on your machine." +
156            "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using",
157        )
158        return {}
159
160    env = setup_vc_env_vars(repository_ctx, vc_path)
161    escaped_paths = escape_string(env["PATH"])
162    escaped_include_paths = escape_string(env["INCLUDE"])
163    escaped_lib_paths = escape_string(env["LIB"])
164    escaped_tmp_dir = escape_string(
165        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
166            "\\",
167            "\\\\",
168        ),
169    )
170
171    msvc_cl_path = get_python_bin(repository_ctx)
172    msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
173        "\\",
174        "/",
175    )
176    msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
177        "\\",
178        "/",
179    )
180    msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
181        "\\",
182        "/",
183    )
184
185    # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
186    # The generated files are guaranteed to have unique name, so they can share
187    # the same tmp directory
188    escaped_cxx_include_directories = [
189        _get_nvcc_tmp_dir_for_windows(repository_ctx),
190        "C:\\\\botcode\\\\w",
191    ]
192    for path in escaped_include_paths.split(";"):
193        if path:
194            escaped_cxx_include_directories.append(path)
195
196    return {
197        "%{msvc_env_tmp}": escaped_tmp_dir,
198        "%{msvc_env_path}": escaped_paths,
199        "%{msvc_env_include}": escaped_include_paths,
200        "%{msvc_env_lib}": escaped_lib_paths,
201        "%{msvc_cl_path}": msvc_cl_path,
202        "%{msvc_ml_path}": msvc_ml_path,
203        "%{msvc_link_path}": msvc_link_path,
204        "%{msvc_lib_path}": msvc_lib_path,
205        "%{cxx_builtin_include_directories}": to_list_of_strings(
206            escaped_cxx_include_directories,
207        ),
208    }
209
210# TODO(dzc): Once these functions have been factored out of Bazel's
211# cc_configure.bzl, load them from @bazel_tools instead.
212# BEGIN cc_configure common functions.
213def find_cc(repository_ctx):
214    """Find the C++ compiler."""
215    if is_windows(repository_ctx):
216        return _get_msvc_compiler(repository_ctx)
217
218    if _use_cuda_clang(repository_ctx):
219        target_cc_name = "clang"
220        cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
221        if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
222            return "extra_tools/bin/clang"
223    else:
224        target_cc_name = "gcc"
225        cc_path_envvar = _GCC_HOST_COMPILER_PATH
226    cc_name = target_cc_name
227
228    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
229    if cc_name_from_env:
230        cc_name = cc_name_from_env
231    if cc_name.startswith("/"):
232        # Absolute path, maybe we should make this supported by our which function.
233        return cc_name
234    cc = which(repository_ctx, cc_name)
235    if cc == None:
236        fail(("Cannot find {}, either correct your path or set the {}" +
237              " environment variable").format(target_cc_name, cc_path_envvar))
238    return cc
239
240_INC_DIR_MARKER_BEGIN = "#include <...>"
241
242# OSX add " (framework directory)" at the end of line, strip it.
243_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
244_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
245
246def _cxx_inc_convert(path):
247    """Convert path returned by cc -E xc++ in a complete path."""
248    path = path.strip()
249    if path.endswith(_OSX_FRAMEWORK_SUFFIX):
250        path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
251    return path
252
253def _normalize_include_path(repository_ctx, path):
254    """Normalizes include paths before writing them to the crosstool.
255
256      If path points inside the 'crosstool' folder of the repository, a relative
257      path is returned.
258      If path points outside the 'crosstool' folder, an absolute path is returned.
259      """
260    path = str(repository_ctx.path(path))
261    crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
262
263    if path.startswith(crosstool_folder):
264        # We drop the path to "$REPO/crosstool" and a trailing path separator.
265        return path[len(crosstool_folder) + 1:]
266    return path
267
268def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot):
269    """Compute the list of default C or C++ include directories."""
270    if lang_is_cpp:
271        lang = "c++"
272    else:
273        lang = "c"
274    sysroot = []
275    if tf_sysroot:
276        sysroot += ["--sysroot", tf_sysroot]
277    result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] +
278                                      sysroot)
279    stderr = err_out(result)
280    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
281    if index1 == -1:
282        return []
283    index1 = stderr.find("\n", index1)
284    if index1 == -1:
285        return []
286    index2 = stderr.rfind("\n ")
287    if index2 == -1 or index2 < index1:
288        return []
289    index2 = stderr.find("\n", index2 + 1)
290    if index2 == -1:
291        inc_dirs = stderr[index1 + 1:]
292    else:
293        inc_dirs = stderr[index1 + 1:index2].strip()
294
295    return [
296        _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
297        for p in inc_dirs.split("\n")
298    ]
299
300def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot):
301    """Compute the list of default C and C++ include directories."""
302
303    # For some reason `clang -xc` sometimes returns include paths that are
304    # different from the ones from `clang -xc++`. (Symlink and a dir)
305    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
306    includes_cpp = _get_cxx_inc_directories_impl(
307        repository_ctx,
308        cc,
309        True,
310        tf_sysroot,
311    )
312    includes_c = _get_cxx_inc_directories_impl(
313        repository_ctx,
314        cc,
315        False,
316        tf_sysroot,
317    )
318
319    return includes_cpp + [
320        inc
321        for inc in includes_c
322        if inc not in includes_cpp
323    ]
324
325def auto_configure_fail(msg):
326    """Output failure message when cuda configuration fails."""
327    red = "\033[0;31m"
328    no_color = "\033[0m"
329    fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
330
331# END cc_configure common functions (see TODO above).
332
333def _cuda_include_path(repository_ctx, cuda_config):
334    """Generates the Starlark string with cuda include directories.
335
336      Args:
337        repository_ctx: The repository context.
338        cc: The path to the gcc host compiler.
339
340      Returns:
341        A list of the gcc host compiler include directories.
342      """
343    nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
344        cuda_config.cuda_toolkit_path,
345        ".exe" if cuda_config.cpu_value == "Windows" else "",
346    ))
347
348    # The expected exit code of this command is non-zero. Bazel remote execution
349    # only caches commands with zero exit code. So force a zero exit code.
350    cmd = "%s -v /dev/null -o /dev/null ; [ $? -eq 1 ]" % str(nvcc_path)
351    result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd])
352    target_dir = ""
353    for one_line in err_out(result).splitlines():
354        if one_line.startswith("#$ _TARGET_DIR_="):
355            target_dir = (
356                cuda_config.cuda_toolkit_path + "/" + one_line.replace(
357                    "#$ _TARGET_DIR_=",
358                    "",
359                ) + "/include"
360            )
361    inc_entries = []
362    if target_dir != "":
363        inc_entries.append(realpath(repository_ctx, target_dir))
364    inc_entries.append(realpath(repository_ctx, cuda_config.cuda_toolkit_path + "/include"))
365    return inc_entries
366
367def enable_cuda(repository_ctx):
368    """Returns whether to build with CUDA support."""
369    return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
370
371def matches_version(environ_version, detected_version):
372    """Checks whether the user-specified version matches the detected version.
373
374      This function performs a weak matching so that if the user specifies only
375      the
376      major or major and minor versions, the versions are still considered
377      matching
378      if the version parts match. To illustrate:
379
380          environ_version  detected_version  result
381          -----------------------------------------
382          5.1.3            5.1.3             True
383          5.1              5.1.3             True
384          5                5.1               True
385          5.1.3            5.1               False
386          5.2.3            5.1.3             False
387
388      Args:
389        environ_version: The version specified by the user via environment
390          variables.
391        detected_version: The version autodetected from the CUDA installation on
392          the system.
393      Returns: True if user-specified version matches detected version and False
394        otherwise.
395    """
396    environ_version_parts = environ_version.split(".")
397    detected_version_parts = detected_version.split(".")
398    if len(detected_version_parts) < len(environ_version_parts):
399        return False
400    for i, part in enumerate(detected_version_parts):
401        if i >= len(environ_version_parts):
402            break
403        if part != environ_version_parts[i]:
404            return False
405    return True
406
407_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
408
409_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
410
411def compute_capabilities(repository_ctx):
412    """Returns a list of strings representing cuda compute capabilities.
413
414    Args:
415      repository_ctx: the repo rule's context.
416    Returns: list of cuda architectures to compile for. 'compute_xy' refers to
417      both PTX and SASS, 'sm_xy' refers to SASS only.
418    """
419    capabilities = get_host_environ(
420        repository_ctx,
421        _TF_CUDA_COMPUTE_CAPABILITIES,
422        "compute_35,compute_52",
423    ).split(",")
424
425    # Map old 'x.y' capabilities to 'compute_xy'.
426    for i, capability in enumerate(capabilities):
427        parts = capability.split(".")
428        if len(parts) != 2:
429            continue
430        capabilities[i] = "compute_%s%s" % (parts[0], parts[1])
431
432    # Make list unique
433    capabilities = dict(zip(capabilities, capabilities)).keys()
434
435    # Validate capabilities.
436    for capability in capabilities:
437        if not capability.startswith(("compute_", "sm_")):
438            auto_configure_fail("Invalid compute capability: %s" % capability)
439        for prefix in ["compute_", "sm_"]:
440            if not capability.startswith(prefix):
441                continue
442            if len(capability) == len(prefix) + 2 and capability[-2:].isdigit():
443                continue
444            auto_configure_fail("Invalid compute capability: %s" % capability)
445
446    return capabilities
447
448def lib_name(base_name, cpu_value, version = None, static = False):
449    """Constructs the platform-specific name of a library.
450
451      Args:
452        base_name: The name of the library, such as "cudart"
453        cpu_value: The name of the host operating system.
454        version: The version of the library.
455        static: True the library is static or False if it is a shared object.
456
457      Returns:
458        The platform-specific name of the library.
459      """
460    version = "" if not version else "." + version
461    if cpu_value in ("Linux", "FreeBSD"):
462        if static:
463            return "lib%s.a" % base_name
464        return "lib%s.so%s" % (base_name, version)
465    elif cpu_value == "Windows":
466        return "%s.lib" % base_name
467    elif cpu_value == "Darwin":
468        if static:
469            return "lib%s.a" % base_name
470        return "lib%s%s.dylib" % (base_name, version)
471    else:
472        auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
473
474def _lib_path(lib, cpu_value, basedir, version, static):
475    file_name = lib_name(lib, cpu_value, version, static)
476    return "%s/%s" % (basedir, file_name)
477
478def _should_check_soname(version, static):
479    return version and not static
480
481def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False):
482    return (
483        _lib_path(lib, cpu_value, basedir, version, static),
484        _should_check_soname(version, static),
485    )
486
487def _check_cuda_libs(repository_ctx, script_path, libs):
488    python_bin = get_python_bin(repository_ctx)
489    contents = repository_ctx.read(script_path).splitlines()
490
491    cmd = "from os import linesep;"
492    cmd += "f = open('script.py', 'w');"
493    for line in contents:
494        cmd += "f.write('%s' + linesep);" % line
495    cmd += "f.close();"
496    cmd += "from os import system;"
497    args = " ".join(["\"" + path + "\" " + str(check) for path, check in libs])
498    cmd += "system('%s script.py %s');" % (python_bin, args)
499
500    all_paths = [path for path, _ in libs]
501    checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines()
502
503    # Filter out empty lines from splitting on '\r\n' on Windows
504    checked_paths = [path for path in checked_paths if len(path) > 0]
505    if all_paths != checked_paths:
506        auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths))
507
508def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
509    """Returns the CUDA and cuDNN libraries on the system.
510
511      Also, verifies that the script actually exist.
512
513      Args:
514        repository_ctx: The repository context.
515        check_cuda_libs_script: The path to a script verifying that the cuda
516          libraries exist on the system.
517        cuda_config: The CUDA config as returned by _get_cuda_config
518
519      Returns:
520        Map of library names to structs of filename and path.
521      """
522    cpu_value = cuda_config.cpu_value
523    stub_dir = "" if is_windows(repository_ctx) else "/stubs"
524
525    check_cuda_libs_params = {
526        "cuda": _check_cuda_lib_params(
527            "cuda",
528            cpu_value,
529            cuda_config.config["cuda_library_dir"] + stub_dir,
530            version = None,
531            static = False,
532        ),
533        "cudart": _check_cuda_lib_params(
534            "cudart",
535            cpu_value,
536            cuda_config.config["cuda_library_dir"],
537            cuda_config.cudart_version,
538            static = False,
539        ),
540        "cudart_static": _check_cuda_lib_params(
541            "cudart_static",
542            cpu_value,
543            cuda_config.config["cuda_library_dir"],
544            cuda_config.cudart_version,
545            static = True,
546        ),
547        "cublas": _check_cuda_lib_params(
548            "cublas",
549            cpu_value,
550            cuda_config.config["cublas_library_dir"],
551            cuda_config.cublas_version,
552            static = False,
553        ),
554        "cublasLt": _check_cuda_lib_params(
555            "cublasLt",
556            cpu_value,
557            cuda_config.config["cublas_library_dir"],
558            cuda_config.cublas_version,
559            static = False,
560        ),
561        "cusolver": _check_cuda_lib_params(
562            "cusolver",
563            cpu_value,
564            cuda_config.config["cusolver_library_dir"],
565            cuda_config.cusolver_version,
566            static = False,
567        ),
568        "curand": _check_cuda_lib_params(
569            "curand",
570            cpu_value,
571            cuda_config.config["curand_library_dir"],
572            cuda_config.curand_version,
573            static = False,
574        ),
575        "cufft": _check_cuda_lib_params(
576            "cufft",
577            cpu_value,
578            cuda_config.config["cufft_library_dir"],
579            cuda_config.cufft_version,
580            static = False,
581        ),
582        "cudnn": _check_cuda_lib_params(
583            "cudnn",
584            cpu_value,
585            cuda_config.config["cudnn_library_dir"],
586            cuda_config.cudnn_version,
587            static = False,
588        ),
589        "cupti": _check_cuda_lib_params(
590            "cupti",
591            cpu_value,
592            cuda_config.config["cupti_library_dir"],
593            cuda_config.cuda_version,
594            static = False,
595        ),
596        "cusparse": _check_cuda_lib_params(
597            "cusparse",
598            cpu_value,
599            cuda_config.config["cusparse_library_dir"],
600            cuda_config.cusparse_version,
601            static = False,
602        ),
603    }
604
605    # Verify that the libs actually exist at their locations.
606    _check_cuda_libs(repository_ctx, check_cuda_libs_script, check_cuda_libs_params.values())
607
608    paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()}
609    return paths
610
611def _cudart_static_linkopt(cpu_value):
612    """Returns additional platform-specific linkopts for cudart."""
613    return "" if cpu_value == "Darwin" else "\"-lrt\","
614
615def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
616    python_bin = get_python_bin(repository_ctx)
617
618    # If used with remote execution then repository_ctx.execute() can't
619    # access files from the source tree. A trick is to read the contents
620    # of the file in Starlark and embed them as part of the command. In
621    # this case the trick is not sufficient as the find_cuda_config.py
622    # script has more than 8192 characters. 8192 is the command length
623    # limit of cmd.exe on Windows. Thus we additionally need to compress
624    # the contents locally and decompress them as part of the execute().
625    compressed_contents = repository_ctx.read(script_path)
626    decompress_and_execute_cmd = (
627        "from zlib import decompress;" +
628        "from base64 import b64decode;" +
629        "from os import system;" +
630        "script = decompress(b64decode('%s'));" % compressed_contents +
631        "f = open('script.py', 'wb');" +
632        "f.write(script);" +
633        "f.close();" +
634        "system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries))
635    )
636
637    return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
638
639# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
640# and nccl_configure.bzl.
641def find_cuda_config(repository_ctx, script_path, cuda_libraries):
642    """Returns CUDA config dictionary from running find_cuda_config.py"""
643    exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
644    if exec_result.return_code:
645        auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
646
647    # Parse the dict from stdout.
648    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
649
650def _get_cuda_config(repository_ctx, find_cuda_config_script):
651    """Detects and returns information about the CUDA installation on the system.
652
653      Args:
654        repository_ctx: The repository context.
655
656      Returns:
657        A struct containing the following fields:
658          cuda_toolkit_path: The CUDA toolkit installation directory.
659          cudnn_install_basedir: The cuDNN installation directory.
660          cuda_version: The version of CUDA on the system.
661          cudart_version: The CUDA runtime version on the system.
662          cudnn_version: The version of cuDNN on the system.
663          compute_capabilities: A list of the system's CUDA compute capabilities.
664          cpu_value: The name of the host operating system.
665      """
666    config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"])
667    cpu_value = get_cpu_value(repository_ctx)
668    toolkit_path = config["cuda_toolkit_path"]
669
670    is_windows = cpu_value == "Windows"
671    cuda_version = config["cuda_version"].split(".")
672    cuda_major = cuda_version[0]
673    cuda_minor = cuda_version[1]
674
675    cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
676    cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
677
678    if int(cuda_major) >= 11:
679        # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability.
680        if int(cuda_major) == 11:
681            cudart_version = "64_110" if is_windows else "11.0"
682        else:
683            cudart_version = ("64_%s" if is_windows else "%s") % cuda_major
684        cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
685        cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
686        curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
687        cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0]
688        cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0]
689    elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
690        # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
691        # It changed from 'x.y' to just 'x' in CUDA 10.1.
692        cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
693        cudart_version = cuda_version
694        cublas_version = cuda_lib_version
695        cusolver_version = cuda_lib_version
696        curand_version = cuda_lib_version
697        cufft_version = cuda_lib_version
698        cusparse_version = cuda_lib_version
699    else:
700        cudart_version = cuda_version
701        cublas_version = cuda_version
702        cusolver_version = cuda_version
703        curand_version = cuda_version
704        cufft_version = cuda_version
705        cusparse_version = cuda_version
706
707    return struct(
708        cuda_toolkit_path = toolkit_path,
709        cuda_version = cuda_version,
710        cuda_version_major = cuda_major,
711        cudart_version = cudart_version,
712        cublas_version = cublas_version,
713        cusolver_version = cusolver_version,
714        curand_version = curand_version,
715        cufft_version = cufft_version,
716        cusparse_version = cusparse_version,
717        cudnn_version = cudnn_version,
718        compute_capabilities = compute_capabilities(repository_ctx),
719        cpu_value = cpu_value,
720        config = config,
721    )
722
723def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
724    if not out:
725        out = tpl.replace(":", "/")
726    repository_ctx.template(
727        out,
728        Label("//third_party/gpus/%s.tpl" % tpl),
729        substitutions,
730    )
731
732def _file(repository_ctx, label):
733    repository_ctx.template(
734        label.replace(":", "/"),
735        Label("//third_party/gpus/%s.tpl" % label),
736        {},
737    )
738
739_DUMMY_CROSSTOOL_BZL_FILE = """
740def error_gpu_disabled():
741  fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
742       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
743       "at the prompt to build with GPU support.")
744
745  native.genrule(
746      name = "error_gen_crosstool",
747      outs = ["CROSSTOOL"],
748      cmd = "echo 'Should not be run.' && exit 1",
749  )
750
751  native.filegroup(
752      name = "crosstool",
753      srcs = [":CROSSTOOL"],
754      output_licenses = ["unencumbered"],
755  )
756"""
757
758_DUMMY_CROSSTOOL_BUILD_FILE = """
759load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
760
761error_gpu_disabled()
762"""
763
764def _create_dummy_repository(repository_ctx):
765    cpu_value = get_cpu_value(repository_ctx)
766
767    # Set up BUILD file for cuda/.
768    _tpl(
769        repository_ctx,
770        "cuda:build_defs.bzl",
771        {
772            "%{cuda_is_configured}": "False",
773            "%{cuda_extra_copts}": "[]",
774            "%{cuda_gpu_architectures}": "[]",
775        },
776    )
777    _tpl(
778        repository_ctx,
779        "cuda:BUILD",
780        {
781            "%{cuda_driver_lib}": lib_name("cuda", cpu_value),
782            "%{cudart_static_lib}": lib_name(
783                "cudart_static",
784                cpu_value,
785                static = True,
786            ),
787            "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
788            "%{cudart_lib}": lib_name("cudart", cpu_value),
789            "%{cublas_lib}": lib_name("cublas", cpu_value),
790            "%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
791            "%{cusolver_lib}": lib_name("cusolver", cpu_value),
792            "%{cudnn_lib}": lib_name("cudnn", cpu_value),
793            "%{cufft_lib}": lib_name("cufft", cpu_value),
794            "%{curand_lib}": lib_name("curand", cpu_value),
795            "%{cupti_lib}": lib_name("cupti", cpu_value),
796            "%{cusparse_lib}": lib_name("cusparse", cpu_value),
797            "%{cub_actual}": ":cuda_headers",
798            "%{copy_rules}": """
799filegroup(name="cuda-include")
800filegroup(name="cublas-include")
801filegroup(name="cusolver-include")
802filegroup(name="cufft-include")
803filegroup(name="cusparse-include")
804filegroup(name="curand-include")
805filegroup(name="cudnn-include")
806""",
807        },
808    )
809
810    # Create dummy files for the CUDA toolkit since they are still required by
811    # tensorflow/core/platform/default/build_config:cuda.
812    repository_ctx.file("cuda/cuda/include/cuda.h")
813    repository_ctx.file("cuda/cuda/include/cublas.h")
814    repository_ctx.file("cuda/cuda/include/cudnn.h")
815    repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h")
816    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value))
817    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value))
818    repository_ctx.file(
819        "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
820    )
821    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
822    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
823    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
824    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
825    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
826    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value))
827    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value))
828    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value))
829
830    # Set up cuda_config.h, which is used by
831    # tensorflow/stream_executor/dso_loader.cc.
832    _tpl(
833        repository_ctx,
834        "cuda:cuda_config.h",
835        {
836            "%{cuda_version}": "",
837            "%{cudart_version}": "",
838            "%{cublas_version}": "",
839            "%{cusolver_version}": "",
840            "%{curand_version}": "",
841            "%{cufft_version}": "",
842            "%{cusparse_version}": "",
843            "%{cudnn_version}": "",
844            "%{cuda_toolkit_path}": "",
845        },
846        "cuda/cuda/cuda_config.h",
847    )
848
849    # Set up cuda_config.py, which is used by gen_build_info to provide
850    # static build environment info to the API
851    _tpl(
852        repository_ctx,
853        "cuda:cuda_config.py",
854        _py_tmpl_dict({}),
855        "cuda/cuda/cuda_config.py",
856    )
857
858    # If cuda_configure is not configured to build with GPU support, and the user
859    # attempts to build with --config=cuda, add a dummy build rule to intercept
860    # this and fail with an actionable error message.
861    repository_ctx.file(
862        "crosstool/error_gpu_disabled.bzl",
863        _DUMMY_CROSSTOOL_BZL_FILE,
864    )
865    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
866
867def _norm_path(path):
868    """Returns a path with '/' and remove the trailing slash."""
869    path = path.replace("\\", "/")
870    if path[-1] == "/":
871        path = path[:-1]
872    return path
873
874def make_copy_files_rule(repository_ctx, name, srcs, outs):
875    """Returns a rule to copy a set of files."""
876    cmds = []
877
878    # Copy files.
879    for src, out in zip(srcs, outs):
880        cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
881    outs = [('        "%s",' % out) for out in outs]
882    return """genrule(
883    name = "%s",
884    outs = [
885%s
886    ],
887    cmd = \"""%s \""",
888)""" % (name, "\n".join(outs), " && \\\n".join(cmds))
889
890def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None):
891    """Returns a rule to recursively copy a directory.
892    If exceptions is not None, it must be a list of files or directories in
893    'src_dir'; these will be excluded from copying.
894    """
895    src_dir = _norm_path(src_dir)
896    out_dir = _norm_path(out_dir)
897    outs = read_dir(repository_ctx, src_dir)
898    post_cmd = ""
899    if exceptions != None:
900        outs = [x for x in outs if not any([
901            x.startswith(src_dir + "/" + y)
902            for y in exceptions
903        ])]
904    outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]
905
906    # '@D' already contains the relative path for a single file, see
907    # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
908    out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
909    if exceptions != None:
910        for x in exceptions:
911            post_cmd += " ; rm -fR " + out_dir + "/" + x
912    return """genrule(
913    name = "%s",
914    outs = [
915%s
916    ],
917    cmd = \"""cp -rLf "%s/." "%s/" %s\""",
918)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd)
919
920def _flag_enabled(repository_ctx, flag_name):
921    return get_host_environ(repository_ctx, flag_name) == "1"
922
923def _use_cuda_clang(repository_ctx):
924    return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
925
926def _tf_sysroot(repository_ctx):
927    return get_host_environ(repository_ctx, _TF_SYSROOT, "")
928
929def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
930    copts = []
931    for capability in compute_capabilities:
932        if capability.startswith("compute_"):
933            capability = capability.replace("compute_", "sm_")
934            copts.append("--cuda-include-ptx=%s" % capability)
935        copts.append("--cuda-gpu-arch=%s" % capability)
936
937    return str(copts)
938
939def _tpl_path(repository_ctx, filename):
940    return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
941
942def _basename(repository_ctx, path_str):
943    """Returns the basename of a path of type string.
944
945    This method is different from path.basename in that it also works if
946    the host platform is different from the execution platform
947    i.e. linux -> windows.
948    """
949
950    num_chars = len(path_str)
951    is_win = is_windows(repository_ctx)
952    for i in range(num_chars):
953        r_i = num_chars - 1 - i
954        if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/":
955            return path_str[r_i + 1:]
956    return path_str
957
958def _create_local_cuda_repository(repository_ctx):
959    """Creates the repository containing files set up to build with CUDA."""
960
961    # Resolve all labels before doing any real work. Resolving causes the
962    # function to be restarted with all previous state being lost. This
963    # can easily lead to a O(n^2) runtime in the number of labels.
964    # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
965    tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [
966        "cuda:build_defs.bzl",
967        "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
968        "crosstool:windows/msvc_wrapper_for_nvcc.py",
969        "crosstool:BUILD",
970        "crosstool:cc_toolchain_config.bzl",
971        "cuda:cuda_config.h",
972        "cuda:cuda_config.py",
973    ]}
974    tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
975    find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
976
977    cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
978
979    cuda_include_path = cuda_config.config["cuda_include_dir"]
980    cublas_include_path = cuda_config.config["cublas_include_dir"]
981    cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
982    cupti_header_dir = cuda_config.config["cupti_include_dir"]
983    nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]
984
985    # Create genrule to copy files from the installed CUDA toolkit into execroot.
986    copy_rules = [
987        make_copy_dir_rule(
988            repository_ctx,
989            name = "cuda-include",
990            src_dir = cuda_include_path,
991            out_dir = "cuda/include",
992        ),
993        make_copy_dir_rule(
994            repository_ctx,
995            name = "cuda-nvvm",
996            src_dir = nvvm_libdevice_dir,
997            out_dir = "cuda/nvvm/libdevice",
998        ),
999        make_copy_dir_rule(
1000            repository_ctx,
1001            name = "cuda-extras",
1002            src_dir = cupti_header_dir,
1003            out_dir = "cuda/extras/CUPTI/include",
1004        ),
1005    ]
1006
1007    copy_rules.append(make_copy_files_rule(
1008        repository_ctx,
1009        name = "cublas-include",
1010        srcs = [
1011            cublas_include_path + "/cublas.h",
1012            cublas_include_path + "/cublas_v2.h",
1013            cublas_include_path + "/cublas_api.h",
1014            cublas_include_path + "/cublasLt.h",
1015        ],
1016        outs = [
1017            "cublas/include/cublas.h",
1018            "cublas/include/cublas_v2.h",
1019            "cublas/include/cublas_api.h",
1020            "cublas/include/cublasLt.h",
1021        ],
1022    ))
1023
1024    cusolver_include_path = cuda_config.config["cusolver_include_dir"]
1025    copy_rules.append(make_copy_files_rule(
1026        repository_ctx,
1027        name = "cusolver-include",
1028        srcs = [
1029            cusolver_include_path + "/cusolver_common.h",
1030            cusolver_include_path + "/cusolverDn.h",
1031        ],
1032        outs = [
1033            "cusolver/include/cusolver_common.h",
1034            "cusolver/include/cusolverDn.h",
1035        ],
1036    ))
1037
1038    cufft_include_path = cuda_config.config["cufft_include_dir"]
1039    copy_rules.append(make_copy_files_rule(
1040        repository_ctx,
1041        name = "cufft-include",
1042        srcs = [
1043            cufft_include_path + "/cufft.h",
1044        ],
1045        outs = [
1046            "cufft/include/cufft.h",
1047        ],
1048    ))
1049
1050    cusparse_include_path = cuda_config.config["cusparse_include_dir"]
1051    copy_rules.append(make_copy_files_rule(
1052        repository_ctx,
1053        name = "cusparse-include",
1054        srcs = [
1055            cusparse_include_path + "/cusparse.h",
1056        ],
1057        outs = [
1058            "cusparse/include/cusparse.h",
1059        ],
1060    ))
1061
1062    curand_include_path = cuda_config.config["curand_include_dir"]
1063    copy_rules.append(make_copy_files_rule(
1064        repository_ctx,
1065        name = "curand-include",
1066        srcs = [
1067            curand_include_path + "/curand.h",
1068        ],
1069        outs = [
1070            "curand/include/curand.h",
1071        ],
1072    ))
1073
1074    check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py"))
1075    cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
1076    cuda_lib_srcs = []
1077    cuda_lib_outs = []
1078    for path in cuda_libs.values():
1079        cuda_lib_srcs.append(path)
1080        cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path))
1081    copy_rules.append(make_copy_files_rule(
1082        repository_ctx,
1083        name = "cuda-lib",
1084        srcs = cuda_lib_srcs,
1085        outs = cuda_lib_outs,
1086    ))
1087
1088    # copy files mentioned in third_party/nccl/build_defs.bzl.tpl
1089    file_ext = ".exe" if is_windows(repository_ctx) else ""
1090    bin_files = (
1091        ["crt/link.stub"] +
1092        [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]]
1093    )
1094    copy_rules.append(make_copy_files_rule(
1095        repository_ctx,
1096        name = "cuda-bin",
1097        srcs = [cuda_config.cuda_toolkit_path + "/bin/" + f for f in bin_files],
1098        outs = ["cuda/bin/" + f for f in bin_files],
1099    ))
1100
1101    # Select the headers based on the cuDNN version (strip '64_' for Windows).
1102    cudnn_headers = ["cudnn.h"]
1103    if cuda_config.cudnn_version.rsplit("_", 1)[0] >= "8":
1104        cudnn_headers += [
1105            "cudnn_backend.h",
1106            "cudnn_adv_infer.h",
1107            "cudnn_adv_train.h",
1108            "cudnn_cnn_infer.h",
1109            "cudnn_cnn_train.h",
1110            "cudnn_ops_infer.h",
1111            "cudnn_ops_train.h",
1112            "cudnn_version.h",
1113        ]
1114
1115    cudnn_srcs = []
1116    cudnn_outs = []
1117    for header in cudnn_headers:
1118        cudnn_srcs.append(cudnn_header_dir + "/" + header)
1119        cudnn_outs.append("cudnn/include/" + header)
1120
1121    copy_rules.append(make_copy_files_rule(
1122        repository_ctx,
1123        name = "cudnn-include",
1124        srcs = cudnn_srcs,
1125        outs = cudnn_outs,
1126    ))
1127
1128    # Set up BUILD file for cuda/
1129    repository_ctx.template(
1130        "cuda/build_defs.bzl",
1131        tpl_paths["cuda:build_defs.bzl"],
1132        {
1133            "%{cuda_is_configured}": "True",
1134            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1135                repository_ctx,
1136                cuda_config.compute_capabilities,
1137            ),
1138            "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities),
1139        },
1140    )
1141
1142    cub_actual = "@cub_archive//:cub"
1143    if int(cuda_config.cuda_version_major) >= 11:
1144        cub_actual = ":cuda_headers"
1145
1146    repository_ctx.template(
1147        "cuda/BUILD",
1148        tpl_paths["cuda:BUILD"],
1149        {
1150            "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]),
1151            "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]),
1152            "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
1153            "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
1154            "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
1155            "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
1156            "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
1157            "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
1158            "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
1159            "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]),
1160            "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]),
1161            "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]),
1162            "%{cub_actual}": cub_actual,
1163            "%{copy_rules}": "\n".join(copy_rules),
1164        },
1165    )
1166
1167    is_cuda_clang = _use_cuda_clang(repository_ctx)
1168    tf_sysroot = _tf_sysroot(repository_ctx)
1169
1170    should_download_clang = is_cuda_clang and _flag_enabled(
1171        repository_ctx,
1172        _TF_DOWNLOAD_CLANG,
1173    )
1174    if should_download_clang:
1175        download_clang(repository_ctx, "crosstool/extra_tools")
1176
1177    # Set up crosstool/
1178    cc = find_cc(repository_ctx)
1179    cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
1180
1181    host_compiler_includes = get_cxx_inc_directories(
1182        repository_ctx,
1183        cc_fullpath,
1184        tf_sysroot,
1185    )
1186    cuda_defines = {}
1187    cuda_defines["%{builtin_sysroot}"] = tf_sysroot
1188    cuda_defines["%{cuda_toolkit_path}"] = ""
1189    cuda_defines["%{compiler}"] = "unknown"
1190    if is_cuda_clang:
1191        cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
1192        cuda_defines["%{compiler}"] = "clang"
1193
1194    host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
1195    if not host_compiler_prefix:
1196        host_compiler_prefix = "/usr/bin"
1197
1198    cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
1199
1200    # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
1201    # https://github.com/bazelbuild/bazel/issues/760).
1202    # However, this stops our custom clang toolchain from picking the provided
1203    # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
1204    # toolchain.
1205    # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
1206    #       flag from the CROSSTOOL completely (see
1207    #       https://github.com/bazelbuild/bazel/issues/5634)
1208    if should_download_clang:
1209        cuda_defines["%{linker_bin_path}"] = ""
1210    else:
1211        cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
1212
1213    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
1214    cuda_defines["%{unfiltered_compile_flags}"] = ""
1215    if is_cuda_clang:
1216        cuda_defines["%{host_compiler_path}"] = str(cc)
1217        cuda_defines["%{host_compiler_warnings}"] = """
1218        # Some parts of the codebase set -Werror and hit this warning, so
1219        # switch it off for now.
1220        "-Wno-invalid-partial-specialization"
1221    """
1222        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes)
1223        cuda_defines["%{compiler_deps}"] = ":empty"
1224        cuda_defines["%{win_compiler_deps}"] = ":empty"
1225        repository_ctx.file(
1226            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1227            "",
1228        )
1229        repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
1230    else:
1231        cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
1232        cuda_defines["%{host_compiler_warnings}"] = ""
1233
1234        # nvcc has the system include paths built in and will automatically
1235        # search them; we cannot work around that, so we add the relevant cuda
1236        # system paths to the allowed compiler specific include paths.
1237        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
1238            host_compiler_includes + _cuda_include_path(
1239                repository_ctx,
1240                cuda_config,
1241            ) + [cupti_header_dir, cudnn_header_dir],
1242        )
1243
1244        # For gcc, do not canonicalize system header paths; some versions of gcc
1245        # pick the shortest possible path for system includes when creating the
1246        # .d file - given that includes that are prefixed with "../" multiple
1247        # time quickly grow longer than the root of the tree, this can lead to
1248        # bazel's header check failing.
1249        cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
1250
1251        file_ext = ".exe" if is_windows(repository_ctx) else ""
1252        nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext)
1253        cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc"
1254        cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files"
1255
1256        wrapper_defines = {
1257            "%{cpu_compiler}": str(cc),
1258            "%{cuda_version}": cuda_config.cuda_version,
1259            "%{nvcc_path}": nvcc_path,
1260            "%{gcc_host_compiler_path}": str(cc),
1261            "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
1262        }
1263        repository_ctx.template(
1264            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1265            tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"],
1266            wrapper_defines,
1267        )
1268        repository_ctx.template(
1269            "crosstool/windows/msvc_wrapper_for_nvcc.py",
1270            tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"],
1271            wrapper_defines,
1272        )
1273
1274    cuda_defines.update(_get_win_cuda_defines(repository_ctx))
1275
1276    verify_build_defines(cuda_defines)
1277
1278    # Only expand template variables in the BUILD file
1279    repository_ctx.template(
1280        "crosstool/BUILD",
1281        tpl_paths["crosstool:BUILD"],
1282        cuda_defines,
1283    )
1284
1285    # No templating of cc_toolchain_config - use attributes and templatize the
1286    # BUILD file.
1287    repository_ctx.template(
1288        "crosstool/cc_toolchain_config.bzl",
1289        tpl_paths["crosstool:cc_toolchain_config.bzl"],
1290        {},
1291    )
1292
1293    # Set up cuda_config.h, which is used by
1294    # tensorflow/stream_executor/dso_loader.cc.
1295    repository_ctx.template(
1296        "cuda/cuda/cuda_config.h",
1297        tpl_paths["cuda:cuda_config.h"],
1298        {
1299            "%{cuda_version}": cuda_config.cuda_version,
1300            "%{cudart_version}": cuda_config.cudart_version,
1301            "%{cublas_version}": cuda_config.cublas_version,
1302            "%{cusolver_version}": cuda_config.cusolver_version,
1303            "%{curand_version}": cuda_config.curand_version,
1304            "%{cufft_version}": cuda_config.cufft_version,
1305            "%{cusparse_version}": cuda_config.cusparse_version,
1306            "%{cudnn_version}": cuda_config.cudnn_version,
1307            "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
1308        },
1309    )
1310
1311    # Set up cuda_config.py, which is used by gen_build_info to provide
1312    # static build environment info to the API
1313    repository_ctx.template(
1314        "cuda/cuda/cuda_config.py",
1315        tpl_paths["cuda:cuda_config.py"],
1316        _py_tmpl_dict({
1317            "cuda_version": cuda_config.cuda_version,
1318            "cudnn_version": cuda_config.cudnn_version,
1319            "cuda_compute_capabilities": cuda_config.compute_capabilities,
1320            "cpu_compiler": str(cc),
1321        }),
1322    )
1323
1324def _py_tmpl_dict(d):
1325    return {"%{cuda_config}": str(d)}
1326
1327def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
1328    """Creates pointers to a remotely configured repo set up to build with CUDA."""
1329    _tpl(
1330        repository_ctx,
1331        "cuda:build_defs.bzl",
1332        {
1333            "%{cuda_is_configured}": "True",
1334            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1335                repository_ctx,
1336                compute_capabilities(repository_ctx),
1337            ),
1338        },
1339    )
1340    repository_ctx.template(
1341        "cuda/BUILD",
1342        config_repo_label(remote_config_repo, "cuda:BUILD"),
1343        {},
1344    )
1345    repository_ctx.template(
1346        "cuda/build_defs.bzl",
1347        config_repo_label(remote_config_repo, "cuda:build_defs.bzl"),
1348        {},
1349    )
1350    repository_ctx.template(
1351        "cuda/cuda/cuda_config.h",
1352        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.h"),
1353        {},
1354    )
1355    repository_ctx.template(
1356        "cuda/cuda/cuda_config.py",
1357        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.py"),
1358        _py_tmpl_dict({}),
1359    )
1360
1361    repository_ctx.template(
1362        "crosstool/BUILD",
1363        config_repo_label(remote_config_repo, "crosstool:BUILD"),
1364        {},
1365    )
1366
1367    repository_ctx.template(
1368        "crosstool/cc_toolchain_config.bzl",
1369        config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"),
1370        {},
1371    )
1372
1373    repository_ctx.template(
1374        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1375        config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"),
1376        {},
1377    )
1378
1379def _cuda_autoconf_impl(repository_ctx):
1380    """Implementation of the cuda_autoconf repository rule."""
1381    if not enable_cuda(repository_ctx):
1382        _create_dummy_repository(repository_ctx)
1383    elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
1384        has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
1385        has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
1386        if not has_cuda_version or not has_cudnn_version:
1387            auto_configure_fail("%s and %s must also be set if %s is specified" %
1388                                (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
1389        _create_remote_cuda_repository(
1390            repository_ctx,
1391            get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
1392        )
1393    else:
1394        _create_local_cuda_repository(repository_ctx)
1395
1396# For @bazel_tools//tools/cpp:windows_cc_configure.bzl
1397_MSVC_ENVVARS = [
1398    "BAZEL_VC",
1399    "BAZEL_VC_FULL_VERSION",
1400    "BAZEL_VS",
1401    "BAZEL_WINSDK_FULL_VERSION",
1402    "VS90COMNTOOLS",
1403    "VS100COMNTOOLS",
1404    "VS110COMNTOOLS",
1405    "VS120COMNTOOLS",
1406    "VS140COMNTOOLS",
1407    "VS150COMNTOOLS",
1408    "VS160COMNTOOLS",
1409]
1410
1411_ENVIRONS = [
1412    _GCC_HOST_COMPILER_PATH,
1413    _GCC_HOST_COMPILER_PREFIX,
1414    _CLANG_CUDA_COMPILER_PATH,
1415    "TF_NEED_CUDA",
1416    "TF_CUDA_CLANG",
1417    _TF_DOWNLOAD_CLANG,
1418    _CUDA_TOOLKIT_PATH,
1419    _CUDNN_INSTALL_PATH,
1420    _TF_CUDA_VERSION,
1421    _TF_CUDNN_VERSION,
1422    _TF_CUDA_COMPUTE_CAPABILITIES,
1423    "NVVMIR_LIBRARY_DIR",
1424    _PYTHON_BIN_PATH,
1425    "TMP",
1426    "TMPDIR",
1427    "TF_CUDA_PATHS",
1428] + _MSVC_ENVVARS
1429
1430remote_cuda_configure = repository_rule(
1431    implementation = _create_local_cuda_repository,
1432    environ = _ENVIRONS,
1433    remotable = True,
1434    attrs = {
1435        "environ": attr.string_dict(),
1436    },
1437)
1438
1439cuda_configure = repository_rule(
1440    implementation = _cuda_autoconf_impl,
1441    environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
1442)
1443"""Detects and configures the local CUDA toolchain.
1444
1445Add the following to your WORKSPACE FILE:
1446
1447```python
1448cuda_configure(name = "local_config_cuda")
1449```
1450
1451Args:
1452  name: A unique name for this workspace rule.
1453"""
1454